main
lizifeng 2023-03-20 20:53:18 +08:00
parent 9c83b42ccd
commit 2fa24a33ab
2 changed files with 41 additions and 42 deletions

View File

@ -1,7 +1,6 @@
package mymysql package mymysql
import ( import (
"database/sql"
"fmt" "fmt"
"git.makemake.in/test/mycommon/mylog" "git.makemake.in/test/mycommon/mylog"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
@ -23,14 +22,10 @@ var (
GormLogger: gormLogger.Default.LogMode(gormLogger.Info), GormLogger: gormLogger.Default.LogMode(gormLogger.Info),
} }
instanceMap = make(map[string]*MyDB) instanceMap = make(map[string]*gorm.DB)
) )
type ( type (
MyDB struct {
db *gorm.DB
sqlDb *sql.DB
}
Config struct { Config struct {
Dsn string Dsn string
MaxOpenConn int MaxOpenConn int
@ -42,7 +37,7 @@ type (
} }
) )
func Instance(key ...string) *MyDB { func Instance(key ...string) *gorm.DB {
var key0 string var key0 string
if len(key) > 0 { if len(key) > 0 {
@ -58,11 +53,11 @@ func Instance(key ...string) *MyDB {
return instance return instance
} }
func NewDefault(config *Config) (*MyDB, error) { func NewDefault(config *Config) (*gorm.DB, error) {
return New(DefaultKey, config) return New(DefaultKey, config)
} }
func New(key string, config *Config) (*MyDB, error) { func New(key string, config *Config) (*gorm.DB, error) {
var ( var (
maxLifeTime, _ = time.ParseDuration(DefaultConfig.MaxLifeTime) maxLifeTime, _ = time.ParseDuration(DefaultConfig.MaxLifeTime)
maxIdleTime, _ = time.ParseDuration(DefaultConfig.MaxIdleTime) maxIdleTime, _ = time.ParseDuration(DefaultConfig.MaxIdleTime)
@ -117,12 +112,8 @@ func New(key string, config *Config) (*MyDB, error) {
sqlDb.SetConnMaxLifetime(maxLifeTime) sqlDb.SetConnMaxLifetime(maxLifeTime)
sqlDb.SetConnMaxIdleTime(maxIdleTime) sqlDb.SetConnMaxIdleTime(maxIdleTime)
myDb := &MyDB{ instanceMap[key] = db
db: db, return db, nil
sqlDb: sqlDb,
}
instanceMap[key] = myDb
return myDb, nil
} }
func DefaultGormLogger(level gormLogger.LogLevel) gormLogger.Interface { func DefaultGormLogger(level gormLogger.LogLevel) gormLogger.Interface {
@ -139,16 +130,14 @@ func NewGormLogger(writer gormLogger.Writer, gormLoggerConfig gormLogger.Config)
return gormLogger.New(writer, gormLoggerConfig) return gormLogger.New(writer, gormLoggerConfig)
} }
func (m *MyDB) Close() { func Close(key string) {
if m.sqlDb != nil { db, _ := Instance(key).DB()
m.sqlDb.Close() db.Close()
}
func CloseAll() {
for _, v := range instanceMap {
db, _ := v.DB()
db.Close()
} }
} }
func (m *MyDB) DB() *gorm.DB {
return m.db
}
func (m *MyDB) SqlDB() *sql.DB {
return m.sqlDb
}

View File

@ -26,7 +26,7 @@ var (
type ( type (
MyRedis struct { MyRedis struct {
redis *redis.Client *redis.Client
ctx context.Context ctx context.Context
} }
@ -106,8 +106,8 @@ func New(key string, config *Config) (*MyRedis, error) {
ctx := context.Background() ctx := context.Background()
rd := &MyRedis{} rd := &MyRedis{}
rd.ctx = ctx rd.ctx = ctx
rd.redis = client rd.Client = client
ping := rd.redis.Ping(ctx) ping := rd.Client.Ping(ctx)
if ping.Err() != nil { if ping.Err() != nil {
return nil, fmt.Errorf("connet redis err: %s", ping.Err()) return nil, fmt.Errorf("connet redis err: %s", ping.Err())
} }
@ -116,23 +116,23 @@ func New(key string, config *Config) (*MyRedis, error) {
return rd, nil return rd, nil
} }
// Get 通用get // GetSimple 通用get
func (r *MyRedis) Get(key string) (string, error) { func (r *MyRedis) GetSimple(key string) (string, error) {
return r.redis.Get(r.ctx, key).Result() return r.Client.Get(r.ctx, key).Result()
} }
// Set 通用set // SetSimple 通用set
func (r *MyRedis) Set(key string, value interface{}, t ...time.Duration) (string, error) { func (r *MyRedis) SetSimple(key string, value interface{}, t ...time.Duration) (string, error) {
var t2 time.Duration var t2 time.Duration
if len(t) > 0 { if len(t) > 0 {
t2 = t[0] t2 = t[0]
} }
return r.redis.Set(r.ctx, key, value, t2).Result() return r.Client.Set(r.ctx, key, value, t2).Result()
} }
// GetJson json序列化 // GetJson json序列化
func (r *MyRedis) GetJson(key string) (interface{}, error) { func (r *MyRedis) GetJson(key string) (interface{}, error) {
res := r.redis.Get(r.ctx, key) res := r.Client.Get(r.ctx, key)
if res.Err() != nil { if res.Err() != nil {
return nil, res.Err() return nil, res.Err()
} }
@ -159,15 +159,25 @@ func (r *MyRedis) SetJson(key string, value interface{}, t ...time.Duration) (st
if err != nil { if err != nil {
return "", fmt.Errorf("set key:%s 序列化json失败", key) return "", fmt.Errorf("set key:%s 序列化json失败", key)
} }
return r.redis.Set(r.ctx, key, v, t2).Result() return r.Client.Set(r.ctx, key, v, t2).Result()
} }
func (r *MyRedis) Close() { func (r *MyRedis) Close() {
if r.redis != nil { if r.Client != nil {
r.redis.Close() r.Client.Close()
} }
} }
func (r *MyRedis) Redis() *redis.Client { func (r *MyRedis) GetConn() *redis.Client {
return r.redis return r.Client
}
func Close(key string) {
Instance(key).Client.Close()
}
func CloseAll() {
for _, v := range instanceMap {
v.Client.Close()
}
} }