proxymysql/app/mysqlserver/conn.go

196 lines
3.5 KiB
Go

package mysqlserver
import (
"io"
"net"
"proxymysql/app/conf"
"proxymysql/app/zlog"
"sync"
"sync/atomic"
)
var (
connectionId uint32
//serverVersion = "8.0.30-tz-mysql-proxy"
)
type MysqlPacketHeader struct {
Length uint32
SequenceId uint8
HeaderByte []byte
}
type MysqlPacket struct {
MysqlPacketHeader
Payload []byte
}
func (pd *MysqlPacket) ToByte() []byte {
res := make([]byte, len(pd.Payload)+4)
copy(res[:3], WriteUint24(pd.Length))
res[3] = pd.SequenceId
copy(res[4:], pd.Payload)
return res
}
type ProxyConn struct {
clientConn net.Conn
serverConn net.Conn
}
func NewProxyConn(clientConn net.Conn) *ProxyConn {
return &ProxyConn{clientConn: clientConn}
}
func (p *ProxyConn) getConnectionId() uint32 {
num := atomic.AddUint32(&connectionId, 1)
if num == 0 {
atomic.StoreUint32(&connectionId, 1)
return 1
}
return num
}
func (p *ProxyConn) Handle() error {
serverConn, err := p.getServerConn()
if err != nil {
return err
}
p.serverConn = serverConn
// 先等待服务端返回handshake 在进行下一步操作
hk, err := ReadHandshakeV10(p.serverConn)
if err != nil {
return err
}
hk.ConnectionId = p.getConnectionId()
hk.ServerVersion = conf.App.ServerVersion
// 暂时去掉ssl
hk.CapabilityFlag &^= uint32(CapabilityClientSSL)
// 去掉压缩
hk.CapabilityFlag &^= uint32(CapabilityClientCanUseCompress)
_, err = p.clientConn.Write(hk.ToByte())
if err != nil {
return err
}
resp, err := ReadHandshakeResponse(p.clientConn)
if err != nil {
return err
}
respByte := resp.ToByte()
_, err = serverConn.Write(respByte)
if err != nil {
return err
}
err = p.authSwitch(p.serverConn)
if err != nil {
return err
}
p.copyStream()
return nil
}
func (p *ProxyConn) copyStream() {
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
defer func() {
p.clientConn.Close()
p.serverConn.Close()
wg.Done()
}()
_, err := io.Copy(p.clientConn, p.serverConn)
if err != nil {
zlog.Errorf("serverConn -> clientConn err: %s", err)
//errMsg := err.Error()
//if !strings.Contains(errMsg, "use of closed network connection") {
// fmt.Printf("serverConn -> clientConn err: %s\n", err)
//}
}
}()
go func() {
rq := NewRecordQuery()
defer func() {
p.clientConn.Close()
p.serverConn.Close()
rq.Close()
wg.Done()
}()
_, err := io.Copy(p.serverConn, io.TeeReader(p.clientConn, rq))
if err != nil {
zlog.Errorf("clientConn -> serverConn err: %s", err)
//errMsg := err.Error()
//if !strings.Contains(errMsg, "use of closed network connection") {
// fmt.Printf("clientConn -> serverConn err: %s\n", errMsg)
//}
}
}()
wg.Wait()
zlog.Debug("copy stream stop")
}
func (p *ProxyConn) authSwitch(serverConn net.Conn) error {
var isFinish bool
for {
serverResult, err := ReadMysqlPacket(serverConn)
if err != nil {
return err
}
//fmt.Printf("serverResult: %+v\n", serverResult)
if len(serverResult.Payload) > 0 && (serverResult.Payload[0] == OKPacket || serverResult.Payload[0] == ErrPacket) {
//fmt.Println("ok ----")
isFinish = true
//return nil
}
_, err = p.clientConn.Write(serverResult.ToByte())
if err != nil {
return err
}
if isFinish {
return nil
}
clientResult, err := ReadMysqlPacket(p.clientConn)
if err != nil {
return err
}
_, err = serverConn.Write(clientResult.ToByte())
if err != nil {
return err
}
}
}
func (p *ProxyConn) getServerConn() (net.Conn, error) {
return net.Dial("tcp", conf.App.RemoteDb)
}