diff --git a/mymysql/mysql.go b/mymysql/mysql.go index 5949ddc..caf8f23 100644 --- a/mymysql/mysql.go +++ b/mymysql/mysql.go @@ -58,6 +58,15 @@ func InitDefault(config *Config) error { } func Init(key string, config *Config) error { + db, err := New(config) + if err != nil { + return err + } + instanceMap[key] = db + return nil +} + +func New(config *Config) (*gorm.DB, error) { var ( maxLifeTime, _ = time.ParseDuration(DefaultConfig.MaxLifeTime) maxIdleTime, _ = time.ParseDuration(DefaultConfig.MaxIdleTime) @@ -75,7 +84,7 @@ func Init(key string, config *Config) error { if config.MaxLifeTime != "" { t, err := time.ParseDuration(config.MaxLifeTime) if err != nil { - return fmt.Errorf("parse MaxLifeTime err: %s\n", err) + return nil, fmt.Errorf("parse MaxLifeTime err: %s\n", err) } maxLifeTime = t @@ -84,7 +93,7 @@ func Init(key string, config *Config) error { if config.MaxIdleTime != "" { t, err := time.ParseDuration(config.MaxIdleTime) if err != nil { - return fmt.Errorf("parse MaxIdleTime err: %s\n", err) + return nil, fmt.Errorf("parse MaxIdleTime err: %s\n", err) } maxIdleTime = t } @@ -103,7 +112,7 @@ func Init(key string, config *Config) error { }) if err != nil { - return fmt.Errorf("connect mysql err: %s", err) + return nil, fmt.Errorf("connect mysql err: %s", err) } sqlDb, _ := db.DB() @@ -112,8 +121,7 @@ func Init(key string, config *Config) error { sqlDb.SetConnMaxLifetime(maxLifeTime) sqlDb.SetConnMaxIdleTime(maxIdleTime) - instanceMap[key] = db - return nil + return db, nil } func DefaultGormLogger(level gormLogger.LogLevel) gormLogger.Interface { @@ -130,12 +138,12 @@ func NewGormLogger(writer gormLogger.Writer, gormLoggerConfig gormLogger.Config) return gormLogger.New(writer, gormLoggerConfig) } -func Close(key string) { +func CloseInstance(key string) { db, _ := Instance(key).DB() db.Close() } -func CloseAll() { +func CloseAllInstance() { for _, v := range instanceMap { db, _ := v.DB() db.Close() diff --git a/mymysql/mysql_test.go b/mymysql/mysql_test.go index 9e4648f..af7b4fb 100644 --- a/mymysql/mysql_test.go +++ b/mymysql/mysql_test.go @@ -20,7 +20,7 @@ func TestMysql(t *testing.T) { return } - defer CloseAll() + defer CloseAllInstance() var res = make(map[string]interface{}) err = Instance().Table("image").Limit(1).Take(&res).Error diff --git a/myredis/redis.go b/myredis/redis.go index 24033a3..f3c246c 100644 --- a/myredis/redis.go +++ b/myredis/redis.go @@ -61,6 +61,15 @@ func InitDefault(config *Config) error { } func Init(key string, config *Config) error { + rd, err := New(config) + if err != nil { + return err + } + instanceMap[key] = rd + return nil +} + +func New(config *Config) (*MyRedis, error) { var ( maxConnAge, _ = time.ParseDuration(DefaultConfig.MaxConnAge) idleTimeout, _ = time.ParseDuration(DefaultConfig.IdleTimeout) @@ -77,7 +86,7 @@ func Init(key string, config *Config) error { if config.MaxConnAge != "" { t, err := time.ParseDuration(config.MaxConnAge) if err != nil { - return fmt.Errorf("parse MaxConnAge err: %s\n", err) + return nil, fmt.Errorf("parse MaxConnAge err: %s\n", err) } maxConnAge = t @@ -86,7 +95,7 @@ func Init(key string, config *Config) error { if config.IdleTimeout != "" { t, err := time.ParseDuration(config.IdleTimeout) if err != nil { - return fmt.Errorf("parse IdleTimeout err: %s\n", err) + return nil, fmt.Errorf("parse IdleTimeout err: %s\n", err) } idleTimeout = t @@ -107,11 +116,10 @@ func Init(key string, config *Config) error { rd.Client = client ping := rd.Client.Ping(ctx) if ping.Err() != nil { - return fmt.Errorf("connet redis err: %s", ping.Err()) + return nil, fmt.Errorf("connet redis err: %s", ping.Err()) } - instanceMap[key] = rd - return nil + return rd, nil } // GetSimple 通用get @@ -165,11 +173,11 @@ func (r *MyRedis) SetJson(key string, value interface{}, t ...time.Duration) (st return r.Client.Set(ctx, key, v, t2).Result() } -func Close(key string) { +func CloseInstance(key string) { Instance(key).Client.Close() } -func CloseAll() { +func CloseAllInstance() { for _, v := range instanceMap { v.Client.Close() } diff --git a/myredis/redis_test.go b/myredis/redis_test.go index 0dea500..2c61744 100644 --- a/myredis/redis_test.go +++ b/myredis/redis_test.go @@ -21,7 +21,7 @@ func TestRedis(t *testing.T) { fmt.Println(err) return } - defer CloseAll() + defer CloseAllInstance() set, err := Instance().Set(context.Background(), "name", "qwe123", time.Minute).Result() if err != nil {