345 lines
8.7 KiB
Go
345 lines
8.7 KiB
Go
package grpcsr
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"git.makemake.in/kzkzzzz/mycommon"
|
||
"git.makemake.in/kzkzzzz/mycommon/graceful"
|
||
"git.makemake.in/kzkzzzz/mycommon/myconf"
|
||
"git.makemake.in/kzkzzzz/mycommon/mygrpc"
|
||
"git.makemake.in/kzkzzzz/mycommon/mylog"
|
||
"git.makemake.in/kzkzzzz/mycommon/myregistry"
|
||
"github.com/spf13/pflag"
|
||
"google.golang.org/grpc"
|
||
"google.golang.org/grpc/codes"
|
||
"google.golang.org/grpc/health"
|
||
"google.golang.org/grpc/health/grpc_health_v1"
|
||
"google.golang.org/grpc/keepalive"
|
||
"google.golang.org/grpc/reflection"
|
||
"google.golang.org/grpc/status"
|
||
"log"
|
||
"net"
|
||
|
||
"runtime/debug"
|
||
"time"
|
||
)
|
||
|
||
//const DefaultInstanceName = "grpc"
|
||
|
||
var _ graceful.IRunner = (*Server)(nil)
|
||
|
||
type Conf struct {
|
||
Addr string
|
||
Port int
|
||
Ip string
|
||
Log bool
|
||
}
|
||
|
||
type Opt func(server *Server)
|
||
|
||
type Server struct {
|
||
gs *grpc.Server
|
||
serviceId string
|
||
serviceName string
|
||
serverConf *Conf
|
||
reg myregistry.IRegister
|
||
grpcOpts []grpc.ServerOption
|
||
logger mylog.ILogger
|
||
registerGrpcFn func(*grpc.Server) // 注册grpc服务, 使用函数延迟调用, 便于先初始化中间件等操作
|
||
|
||
unaryMiddlewares []grpc.UnaryServerInterceptor // grpc一元服务端中间件
|
||
|
||
useDefaultBufferCfg bool
|
||
delayStopMs int
|
||
|
||
serviceRegInfo *myregistry.ServiceInfo
|
||
}
|
||
|
||
func UseDefaultBufferCfg(v bool) Opt {
|
||
return func(server *Server) {
|
||
server.useDefaultBufferCfg = v
|
||
}
|
||
}
|
||
|
||
func WithRegistry(serviceName string, reg myregistry.IRegister) Opt {
|
||
return func(server *Server) {
|
||
server.serviceName = mygrpc.ServicePrefix + serviceName
|
||
server.reg = reg
|
||
}
|
||
}
|
||
|
||
func WithGrpcOpts(v ...grpc.ServerOption) Opt {
|
||
return func(server *Server) {
|
||
server.grpcOpts = v
|
||
}
|
||
}
|
||
|
||
func WithDelayStopMs(v int) Opt {
|
||
return func(server *Server) {
|
||
server.delayStopMs = v
|
||
}
|
||
}
|
||
|
||
func SetFlag() {
|
||
pflag.Int("grpc.port", 0, "listen port, 0 is random port")
|
||
pflag.String("grpc.log", "true", "enable request log")
|
||
}
|
||
|
||
func New(cfg *myconf.Config, opts ...Opt) *Server {
|
||
cf := &Conf{}
|
||
err := cfg.UnmarshalKey("grpc", cf)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
// 命令行的参数覆盖一次, Unmarshal解析的时候, 不会用命令行的参数覆盖 https://github.com/spf13/viper/issues/190
|
||
cf.Port = cfg.GetInt(fmt.Sprintf("grpc.port"))
|
||
cf.Log = cfg.GetBool(fmt.Sprintf("grpc.log"))
|
||
|
||
return NewByConf(cf, opts...)
|
||
}
|
||
|
||
func NewByConf(conf *Conf, opts ...Opt) *Server {
|
||
s := &Server{
|
||
serverConf: conf,
|
||
useDefaultBufferCfg: true,
|
||
}
|
||
for _, opt := range opts {
|
||
opt(s)
|
||
}
|
||
|
||
if s.logger == nil {
|
||
s.logger = mylog.GetLoggerSkip(-1)
|
||
}
|
||
|
||
if s.reg != nil && s.serviceName == "" {
|
||
panic("service name is empty")
|
||
}
|
||
|
||
s.unaryMiddlewares = []grpc.UnaryServerInterceptor{
|
||
s.grpcRecover(), // 默认启用recover中间件
|
||
}
|
||
|
||
if s.serverConf.Log {
|
||
s.unaryMiddlewares = append(s.unaryMiddlewares, s.requestLog())
|
||
}
|
||
|
||
return s
|
||
}
|
||
|
||
func (s *Server) Use(middlewares ...grpc.UnaryServerInterceptor) {
|
||
s.unaryMiddlewares = append(s.unaryMiddlewares, middlewares...)
|
||
}
|
||
|
||
func (s *Server) RegisterGrpc(fn func(*grpc.Server)) {
|
||
s.registerGrpcFn = fn
|
||
}
|
||
|
||
func (s *Server) initServer() {
|
||
grpcOpts := []grpc.ServerOption{
|
||
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
|
||
MinTime: time.Second * 5, // 如果客户端两次 ping 的间隔小于 N,则关闭连接
|
||
PermitWithoutStream: true, // 即使没有 active stream, 也允许 ping
|
||
}),
|
||
grpc.KeepaliveParams(keepalive.ServerParameters{
|
||
MaxConnectionIdle: time.Minute * 15, // 空闲连接时间
|
||
MaxConnectionAgeGrace: time.Second * 30, // 在强制关闭连接之间, 允许有 N 的时间完成 pending 的 rpc 请求
|
||
Time: time.Second * 20, // 如果一个连接空闲超过 N, 则发送一个 ping 请求
|
||
Timeout: time.Second * 5, // 如果 ping 请求 N 内未收到回复, 则认为该连接已断开
|
||
}),
|
||
}
|
||
|
||
if s.useDefaultBufferCfg {
|
||
grpcOpts = append(grpcOpts,
|
||
grpc.InitialWindowSize(mygrpc.DefaultWindowSize),
|
||
grpc.InitialConnWindowSize(mygrpc.DefaultWindowSize),
|
||
|
||
grpc.ReadBufferSize(mygrpc.DefaultReadBufferSize),
|
||
grpc.WriteBufferSize(mygrpc.DefaultWriteBufferSize),
|
||
)
|
||
}
|
||
|
||
if len(s.unaryMiddlewares) > 0 {
|
||
grpcOpts = append(grpcOpts, grpc.ChainUnaryInterceptor(s.unaryMiddlewares...))
|
||
}
|
||
|
||
if len(s.grpcOpts) > 0 {
|
||
grpcOpts = append(grpcOpts, s.grpcOpts...)
|
||
}
|
||
|
||
s.gs = grpc.NewServer(grpcOpts...)
|
||
|
||
// 注册grpc服务
|
||
s.registerGrpcFn(s.gs)
|
||
}
|
||
|
||
func (s *Server) Run(ctx context.Context) error {
|
||
s.initServer()
|
||
|
||
// 端口如果=0, 监听随机端口
|
||
addr0 := fmt.Sprintf("%s:%d", s.serverConf.Addr, s.serverConf.Port)
|
||
lis, err := net.Listen("tcp", addr0)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取监听的端口
|
||
port := lis.Addr().(*net.TCPAddr).Port
|
||
|
||
// 健康服务
|
||
healthServer := health.NewServer()
|
||
grpc_health_v1.RegisterHealthServer(s.gs, healthServer)
|
||
|
||
// 服务反射, 方便调试
|
||
reflection.Register(s.gs)
|
||
|
||
var svcIp = s.serverConf.Ip
|
||
if svcIp == "" {
|
||
svcIp = mycommon.GetOutboundIP()
|
||
}
|
||
|
||
// 注册服务
|
||
if s.reg != nil {
|
||
s.serviceRegInfo = &myregistry.ServiceInfo{
|
||
ServiceName: s.serviceName,
|
||
Ip: svcIp,
|
||
Port: port,
|
||
}
|
||
|
||
err = s.reg.Register(s.serviceRegInfo)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
log.Printf("[%s] register service: %s - %s:%d",
|
||
s.reg.Name(),
|
||
s.serviceRegInfo.ServiceName,
|
||
s.serviceRegInfo.Ip, s.serviceRegInfo.Port,
|
||
)
|
||
}
|
||
|
||
addr := fmt.Sprintf("%s:%d", s.serverConf.Addr, port)
|
||
log.Printf("grpc server listen on %s", addr)
|
||
|
||
err = s.gs.Serve(lis)
|
||
if err != nil {
|
||
log.Printf("start grpc server err: %s", err)
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *Server) Stop() {
|
||
if s.reg != nil {
|
||
err := s.reg.Deregister(s.serviceRegInfo)
|
||
if err != nil {
|
||
s.logger.Errorf("grpc server deregister err: %s", err)
|
||
}
|
||
|
||
log.Printf("[%s] deregister service: %s - %s:%d",
|
||
s.reg.Name(),
|
||
s.serviceRegInfo.ServiceName,
|
||
s.serviceRegInfo.Ip, s.serviceRegInfo.Port,
|
||
)
|
||
|
||
}
|
||
|
||
// 如果使用k8s service, 关闭pod和往service注销ip是同时进行的, 如果退出服务比注销ip先完成, 可能有流量继续进来, 导致请求失败
|
||
// 延迟一段时间, 确保服务已经注销ip, 再关闭服务
|
||
|
||
// 如何使用注册中心, 先从中心退出ip, 也延迟一段时间, 等上游网关更新ip完成(正常不会太久), 不会有流量进来旧服务, 再退出服务
|
||
if s.delayStopMs > 0 {
|
||
delayTime := time.Millisecond * time.Duration(s.delayStopMs)
|
||
log.Printf("grpc server delay stop: %s", delayTime)
|
||
time.Sleep(delayTime)
|
||
}
|
||
|
||
s.gs.GracefulStop()
|
||
|
||
log.Printf("grpc server stop")
|
||
}
|
||
|
||
type handleResp struct {
|
||
resp interface{}
|
||
err error
|
||
}
|
||
|
||
func (s *Server) requestLog() grpc.UnaryServerInterceptor {
|
||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||
start := time.Now()
|
||
|
||
resp, err := handler(ctx, req)
|
||
|
||
var (
|
||
code codes.Code
|
||
codeMsg = "OK"
|
||
)
|
||
if err != nil {
|
||
fromError, ok := status.FromError(err)
|
||
if ok {
|
||
code = fromError.Code()
|
||
} else {
|
||
code = status.New(codes.Unknown, err.Error()).Code()
|
||
}
|
||
|
||
codeMsg = fmt.Sprintf("%s(%d)", code.String(), uint32(code))
|
||
}
|
||
|
||
s.logger.Infof(
|
||
"%s - %s - %s - %s",
|
||
codeMsg, time.Since(start), mygrpc.ClientIP(ctx), info.FullMethod,
|
||
)
|
||
|
||
return resp, err
|
||
}
|
||
}
|
||
|
||
func (s *Server) grpcRecover() grpc.UnaryServerInterceptor {
|
||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||
defer func() {
|
||
if err0 := recover(); err0 != nil {
|
||
log.Printf("%s - panic: %v\n%s", info.FullMethod, err0, debug.Stack())
|
||
err = fmt.Errorf("server err %s - %s", info.FullMethod, err0)
|
||
}
|
||
}()
|
||
|
||
return handler(ctx, req)
|
||
}
|
||
}
|
||
|
||
// Timeout 控制服务端超时
|
||
func Timeout(timeout time.Duration) grpc.UnaryServerInterceptor {
|
||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||
tCtx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
|
||
ch := make(chan *handleResp, 1)
|
||
|
||
go func() {
|
||
defer func() {
|
||
if err0 := recover(); err0 != nil {
|
||
log.Printf("%s - panic: %v\n%s", info.FullMethod, err0, debug.Stack())
|
||
ch <- &handleResp{nil, fmt.Errorf("server err: %s - system err: %s", info.FullMethod, err0)}
|
||
}
|
||
}()
|
||
//start := time.Now()
|
||
r := &handleResp{}
|
||
r.resp, r.err = handler(tCtx, req)
|
||
//log.Printf("rta server time: %s", time.Since(start))
|
||
ch <- r
|
||
}()
|
||
|
||
select {
|
||
case <-tCtx.Done():
|
||
return nil, mygrpc.GrpcServerTimeout("server err: grpc handle timeout %s %s", timeout, info.FullMethod)
|
||
//return nil, status.Errorf(codes.DeadlineExceeded, "grpc handle timeout %s %s", timeout, info.FullMethod)
|
||
|
||
case res := <-ch:
|
||
return res.resp, res.err
|
||
}
|
||
|
||
//return nil, fmt.Errorf("handle err %s.%s", info.Server, info.FullMethod)
|
||
}
|
||
}
|