package web import ( "context" "fmt" "git.makemake.in/kzkzzzz/mycommon/mylog" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/go-playground/locales/en" ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" enTr "github.com/go-playground/validator/v10/translations/en" "github.com/spf13/cast" "gorm.io/gorm" "log" "net/http" "net/netip" "proxyport/app/db" "proxyport/app/forward" "proxyport/app/model" "proxyport/app/static" "reflect" "regexp" "strings" "time" ) var Config = &Cfg{} type Cfg struct { ListenAddr string User string Password string WebName string } func Start(ctx context.Context) { InitGinTrans() engine := gin.Default() if Config.User != "" && Config.Password != "" { engine.Use(gin.BasicAuth(gin.Accounts{ Config.User: Config.Password, })) } static.LoadStatic(engine) engine.Use(cors()) engine.GET("/", func(ctx *gin.Context) { ctx.HTML(http.StatusOK, "index.html", nil) }) engine.GET("/GetConfig", GetConfig) engine.GET("/List", List) engine.POST("/Create", Create) engine.POST("/Update", Update) engine.POST("/Delete", Delete) engine.POST("/SwitchStatus", SwitchStatus) hs := &http.Server{ Addr: Config.ListenAddr, Handler: engine, } go func() { log.Printf("http listen on %s", hs.Addr) mylog.Warn(hs.ListenAndServe()) }() <-ctx.Done() hs.Close() } func cors() gin.HandlerFunc { return func(ctx *gin.Context) { ctx.Header("Access-Control-Allow-Origin", "*") ctx.Header("Access-Control-Allow-Methods", "GET, POST") ctx.Header("Access-Control-Allow-Headers", "Content-Type") if ctx.Request.Method == http.MethodOptions { ctx.AbortWithStatus(http.StatusOK) } else { ctx.Next() } } } func GetConfig(ctx *gin.Context) { success(ctx, gin.H{ "web_name": Config.WebName, }) } func List(ctx *gin.Context) { data := make([]*model.Forward, 0) err := db.DB().Table("forward").Select("*").Order("update_time desc").Find(&data).Error if err != nil { fail(ctx, err) return } success(ctx, data) } func Create(ctx *gin.Context) { type Req struct { TargetAddr string `json:"target_addr" binding:"min=1"` LocalPort int `json:"local_port" binding:"required"` Name string `json:"name" binding:"required"` ForwardType int `json:"protocol" binding:"oneof=0 1"` Status int `json:"status" binding:"oneof=0 1"` } req := &Req{} err := ctx.ShouldBindJSON(req) if err != nil { fail(ctx, err) return } err = checkAddr(fmt.Sprintf("0.0.0.0:%d", req.LocalPort)) if err != nil { fail(ctx, err) return } targetAddr, err := parseAddrList(req.TargetAddr) if err != nil { fail(ctx, err) return } err = checkUnique(0, req.LocalPort, req.ForwardType) if err != nil { fail(ctx, err) return } now := time.Now() err = db.DB().Transaction(func(tx *gorm.DB) error { mForward := &model.Forward{ LocalPort: req.LocalPort, TargetAddr: targetAddr, Name: req.Name, Protocol: req.ForwardType, Status: 1, CreateTime: now.Format(time.DateTime), UpdateTime: now.Format(time.DateTime), } err := tx.Table("forward").Create(mForward).Error if err != nil { return err } err = forward.ListenerManager.Add(mForward) if err != nil { return err } return nil }) if err != nil { fail(ctx, err) return } success(ctx, "ok") } func Update(ctx *gin.Context) { type Req struct { Id int `json:"id" binding:"required"` TargetAddr string `json:"target_addr" binding:"min=1"` LocalPort int `json:"local_port" binding:"required"` Name string `json:"name" binding:"required"` Protocol int `json:"protocol" binding:"oneof=0 1"` Status int `json:"status" binding:"oneof=0 1"` } req := &Req{} err := ctx.ShouldBindJSON(req) if err != nil { fail(ctx, err) return } err = checkAddr(fmt.Sprintf("0.0.0.0:%d", req.LocalPort)) if err != nil { fail(ctx, err) return } targetAddr, err := parseAddrList(req.TargetAddr) if err != nil { fail(ctx, err) return } err = checkUnique(req.Id, req.LocalPort, req.Protocol) if err != nil { fail(ctx, err) return } current, err := getForwardById(req.Id) if err != nil { fail(ctx, err) return } err = db.DB().Transaction(func(tx *gorm.DB) error { mForward := &model.Forward{ LocalPort: req.LocalPort, TargetAddr: targetAddr, Name: req.Name, Protocol: req.Protocol, UpdateTime: time.Now().Format(time.DateTime), } err = tx.Table("forward").Select("*").Omit("create_time", "status").Where("id = ?", req.Id). Limit(1). Updates(mForward).Error if err != nil { return err } forward.ListenerManager.Remove(current) if current.Status == 1 { err = forward.ListenerManager.Add(mForward) if err != nil { return err } } return nil }) if err != nil { fail(ctx, err) return } success(ctx, "ok") } func Delete(ctx *gin.Context) { type Req struct { Id int `json:"id" binding:"required"` } req := &Req{} err := ctx.ShouldBindJSON(req) if err != nil { fail(ctx, err) return } mForward, err := getForwardById(req.Id) if err != nil { fail(ctx, err) return } err = db.DB().Table("forward").Select("*").Where("id = ?", req.Id).Limit(1).Delete(nil).Error if err != nil { fail(ctx, err) return } forward.ListenerManager.Remove(mForward) success(ctx, "ok") } var ( hostReg = regexp.MustCompile(`^[a-zA-Z0-9]+[a-zA-Z0-9\-.]*[a-zA-Z]+$`) letterReg = regexp.MustCompile(`[a-zA-Z\-]`) emptyReg = regexp.MustCompile(`\s+`) ) func checkAddr(addr string) error { addr = strings.TrimSpace(addr) if emptyReg.MatchString(addr) { return fmt.Errorf("addr format err: %s", addr) } sp := strings.SplitN(addr, ":", 2) if len(sp) != 2 { return fmt.Errorf("addr format err: %s", addr) } host := sp[0] port := sp[1] intPort := cast.ToInt(port) if intPort > 65535 || intPort < 10 { return fmt.Errorf("port [%s] out of range: 10 - 65535", port) } if !letterReg.MatchString(host) { _, err := netip.ParseAddrPort(addr) if err != nil { return fmt.Errorf("ip [%s] format err: %s", host, err) } return nil } if !hostReg.MatchString(host) { return fmt.Errorf("host [%s] format err", host) } return nil } func checkUnique(id, localPort, forwardType int) error { query := db.DB().Table("forward"). Where("local_port = ?", localPort). Where("protocol = ?", forwardType) if id > 0 { query.Where("id != ?", id) } var res []int err := query.Select("id").Limit(1).Find(&res).Error if err != nil { return err } if len(res) > 0 { return fmt.Errorf("%s port: %d already use", forward.Protocol(forwardType).String(), localPort) } return nil } func getForwardById(id int) (*model.Forward, error) { if id <= 0 { return nil, fmt.Errorf("id err: %d", id) } mForward := &model.Forward{} err := db.DB().Table("forward").Select("*").Where("id = ?", id).First(mForward).Error if err != nil { return nil, err } return mForward, nil } func SwitchStatus(ctx *gin.Context) { type Req struct { Id int `json:"id" binding:"required"` Status int `json:"status" binding:"oneof=0 1"` } req := &Req{} err := ctx.ShouldBindJSON(req) if err != nil { fail(ctx, err) return } mForward, err := getForwardById(req.Id) if err != nil { fail(ctx, err) return } err = db.DB().Transaction(func(tx *gorm.DB) error { err = tx.Table("forward").Where("id = ?", req.Id).Updates(map[string]any{ "status": req.Status, }).Limit(1).Error if err != nil { return err } switch req.Status { case 1: err = forward.ListenerManager.Add(mForward) if err != nil { return err } default: forward.ListenerManager.Remove(mForward) } return nil }) if err != nil { fail(ctx, err) return } success(ctx, "ok") } type apiRes struct { Code int `json:"code"` Message string `json:"message"` Data any `json:"data"` } func success(ctx *gin.Context, data any) { ctx.JSON(http.StatusOK, &apiRes{ Code: 0, Message: "ok", Data: data, }) } var ginTrans ut.Translator func InitGinTrans() { v := binding.Validator.Engine().(*validator.Validate) v.RegisterTagNameFunc(func(fld reflect.StructField) string { nameSp := strings.SplitN(fld.Tag.Get("json"), ",", 2) if len(nameSp) == 0 { return fld.Name } name := nameSp[0] if name == "-" { return "" } if name != "" { return name } return fld.Name }) enT := en.New() uni := ut.New(enT, enT) tr, _ := uni.GetTranslator("en") _ = enTr.RegisterDefaultTranslations(v, tr) ginTrans = tr } func fail(ctx *gin.Context, err error) { res := &apiRes{Code: 1} switch ve := err.(type) { case validator.ValidationErrors: if len(ve) > 0 { res.Message = ve[0].Translate(ginTrans) } else { res.Message = ve.Error() } default: res.Message = err.Error() } ctx.JSON(http.StatusOK, res) } var emptySplitReg = regexp.MustCompile(`\s*\n\s*`) func parseAddrList(addrStr string) ([]string, error) { sp := emptySplitReg.Split(addrStr, -1) if len(sp) == 0 { return nil, fmt.Errorf("addr list is empty") } res := make([]string, 0) for i, _ := range sp { sp[i] = strings.TrimSpace(sp[i]) if sp[i] == "" { continue } err := checkAddr(sp[i]) if err != nil { return nil, err } res = append(res, sp[i]) } return res, nil }