137 lines
2.6 KiB
Go
137 lines
2.6 KiB
Go
package forward
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"git.makemake.in/kzkzzzz/mycommon/mylog"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net"
|
|
"time"
|
|
)
|
|
|
|
var _ IForward = (*TCP)(nil)
|
|
|
|
type TCP struct {
|
|
forwardInfo *Info
|
|
listener net.Listener
|
|
}
|
|
|
|
func NewTCP(forwardInfo *Info) *TCP {
|
|
return &TCP{
|
|
forwardInfo: forwardInfo,
|
|
}
|
|
}
|
|
|
|
func (t *TCP) Forward() error {
|
|
listenTimeout := time.Second
|
|
ctx0, cancel0 := context.WithTimeout(context.Background(), listenTimeout)
|
|
defer cancel0()
|
|
|
|
var (
|
|
errChan = make(chan error, 1)
|
|
listenerRes = make(chan net.Listener, 1)
|
|
)
|
|
|
|
go func() {
|
|
listener, err := net.Listen("tcp", t.forwardInfo.LocalAddr)
|
|
if err != nil {
|
|
mylog.Error(err)
|
|
errChan <- err
|
|
return
|
|
}
|
|
listenerRes <- listener
|
|
}()
|
|
|
|
select {
|
|
case listener := <-listenerRes:
|
|
t.listener = listener
|
|
|
|
case err := <-errChan:
|
|
return err
|
|
|
|
case <-ctx0.Done():
|
|
return fmt.Errorf("listen timeout %s %s", t.forwardInfo.LocalAddr, listenTimeout)
|
|
}
|
|
|
|
//
|
|
//mylog.Infof("start listen: %s", t.forwardInfo.LocalAddr)
|
|
//lc := net.ListenConfig{}
|
|
|
|
//listener, err := lc.Listen(ctx0, "tcp", t.forwardInfo.LocalAddr)
|
|
// 启动 TCP 监听
|
|
|
|
log.Printf("[TCP] [%s] %s -> %s", t.forwardInfo.Name, t.forwardInfo.LocalAddr, t.forwardInfo.TargetAddr)
|
|
|
|
go func() {
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
for {
|
|
// 接受连接
|
|
conn, err := t.listener.Accept()
|
|
if err != nil {
|
|
mylog.Error(err)
|
|
break
|
|
}
|
|
// 处理连接
|
|
go t.handleConn(ctx, conn, t.forwardInfo.TargetAddr)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *TCP) handleConn(mainCtx context.Context, localConn net.Conn, targetAddrList []string) {
|
|
defer localConn.Close()
|
|
|
|
targetAddr := targetAddrList[rand.Intn(len(targetAddrList))]
|
|
|
|
// 连接到目标地址
|
|
targetConn, err := net.Dial("tcp", targetAddr)
|
|
if err != nil {
|
|
mylog.Error("Error connecting to target:", err)
|
|
return
|
|
}
|
|
defer targetConn.Close()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
defer func() {
|
|
mylog.Warnf("tcp forward stop %s -> %+v", localConn.RemoteAddr(), targetAddrList)
|
|
}()
|
|
mylog.Debugf("tcp forward %s -> %s", localConn.RemoteAddr(), targetAddr)
|
|
|
|
go func() {
|
|
defer cancelCtx()
|
|
_, err := io.Copy(targetConn, localConn) // 从客户端转发到目标
|
|
if err != nil {
|
|
mylog.Error(err)
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer cancelCtx()
|
|
_, err := io.Copy(localConn, targetConn) // 从目标转发到客户端
|
|
if err != nil {
|
|
mylog.Error(err)
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-mainCtx.Done():
|
|
cancelCtx()
|
|
case <-ctx.Done():
|
|
|
|
}
|
|
|
|
}
|
|
|
|
func (t *TCP) Stop() {
|
|
if t.listener != nil {
|
|
t.listener.Close()
|
|
}
|
|
}
|