package myconf import ( "fmt" "git.makemake.in/kzkzzzz/mycommon" pjson "github.com/knadh/koanf/parsers/json" ptoml "github.com/knadh/koanf/parsers/toml" pyaml "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/confmap" kfile "github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/providers/posflag" "github.com/knadh/koanf/v2" "github.com/spf13/cast" "github.com/spf13/pflag" "github.com/spf13/viper" "log" "os" "path/filepath" "strings" "sync" ) const ( Default = "default" envConfigFile = "APP_CONFIG_FILE" ) var ( configInstanceMap = &sync.Map{} // 记录配置实例, 使用conf.Conf("name") 获取 defaultConf *Config defaultConfigFile = []string{"config.yaml", "config.yml"} // 兼容yaml和yml格式 flagConfigFile string // 命令行传参指定的配置文件 ) type Config struct { InstanceName string ConfigFile string kof *koanf.Koanf lock *sync.RWMutex koanfOpt []koanf.Option } type Opt func(c *Config) func WithKoanfOpt(v ...koanf.Option) Opt { return func(c *Config) { c.koanfOpt = v } } func WithLoadOverwrite(v bool) Opt { return func(c *Config) { if v == true { c.koanfOpt = append(c.koanfOpt, koanf.WithMergeFunc(func(src, dest map[string]interface{}) error { dest = src return nil })) } } } func init() { defaultConf = New(Default) log.SetOutput(os.Stdout) log.SetFlags(log.LstdFlags | log.Lshortfile) // --config 指定配置文件 pflag.StringVar(&flagConfigFile, "config", "", "set config file") } // New 初始化配置 name 只是一个实例标识区分名字, 不一定是文件名称, 文件名称在loadConf的时候定义 // 由于viper在访问key时不区分大小写(会强制转为小写, https://github.com/spf13/viper/issues/373) // 如 Get("mysql.Dsn"), Get("mysql.dsn"), // 或者把 Unmarshal解析到 map[string]string等map格式时, key也是转为小写, 导致部分场景判断可能会有问题 // 改为使用 koanf 这个库 https://github.com/knadh/koanf 可以区分大小写 func New(name string, opts ...Opt) *Config { cfg := &Config{ InstanceName: name, kof: koanf.New("."), lock: &sync.RWMutex{}, // 远程更新用到 koanfOpt: make([]koanf.Option, 0), } for _, opt := range opts { opt(cfg) } configInstanceMap.Store(name, cfg) return cfg } // Conf 根据name获取对应的实例 func Conf(name ...string) *Config { if len(name) == 0 { v, ok := configInstanceMap.Load(Default) if !ok { panic(fmt.Errorf("conf [%s] not exists", Default)) } //vv, ok := v.(*Config) return v.(*Config) } v, ok := configInstanceMap.Load(name[0]) if !ok { panic(fmt.Errorf("conf [%s] not exists", name[0])) } return v.(*Config) } // FromViper 转化viper配置 func FromViper(v *viper.Viper) *Config { kof := koanf.New(".") keys := v.AllKeys() // key是 mysql.dsn, http.port 这种展平的格式, 中间默认是.分割 data := make(map[string]any, len(keys)) for _, key := range keys { data[key] = v.Get(key) } // 把viper的数据转化到koanf err := kof.Load(confmap.Provider(data, "."), nil) if err != nil { log.Printf("load viper map err: %s", err) } cfg := &Config{ kof: kof, lock: &sync.RWMutex{}, // 远程更新用到 } return cfg } // Load 加载命令行参数, 以及默认配置文件 func Load() *Config { parseFlag() // 命令行参数指定文件 --config 指定配置文件 if flagConfigFile != "" { LoadFile(flagConfigFile) } else { if v := getDefaultConfigFile(); v != "" { LoadFile(v) } } // 命令行参数后加载, 可以覆盖文件配置 LoadFlag() return defaultConf } // LoadFlag 加载命令行参数 func LoadFlag() { parseFlag() err := defaultConf.kof.Load(posflag.Provider(pflag.CommandLine, ".", defaultConf.kof), nil) //err := defaultConf.kof.BindPFlags(pflag.CommandLine) if err != nil { panic(fmt.Errorf("load command line fail: %s", err)) } // 如果有环境变量定义, 则覆盖掉 if v := os.Getenv(envConfigFile); v != "" { flagConfigFile = v } } // 解析命令行参数 func parseFlag() { if pflag.Parsed() { return } pflag.Parse() } // LoadFile 加载指定的文件 func LoadFile(configFile string) *Config { return defaultConf.LoadFile(configFile) } func getDefaultConfigFile() string { var cf = "" // 兼容yaml, yml for _, v := range defaultConfigFile { if mycommon.ExistFile(v) { cf = v break } } return cf } func (c *Config) LoadFile(configFile string) *Config { c.ConfigFile = configFile log.Printf("read local config file: %s", configFile) err := c.kof.Load(kfile.Provider(configFile), GetKoanfParserByFileExt(configFile), c.koanfOpt...) if err != nil { panic(fmt.Errorf("read file fail: %s", err)) } return c } func (c *Config) All() map[string]any { c.RLock() defer c.RUnLock() return c.kof.All() } func (c *Config) Raw() map[string]any { c.RLock() defer c.RUnLock() return c.kof.Raw() } func (c *Config) Get(key string) any { c.RLock() defer c.RUnLock() return c.kof.Get(key) } func (c *Config) GetString(key string) string { c.RLock() defer c.RUnLock() return c.kof.String(key) } func (c *Config) GetInt(key string) int { c.RLock() defer c.RUnLock() return c.kof.Int(key) } func (c *Config) GetInt64(key string) int64 { c.RLock() defer c.RUnLock() return c.kof.Int64(key) } func (c *Config) GetBool(key string) bool { c.RLock() defer c.RUnLock() return c.kof.Bool(key) } func (c *Config) GetStringMap(key string) map[string]string { c.RLock() defer c.RUnLock() return c.kof.StringMap(key) } func (c *Config) GetStringsMap(key string) map[string][]string { c.RLock() defer c.RUnLock() return c.kof.StringsMap(key) } func (c *Config) GetMap(key string) map[string]any { c.RLock() defer c.RUnLock() return cast.ToStringMap(c.kof.Get(key)) } func (c *Config) GetIntMap(key string) map[string]int { c.RLock() defer c.RUnLock() return c.kof.IntMap(key) } func (c *Config) GetStringSlice(key string) []string { c.RLock() defer c.RUnLock() return c.kof.Strings(key) } func (c *Config) GetIntSlice(key string) []int { c.RLock() defer c.RUnLock() return c.kof.Ints(key) } func (c *Config) WithRead(fn func(k *koanf.Koanf)) { c.RLock() defer c.RUnLock() fn(c.kof) } func (c *Config) WithDo(fn func(k *koanf.Koanf)) { c.Lock() defer c.UnLock() fn(c.kof) } // Sub 获取子路径的配置 func (c *Config) Sub(key string) *Config { c.RLock() defer c.RUnLock() newCfg := &Config{ InstanceName: fmt.Sprintf("%s.%s", c.InstanceName, key), kof: c.kof.Cut(key), lock: &sync.RWMutex{}, } return newCfg } func (c *Config) UnmarshalKey(key string, toVal interface{}) error { c.RLock() defer c.RUnLock() return c.kof.Unmarshal(key, toVal) } func (c *Config) Unmarshal(toVal interface{}) error { c.RLock() defer c.RUnLock() return c.kof.Unmarshal("", &toVal) } func (c *Config) Set(key string, value interface{}) { c.Lock() defer c.UnLock() c.kof.Set(key, value) } func (c *Config) GetStringDefault(key, defaultVal string) string { c.RLock() defer c.RUnLock() v := c.kof.String(key) if v == "" { return defaultVal } return v } func (c *Config) GetIntDefault(key string, defaultVal int) int { c.RLock() defer c.RUnLock() v := c.kof.String(key) // 未设置 空字符串 if v == "" { return defaultVal } return c.kof.Int(key) } func (c *Config) GetBoolDefault(key string, defaultVal bool) bool { c.RLock() defer c.RUnLock() v := c.kof.String(key) // 未设置 空字符串 if v == "" { return defaultVal } return c.kof.Bool(key) } func (c *Config) RLock() { if c.lock != nil { c.lock.RLock() } } func (c *Config) RUnLock() { if c.lock != nil { c.lock.RUnlock() } } func (c *Config) Lock() { if c.lock != nil { c.lock.Lock() } } func (c *Config) UnLock() { if c.lock != nil { c.lock.Unlock() } } // GetKoanfParserByFileExt 根据文件后缀文件类型 .json .yaml .toml 获取 parser func GetKoanfParserByFileExt(configFile string) koanf.Parser { ext := strings.TrimLeft(filepath.Ext(configFile), ".") ext = strings.ToLower(ext) switch ext { case "yaml", "yml": return pyaml.Parser() case "json": return pjson.Parser() case "toml": return ptoml.Parser() default: return pyaml.Parser() } }