diff --git a/mymysql/mysql.go b/mymysql/mysql.go index c287f1b..7c77787 100644 --- a/mymysql/mysql.go +++ b/mymysql/mysql.go @@ -10,6 +10,8 @@ import ( "time" ) +const DefaultKey = "default" + var ( DefaultConfig = &Config{ Dsn: "root:root@tcp(127.0.0.1:3306)/?loc=Local&charset=utf8mb4&parseTime=true", @@ -20,12 +22,14 @@ var ( Debug: true, GormLogger: gormLogger.Default.LogMode(gormLogger.Info), } + + instanceMap = make(map[string]*MyDB) ) type ( MyDB struct { - DB *gorm.DB - SqlDB *sql.DB + db *gorm.DB + sqlDb *sql.DB } Config struct { Dsn string @@ -38,7 +42,27 @@ type ( } ) -func New(config *Config) (*MyDB, error) { +func Instance(key ...string) *MyDB { + var key0 string + + if len(key) > 0 { + key0 = key[0] + } else { + key0 = DefaultKey + } + + instance, ok := instanceMap[key0] + if !ok { + panic(fmt.Errorf("%s not config", key)) + } + return instance +} + +func NewDefault(config *Config) (*MyDB, error) { + return New(DefaultKey, config) +} + +func New(key string, config *Config) (*MyDB, error) { var ( maxLifeTime, _ = time.ParseDuration(DefaultConfig.MaxLifeTime) maxIdleTime, _ = time.ParseDuration(DefaultConfig.MaxIdleTime) @@ -92,10 +116,13 @@ func New(config *Config) (*MyDB, error) { sqlDb.SetMaxIdleConns(config.MaxIdleConn) sqlDb.SetConnMaxLifetime(maxLifeTime) sqlDb.SetConnMaxIdleTime(maxIdleTime) - return &MyDB{ - DB: db, - SqlDB: sqlDb, - }, nil + + myDb := &MyDB{ + db: db, + sqlDb: sqlDb, + } + instanceMap[key] = myDb + return myDb, nil } func DefaultGormLogger(level gormLogger.LogLevel) gormLogger.Interface { @@ -113,7 +140,15 @@ func NewGormLogger(writer gormLogger.Writer, gormLoggerConfig gormLogger.Config) } func (m *MyDB) Close() { - if m.SqlDB != nil { - m.SqlDB.Close() + if m.sqlDb != nil { + m.sqlDb.Close() } } + +func (m *MyDB) DB() *gorm.DB { + return m.db +} + +func (m *MyDB) SqlDB() *sql.DB { + return m.sqlDb +} diff --git a/mymysql/mysql_test.go b/mymysql/mysql_test.go index 31701c3..25a9e8c 100644 --- a/mymysql/mysql_test.go +++ b/mymysql/mysql_test.go @@ -6,7 +6,7 @@ import ( ) func TestMysql(t *testing.T) { - myDB, err := New(&Config{ + myDB, err := NewDefault(&Config{ Dsn: "root:root@tcp(127.0.0.1:3306)/site?loc=Local&charset=utf8mb4&writeTimeout=3s&readTimeout=3s&timeout=2s&parseTime=true", MaxOpenConn: 16, MaxIdleConn: 4, @@ -23,7 +23,7 @@ func TestMysql(t *testing.T) { defer myDB.Close() var res = make(map[string]interface{}) - err = myDB.DB.Table("image").Limit(1).Take(&res).Error + err = myDB.db.Table("image").Limit(1).Take(&res).Error if err != nil { fmt.Println(err) return diff --git a/myredis/redis.go b/myredis/redis.go index 06540ac..db465f2 100644 --- a/myredis/redis.go +++ b/myredis/redis.go @@ -8,6 +8,8 @@ import ( "time" ) +const DefaultKey = "default" + var ( DefaultConfig = &Config{ Addr: "127.0.0.1:6379", @@ -18,11 +20,13 @@ var ( MaxConnAge: "1h", IdleTimeout: "10m", } + + instanceMap = make(map[string]*MyRedis) ) type ( MyRedis struct { - Redis *redis.Client + redis *redis.Client ctx context.Context } @@ -37,7 +41,27 @@ type ( } ) -func New(config *Config) (*MyRedis, error) { +func Instance(key ...string) *MyRedis { + var key0 string + + if len(key) > 0 { + key0 = key[0] + } else { + key0 = DefaultKey + } + + instance, ok := instanceMap[key0] + if !ok { + panic(fmt.Errorf("%s not config", key)) + } + return instance +} + +func NewDefault(config *Config) (*MyRedis, error) { + return New(DefaultKey, config) +} + +func New(key string, config *Config) (*MyRedis, error) { var ( maxConnAge, _ = time.ParseDuration(DefaultConfig.MaxConnAge) idleTimeout, _ = time.ParseDuration(DefaultConfig.IdleTimeout) @@ -82,18 +106,19 @@ func New(config *Config) (*MyRedis, error) { ctx := context.Background() rd := &MyRedis{} rd.ctx = ctx - rd.Redis = client - ping := rd.Redis.Ping(ctx) + rd.redis = client + ping := rd.redis.Ping(ctx) if ping.Err() != nil { return nil, fmt.Errorf("connet redis err: %s", ping.Err()) } + instanceMap[key] = rd return rd, nil } // Get 通用get func (r *MyRedis) Get(key string) (string, error) { - return r.Redis.Get(r.ctx, key).Result() + return r.redis.Get(r.ctx, key).Result() } // Set 通用set @@ -102,12 +127,12 @@ func (r *MyRedis) Set(key string, value interface{}, t ...time.Duration) (string if len(t) > 0 { t2 = t[0] } - return r.Redis.Set(r.ctx, key, value, t2).Result() + return r.redis.Set(r.ctx, key, value, t2).Result() } // GetJson json序列化 func (r *MyRedis) GetJson(key string) (interface{}, error) { - res := r.Redis.Get(r.ctx, key) + res := r.redis.Get(r.ctx, key) if res.Err() != nil { return nil, res.Err() } @@ -134,11 +159,15 @@ func (r *MyRedis) SetJson(key string, value interface{}, t ...time.Duration) (st if err != nil { return "", fmt.Errorf("set key:%s 序列化json失败", key) } - return r.Redis.Set(r.ctx, key, v, t2).Result() + return r.redis.Set(r.ctx, key, v, t2).Result() } func (r *MyRedis) Close() { - if r.Redis != nil { - r.Redis.Close() + if r.redis != nil { + r.redis.Close() } } + +func (r *MyRedis) Redis() *redis.Client { + return r.redis +} diff --git a/myredis/redis_test.go b/myredis/redis_test.go index 5a39efb..d4cbb39 100644 --- a/myredis/redis_test.go +++ b/myredis/redis_test.go @@ -7,7 +7,7 @@ import ( ) func TestRedis(t *testing.T) { - redis, err := New(&Config{ + redis, err := NewDefault(&Config{ Addr: "127.0.0.1:6379", Password: "", DB: 15,