diff --git a/mymysql/mysql.go b/mymysql/mysql.go index 7c77787..5a01088 100644 --- a/mymysql/mysql.go +++ b/mymysql/mysql.go @@ -1,7 +1,6 @@ package mymysql import ( - "database/sql" "fmt" "git.makemake.in/test/mycommon/mylog" "gorm.io/driver/mysql" @@ -23,14 +22,10 @@ var ( GormLogger: gormLogger.Default.LogMode(gormLogger.Info), } - instanceMap = make(map[string]*MyDB) + instanceMap = make(map[string]*gorm.DB) ) type ( - MyDB struct { - db *gorm.DB - sqlDb *sql.DB - } Config struct { Dsn string MaxOpenConn int @@ -42,7 +37,7 @@ type ( } ) -func Instance(key ...string) *MyDB { +func Instance(key ...string) *gorm.DB { var key0 string if len(key) > 0 { @@ -58,11 +53,11 @@ func Instance(key ...string) *MyDB { return instance } -func NewDefault(config *Config) (*MyDB, error) { +func NewDefault(config *Config) (*gorm.DB, error) { return New(DefaultKey, config) } -func New(key string, config *Config) (*MyDB, error) { +func New(key string, config *Config) (*gorm.DB, error) { var ( maxLifeTime, _ = time.ParseDuration(DefaultConfig.MaxLifeTime) maxIdleTime, _ = time.ParseDuration(DefaultConfig.MaxIdleTime) @@ -117,12 +112,8 @@ func New(key string, config *Config) (*MyDB, error) { sqlDb.SetConnMaxLifetime(maxLifeTime) sqlDb.SetConnMaxIdleTime(maxIdleTime) - myDb := &MyDB{ - db: db, - sqlDb: sqlDb, - } - instanceMap[key] = myDb - return myDb, nil + instanceMap[key] = db + return db, nil } func DefaultGormLogger(level gormLogger.LogLevel) gormLogger.Interface { @@ -139,16 +130,14 @@ func NewGormLogger(writer gormLogger.Writer, gormLoggerConfig gormLogger.Config) return gormLogger.New(writer, gormLoggerConfig) } -func (m *MyDB) Close() { - if m.sqlDb != nil { - m.sqlDb.Close() +func Close(key string) { + db, _ := Instance(key).DB() + db.Close() +} + +func CloseAll() { + for _, v := range instanceMap { + db, _ := v.DB() + db.Close() } } - -func (m *MyDB) DB() *gorm.DB { - return m.db -} - -func (m *MyDB) SqlDB() *sql.DB { - return m.sqlDb -} diff --git a/myredis/redis.go b/myredis/redis.go index db465f2..beb3f1a 100644 --- a/myredis/redis.go +++ b/myredis/redis.go @@ -26,8 +26,8 @@ var ( type ( MyRedis struct { - redis *redis.Client - ctx context.Context + *redis.Client + ctx context.Context } Config struct { @@ -106,8 +106,8 @@ func New(key string, config *Config) (*MyRedis, error) { ctx := context.Background() rd := &MyRedis{} rd.ctx = ctx - rd.redis = client - ping := rd.redis.Ping(ctx) + rd.Client = client + ping := rd.Client.Ping(ctx) if ping.Err() != nil { return nil, fmt.Errorf("connet redis err: %s", ping.Err()) } @@ -116,23 +116,23 @@ func New(key string, config *Config) (*MyRedis, error) { return rd, nil } -// Get 通用get -func (r *MyRedis) Get(key string) (string, error) { - return r.redis.Get(r.ctx, key).Result() +// GetSimple 通用get +func (r *MyRedis) GetSimple(key string) (string, error) { + return r.Client.Get(r.ctx, key).Result() } -// Set 通用set -func (r *MyRedis) Set(key string, value interface{}, t ...time.Duration) (string, error) { +// SetSimple 通用set +func (r *MyRedis) SetSimple(key string, value interface{}, t ...time.Duration) (string, error) { var t2 time.Duration if len(t) > 0 { t2 = t[0] } - return r.redis.Set(r.ctx, key, value, t2).Result() + return r.Client.Set(r.ctx, key, value, t2).Result() } // GetJson json序列化 func (r *MyRedis) GetJson(key string) (interface{}, error) { - res := r.redis.Get(r.ctx, key) + res := r.Client.Get(r.ctx, key) if res.Err() != nil { return nil, res.Err() } @@ -159,15 +159,25 @@ func (r *MyRedis) SetJson(key string, value interface{}, t ...time.Duration) (st if err != nil { return "", fmt.Errorf("set key:%s 序列化json失败", key) } - return r.redis.Set(r.ctx, key, v, t2).Result() + return r.Client.Set(r.ctx, key, v, t2).Result() } func (r *MyRedis) Close() { - if r.redis != nil { - r.redis.Close() + if r.Client != nil { + r.Client.Close() } } -func (r *MyRedis) Redis() *redis.Client { - return r.redis +func (r *MyRedis) GetConn() *redis.Client { + return r.Client +} + +func Close(key string) { + Instance(key).Client.Close() +} + +func CloseAll() { + for _, v := range instanceMap { + v.Client.Close() + } }