498 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			498 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Go
		
	
	
package web
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"git.makemake.in/kzkzzzz/mycommon/mylog"
 | 
						|
	"github.com/gin-gonic/gin"
 | 
						|
	"github.com/gin-gonic/gin/binding"
 | 
						|
	"github.com/go-playground/locales/en"
 | 
						|
	ut "github.com/go-playground/universal-translator"
 | 
						|
	"github.com/go-playground/validator/v10"
 | 
						|
	enTr "github.com/go-playground/validator/v10/translations/en"
 | 
						|
	"github.com/spf13/cast"
 | 
						|
	"gorm.io/gorm"
 | 
						|
	"log"
 | 
						|
	"net/http"
 | 
						|
	"net/netip"
 | 
						|
	"proxyport/app/db"
 | 
						|
	"proxyport/app/forward"
 | 
						|
	"proxyport/app/model"
 | 
						|
	"proxyport/app/static"
 | 
						|
	"reflect"
 | 
						|
	"regexp"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
var Config = &Cfg{}
 | 
						|
 | 
						|
type Cfg struct {
 | 
						|
	ListenAddr string
 | 
						|
	User       string
 | 
						|
	Password   string
 | 
						|
	WebName    string
 | 
						|
}
 | 
						|
 | 
						|
func Start(ctx context.Context) {
 | 
						|
	InitGinTrans()
 | 
						|
 | 
						|
	engine := gin.Default()
 | 
						|
 | 
						|
	if Config.User != "" && Config.Password != "" {
 | 
						|
		engine.Use(gin.BasicAuth(gin.Accounts{
 | 
						|
			Config.User: Config.Password,
 | 
						|
		}))
 | 
						|
	}
 | 
						|
 | 
						|
	static.LoadStatic(engine)
 | 
						|
	engine.Use(cors())
 | 
						|
 | 
						|
	engine.GET("/", func(ctx *gin.Context) {
 | 
						|
		ctx.HTML(http.StatusOK, "index.html", nil)
 | 
						|
	})
 | 
						|
 | 
						|
	engine.GET("/GetConfig", GetConfig)
 | 
						|
	engine.GET("/List", List)
 | 
						|
	engine.POST("/Create", Create)
 | 
						|
	engine.POST("/Update", Update)
 | 
						|
	engine.POST("/Delete", Delete)
 | 
						|
	engine.POST("/SwitchStatus", SwitchStatus)
 | 
						|
 | 
						|
	hs := &http.Server{
 | 
						|
		Addr:    Config.ListenAddr,
 | 
						|
		Handler: engine,
 | 
						|
	}
 | 
						|
 | 
						|
	go func() {
 | 
						|
		log.Printf("http listen on %s", hs.Addr)
 | 
						|
		mylog.Warn(hs.ListenAndServe())
 | 
						|
	}()
 | 
						|
 | 
						|
	<-ctx.Done()
 | 
						|
	hs.Close()
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
func cors() gin.HandlerFunc {
 | 
						|
	return func(ctx *gin.Context) {
 | 
						|
		ctx.Header("Access-Control-Allow-Origin", "*")
 | 
						|
		ctx.Header("Access-Control-Allow-Methods", "GET, POST")
 | 
						|
 | 
						|
		ctx.Header("Access-Control-Allow-Headers", "Content-Type")
 | 
						|
 | 
						|
		if ctx.Request.Method == http.MethodOptions {
 | 
						|
			ctx.AbortWithStatus(http.StatusOK)
 | 
						|
		} else {
 | 
						|
			ctx.Next()
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func GetConfig(ctx *gin.Context) {
 | 
						|
	success(ctx, gin.H{
 | 
						|
		"web_name": Config.WebName,
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func List(ctx *gin.Context) {
 | 
						|
	data := make([]*model.Forward, 0)
 | 
						|
 | 
						|
	err := db.DB().Table("forward").Select("*").Order("update_time desc").Find(&data).Error
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
	success(ctx, data)
 | 
						|
}
 | 
						|
 | 
						|
func Create(ctx *gin.Context) {
 | 
						|
	type Req struct {
 | 
						|
		TargetAddr  string `json:"target_addr" binding:"min=1"`
 | 
						|
		LocalPort   int    `json:"local_port" binding:"required"`
 | 
						|
		Name        string `json:"name" binding:"required"`
 | 
						|
		ForwardType int    `json:"protocol" binding:"oneof=0 1"`
 | 
						|
		Status      int    `json:"status" binding:"oneof=0 1"`
 | 
						|
	}
 | 
						|
 | 
						|
	req := &Req{}
 | 
						|
	err := ctx.ShouldBindJSON(req)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	err = checkAddr(fmt.Sprintf("0.0.0.0:%d", req.LocalPort))
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	targetAddr, err := parseAddrList(req.TargetAddr)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	err = checkUnique(0, req.LocalPort, req.ForwardType)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	now := time.Now()
 | 
						|
 | 
						|
	err = db.DB().Transaction(func(tx *gorm.DB) error {
 | 
						|
		mForward := &model.Forward{
 | 
						|
			LocalPort:  req.LocalPort,
 | 
						|
			TargetAddr: targetAddr,
 | 
						|
			Name:       req.Name,
 | 
						|
			Protocol:   req.ForwardType,
 | 
						|
			Status:     1,
 | 
						|
			CreateTime: now.Format(time.DateTime),
 | 
						|
			UpdateTime: now.Format(time.DateTime),
 | 
						|
		}
 | 
						|
 | 
						|
		err := tx.Table("forward").Create(mForward).Error
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		err = forward.ListenerManager.Add(mForward)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	success(ctx, "ok")
 | 
						|
}
 | 
						|
 | 
						|
func Update(ctx *gin.Context) {
 | 
						|
	type Req struct {
 | 
						|
		Id         int    `json:"id" binding:"required"`
 | 
						|
		TargetAddr string `json:"target_addr" binding:"min=1"`
 | 
						|
		LocalPort  int    `json:"local_port" binding:"required"`
 | 
						|
		Name       string `json:"name" binding:"required"`
 | 
						|
		Protocol   int    `json:"protocol" binding:"oneof=0 1"`
 | 
						|
		Status     int    `json:"status" binding:"oneof=0 1"`
 | 
						|
	}
 | 
						|
 | 
						|
	req := &Req{}
 | 
						|
	err := ctx.ShouldBindJSON(req)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	err = checkAddr(fmt.Sprintf("0.0.0.0:%d", req.LocalPort))
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	targetAddr, err := parseAddrList(req.TargetAddr)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	err = checkUnique(req.Id, req.LocalPort, req.Protocol)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	current, err := getForwardById(req.Id)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	err = db.DB().Transaction(func(tx *gorm.DB) error {
 | 
						|
		mForward := &model.Forward{
 | 
						|
			LocalPort:  req.LocalPort,
 | 
						|
			TargetAddr: targetAddr,
 | 
						|
			Name:       req.Name,
 | 
						|
			Protocol:   req.Protocol,
 | 
						|
			UpdateTime: time.Now().Format(time.DateTime),
 | 
						|
		}
 | 
						|
 | 
						|
		err = tx.Table("forward").Select("*").Omit("create_time", "status").Where("id = ?", req.Id).
 | 
						|
			Limit(1).
 | 
						|
			Updates(mForward).Error
 | 
						|
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		forward.ListenerManager.Remove(current)
 | 
						|
 | 
						|
		if current.Status == 1 {
 | 
						|
			err = forward.ListenerManager.Add(mForward)
 | 
						|
			if err != nil {
 | 
						|
				return err
 | 
						|
			}
 | 
						|
		}
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	success(ctx, "ok")
 | 
						|
}
 | 
						|
 | 
						|
func Delete(ctx *gin.Context) {
 | 
						|
	type Req struct {
 | 
						|
		Id int `json:"id" binding:"required"`
 | 
						|
	}
 | 
						|
 | 
						|
	req := &Req{}
 | 
						|
	err := ctx.ShouldBindJSON(req)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	mForward, err := getForwardById(req.Id)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	err = db.DB().Table("forward").Select("*").Where("id = ?", req.Id).Limit(1).Delete(nil).Error
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	forward.ListenerManager.Remove(mForward)
 | 
						|
 | 
						|
	success(ctx, "ok")
 | 
						|
}
 | 
						|
 | 
						|
var (
 | 
						|
	hostReg   = regexp.MustCompile(`^[a-zA-Z0-9]+[a-zA-Z0-9\-.]*[a-zA-Z]+$`)
 | 
						|
	letterReg = regexp.MustCompile(`[a-zA-Z\-]`)
 | 
						|
	emptyReg  = regexp.MustCompile(`\s+`)
 | 
						|
)
 | 
						|
 | 
						|
func checkAddr(addr string) error {
 | 
						|
	addr = strings.TrimSpace(addr)
 | 
						|
 | 
						|
	if emptyReg.MatchString(addr) {
 | 
						|
		return fmt.Errorf("addr format err: %s", addr)
 | 
						|
	}
 | 
						|
 | 
						|
	sp := strings.SplitN(addr, ":", 2)
 | 
						|
 | 
						|
	if len(sp) != 2 {
 | 
						|
		return fmt.Errorf("addr format err: %s", addr)
 | 
						|
	}
 | 
						|
 | 
						|
	host := sp[0]
 | 
						|
	port := sp[1]
 | 
						|
	intPort := cast.ToInt(port)
 | 
						|
 | 
						|
	if intPort > 65535 || intPort < 10 {
 | 
						|
		return fmt.Errorf("port [%s] out of range: 10 - 65535", port)
 | 
						|
	}
 | 
						|
 | 
						|
	if !letterReg.MatchString(host) {
 | 
						|
		_, err := netip.ParseAddrPort(addr)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("ip [%s] format err: %s", host, err)
 | 
						|
		}
 | 
						|
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	if !hostReg.MatchString(host) {
 | 
						|
		return fmt.Errorf("host [%s] format err", host)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func checkUnique(id, localPort, forwardType int) error {
 | 
						|
	query := db.DB().Table("forward").
 | 
						|
		Where("local_port = ?", localPort).
 | 
						|
		Where("protocol = ?", forwardType)
 | 
						|
 | 
						|
	if id > 0 {
 | 
						|
		query.Where("id != ?", id)
 | 
						|
	}
 | 
						|
 | 
						|
	var res []int
 | 
						|
	err := query.Select("id").Limit(1).Find(&res).Error
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	if len(res) > 0 {
 | 
						|
		return fmt.Errorf("%s port: %d already use", forward.Protocol(forwardType).String(), localPort)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func getForwardById(id int) (*model.Forward, error) {
 | 
						|
	if id <= 0 {
 | 
						|
		return nil, fmt.Errorf("id err: %d", id)
 | 
						|
	}
 | 
						|
	mForward := &model.Forward{}
 | 
						|
	err := db.DB().Table("forward").Select("*").Where("id = ?", id).First(mForward).Error
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return mForward, nil
 | 
						|
}
 | 
						|
 | 
						|
func SwitchStatus(ctx *gin.Context) {
 | 
						|
	type Req struct {
 | 
						|
		Id     int `json:"id" binding:"required"`
 | 
						|
		Status int `json:"status" binding:"oneof=0 1"`
 | 
						|
	}
 | 
						|
 | 
						|
	req := &Req{}
 | 
						|
	err := ctx.ShouldBindJSON(req)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	mForward, err := getForwardById(req.Id)
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	err = db.DB().Transaction(func(tx *gorm.DB) error {
 | 
						|
		err = tx.Table("forward").Where("id = ?", req.Id).Updates(map[string]any{
 | 
						|
			"status": req.Status,
 | 
						|
		}).Limit(1).Error
 | 
						|
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		switch req.Status {
 | 
						|
		case 1:
 | 
						|
			err = forward.ListenerManager.Add(mForward)
 | 
						|
			if err != nil {
 | 
						|
				return err
 | 
						|
			}
 | 
						|
		default:
 | 
						|
			forward.ListenerManager.Remove(mForward)
 | 
						|
		}
 | 
						|
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		fail(ctx, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	success(ctx, "ok")
 | 
						|
}
 | 
						|
 | 
						|
type apiRes struct {
 | 
						|
	Code    int    `json:"code"`
 | 
						|
	Message string `json:"message"`
 | 
						|
	Data    any    `json:"data"`
 | 
						|
}
 | 
						|
 | 
						|
func success(ctx *gin.Context, data any) {
 | 
						|
	ctx.JSON(http.StatusOK, &apiRes{
 | 
						|
		Code:    0,
 | 
						|
		Message: "ok",
 | 
						|
		Data:    data,
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
var ginTrans ut.Translator
 | 
						|
 | 
						|
func InitGinTrans() {
 | 
						|
	v := binding.Validator.Engine().(*validator.Validate)
 | 
						|
	v.RegisterTagNameFunc(func(fld reflect.StructField) string {
 | 
						|
		nameSp := strings.SplitN(fld.Tag.Get("json"), ",", 2)
 | 
						|
		if len(nameSp) == 0 {
 | 
						|
			return fld.Name
 | 
						|
		}
 | 
						|
 | 
						|
		name := nameSp[0]
 | 
						|
		if name == "-" {
 | 
						|
			return ""
 | 
						|
		}
 | 
						|
 | 
						|
		if name != "" {
 | 
						|
			return name
 | 
						|
		}
 | 
						|
 | 
						|
		return fld.Name
 | 
						|
	})
 | 
						|
	enT := en.New()
 | 
						|
 | 
						|
	uni := ut.New(enT, enT)
 | 
						|
	tr, _ := uni.GetTranslator("en")
 | 
						|
	_ = enTr.RegisterDefaultTranslations(v, tr)
 | 
						|
 | 
						|
	ginTrans = tr
 | 
						|
}
 | 
						|
 | 
						|
func fail(ctx *gin.Context, err error) {
 | 
						|
	res := &apiRes{Code: 1}
 | 
						|
 | 
						|
	switch ve := err.(type) {
 | 
						|
	case validator.ValidationErrors:
 | 
						|
		if len(ve) > 0 {
 | 
						|
			res.Message = ve[0].Translate(ginTrans)
 | 
						|
		} else {
 | 
						|
			res.Message = ve.Error()
 | 
						|
		}
 | 
						|
 | 
						|
	default:
 | 
						|
		res.Message = err.Error()
 | 
						|
	}
 | 
						|
 | 
						|
	ctx.JSON(http.StatusOK, res)
 | 
						|
}
 | 
						|
 | 
						|
var emptySplitReg = regexp.MustCompile(`\s*\n\s*`)
 | 
						|
 | 
						|
func parseAddrList(addrStr string) ([]string, error) {
 | 
						|
	sp := emptySplitReg.Split(addrStr, -1)
 | 
						|
	if len(sp) == 0 {
 | 
						|
		return nil, fmt.Errorf("addr list is empty")
 | 
						|
	}
 | 
						|
 | 
						|
	res := make([]string, 0)
 | 
						|
	for i, _ := range sp {
 | 
						|
		sp[i] = strings.TrimSpace(sp[i])
 | 
						|
 | 
						|
		if sp[i] == "" {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		err := checkAddr(sp[i])
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		res = append(res, sp[i])
 | 
						|
	}
 | 
						|
 | 
						|
	return res, nil
 | 
						|
}
 |