From 7e0bf824182e4164cd68bfd9d0ac57c9c1f50a6f Mon Sep 17 00:00:00 2001 From: kzkzzzz Date: Sat, 22 Mar 2025 01:24:57 +0800 Subject: [PATCH] update --- .DS_Store | Bin 0 -> 6148 bytes common.go | 64 +++ go.mod | 88 +++-- go.sum | 742 +++++++++++++---------------------- graceful/graeceful.go | 79 ++++ myconf/conf.go | 389 ++++++++++++++++-- myconf/test.yaml | 8 - mygrpc/error.go | 19 + mygrpc/grpc.go | 41 ++ mygrpc/grpcc/client.go | 154 ++++++++ mygrpc/grpcsr/server.go | 331 ++++++++++++++++ mylog/log.go | 72 +--- mymysql/mysql.go | 281 +++++++------ mymysql/mysql_test.go | 33 -- mymysql/option.go | 17 + myredis/option.go | 19 + myredis/redis.go | 369 ++++++++++------- myredis/redis_test.go | 40 -- myregistry/consul/builder.go | 189 +++++++++ myregistry/consul/consul.go | 153 ++++++++ myregistry/consul/target.go | 101 +++++ myregistry/reigster.go | 21 + 22 files changed, 2291 insertions(+), 919 deletions(-) create mode 100644 .DS_Store create mode 100644 graceful/graeceful.go delete mode 100644 myconf/test.yaml create mode 100644 mygrpc/error.go create mode 100644 mygrpc/grpc.go create mode 100644 mygrpc/grpcc/client.go create mode 100644 mygrpc/grpcsr/server.go delete mode 100644 mymysql/mysql_test.go create mode 100644 mymysql/option.go create mode 100644 myredis/option.go delete mode 100644 myredis/redis_test.go create mode 100644 myregistry/consul/builder.go create mode 100644 myregistry/consul/consul.go create mode 100644 myregistry/consul/target.go create mode 100644 myregistry/reigster.go diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..31513669f34176c1e2d97fbcdae44f88d29d41d4 GIT binary patch literal 6148 zcmeHKyH3L}6umA10!UE7KnEnUvLL!3u>}-X20AdXKnWD2gp~4-*zyU8m4OdIto;Nw z{)KaGE3umdHiVF`WFN=(*gidZxF#a8cOQ3&+C=1|FnaT-YJ&4zD$&wCYe2zsB!h16 zEK7$Krt5GHI0ycs1N`oKv`GWHqC4yT72D@#F3qB-pJaXX`AY}OtPk8bIP z5?jOKYn5T@cO_j&MC&zdTZ-x9~XKGhNSEz=Q z(1#D9R~Gt&BIN3rKag}1p+Z+X2b=@815@g<&HMj;^ZDN% 0 { + return v[0] + } + } + + p, ok := peer.FromContext(ctx) + if ok { + switch v := p.Addr.(type) { + case *net.TCPAddr: + return v.IP.String() + } + } + + return "" +} diff --git a/mygrpc/grpcc/client.go b/mygrpc/grpcc/client.go new file mode 100644 index 0000000..aadee83 --- /dev/null +++ b/mygrpc/grpcc/client.go @@ -0,0 +1,154 @@ +package grpcc + +import ( + "context" + "fmt" + "git.makemake.in/kzkzzzz/mycommon/myconf" + "git.makemake.in/kzkzzzz/mycommon/mygrpc" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "log" + "time" +) + +type ClientConf struct { + useDefaultBufferCfg bool + conf *myconf.Config + grpcOpts []grpc.DialOption + unaryMiddlewares []grpc.UnaryClientInterceptor +} + +type Opt func(*ClientConf) + +func UseDefaultBufferCfg(v bool) Opt { + return func(c *ClientConf) { + c.useDefaultBufferCfg = v + } +} + +func WithConf(v *myconf.Config) Opt { + return func(c *ClientConf) { + c.conf = v + } +} + +func WithGrpcOpts(v ...grpc.DialOption) Opt { + return func(c *ClientConf) { + c.grpcOpts = v + } +} + +func WithUnaryMiddlewares(v ...grpc.UnaryClientInterceptor) Opt { + return func(c *ClientConf) { + c.unaryMiddlewares = append(c.unaryMiddlewares, v...) + } +} + +func MustNew(grpcUrl string, opts ...Opt) *grpc.ClientConn { + client, err := New(grpcUrl, opts...) + if err != nil { + panic(err) + } + return client +} + +func New(grpcUrl string, opts ...Opt) (*grpc.ClientConn, error) { + log.Printf("new grpc client url: %s", grpcUrl) + + c := &ClientConf{ + useDefaultBufferCfg: true, + unaryMiddlewares: []grpc.UnaryClientInterceptor{WrapRequestError()}, // 默认加上错误包装 + } + + for _, opt := range opts { + opt(c) + } + + dialOpts := []grpc.DialOption{ + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: time.Second * 20, // 如果没有activity,则每隔N s发送一个ping包 + Timeout: time.Second * 5, // 如果ping ack N s之内未返回则认为连接已断开 + PermitWithoutStream: true, // 如果没有active的stream,是否允许发送ping + }), + // 参考 https://github.com/grpc/grpc-go/tree/master/examples/features/load_balancing 设置轮训策略 + grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), // This sets the initial balancing policy. + grpc.WithTransportCredentials(insecure.NewCredentials()), + } + + // 使用默认的buffer调整配置 + if c.useDefaultBufferCfg { + dialOpts = append(dialOpts, + grpc.WithInitialWindowSize(mygrpc.DefaultWindowSize), + grpc.WithInitialConnWindowSize(mygrpc.DefaultWindowSize), + + grpc.WithReadBufferSize(mygrpc.DefaultReadBufferSize), + grpc.WithWriteBufferSize(mygrpc.DefaultWriteBufferSize), + + grpc.WithUnaryInterceptor(WrapRequestError()), + ) + } + + if len(c.unaryMiddlewares) > 0 { + grpc.WithChainUnaryInterceptor(c.unaryMiddlewares...) + } + + if len(c.grpcOpts) > 0 { + dialOpts = append(dialOpts, c.grpcOpts...) + } + + conn, err := grpc.NewClient( + grpcUrl, + dialOpts..., + ) + if err != nil { + return nil, err + } + + return conn, nil +} + +func WrapRequestError() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + // 提取服务名称 + var serviceName string + if md, ok := metadata.FromOutgoingContext(ctx); ok { + if v := md.Get(mygrpc.HeaderServiceName); len(v) > 0 { + serviceName = v[0] + } + } + + err := invoker(ctx, method, req, reply, cc, opts...) + if err != nil { + + if serviceName != "" { + return fmt.Errorf("request grpc err: [%s - %s] %s", serviceName, method, err) + } + + return fmt.Errorf("request grpc err: [%s] %s", method, err) + } + + return nil + } +} + +// Timeout 客户端超时 +func Timeout(timeout time.Duration) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + tCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + err := invoker(tCtx, method, req, reply, cc, opts...) + + if v, ok := status.FromError(err); ok && v.Code() == codes.DeadlineExceeded { + //return status.Errorf(grpcserver.Timeout, "call %s timeout %s", method, timeout) + return mygrpc.GrpcClientTimeout("request timeout: %s", timeout) + } + + return err + } + +} diff --git a/mygrpc/grpcsr/server.go b/mygrpc/grpcsr/server.go new file mode 100644 index 0000000..13441ee --- /dev/null +++ b/mygrpc/grpcsr/server.go @@ -0,0 +1,331 @@ +package grpcsr + +import ( + "context" + "fmt" + "git.makemake.in/kzkzzzz/mycommon" + "git.makemake.in/kzkzzzz/mycommon/graceful" + "git.makemake.in/kzkzzzz/mycommon/myconf" + "git.makemake.in/kzkzzzz/mycommon/mygrpc" + "git.makemake.in/kzkzzzz/mycommon/mylog" + "git.makemake.in/kzkzzzz/mycommon/myregistry" + "github.com/spf13/pflag" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/status" + "log" + "net" + + "runtime/debug" + "time" +) + +//const DefaultInstanceName = "grpc" + +var _ graceful.IRunner = (*Server)(nil) + +type Conf struct { + Addr string + Port int + Ip string + Log bool +} + +type Opt func(server *Server) + +type Server struct { + gs *grpc.Server + serviceId string + serviceName string + serverConf *Conf + reg myregistry.IRegister + grpcOpts []grpc.ServerOption + logger mylog.ILogger + registerGrpcFn func(*grpc.Server) // 注册grpc服务, 使用函数延迟调用, 便于先初始化中间件等操作 + + unaryMiddlewares []grpc.UnaryServerInterceptor // grpc一元服务端中间件 + + useDefaultBufferCfg bool + delayStopMs int + + serviceRegInfo *myregistry.ServiceInfo +} + +func UseDefaultBufferCfg(v bool) Opt { + return func(server *Server) { + server.useDefaultBufferCfg = v + } +} + +func WithRegistry(serviceName string, reg myregistry.IRegister) Opt { + return func(server *Server) { + server.serviceName = serviceName + server.reg = reg + } +} + +func WithGrpcOpts(v ...grpc.ServerOption) Opt { + return func(server *Server) { + server.grpcOpts = v + } +} + +func WithDelayStopMs(v int) Opt { + return func(server *Server) { + server.delayStopMs = v + } +} + +func SetFlag() { + pflag.Int("grpc.port", 18082, "listen port") + pflag.String("grpc.log", "true", "enable request log") +} + +func New(cfg *myconf.Config, opts ...Opt) *Server { + cf := &Conf{} + err := cfg.UnmarshalKey("grpc", cf) + if err != nil { + panic(err) + } + + // 命令行的参数覆盖一次, Unmarshal解析的时候, 不会用命令行的参数覆盖 https://github.com/spf13/viper/issues/190 + cf.Port = cfg.GetInt(fmt.Sprintf("grpc.port")) + cf.Log = cfg.GetBool(fmt.Sprintf("grpc.log")) + + return NewByConf(cf, opts...) +} + +func NewByConf(conf *Conf, opts ...Opt) *Server { + s := &Server{ + serverConf: conf, + useDefaultBufferCfg: true, + } + for _, opt := range opts { + opt(s) + } + + if s.logger == nil { + s.logger = mylog.GetLogger() + } + + if s.reg != nil && s.serviceName == "" { + panic("service name is empty") + } + + s.unaryMiddlewares = []grpc.UnaryServerInterceptor{ + s.grpcRecover(), // 默认启用recover中间件 + } + + if s.serverConf.Log { + s.unaryMiddlewares = append(s.unaryMiddlewares, s.requestLog()) + } + + return s +} + +func (s *Server) Use(middlewares ...grpc.UnaryServerInterceptor) { + s.unaryMiddlewares = append(s.unaryMiddlewares, middlewares...) +} + +func (s *Server) RegisterGrpc(fn func(*grpc.Server)) { + s.registerGrpcFn = fn +} + +func (s *Server) initServer() { + grpcOpts := []grpc.ServerOption{ + grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + MinTime: time.Second * 5, // 如果客户端两次 ping 的间隔小于 N,则关闭连接 + PermitWithoutStream: true, // 即使没有 active stream, 也允许 ping + }), + grpc.KeepaliveParams(keepalive.ServerParameters{ + MaxConnectionIdle: time.Hour * 2, // 空闲连接时间 + MaxConnectionAgeGrace: time.Second * 30, // 在强制关闭连接之间, 允许有 N 的时间完成 pending 的 rpc 请求 + Time: time.Second * 20, // 如果一个连接空闲超过 N, 则发送一个 ping 请求 + Timeout: time.Second * 5, // 如果 ping 请求 N 内未收到回复, 则认为该连接已断开 + }), + } + + if s.useDefaultBufferCfg { + grpcOpts = append(grpcOpts, + grpc.InitialWindowSize(mygrpc.DefaultWindowSize), + grpc.InitialConnWindowSize(mygrpc.DefaultWindowSize), + + grpc.ReadBufferSize(mygrpc.DefaultReadBufferSize), + grpc.WriteBufferSize(mygrpc.DefaultWriteBufferSize), + ) + } + + if len(s.unaryMiddlewares) > 0 { + grpcOpts = append(grpcOpts, grpc.ChainUnaryInterceptor(s.unaryMiddlewares...)) + } + + if len(s.grpcOpts) > 0 { + grpcOpts = append(grpcOpts, s.grpcOpts...) + } + + s.gs = grpc.NewServer(grpcOpts...) + + // 注册grpc服务 + s.registerGrpcFn(s.gs) +} + +func (s *Server) Run(ctx context.Context) error { + s.initServer() + + // 端口如果=0, 监听随机端口 + addr0 := fmt.Sprintf("%s:%d", s.serverConf.Addr, s.serverConf.Port) + lis, err := net.Listen("tcp", addr0) + if err != nil { + return err + } + + // 获取监听的端口 + port := lis.Addr().(*net.TCPAddr).Port + + // 健康服务 + healthServer := health.NewServer() + grpc_health_v1.RegisterHealthServer(s.gs, healthServer) + + // 服务反射, 方便调试 + reflection.Register(s.gs) + + var svcIp = s.serverConf.Ip + if svcIp == "" { + svcIp = mycommon.GetOutboundIP() + } + + // 注册服务 + if s.reg != nil { + s.serviceRegInfo = &myregistry.ServiceInfo{ + ServiceName: s.serviceName, + Ip: svcIp, + Port: port, + } + + err = s.reg.Register(s.serviceRegInfo) + if err != nil { + return err + } + } + + addr := fmt.Sprintf("%s:%d", s.serverConf.Addr, port) + log.Printf("grpc server listen on %s", addr) + + err = s.gs.Serve(lis) + if err != nil { + log.Printf("start grpc server err: %s", err) + return err + } + + return nil +} + +func (s *Server) Stop() { + if s.reg != nil { + err := s.reg.Deregister(s.serviceRegInfo) + if err != nil { + s.logger.Errorf("grpc server deregister err: %s", err) + } + } + + // 如果使用k8s service, 关闭pod和往service注销ip是同时进行的, 如果退出服务比注销ip先完成, 可能有流量继续进来, 导致请求失败 + // 延迟一段时间, 确保服务已经注销ip, 再关闭服务 + + // 如何使用注册中心, 先从中心退出ip, 也延迟一段时间, 等上游网关更新ip完成(正常不会太久), 不会有流量进来旧服务, 再退出服务 + if s.delayStopMs > 0 { + delayTime := time.Millisecond * time.Duration(s.delayStopMs) + log.Printf("grpc server delay stop: %s", delayTime) + time.Sleep(delayTime) + } + + s.gs.GracefulStop() + + log.Printf("grpc server stop") +} + +type handleResp struct { + resp interface{} + err error +} + +func (s *Server) requestLog() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + start := time.Now() + + resp, err := handler(ctx, req) + + var ( + code codes.Code + codeMsg = "OK" + ) + if err != nil { + fromError, ok := status.FromError(err) + if ok { + code = fromError.Code() + } else { + code = status.New(codes.Unknown, err.Error()).Code() + } + + codeMsg = fmt.Sprintf("Error Code: %s(%d)", code.String(), uint32(code)) + } + + s.logger.Infof( + "%s - %s - %s - %s", + codeMsg, time.Since(start), mygrpc.ClientIP(ctx), info.FullMethod, + ) + + return resp, err + } +} + +func (s *Server) grpcRecover() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + defer func() { + if err0 := recover(); err0 != nil { + log.Printf("%s - panic: %v\n%s", info.FullMethod, err0, debug.Stack()) + err = fmt.Errorf("server err: %s - system err: %s", info.FullMethod, err0) + } + }() + + return handler(ctx, req) + } +} + +// Timeout 控制服务端超时 +func Timeout(timeout time.Duration) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + tCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ch := make(chan *handleResp, 1) + + go func() { + defer func() { + if err0 := recover(); err0 != nil { + log.Printf("%s - panic: %v\n%s", info.FullMethod, err0, debug.Stack()) + ch <- &handleResp{nil, fmt.Errorf("server err: %s - system err: %s", info.FullMethod, err0)} + } + }() + //start := time.Now() + r := &handleResp{} + r.resp, r.err = handler(tCtx, req) + //log.Printf("rta server time: %s", time.Since(start)) + ch <- r + }() + + select { + case <-tCtx.Done(): + return nil, mygrpc.GrpcServerTimeout("server err: grpc handle timeout %s %s", timeout, info.FullMethod) + //return nil, status.Errorf(codes.DeadlineExceeded, "grpc handle timeout %s %s", timeout, info.FullMethod) + + case res := <-ch: + return res.resp, res.err + } + + //return nil, fmt.Errorf("handle err %s.%s", info.Server, info.FullMethod) + } +} diff --git a/mylog/log.go b/mylog/log.go index 27a52df..0d7d081 100644 --- a/mylog/log.go +++ b/mylog/log.go @@ -3,13 +3,18 @@ package mylog import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" - "gopkg.in/natefinch/lumberjack.v2" "io" "os" - "path/filepath" "strings" ) +type ILogger interface { + Debugf(format string, v ...any) + Infof(format string, v ...any) + Warnf(format string, v ...any) + Errorf(format string, v ...any) +} + const ( DebugLevel = "DEBUG" InfoLevel = "INFO" @@ -36,13 +41,7 @@ var ( ConsoleWriter: os.Stdout, } - DefaultLogFile = &LogFile{ - LogFilePath: "logs", - MaxSize: 200, - MaxAge: 0, - MaxBackups: 0, - } - globalLog = NewLogger("debug", DefaultConfig) + globalLog = NewLogger(DefaultConfig) ) type ( @@ -71,7 +70,7 @@ func SetLogLevel(level string) { } func Init() { - globalLog = NewLogger("app", &Config{ + globalLog = NewLogger(&Config{ Level: defaultLogLevel, NeedLogFile: false, ConsoleWriter: os.Stdout, @@ -79,11 +78,11 @@ func Init() { } // InitWithConfig 覆盖默认日志 -func InitWithConfig(serverName string, config *Config) { - globalLog = NewLogger(serverName, config) +func InitWithConfig(config *Config) { + globalLog = NewLogger(config) } -func NewLogger(serverName string, config *Config) *ZapLog { +func NewLogger(config *Config) *ZapLog { if config == nil { config = DefaultConfig } @@ -96,29 +95,15 @@ func NewLogger(serverName string, config *Config) *ZapLog { cores := make([]zapcore.Core, 0) // 使用控制台输出 - if config.ConsoleWriter != nil { - cfg := zap.NewProductionEncoderConfig() - cfg.EncodeLevel = zapcore.CapitalColorLevelEncoder - cfg.ConsoleSeparator = " | " - // 指定日志时间格式 - cfg.EncodeTime = zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05.000") - cfg.EncodeCaller = zapcore.ShortCallerEncoder - encoder := zapcore.NewConsoleEncoder(cfg) - core := zapcore.NewCore(encoder, zapcore.AddSync(config.ConsoleWriter), level) - cores = append(cores, core) - } - - if config.NeedLogFile { - cfg := zap.NewProductionEncoderConfig() - cfg.EncodeLevel = zapcore.CapitalLevelEncoder - cfg.ConsoleSeparator = " | " - // 指定日志时间格式 - cfg.EncodeTime = zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05.000") - cfg.EncodeCaller = zapcore.ShortCallerEncoder - encoder := zapcore.NewConsoleEncoder(cfg) - core := zapcore.NewCore(encoder, zapcore.AddSync(getRollingFileWriter(serverName, config)), level) - cores = append(cores, core) - } + cfg := zap.NewProductionEncoderConfig() + cfg.EncodeLevel = zapcore.CapitalColorLevelEncoder + cfg.ConsoleSeparator = " | " + // 指定日志时间格式 + cfg.EncodeTime = zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05.000") + cfg.EncodeCaller = zapcore.ShortCallerEncoder + encoder := zapcore.NewConsoleEncoder(cfg) + core := zapcore.NewCore(encoder, zapcore.AddSync(config.ConsoleWriter), level) + cores = append(cores, core) opts := make([]zap.Option, 0) if config.ZapOpt != nil { @@ -134,21 +119,6 @@ func NewLogger(serverName string, config *Config) *ZapLog { } } -func getRollingFileWriter(serverName string, config *Config) *lumberjack.Logger { - if config.LogFile == nil { - config.LogFile = DefaultLogFile - } - - return &lumberjack.Logger{ - Filename: filepath.Join(config.LogFile.LogFilePath, serverName+".log"), - MaxSize: config.LogFile.MaxSize, - MaxAge: config.LogFile.MaxAge, - MaxBackups: config.LogFile.MaxBackups, - LocalTime: true, - Compress: false, - } -} - func (z *ZapLog) Debug(args ...interface{}) { z.sugarLog.Debug(args...) } diff --git a/mymysql/mysql.go b/mymysql/mysql.go index bf39564..0384f38 100644 --- a/mymysql/mysql.go +++ b/mymysql/mysql.go @@ -1,150 +1,203 @@ package mymysql import ( + "database/sql" "fmt" - "git.makemake.in/kzkzzzz/mycommon/mylog" + "git.makemake.in/kzkzzzz/mycommon/myconf" + driverMysql "github.com/go-sql-driver/mysql" + "github.com/google/uuid" "gorm.io/driver/mysql" "gorm.io/gorm" - gormLogger "gorm.io/gorm/logger" + "gorm.io/gorm/logger" + "log" + "os" + "sync" "time" ) -const DefaultKey = "default" +const ( + DefaultInstance = "mysql" +) + +type MysqlDb struct { + *gorm.DB + SqlDB *sql.DB + gormConfig *gorm.Config + disablePing bool +} + +type Conf struct { + Dsn string + MaxOpenConn int // 最大连接数 + MaxIdleConn int // 最大空闲连接数 + MaxIdleTime string // 空闲时间 + MaxLifeTime string // 连接最大有效时间 + Debug bool + LogSqlSlowTimeMs int + LogDisableColor bool +} var ( - DefaultConfig = &Config{ - Dsn: "root:root@tcp(127.0.0.1:3306)/?loc=Local&charset=utf8mb4&parseTime=true", - MaxOpenConn: 32, - MaxIdleConn: 8, - MaxLifeTime: "4h", - MaxIdleTime: "15m", - Debug: true, - GormLogger: gormLogger.Default.LogMode(gormLogger.Info), - } - - instanceMap = make(map[string]*gorm.DB) + instanceMap = &sync.Map{} ) -type ( - Config struct { - Dsn string - MaxOpenConn int - MaxIdleConn int - MaxIdleTime string - MaxLifeTime string - Debug bool - GormLogger gormLogger.Interface - } -) - -func DB(key ...string) *gorm.DB { - var key0 string - - if len(key) > 0 { - key0 = key[0] +func GetDb(name ...string) *MysqlDb { + var instanceName string + if len(name) > 0 { + instanceName = name[0] } else { - key0 = DefaultKey + instanceName = DefaultInstance } - instance, ok := instanceMap[key0] + v, ok := instanceMap.Load(instanceName) if !ok { - panic(fmt.Errorf("mysql %s not config", key0)) + panic(fmt.Errorf("mysql instance [%s] not init", instanceName)) } - return instance + + return v.(*MysqlDb) } -func InitDefault(config *Config) { - Init(DefaultKey, config) -} - -func Init(key string, config *Config) { - db, err := New(config) +// InitDb 初始化全局默认db +func InitDb(config *myconf.Config, opts ...Opt) { + client, err := NewDb(DefaultInstance, config, opts...) if err != nil { panic(err) } - instanceMap[key] = db + instanceMap.Store(DefaultInstance, client) } -func New(config *Config) (*gorm.DB, error) { - var ( - maxLifeTime, _ = time.ParseDuration(DefaultConfig.MaxLifeTime) - maxIdleTime, _ = time.ParseDuration(DefaultConfig.MaxIdleTime) - logger gormLogger.Interface - ) - - if config.MaxOpenConn <= 0 { - config.MaxOpenConn = DefaultConfig.MaxOpenConn - } - - if config.MaxIdleConn <= 0 { - config.MaxIdleConn = DefaultConfig.MaxIdleConn - } - - if config.MaxLifeTime != "" { - t, err := time.ParseDuration(config.MaxLifeTime) - if err != nil { - return nil, fmt.Errorf("parse MaxLifeTime err: %s\n", err) - - } - maxLifeTime = t - } - - if config.MaxIdleTime != "" { - t, err := time.ParseDuration(config.MaxIdleTime) - if err != nil { - return nil, fmt.Errorf("parse MaxIdleTime err: %s\n", err) - } - maxIdleTime = t - } - - if config.GormLogger == nil { - level := gormLogger.Warn - if config.Debug { - level = gormLogger.Info - } - logger = DefaultGormLogger(level) - } - - db, err := gorm.Open(mysql.Open(config.Dsn), &gorm.Config{ - SkipDefaultTransaction: true, - Logger: logger, - }) - +// InitDbInstance 初始化全局的db +func InitDbInstance(instanceName string, config *myconf.Config, opts ...Opt) { + client, err := NewDb(instanceName, config, opts...) if err != nil { - return nil, fmt.Errorf("connect mysql err: %s", err) + panic(err) } - sqlDb, _ := db.DB() + instanceMap.Store(instanceName, client) +} - sqlDb.SetMaxOpenConns(config.MaxOpenConn) - sqlDb.SetMaxIdleConns(config.MaxIdleConn) - sqlDb.SetConnMaxLifetime(maxLifeTime) - sqlDb.SetConnMaxIdleTime(maxIdleTime) +func NewDb(instanceName string, config *myconf.Config, opts ...Opt) (*MysqlDb, error) { + cf := &Conf{Debug: true} + err := config.UnmarshalKey(instanceName, &cf) + if err != nil { + return nil, err + } + db, err := NewDbFromConf(cf, opts...) + if err != nil { + return nil, err + } + + instanceMap.Store(instanceName, db) + return db, nil +} + +func NewDbFromConf(cf *Conf, opts ...Opt) (*MysqlDb, error) { + parseDsn, err := driverMysql.ParseDSN(cf.Dsn) + if err != nil { + return nil, fmt.Errorf("mysql parse dsn error: %s", err) + } + + db := &MysqlDb{} + for _, opt := range opts { + opt(db) + } + + if db.gormConfig == nil { + db.gormConfig = &gorm.Config{ + SkipDefaultTransaction: true, + } + + lCfg := logger.Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: logger.Warn, + IgnoreRecordNotFoundError: false, + Colorful: true, + } + + if cf.LogSqlSlowTimeMs > 0 { + lCfg.SlowThreshold = time.Duration(cf.LogSqlSlowTimeMs) * time.Millisecond + } + + if cf.LogDisableColor { + lCfg.Colorful = false + } + + l := newGormLogger(lCfg) + + if cf.Debug { + db.gormConfig.Logger = l.LogMode(logger.Info) + } else { + db.gormConfig.Logger = l + } + } + + gormDB, err := gorm.Open(mysql.Open(cf.Dsn), db.gormConfig) + if err != nil { + return nil, err + } + + sqlDB, err := gormDB.DB() + if err != nil { + return nil, err + } + + if db.disablePing == false { + err = sqlDB.Ping() + if err != nil { + return nil, err + } + } + + if cf.MaxOpenConn <= 0 { + cf.MaxOpenConn = 1024 + } + + if cf.MaxIdleConn <= 0 { + // 默认最大空闲数等于最大连接数 + cf.MaxIdleConn = cf.MaxOpenConn + } + + if cf.MaxIdleTime == "" { + cf.MaxIdleTime = "10m" + } + + sqlDB.SetMaxOpenConns(cf.MaxOpenConn) + sqlDB.SetMaxIdleConns(cf.MaxIdleConn) + + if dv, err := time.ParseDuration(cf.MaxIdleTime); err != nil { + return nil, fmt.Errorf("parse MaxIdleTime err: %s", err) + } else { + sqlDB.SetConnMaxIdleTime(dv) + } + + // max life time默认暂不设置, 使用idle time控制即可 + if cf.MaxLifeTime != "" { + if dv, err := time.ParseDuration(cf.MaxLifeTime); err != nil { + return nil, fmt.Errorf("parse MaxLifeTime err: %s", err) + } else { + sqlDB.SetConnMaxLifetime(dv) + } + } + + db.DB = gormDB + + db.SqlDB = sqlDB + instanceMap.Store(uuid.New().String(), db) + + log.Printf("connect db success [addr:%s - db:%s]", parseDsn.Addr, parseDsn.DBName) return db, nil } -func DefaultGormLogger(level gormLogger.LogLevel) gormLogger.Interface { - return gormLogger.New(mylog.NewLogger("gorm", mylog.DefaultConfig), gormLogger.Config{ - SlowThreshold: time.Second * 2, - Colorful: true, - IgnoreRecordNotFoundError: false, - ParameterizedQueries: false, - LogLevel: level, +func CloseAll() { + instanceMap.Range(func(k, v any) bool { + db, err := (v.(*MysqlDb)).DB.DB() + if err != nil { + db.Close() + } + return true }) } -func NewGormLogger(writer gormLogger.Writer, gormLoggerConfig gormLogger.Config) gormLogger.Interface { - return gormLogger.New(writer, gormLoggerConfig) -} - -func CloseDB(key string) { - db, _ := DB(key).DB() - db.Close() -} - -func CloseAllDB() { - for _, v := range instanceMap { - db, _ := v.DB() - db.Close() - } +func newGormLogger(cfg logger.Config) logger.Interface { + return logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), cfg) } diff --git a/mymysql/mysql_test.go b/mymysql/mysql_test.go deleted file mode 100644 index 38b37ea..0000000 --- a/mymysql/mysql_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package mymysql - -import ( - "fmt" - "testing" -) - -func TestMysql(t *testing.T) { - err := InitDefault(&Config{ - Dsn: "root:Tqa129126@tcp(119.29.187.200:3306)/site?loc=Local&charset=utf8mb4&writeTimeout=3s&readTimeout=3s&timeout=2s&parseTime=true", - MaxOpenConn: 16, - MaxIdleConn: 4, - MaxIdleTime: "5m", - MaxLifeTime: "30m", - Debug: false, - GormLogger: nil, - }) - if err != nil { - fmt.Println(err) - return - } - - defer CloseAllDB() - - var res = make(map[string]interface{}) - err = DB().Table("image").Limit(1).Take(&res).Error - if err != nil { - fmt.Println(err) - return - } - - fmt.Printf("%+v\n", res) -} diff --git a/mymysql/option.go b/mymysql/option.go new file mode 100644 index 0000000..0695831 --- /dev/null +++ b/mymysql/option.go @@ -0,0 +1,17 @@ +package mymysql + +import "gorm.io/gorm" + +type Opt func(m *MysqlDb) + +func WithDisablePing(v bool) Opt { + return func(m *MysqlDb) { + m.disablePing = v + } +} + +func WithGormConfig(v *gorm.Config) Opt { + return func(m *MysqlDb) { + m.gormConfig = v + } +} diff --git a/myredis/option.go b/myredis/option.go new file mode 100644 index 0000000..0dc2c90 --- /dev/null +++ b/myredis/option.go @@ -0,0 +1,19 @@ +package myredis + +import ( + "github.com/redis/go-redis/v9" +) + +type Opt func(*Client) + +func WithRedisOpt(v *redis.Options) Opt { + return func(r *Client) { + r.redisOpt = v + } +} + +func WithDisablePing(v bool) Opt { + return func(r *Client) { + r.disablePing = v + } +} diff --git a/myredis/redis.go b/myredis/redis.go index 3783ee0..58b1aa5 100644 --- a/myredis/redis.go +++ b/myredis/redis.go @@ -3,178 +3,245 @@ package myredis import ( "context" "fmt" - "github.com/go-redis/redis/v8" - jsoniter "github.com/json-iterator/go" + "git.makemake.in/kzkzzzz/mycommon/myconf" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + "log" + "sync" "time" ) -const DefaultKey = "default" +const ( + DefaultInstance = "redis" +) + +type Client struct { + *redis.Client + redisOpt *redis.Options + disablePing bool +} + +//type Conf struct { +// Addr string +// Password string +// Db int +// PoolSize int // 连接池数量 如果不够用会继续新建连接 可以用 MaxActiveConns 限制 +// MinIdleConn int // 最小空闲连接数 预热可能用到 +// MaxIdleConn int // 最大空闲连接数 0不限制 +// MaxActiveConn int // 0 不限制, 如果不限制, 超出pool size可以继续新建立连接, 但使用完之后不会放回连接池 +// PoolTimeout string // 等待连接池超时时间 Default is ReadTimeout + 1 second. +// DialTimeout string // 拨号超时时间 +// ReadTimeout string // 读取超时 +// WriteTimeout string // 写入超时 +// MaxRetries int // 重试次数 +// ConnMaxIdleTime string // 连接空闲时间 +//} var ( - DefaultConfig = &Config{ - Addr: "127.0.0.1:6379", - Password: "", - DB: 0, - PoolSize: 16, - MinIdleConn: 4, - MaxConnAge: "4h", - IdleTimeout: "15m", - } - - instanceMap = make(map[string]*MyRedis) + defaultClient *redis.Client + instanceMap = &sync.Map{} ) -type ( - MyRedis struct { - *redis.Client - } - - Config struct { - Addr string - Password string - DB int - PoolSize int - MinIdleConn int - MaxConnAge string - IdleTimeout string - } -) - -func DB(key ...string) *MyRedis { - var key0 string - - if len(key) > 0 { - key0 = key[0] +func GetClient(name ...string) *Client { + var instanceName string + if len(name) > 0 { + instanceName = name[0] } else { - key0 = DefaultKey + instanceName = DefaultInstance } - instance, ok := instanceMap[key0] + v, ok := instanceMap.Load(instanceName) if !ok { - panic(fmt.Errorf("redis %s not config", key0)) + panic(fmt.Errorf("redis instance [%s] not init", instanceName)) } - return instance + + return v.(*Client) } -func InitDefault(config *Config) error { - return Init(DefaultKey, config) -} - -func Init(key string, config *Config) error { - rd, err := New(config) +// InitClient 初始化全局默认client +func InitClient(config *myconf.Config, opts ...Opt) { + client, err := NewClient(DefaultInstance, config, opts...) if err != nil { - return err + panic(err) } - instanceMap[key] = rd + instanceMap.Store(DefaultInstance, client) +} + +// InitClientInstance 初始化全局的client +func InitClientInstance(instanceName string, config *myconf.Config, opts ...Opt) { + client, err := NewClient(instanceName, config, opts...) + if err != nil { + panic(err) + } + instanceMap.Store(instanceName, client) +} + +func NewClient(instanceName string, config *myconf.Config, opts ...Opt) (*Client, error) { + cf := &Conf{} + err := config.UnmarshalKey(instanceName, &cf) + if err != nil { + return nil, err + } + client, err := NewClientFromConf(cf, opts...) + if err != nil { + return nil, err + } + + instanceMap.Store(instanceMap, client) + return client, nil +} + +type Conf struct { + Addr string + Password string + Db int + PoolSize int // 连接池数量 如果不够用会继续新建连接 可以用 MaxActiveConns 限制 + MinIdleConn int // 最小空闲连接数 预热可能用到 + MaxIdleConn int // 最大空闲连接数 0不限制 + MaxActiveConn int // 0 不限制, 如果不限制, 超出pool size可以继续新建立连接, 但使用完之后不会放回连接池 + DialTimeout string // 拨号超时时间, 时间单位格式 10ms, 60s, 15m, 2h... + ReadTimeout string // 读取超时 + WriteTimeout string // 写入超时 + ConnMaxIdleTime string // 连接空闲时间 + PoolTimeout string // 等待连接池超时时间 Default is ReadTimeout + 1 second. + MaxRetries int // 重试次数 +} + +func NewClientFromConf(cf *Conf, opts ...Opt) (*Client, error) { + c := &Client{ + redisOpt: DefaultRedisOpt(), + } + + for _, opt := range opts { + opt(c) + } + + err := c.parseConfToRedisOpt(cf) + if err != nil { + return nil, err + } + + client := redis.NewClient(c.redisOpt) + + if c.disablePing == false { + _, err := client.Ping(context.Background()).Result() + if err != nil { + return nil, fmt.Errorf("redis ping err: %s", err) + } + } + + c.Client = client + + instanceMap.Store(uuid.New().String(), c) + + log.Printf("connect redis success [addr:%s - db:%d]", cf.Addr, cf.Db) + + return c, nil +} + +func DefaultRedisOpt() *redis.Options { + return &redis.Options{ + Addr: "", + Password: "", + DB: 0, + MaxRetries: 3, // 重试次数 + DialTimeout: time.Millisecond * 500, // 拨号超时时间 + ReadTimeout: time.Second, // 读取超时 + WriteTimeout: time.Second * 3, // 写入超时 + PoolSize: 1024, // 连接池数量 如果不够用会继续新建连接 可以用 MaxActiveConns 限制 + PoolTimeout: time.Second + time.Second, // 等待连接池超时时间 Default is ReadTimeout + 1 second. + MinIdleConns: 0, // 最小空闲连接数 预热可能用到 + MaxIdleConns: 0, // 最大空闲连接数 0不限制 + MaxActiveConns: 0, // 0 不限制, 如果不限制, 超出pool size可以继续新建立连接, 但使用完之后不会放回连接池 + ConnMaxIdleTime: time.Minute * 10, // 连接空闲时间 + ContextTimeoutEnabled: true, // context控制超时用到 + } +} + +func (c *Client) parseTime(v string) (time.Duration, error) { + if v == "" { + return 0, nil + } + + d, err := time.ParseDuration(v) + if err != nil { + return 0, fmt.Errorf("parse time %s err: %s", v, err) + } + + return d, nil +} + +func (c *Client) parseConfToRedisOpt(cf *Conf) error { + c.redisOpt.Addr = cf.Addr + c.redisOpt.Password = cf.Password + c.redisOpt.DB = cf.Db + + if v := cf.PoolSize; v > 0 { + c.redisOpt.PoolSize = v + } + + if v := cf.MinIdleConn; v > 0 { + c.redisOpt.MinIdleConns = v + } + + if v := cf.MaxIdleConn; v > 0 { + c.redisOpt.MaxIdleConns = v + } + + if v := cf.MaxActiveConn; v > 0 { + c.redisOpt.MaxActiveConns = v + } + + if v := cf.MaxRetries; v > 0 { + c.redisOpt.MaxRetries = v + } + + if v, err := c.parseTime(cf.DialTimeout); err != nil { + return err + } else if v > 0 { + c.redisOpt.DialTimeout = v + } + + if v, err := c.parseTime(cf.ReadTimeout); err != nil { + return err + } else if v > 0 { + c.redisOpt.ReadTimeout = v + } + + if v, err := c.parseTime(cf.WriteTimeout); err != nil { + return err + } else if v > 0 { + c.redisOpt.WriteTimeout = v + } + + if v, err := c.parseTime(cf.ConnMaxIdleTime); err != nil { + return err + } else if v > 0 { + c.redisOpt.ConnMaxIdleTime = v + } + + if v, err := c.parseTime(cf.PoolTimeout); err != nil { + return err + } else if v > 0 { + c.redisOpt.PoolTimeout = v + } + return nil } -func New(config *Config) (*MyRedis, error) { - var ( - maxConnAge, _ = time.ParseDuration(DefaultConfig.MaxConnAge) - idleTimeout, _ = time.ParseDuration(DefaultConfig.IdleTimeout) - ) +const luaSetOnce = `if redis.call('setnx',KEYS[1],ARGV[1]) == 1 then redis.call('expire',KEYS[1],ARGV[2]) return 1 else return 0 end` - if config.PoolSize <= 0 { - config.MinIdleConn = DefaultConfig.PoolSize +// SetOnce 设置一次并设置过期时间, key不存在则设置成功返回1, key已存在返回0 +func (c *Client) SetOnce(ctx context.Context, key, value string, t time.Duration) (int, error) { + if t < time.Second { + return 0, fmt.Errorf("time must >= 1s") } + return c.Client.Eval(ctx, luaSetOnce, []string{key}, value, t).Int() +} - if config.MinIdleConn <= 0 { - config.MinIdleConn = DefaultConfig.MinIdleConn - } - - if config.MaxConnAge != "" { - t, err := time.ParseDuration(config.MaxConnAge) - if err != nil { - return nil, fmt.Errorf("parse MaxConnAge err: %s\n", err) - - } - maxConnAge = t - } - - if config.IdleTimeout != "" { - t, err := time.ParseDuration(config.IdleTimeout) - if err != nil { - return nil, fmt.Errorf("parse IdleTimeout err: %s\n", err) - - } - idleTimeout = t - } - - client := redis.NewClient(&redis.Options{ - Addr: config.Addr, - Password: config.Password, - DB: config.DB, - PoolSize: config.PoolSize, - MinIdleConns: config.MinIdleConn, - MaxConnAge: maxConnAge, - IdleTimeout: idleTimeout, +func CloseAllClient() { + instanceMap.Range(func(k, v any) bool { + v.(*Client).Close() + return true }) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - defer cancel() - rd := &MyRedis{} - rd.Client = client - ping := rd.Client.Ping(ctx) - if ping.Err() != nil { - return nil, fmt.Errorf("connet redis err: %s", ping.Err()) - } - - return rd, nil -} - -// GetSimple 通用get -func (r *MyRedis) GetSimple(key string) (string, error) { - ctx := context.Background() - return r.Client.Get(ctx, key).Result() -} - -// SetSimple 通用set -func (r *MyRedis) SetSimple(key string, value interface{}, t ...time.Duration) (string, error) { - ctx := context.Background() - var t2 time.Duration - if len(t) > 0 { - t2 = t[0] - } - return r.Client.Set(ctx, key, value, t2).Result() -} - -// GetJson json序列化 -func (r *MyRedis) GetJson(key string, result interface{}) error { - ctx := context.Background() - res, err := r.Client.Get(ctx, key).Bytes() - if err != nil { - return err - } - - err = jsoniter.Unmarshal(res, &result) - if err != nil { - return fmt.Errorf("get key:%s 反序列化json失败(-2)", key) - } - return nil -} - -// SetJson json序列化set -func (r *MyRedis) SetJson(key string, value interface{}, t ...time.Duration) (string, error) { - ctx := context.Background() - - var t2 time.Duration - if len(t) > 0 { - t2 = t[0] - } - v, err := jsoniter.Marshal(value) - if err != nil { - return "", fmt.Errorf("set key:%s 序列化json失败", key) - } - return r.Client.Set(ctx, key, v, t2).Result() -} - -func CloseDB(key string) { - DB(key).Client.Close() -} - -func CloseAllDB() { - for _, v := range instanceMap { - v.Client.Close() - } } diff --git a/myredis/redis_test.go b/myredis/redis_test.go deleted file mode 100644 index 1c71f05..0000000 --- a/myredis/redis_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package myredis - -import ( - "context" - "fmt" - "testing" - "time" -) - -func TestRedis(t *testing.T) { - err := InitDefault(&Config{ - Addr: "192.168.244.128:6379", - Password: "", - DB: 15, - PoolSize: 16, - MinIdleConn: 4, - MaxConnAge: "1h", - IdleTimeout: "10m", - }) - if err != nil { - fmt.Println(err) - return - } - defer CloseAllDB() - - set, err := DB().Set(context.Background(), "name", "qwe123", time.Minute).Result() - if err != nil { - fmt.Println(err) - return - } - fmt.Println(set) - - get, err := DB().Get(context.Background(), "name").Result() - if err != nil { - fmt.Println(err) - return - } - - fmt.Println(get) -} diff --git a/myregistry/consul/builder.go b/myregistry/consul/builder.go new file mode 100644 index 0000000..cfcbf86 --- /dev/null +++ b/myregistry/consul/builder.go @@ -0,0 +1,189 @@ +package consul + +import ( + "context" + "fmt" + "github.com/jpillora/backoff" + "google.golang.org/grpc/grpclog" + "sort" + "time" + + "github.com/hashicorp/consul/api" + "github.com/pkg/errors" + "google.golang.org/grpc/resolver" +) + +// schemeName for the urls +// All target URLs like 'consul://.../...' will be resolved by this resolver +const schemeName = "consul" + +// builder implements resolver.Builder and use for constructing all consul resolvers +type builder struct{} + +func (b *builder) Build(url resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { + tgt, err := parseURL(url.URL.String()) + if err != nil { + return nil, errors.Wrap(err, "Wrong consul URL") + } + cli, err := api.NewClient(tgt.consulConfig()) + if err != nil { + return nil, errors.Wrap(err, "Couldn't connect to the Consul API") + } + + ctx, cancel := context.WithCancel(context.Background()) + pipe := make(chan []string) + go watchConsulService(ctx, cli.Health(), tgt, pipe) + go populateEndpoints(ctx, cc, pipe) + + return &resolvr{cancelFunc: cancel}, nil +} + +// Scheme returns the scheme supported by this resolver. +// Scheme is defined at https://github.com/grpc/grpc/blob/master/doc/naming.md. +func (b *builder) Scheme() string { + return schemeName +} + +// init function needs for auto-register in resolvers registry +func init() { + resolver.Register(&builder{}) +} + +// resolvr implements resolver.Resolver from the gRPC package. +// It watches for endpoints changes and pushes them to the underlying gRPC connection. +type resolvr struct { + cancelFunc context.CancelFunc +} + +// ResolveNow will be skipped due unnecessary in this case +func (r *resolvr) ResolveNow(resolver.ResolveNowOptions) {} + +// Close closes the resolver. +func (r *resolvr) Close() { + r.cancelFunc() +} + +//go:generate ./bin/moq -out mocks_test.go . servicer +type servicer interface { + Service(string, string, bool, *api.QueryOptions) ([]*api.ServiceEntry, *api.QueryMeta, error) +} + +func watchConsulService(ctx context.Context, s servicer, tgt target, out chan<- []string) { + res := make(chan []string) + quit := make(chan struct{}) + bck := &backoff.Backoff{ + Factor: 2, + Jitter: true, + Min: 10 * time.Millisecond, + Max: tgt.MaxBackoff, + } + go func() { + var lastIndex uint64 + for { + ss, meta, err := s.Service( + tgt.Service, + tgt.Tag, + tgt.Healthy, + &api.QueryOptions{ + WaitIndex: lastIndex, + Near: tgt.Near, + WaitTime: tgt.Wait, + Datacenter: tgt.Dc, + AllowStale: tgt.AllowStale, + RequireConsistent: tgt.RequireConsistent, + }, + ) + if err != nil { + // No need to continue if the context is done/cancelled. + // We check that here directly because the check for the closed quit channel + // at the end of the loop is not reached when calling continue here. + select { + case <-quit: + return + default: + grpclog.Errorf("[Consul resolver] Couldn't fetch endpoints. target={%s}; error={%v}", tgt.String(), err) + time.Sleep(bck.Duration()) + continue + } + } + bck.Reset() + lastIndex = meta.LastIndex + grpclog.Infof("[Consul resolver] %d endpoints fetched in(+wait) %s for target={%s}", + len(ss), + meta.RequestTime, + tgt.String(), + ) + + ee := make([]string, 0, len(ss)) + for _, s := range ss { + address := s.Service.Address + if s.Service.Address == "" { + address = s.Node.Address + } + ee = append(ee, fmt.Sprintf("%s:%d", address, s.Service.Port)) + } + + if tgt.Limit != 0 && len(ee) > tgt.Limit { + ee = ee[:tgt.Limit] + } + select { + case res <- ee: + continue + case <-quit: + return + } + } + }() + + for { + // If in the below select both channels have values that can be read, + // Go picks one pseudo-randomly. + // But when the context is canceled we want to act upon it immediately. + if ctx.Err() != nil { + // Close quit so the goroutine returns and doesn't leak. + // Do NOT close res because that can lead to panics in the goroutine. + // res will be garbage collected at some point. + close(quit) + return + } + select { + case ee := <-res: + out <- ee + case <-ctx.Done(): + close(quit) + return + } + } +} + +func populateEndpoints(ctx context.Context, clientConn resolver.ClientConn, input <-chan []string) { + for { + select { + case cc := <-input: + connsSet := make(map[string]struct{}, len(cc)) + for _, c := range cc { + connsSet[c] = struct{}{} + } + conns := make([]resolver.Address, 0, len(connsSet)) + for c := range connsSet { + conns = append(conns, resolver.Address{Addr: c}) + } + sort.Sort(byAddressString(conns)) // Don't replace the same address list in the balancer + err := clientConn.UpdateState(resolver.State{Addresses: conns}) + if err != nil { + grpclog.Errorf("[Consul resolver] Couldn't update client connection. error={%v}", err) + continue + } + case <-ctx.Done(): + grpclog.Info("[Consul resolver] Watch has been finished") + return + } + } +} + +// byAddressString sorts resolver.Address by Address Field sorting in increasing order. +type byAddressString []resolver.Address + +func (p byAddressString) Len() int { return len(p) } +func (p byAddressString) Less(i, j int) bool { return p[i].Addr < p[j].Addr } +func (p byAddressString) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/myregistry/consul/consul.go b/myregistry/consul/consul.go new file mode 100644 index 0000000..ac44409 --- /dev/null +++ b/myregistry/consul/consul.go @@ -0,0 +1,153 @@ +package consul + +import ( + "fmt" + "git.makemake.in/kzkzzzz/mycommon/myconf" + "git.makemake.in/kzkzzzz/mycommon/myregistry" + "github.com/google/uuid" + api "github.com/hashicorp/consul/api" + "log" + "net" + "net/url" + "time" +) + +var _ myregistry.IRegister = (*Consul)(nil) + +type Consul struct { + client *api.Client + serviceIds map[string][]string + serviceTags []string +} + +func (c *Consul) Name() string { + return "consul" +} + +func (c *Consul) Register(service *myregistry.ServiceInfo) error { + // 健康检查 + serviceId := uuid.New().String() + + c.serviceIds[service.ServiceName] = append(c.serviceIds[service.ServiceName], serviceId) + + check := &api.AgentServiceCheck{ + CheckID: serviceId, + TCP: fmt.Sprintf("%s:%d", service.Ip, service.Port), + Timeout: "5s", // 超时时间 + Interval: "20s", // 运行检查的频率 + // 指定时间后自动注销不健康的服务节点 + // 最小超时时间为1分钟,收获不健康服务的进程每30秒运行一次,因此触发注销的时间可能略长于配置的超时时间。 + DeregisterCriticalServiceAfter: "5m", + Status: "passing", + } + srv := &api.AgentServiceRegistration{ + ID: serviceId, // 服务唯一ID + Name: service.ServiceName, // 服务名称 + Tags: c.serviceTags, // 为服务打标签 + Address: service.Ip, + Port: service.Port, + Check: check, + } + + return c.client.Agent().ServiceRegister(srv) +} + +func (c *Consul) Deregister(service *myregistry.ServiceInfo) error { + for _, svcId := range c.serviceIds[service.ServiceName] { + err := c.client.Agent().ServiceDeregister(svcId) + if err != nil { + log.Printf("Failed to deregister service %s: %s\n", service, err) + } + } + return nil +} + +type Conf struct { + Addr string + Token string +} + +func MustNew(conf *myconf.Config) *Consul { + consul, err := New(conf) + if err != nil { + panic(err) + } + return consul +} + +func New(conf *myconf.Config) (*Consul, error) { + cfg := api.DefaultConfig() + cfg.Address = conf.GetString("addr") + cfg.Transport.DialContext = (&net.Dialer{ + Timeout: 3 * time.Second, + KeepAlive: 20 * time.Second, + DualStack: true, + }).DialContext + cfg.Token = conf.GetString("token") + + username := conf.GetString("username") + password := conf.GetString("password") + + if username != "" && password != "" { + cfg.HttpAuth = &api.HttpBasicAuth{ + Username: username, + Password: password, + } + } + + client, err := api.NewClient(cfg) + if err != nil { + return nil, err + } + cl := &Consul{ + client: client, + serviceIds: make(map[string][]string), + serviceTags: make([]string, 0), + } + + if v := conf.GetStringSlice("serviceTags"); len(v) > 0 { + cl.serviceTags = v + } else { + cl.serviceTags = []string{} + } + + return cl, nil +} + +func (c *Consul) Client() *api.Client { + return c.client +} + +func GrpcUrl(serviceName string, conf *myconf.Config) string { + return GrpcUrlWithTag("", serviceName, conf) +} + +func GrpcUrlWithTag(tag string, serviceName string, conf *myconf.Config) string { + u := &url.URL{ + Scheme: schemeName, + Host: conf.GetString("addr"), + Path: serviceName, + } + + query := u.Query() + query.Set("healthy", "true") + + if v := conf.GetString("token"); v != "" { + query.Set("token", v) + } + + if tag != "" { + query.Set("tag", tag) + } + + username := conf.GetString("username") + password := conf.GetString("password") + + if username != "" && password != "" { + u.User = url.UserPassword(username, password) + } + + u.RawQuery = query.Encode() + + return u.String() +} diff --git a/myregistry/consul/target.go b/myregistry/consul/target.go new file mode 100644 index 0000000..703bbad --- /dev/null +++ b/myregistry/consul/target.go @@ -0,0 +1,101 @@ +package consul + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-playground/form" + "github.com/hashicorp/consul/api" + "github.com/pkg/errors" +) + +type target struct { + Addr string `form:"-"` + User string `form:"-"` + Password string `form:"-"` + Service string `form:"-"` + Wait time.Duration `form:"wait"` + Timeout time.Duration `form:"timeout"` + MaxBackoff time.Duration `form:"max-backoff"` + Tag string `form:"tag"` + Near string `form:"near"` + Limit int `form:"limit"` + Healthy bool `form:"healthy"` + TLSInsecure bool `form:"insecure"` + Token string `form:"token"` + Dc string `form:"dc"` + AllowStale bool `form:"allow-stale"` + RequireConsistent bool `form:"require-consistent"` + // TODO(mbobakov): custom parameters for the http-transport + // TODO(mbobakov): custom parameters for the TLS subsystem +} + +func (t *target) String() string { + return fmt.Sprintf("service='%s' healthy='%t' tag='%s'", t.Service, t.Healthy, t.Tag) +} + +// parseURL with parameters +// see README.md for the actual format +// URL schema will stay stable in the future for backward compatibility +func parseURL(u string) (target, error) { + rawURL, err := url.Parse(u) + if err != nil { + return target{}, errors.Wrap(err, "Malformed URL") + } + + if rawURL.Scheme != schemeName || + len(rawURL.Host) == 0 || len(strings.TrimLeft(rawURL.Path, "/")) == 0 { + return target{}, + errors.Errorf("Malformed URL('%s'). Must be in the next format: 'consul://[user:passwd]@host/service?param=value'", u) + } + + var tgt target + tgt.User = rawURL.User.Username() + tgt.Password, _ = rawURL.User.Password() + tgt.Addr = rawURL.Host + tgt.Service = strings.TrimLeft(rawURL.Path, "/") + decoder := form.NewDecoder() + decoder.RegisterCustomTypeFunc(func(vals []string) (interface{}, error) { + return time.ParseDuration(vals[0]) + }, time.Duration(0)) + + err = decoder.Decode(&tgt, rawURL.Query()) + if err != nil { + return target{}, errors.Wrap(err, "Malformed URL parameters") + } + if len(tgt.Near) == 0 { + tgt.Near = "_agent" + } + if tgt.MaxBackoff == 0 { + tgt.MaxBackoff = time.Second + } + return tgt, nil +} + +// consulConfig returns config based on the parsed target. +// It uses custom http-client. +func (t *target) consulConfig() *api.Config { + var creds *api.HttpBasicAuth + if len(t.User) > 0 && len(t.Password) > 0 { + creds = new(api.HttpBasicAuth) + creds.Password = t.Password + creds.Username = t.User + } + // custom http.Client + c := &http.Client{ + Timeout: t.Timeout, + } + return &api.Config{ + Address: t.Addr, + HttpAuth: creds, + WaitTime: t.Wait, + HttpClient: c, + TLSConfig: api.TLSConfig{ + InsecureSkipVerify: t.TLSInsecure, + }, + Token: t.Token, + } +} diff --git a/myregistry/reigster.go b/myregistry/reigster.go new file mode 100644 index 0000000..a5a4c22 --- /dev/null +++ b/myregistry/reigster.go @@ -0,0 +1,21 @@ +package myregistry + +import "fmt" + +type ServiceInfo struct { + ServiceName string + Ip string + Port int + Extend map[string]string +} + +func (s *ServiceInfo) String() string { + return fmt.Sprintf("%s - %s:%d", s.ServiceName, s.Ip, s.Port) +} + +// IRegister 注册中心 服务注册发现 +type IRegister interface { + Name() string + Register(service *ServiceInfo) error + Deregister(service *ServiceInfo) error +}