proxyport/app/forward/tcp.go

137 lines
2.6 KiB
Go
Raw Normal View History

2024-09-30 16:37:53 +08:00
package forward
import (
"context"
"fmt"
"git.makemake.in/kzkzzzz/mycommon/mylog"
"io"
2024-10-10 10:53:08 +08:00
"log"
2024-09-30 16:37:53 +08:00
"math/rand"
"net"
"time"
)
var _ IForward = (*TCP)(nil)
type TCP struct {
forwardInfo *Info
listener net.Listener
}
func NewTCP(forwardInfo *Info) *TCP {
return &TCP{
forwardInfo: forwardInfo,
}
}
func (t *TCP) Forward() error {
listenTimeout := time.Second
ctx0, cancel0 := context.WithTimeout(context.Background(), listenTimeout)
defer cancel0()
var (
errChan = make(chan error, 1)
listenerRes = make(chan net.Listener, 1)
)
go func() {
listener, err := net.Listen("tcp", t.forwardInfo.LocalAddr)
if err != nil {
mylog.Error(err)
errChan <- err
return
}
listenerRes <- listener
}()
select {
case listener := <-listenerRes:
t.listener = listener
case err := <-errChan:
return err
case <-ctx0.Done():
return fmt.Errorf("listen timeout %s %s", t.forwardInfo.LocalAddr, listenTimeout)
}
//
//mylog.Infof("start listen: %s", t.forwardInfo.LocalAddr)
//lc := net.ListenConfig{}
//listener, err := lc.Listen(ctx0, "tcp", t.forwardInfo.LocalAddr)
// 启动 TCP 监听
2024-10-10 10:53:08 +08:00
log.Printf("[TCP] [%s] %s -> %s", t.forwardInfo.Name, t.forwardInfo.LocalAddr, t.forwardInfo.TargetAddr)
2024-09-30 16:37:53 +08:00
go func() {
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
for {
// 接受连接
conn, err := t.listener.Accept()
if err != nil {
mylog.Error(err)
break
}
// 处理连接
go t.handleConn(ctx, conn, t.forwardInfo.TargetAddr)
}
}()
return nil
}
func (t *TCP) handleConn(mainCtx context.Context, localConn net.Conn, targetAddrList []string) {
defer localConn.Close()
targetAddr := targetAddrList[rand.Intn(len(targetAddrList))]
// 连接到目标地址
targetConn, err := net.Dial("tcp", targetAddr)
if err != nil {
mylog.Error("Error connecting to target:", err)
return
}
defer targetConn.Close()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
defer func() {
mylog.Warnf("tcp forward stop %s -> %+v", localConn.RemoteAddr(), targetAddrList)
}()
mylog.Debugf("tcp forward %s -> %s", localConn.RemoteAddr(), targetAddr)
go func() {
defer cancelCtx()
_, err := io.Copy(targetConn, localConn) // 从客户端转发到目标
if err != nil {
mylog.Error(err)
}
}()
go func() {
defer cancelCtx()
_, err := io.Copy(localConn, targetConn) // 从目标转发到客户端
if err != nil {
mylog.Error(err)
}
}()
select {
case <-mainCtx.Done():
cancelCtx()
case <-ctx.Done():
}
}
func (t *TCP) Stop() {
if t.listener != nil {
t.listener.Close()
}
}