package forward import ( "context" "fmt" "git.makemake.in/kzkzzzz/mycommon/mylog" "io" "log" "math/rand" "net" "time" ) var _ IForward = (*TCP)(nil) type TCP struct { forwardInfo *Info listener net.Listener } func NewTCP(forwardInfo *Info) *TCP { return &TCP{ forwardInfo: forwardInfo, } } func (t *TCP) Forward() error { listenTimeout := time.Second ctx0, cancel0 := context.WithTimeout(context.Background(), listenTimeout) defer cancel0() var ( errChan = make(chan error, 1) listenerRes = make(chan net.Listener, 1) ) go func() { listener, err := net.Listen("tcp", t.forwardInfo.LocalAddr) if err != nil { mylog.Error(err) errChan <- err return } listenerRes <- listener }() select { case listener := <-listenerRes: t.listener = listener case err := <-errChan: return err case <-ctx0.Done(): return fmt.Errorf("listen timeout %s %s", t.forwardInfo.LocalAddr, listenTimeout) } // //mylog.Infof("start listen: %s", t.forwardInfo.LocalAddr) //lc := net.ListenConfig{} //listener, err := lc.Listen(ctx0, "tcp", t.forwardInfo.LocalAddr) // 启动 TCP 监听 log.Printf("[TCP] [%s] %s -> %s", t.forwardInfo.Name, t.forwardInfo.LocalAddr, t.forwardInfo.TargetAddr) go func() { ctx, cancelCtx := context.WithCancel(context.Background()) defer cancelCtx() for { // 接受连接 conn, err := t.listener.Accept() if err != nil { mylog.Error(err) break } // 处理连接 go t.handleConn(ctx, conn, t.forwardInfo.TargetAddr) } }() return nil } func (t *TCP) handleConn(mainCtx context.Context, localConn net.Conn, targetAddrList []string) { defer localConn.Close() targetAddr := targetAddrList[rand.Intn(len(targetAddrList))] // 连接到目标地址 targetConn, err := net.Dial("tcp", targetAddr) if err != nil { mylog.Error("Error connecting to target:", err) return } defer targetConn.Close() ctx, cancelCtx := context.WithCancel(context.Background()) defer cancelCtx() defer func() { mylog.Warnf("tcp forward stop %s -> %+v", localConn.RemoteAddr(), targetAddrList) }() mylog.Debugf("tcp forward %s -> %s", localConn.RemoteAddr(), targetAddr) go func() { defer cancelCtx() _, err := io.Copy(targetConn, localConn) // 从客户端转发到目标 if err != nil { mylog.Error(err) } }() go func() { defer cancelCtx() _, err := io.Copy(localConn, targetConn) // 从目标转发到客户端 if err != nil { mylog.Error(err) } }() select { case <-mainCtx.Done(): cancelCtx() case <-ctx.Done(): } } func (t *TCP) Stop() { if t.listener != nil { t.listener.Close() } }