498 lines
9.2 KiB
Go
498 lines
9.2 KiB
Go
package web
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"git.makemake.in/kzkzzzz/mycommon/mylog"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin/binding"
|
|
"github.com/go-playground/locales/en"
|
|
ut "github.com/go-playground/universal-translator"
|
|
"github.com/go-playground/validator/v10"
|
|
enTr "github.com/go-playground/validator/v10/translations/en"
|
|
"github.com/spf13/cast"
|
|
"gorm.io/gorm"
|
|
"log"
|
|
"net/http"
|
|
"net/netip"
|
|
"proxyport/app/db"
|
|
"proxyport/app/forward"
|
|
"proxyport/app/model"
|
|
"proxyport/app/static"
|
|
"reflect"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var Config = &Cfg{}
|
|
|
|
type Cfg struct {
|
|
ListenAddr string
|
|
User string
|
|
Password string
|
|
WebName string
|
|
}
|
|
|
|
func Start(ctx context.Context) {
|
|
InitGinTrans()
|
|
|
|
engine := gin.Default()
|
|
|
|
if Config.User != "" && Config.Password != "" {
|
|
engine.Use(gin.BasicAuth(gin.Accounts{
|
|
Config.User: Config.Password,
|
|
}))
|
|
}
|
|
|
|
static.LoadStatic(engine)
|
|
engine.Use(cors())
|
|
|
|
engine.GET("/", func(ctx *gin.Context) {
|
|
ctx.HTML(http.StatusOK, "index.html", nil)
|
|
})
|
|
|
|
engine.GET("/GetConfig", GetConfig)
|
|
engine.GET("/List", List)
|
|
engine.POST("/Create", Create)
|
|
engine.POST("/Update", Update)
|
|
engine.POST("/Delete", Delete)
|
|
engine.POST("/SwitchStatus", SwitchStatus)
|
|
|
|
hs := &http.Server{
|
|
Addr: Config.ListenAddr,
|
|
Handler: engine,
|
|
}
|
|
|
|
go func() {
|
|
log.Printf("http listen on %s", hs.Addr)
|
|
mylog.Warn(hs.ListenAndServe())
|
|
}()
|
|
|
|
<-ctx.Done()
|
|
hs.Close()
|
|
|
|
}
|
|
|
|
func cors() gin.HandlerFunc {
|
|
return func(ctx *gin.Context) {
|
|
ctx.Header("Access-Control-Allow-Origin", "*")
|
|
ctx.Header("Access-Control-Allow-Methods", "GET, POST")
|
|
|
|
ctx.Header("Access-Control-Allow-Headers", "Content-Type")
|
|
|
|
if ctx.Request.Method == http.MethodOptions {
|
|
ctx.AbortWithStatus(http.StatusOK)
|
|
} else {
|
|
ctx.Next()
|
|
}
|
|
}
|
|
}
|
|
|
|
func GetConfig(ctx *gin.Context) {
|
|
success(ctx, gin.H{
|
|
"web_name": Config.WebName,
|
|
})
|
|
}
|
|
|
|
func List(ctx *gin.Context) {
|
|
data := make([]*model.Forward, 0)
|
|
|
|
err := db.DB().Table("forward").Select("*").Order("update_time desc").Find(&data).Error
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
success(ctx, data)
|
|
}
|
|
|
|
func Create(ctx *gin.Context) {
|
|
type Req struct {
|
|
TargetAddr string `json:"target_addr" binding:"min=1"`
|
|
LocalPort int `json:"local_port" binding:"required"`
|
|
Name string `json:"name" binding:"required"`
|
|
ForwardType int `json:"protocol" binding:"oneof=0 1"`
|
|
Status int `json:"status" binding:"oneof=0 1"`
|
|
}
|
|
|
|
req := &Req{}
|
|
err := ctx.ShouldBindJSON(req)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
err = checkAddr(fmt.Sprintf("0.0.0.0:%d", req.LocalPort))
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
targetAddr, err := parseAddrList(req.TargetAddr)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
err = checkUnique(0, req.LocalPort, req.ForwardType)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
err = db.DB().Transaction(func(tx *gorm.DB) error {
|
|
mForward := &model.Forward{
|
|
LocalPort: req.LocalPort,
|
|
TargetAddr: targetAddr,
|
|
Name: req.Name,
|
|
Protocol: req.ForwardType,
|
|
Status: 1,
|
|
CreateTime: now.Format(time.DateTime),
|
|
UpdateTime: now.Format(time.DateTime),
|
|
}
|
|
|
|
err := tx.Table("forward").Create(mForward).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = forward.ListenerManager.Add(mForward)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
success(ctx, "ok")
|
|
}
|
|
|
|
func Update(ctx *gin.Context) {
|
|
type Req struct {
|
|
Id int `json:"id" binding:"required"`
|
|
TargetAddr string `json:"target_addr" binding:"min=1"`
|
|
LocalPort int `json:"local_port" binding:"required"`
|
|
Name string `json:"name" binding:"required"`
|
|
Protocol int `json:"protocol" binding:"oneof=0 1"`
|
|
Status int `json:"status" binding:"oneof=0 1"`
|
|
}
|
|
|
|
req := &Req{}
|
|
err := ctx.ShouldBindJSON(req)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
err = checkAddr(fmt.Sprintf("0.0.0.0:%d", req.LocalPort))
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
targetAddr, err := parseAddrList(req.TargetAddr)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
err = checkUnique(req.Id, req.LocalPort, req.Protocol)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
current, err := getForwardById(req.Id)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
err = db.DB().Transaction(func(tx *gorm.DB) error {
|
|
mForward := &model.Forward{
|
|
LocalPort: req.LocalPort,
|
|
TargetAddr: targetAddr,
|
|
Name: req.Name,
|
|
Protocol: req.Protocol,
|
|
UpdateTime: time.Now().Format(time.DateTime),
|
|
}
|
|
|
|
err = tx.Table("forward").Select("*").Omit("create_time", "status").Where("id = ?", req.Id).
|
|
Limit(1).
|
|
Updates(mForward).Error
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
forward.ListenerManager.Remove(current)
|
|
|
|
if current.Status == 1 {
|
|
err = forward.ListenerManager.Add(mForward)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
success(ctx, "ok")
|
|
}
|
|
|
|
func Delete(ctx *gin.Context) {
|
|
type Req struct {
|
|
Id int `json:"id" binding:"required"`
|
|
}
|
|
|
|
req := &Req{}
|
|
err := ctx.ShouldBindJSON(req)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
mForward, err := getForwardById(req.Id)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
err = db.DB().Table("forward").Select("*").Where("id = ?", req.Id).Limit(1).Delete(nil).Error
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
forward.ListenerManager.Remove(mForward)
|
|
|
|
success(ctx, "ok")
|
|
}
|
|
|
|
var (
|
|
hostReg = regexp.MustCompile(`^[a-zA-Z0-9]+[a-zA-Z0-9\-.]*[a-zA-Z]+$`)
|
|
letterReg = regexp.MustCompile(`[a-zA-Z\-]`)
|
|
emptyReg = regexp.MustCompile(`\s+`)
|
|
)
|
|
|
|
func checkAddr(addr string) error {
|
|
addr = strings.TrimSpace(addr)
|
|
|
|
if emptyReg.MatchString(addr) {
|
|
return fmt.Errorf("addr format err: %s", addr)
|
|
}
|
|
|
|
sp := strings.SplitN(addr, ":", 2)
|
|
|
|
if len(sp) != 2 {
|
|
return fmt.Errorf("addr format err: %s", addr)
|
|
}
|
|
|
|
host := sp[0]
|
|
port := sp[1]
|
|
intPort := cast.ToInt(port)
|
|
|
|
if intPort > 65535 || intPort < 10 {
|
|
return fmt.Errorf("port [%s] out of range: 10 - 65535", port)
|
|
}
|
|
|
|
if !letterReg.MatchString(host) {
|
|
_, err := netip.ParseAddrPort(addr)
|
|
if err != nil {
|
|
return fmt.Errorf("ip [%s] format err: %s", host, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
if !hostReg.MatchString(host) {
|
|
return fmt.Errorf("host [%s] format err", host)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func checkUnique(id, localPort, forwardType int) error {
|
|
query := db.DB().Table("forward").
|
|
Where("local_port = ?", localPort).
|
|
Where("protocol = ?", forwardType)
|
|
|
|
if id > 0 {
|
|
query.Where("id != ?", id)
|
|
}
|
|
|
|
var res []int
|
|
err := query.Select("id").Limit(1).Find(&res).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(res) > 0 {
|
|
return fmt.Errorf("%s port: %d already use", forward.Protocol(forwardType).String(), localPort)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getForwardById(id int) (*model.Forward, error) {
|
|
if id <= 0 {
|
|
return nil, fmt.Errorf("id err: %d", id)
|
|
}
|
|
mForward := &model.Forward{}
|
|
err := db.DB().Table("forward").Select("*").Where("id = ?", id).First(mForward).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mForward, nil
|
|
}
|
|
|
|
func SwitchStatus(ctx *gin.Context) {
|
|
type Req struct {
|
|
Id int `json:"id" binding:"required"`
|
|
Status int `json:"status" binding:"oneof=0 1"`
|
|
}
|
|
|
|
req := &Req{}
|
|
err := ctx.ShouldBindJSON(req)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
mForward, err := getForwardById(req.Id)
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
err = db.DB().Transaction(func(tx *gorm.DB) error {
|
|
err = tx.Table("forward").Where("id = ?", req.Id).Updates(map[string]any{
|
|
"status": req.Status,
|
|
}).Limit(1).Error
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch req.Status {
|
|
case 1:
|
|
err = forward.ListenerManager.Add(mForward)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
forward.ListenerManager.Remove(mForward)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
fail(ctx, err)
|
|
return
|
|
}
|
|
|
|
success(ctx, "ok")
|
|
}
|
|
|
|
type apiRes struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data any `json:"data"`
|
|
}
|
|
|
|
func success(ctx *gin.Context, data any) {
|
|
ctx.JSON(http.StatusOK, &apiRes{
|
|
Code: 0,
|
|
Message: "ok",
|
|
Data: data,
|
|
})
|
|
}
|
|
|
|
var ginTrans ut.Translator
|
|
|
|
func InitGinTrans() {
|
|
v := binding.Validator.Engine().(*validator.Validate)
|
|
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
|
nameSp := strings.SplitN(fld.Tag.Get("json"), ",", 2)
|
|
if len(nameSp) == 0 {
|
|
return fld.Name
|
|
}
|
|
|
|
name := nameSp[0]
|
|
if name == "-" {
|
|
return ""
|
|
}
|
|
|
|
if name != "" {
|
|
return name
|
|
}
|
|
|
|
return fld.Name
|
|
})
|
|
enT := en.New()
|
|
|
|
uni := ut.New(enT, enT)
|
|
tr, _ := uni.GetTranslator("en")
|
|
_ = enTr.RegisterDefaultTranslations(v, tr)
|
|
|
|
ginTrans = tr
|
|
}
|
|
|
|
func fail(ctx *gin.Context, err error) {
|
|
res := &apiRes{Code: 1}
|
|
|
|
switch ve := err.(type) {
|
|
case validator.ValidationErrors:
|
|
if len(ve) > 0 {
|
|
res.Message = ve[0].Translate(ginTrans)
|
|
} else {
|
|
res.Message = ve.Error()
|
|
}
|
|
|
|
default:
|
|
res.Message = err.Error()
|
|
}
|
|
|
|
ctx.JSON(http.StatusOK, res)
|
|
}
|
|
|
|
var emptySplitReg = regexp.MustCompile(`\s*\n\s*`)
|
|
|
|
func parseAddrList(addrStr string) ([]string, error) {
|
|
sp := emptySplitReg.Split(addrStr, -1)
|
|
if len(sp) == 0 {
|
|
return nil, fmt.Errorf("addr list is empty")
|
|
}
|
|
|
|
res := make([]string, 0)
|
|
for i, _ := range sp {
|
|
sp[i] = strings.TrimSpace(sp[i])
|
|
|
|
if sp[i] == "" {
|
|
continue
|
|
}
|
|
|
|
err := checkAddr(sp[i])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res = append(res, sp[i])
|
|
}
|
|
|
|
return res, nil
|
|
}
|