diff --git a/go.mod b/go.mod index 32690e2..d2dc655 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/go-playground/validator/v10 v10.11.1 github.com/go-redis/redis/v8 v8.11.5 github.com/json-iterator/go v1.1.12 + github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.14.0 github.com/stretchr/testify v1.8.1 go.uber.org/zap v1.24.0 @@ -36,7 +37,6 @@ require ( github.com/spf13/afero v1.9.3 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.4.2 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/multierr v1.9.0 // indirect diff --git a/myconf/conf.go b/myconf/conf.go index f217400..39ff4a2 100644 --- a/myconf/conf.go +++ b/myconf/conf.go @@ -2,27 +2,65 @@ package myconf import ( "fmt" + "github.com/spf13/pflag" "github.com/spf13/viper" + "log" ) +type Config struct { + *viper.Viper +} + var ( - vp *viper.Viper + conf = &Config{Viper: viper.New()} ) -func Init(confPath string, conf interface{}) { - vp = viper.New() - vp.SetConfigFile(confPath) - err := vp.ReadInConfig() - if err != nil { - panic(fmt.Errorf("读取配置文件失败 %s", err)) +// 加载命令行参数 +func LoadFlag() { + if !pflag.Parsed() { + pflag.Parse() } - - err = vp.Unmarshal(conf) + err := conf.BindPFlags(pflag.CommandLine) if err != nil { - panic(fmt.Errorf("解析配置文件失败 %s", err)) + panic(fmt.Errorf("load command line fail: %s", err)) } } -func Conf() *viper.Viper { - return vp +// 指定文件加载 +func LoadFile(confFile string) { + log.Printf("read conf file: %s", confFile) + + conf.SetConfigFile(confFile) + err := conf.ReadInConfig() + if err != nil { + panic(fmt.Errorf("read file fail: %s", err)) + } +} + +func Conf() *Config { + return conf +} + +func (c *Config) GetStringDefault(key, defaultVal string) string { + v := c.GetString(key) + if v == "" { + return defaultVal + } + return v +} + +func (c *Config) GetIntDefault(key string, defaultVal int) int { + v := c.GetString(key) // 未设置 空字符串 + if v == "" { + return defaultVal + } + return c.GetInt(key) +} + +func (c *Config) GetBoolDefault(key string, defaultVal bool) bool { + v := c.GetString(key) // 未设置 空字符串 + if v == "" { + return defaultVal + } + return c.GetBool(key) }