proxyport/app/web/web.go

498 lines
9.2 KiB
Go

package web
import (
"context"
"fmt"
"git.makemake.in/kzkzzzz/mycommon/mylog"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
enTr "github.com/go-playground/validator/v10/translations/en"
"github.com/spf13/cast"
"gorm.io/gorm"
"log"
"net/http"
"net/netip"
"proxyport/app/db"
"proxyport/app/forward"
"proxyport/app/model"
"proxyport/app/static"
"reflect"
"regexp"
"strings"
"time"
)
var Config = &Cfg{}
type Cfg struct {
ListenAddr string
User string
Password string
WebName string
}
func Start(ctx context.Context) {
InitGinTrans()
engine := gin.Default()
if Config.User != "" && Config.Password != "" {
engine.Use(gin.BasicAuth(gin.Accounts{
Config.User: Config.Password,
}))
}
static.LoadStatic(engine)
engine.Use(cors())
engine.GET("/", func(ctx *gin.Context) {
ctx.HTML(http.StatusOK, "index.html", nil)
})
engine.GET("/GetConfig", GetConfig)
engine.GET("/List", List)
engine.POST("/Create", Create)
engine.POST("/Update", Update)
engine.POST("/Delete", Delete)
engine.POST("/SwitchStatus", SwitchStatus)
hs := &http.Server{
Addr: Config.ListenAddr,
Handler: engine,
}
go func() {
log.Printf("http listen on %s", hs.Addr)
mylog.Warn(hs.ListenAndServe())
}()
<-ctx.Done()
hs.Close()
}
func cors() gin.HandlerFunc {
return func(ctx *gin.Context) {
ctx.Header("Access-Control-Allow-Origin", "*")
ctx.Header("Access-Control-Allow-Methods", "GET, POST")
ctx.Header("Access-Control-Allow-Headers", "Content-Type")
if ctx.Request.Method == http.MethodOptions {
ctx.AbortWithStatus(http.StatusOK)
} else {
ctx.Next()
}
}
}
func GetConfig(ctx *gin.Context) {
success(ctx, gin.H{
"web_name": Config.WebName,
})
}
func List(ctx *gin.Context) {
data := make([]*model.Forward, 0)
err := db.DB().Table("forward").Select("*").Order("update_time desc").Find(&data).Error
if err != nil {
fail(ctx, err)
return
}
success(ctx, data)
}
func Create(ctx *gin.Context) {
type Req struct {
TargetAddr string `json:"target_addr" binding:"min=1"`
LocalPort int `json:"local_port" binding:"required"`
Name string `json:"name" binding:"required"`
ForwardType int `json:"protocol" binding:"oneof=0 1"`
Status int `json:"status" binding:"oneof=0 1"`
}
req := &Req{}
err := ctx.ShouldBindJSON(req)
if err != nil {
fail(ctx, err)
return
}
err = checkAddr(fmt.Sprintf("0.0.0.0:%d", req.LocalPort))
if err != nil {
fail(ctx, err)
return
}
targetAddr, err := parseAddrList(req.TargetAddr)
if err != nil {
fail(ctx, err)
return
}
err = checkUnique(0, req.LocalPort, req.ForwardType)
if err != nil {
fail(ctx, err)
return
}
now := time.Now()
err = db.DB().Transaction(func(tx *gorm.DB) error {
mForward := &model.Forward{
LocalPort: req.LocalPort,
TargetAddr: targetAddr,
Name: req.Name,
Protocol: req.ForwardType,
Status: 1,
CreateTime: now.Format(time.DateTime),
UpdateTime: now.Format(time.DateTime),
}
err := tx.Table("forward").Create(mForward).Error
if err != nil {
return err
}
err = forward.ListenerManager.Add(mForward)
if err != nil {
return err
}
return nil
})
if err != nil {
fail(ctx, err)
return
}
success(ctx, "ok")
}
func Update(ctx *gin.Context) {
type Req struct {
Id int `json:"id" binding:"required"`
TargetAddr string `json:"target_addr" binding:"min=1"`
LocalPort int `json:"local_port" binding:"required"`
Name string `json:"name" binding:"required"`
Protocol int `json:"protocol" binding:"oneof=0 1"`
Status int `json:"status" binding:"oneof=0 1"`
}
req := &Req{}
err := ctx.ShouldBindJSON(req)
if err != nil {
fail(ctx, err)
return
}
err = checkAddr(fmt.Sprintf("0.0.0.0:%d", req.LocalPort))
if err != nil {
fail(ctx, err)
return
}
targetAddr, err := parseAddrList(req.TargetAddr)
if err != nil {
fail(ctx, err)
return
}
err = checkUnique(req.Id, req.LocalPort, req.Protocol)
if err != nil {
fail(ctx, err)
return
}
current, err := getForwardById(req.Id)
if err != nil {
fail(ctx, err)
return
}
err = db.DB().Transaction(func(tx *gorm.DB) error {
mForward := &model.Forward{
LocalPort: req.LocalPort,
TargetAddr: targetAddr,
Name: req.Name,
Protocol: req.Protocol,
UpdateTime: time.Now().Format(time.DateTime),
}
err = tx.Table("forward").Select("*").Omit("create_time", "status").Where("id = ?", req.Id).
Limit(1).
Updates(mForward).Error
if err != nil {
return err
}
forward.ListenerManager.Remove(current)
if current.Status == 1 {
err = forward.ListenerManager.Add(mForward)
if err != nil {
return err
}
}
return nil
})
if err != nil {
fail(ctx, err)
return
}
success(ctx, "ok")
}
func Delete(ctx *gin.Context) {
type Req struct {
Id int `json:"id" binding:"required"`
}
req := &Req{}
err := ctx.ShouldBindJSON(req)
if err != nil {
fail(ctx, err)
return
}
mForward, err := getForwardById(req.Id)
if err != nil {
fail(ctx, err)
return
}
err = db.DB().Table("forward").Select("*").Where("id = ?", req.Id).Limit(1).Delete(nil).Error
if err != nil {
fail(ctx, err)
return
}
forward.ListenerManager.Remove(mForward)
success(ctx, "ok")
}
var (
hostReg = regexp.MustCompile(`^[a-zA-Z0-9]+[a-zA-Z0-9\-.]*[a-zA-Z]+$`)
letterReg = regexp.MustCompile(`[a-zA-Z\-]`)
emptyReg = regexp.MustCompile(`\s+`)
)
func checkAddr(addr string) error {
addr = strings.TrimSpace(addr)
if emptyReg.MatchString(addr) {
return fmt.Errorf("addr format err: %s", addr)
}
sp := strings.SplitN(addr, ":", 2)
if len(sp) != 2 {
return fmt.Errorf("addr format err: %s", addr)
}
host := sp[0]
port := sp[1]
intPort := cast.ToInt(port)
if intPort > 65535 || intPort < 10 {
return fmt.Errorf("port [%s] out of range: 10 - 65535", port)
}
if !letterReg.MatchString(host) {
_, err := netip.ParseAddrPort(addr)
if err != nil {
return fmt.Errorf("ip [%s] format err: %s", host, err)
}
return nil
}
if !hostReg.MatchString(host) {
return fmt.Errorf("host [%s] format err", host)
}
return nil
}
func checkUnique(id, localPort, forwardType int) error {
query := db.DB().Table("forward").
Where("local_port = ?", localPort).
Where("protocol = ?", forwardType)
if id > 0 {
query.Where("id != ?", id)
}
var res []int
err := query.Select("id").Limit(1).Find(&res).Error
if err != nil {
return err
}
if len(res) > 0 {
return fmt.Errorf("%s port: %d already use", forward.Protocol(forwardType).String(), localPort)
}
return nil
}
func getForwardById(id int) (*model.Forward, error) {
if id <= 0 {
return nil, fmt.Errorf("id err: %d", id)
}
mForward := &model.Forward{}
err := db.DB().Table("forward").Select("*").Where("id = ?", id).First(mForward).Error
if err != nil {
return nil, err
}
return mForward, nil
}
func SwitchStatus(ctx *gin.Context) {
type Req struct {
Id int `json:"id" binding:"required"`
Status int `json:"status" binding:"oneof=0 1"`
}
req := &Req{}
err := ctx.ShouldBindJSON(req)
if err != nil {
fail(ctx, err)
return
}
mForward, err := getForwardById(req.Id)
if err != nil {
fail(ctx, err)
return
}
err = db.DB().Transaction(func(tx *gorm.DB) error {
err = tx.Table("forward").Where("id = ?", req.Id).Updates(map[string]any{
"status": req.Status,
}).Limit(1).Error
if err != nil {
return err
}
switch req.Status {
case 1:
err = forward.ListenerManager.Add(mForward)
if err != nil {
return err
}
default:
forward.ListenerManager.Remove(mForward)
}
return nil
})
if err != nil {
fail(ctx, err)
return
}
success(ctx, "ok")
}
type apiRes struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
}
func success(ctx *gin.Context, data any) {
ctx.JSON(http.StatusOK, &apiRes{
Code: 0,
Message: "ok",
Data: data,
})
}
var ginTrans ut.Translator
func InitGinTrans() {
v := binding.Validator.Engine().(*validator.Validate)
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
nameSp := strings.SplitN(fld.Tag.Get("json"), ",", 2)
if len(nameSp) == 0 {
return fld.Name
}
name := nameSp[0]
if name == "-" {
return ""
}
if name != "" {
return name
}
return fld.Name
})
enT := en.New()
uni := ut.New(enT, enT)
tr, _ := uni.GetTranslator("en")
_ = enTr.RegisterDefaultTranslations(v, tr)
ginTrans = tr
}
func fail(ctx *gin.Context, err error) {
res := &apiRes{Code: 1}
switch ve := err.(type) {
case validator.ValidationErrors:
if len(ve) > 0 {
res.Message = ve[0].Translate(ginTrans)
} else {
res.Message = ve.Error()
}
default:
res.Message = err.Error()
}
ctx.JSON(http.StatusOK, res)
}
var emptySplitReg = regexp.MustCompile(`\s*\n\s*`)
func parseAddrList(addrStr string) ([]string, error) {
sp := emptySplitReg.Split(addrStr, -1)
if len(sp) == 0 {
return nil, fmt.Errorf("addr list is empty")
}
res := make([]string, 0)
for i, _ := range sp {
sp[i] = strings.TrimSpace(sp[i])
if sp[i] == "" {
continue
}
err := checkAddr(sp[i])
if err != nil {
return nil, err
}
res = append(res, sp[i])
}
return res, nil
}