proxymysql/app/mysqlserver/handshake.go

393 lines
8.7 KiB
Go
Raw Permalink Normal View History

2024-02-04 18:17:06 +08:00
package mysqlserver
import (
"bytes"
"fmt"
"io"
"proxymysql/app/zlog"
)
const (
errAuthPluginMethod = "err_auth_plugin_method"
nativePasswordAuthPluginMethod = "mysql_native_password"
)
var DefaultHandshakeCapability uint32 = CapabilityClientLongPassword |
CapabilityClientFoundRows |
CapabilityClientLongFlag |
CapabilityClientConnectWithDB |
CapabilityClientProtocol41 |
CapabilityClientTransactions |
CapabilityClientSecureConnection |
CapabilityClientMultiStatements |
CapabilityClientMultiResults |
CapabilityClientPluginAuth |
CapabilityClientPluginAuthLenencClientData |
CapabilityClientDeprecateEOF |
CapabilityClientConnAttr |
CapabilityClientQueryAttributes
type HandshakeV10 struct {
ProtocolVersion uint8
ServerVersion string
AuthPluginMethod string
ConnectionId uint32
CharsetCollation uint8
ServerStatus uint16
CapabilityFlag uint32
AuthPluginData []byte
}
func (hk *HandshakeV10) ToByte() []byte {
data := make([]byte, 0)
// protocol version
data = append(data, WriteByte(protocolVersion)...)
// server version
data = append(data, WriteStringNull(hk.ServerVersion)...)
// thread id
data = append(data, WriteUint32(hk.ConnectionId)...)
// auth-plugin-data-part-1 first 8 bytes of the plugin provided data (scramble)
data = append(data, hk.AuthPluginData[:8]...)
// filler 0x00 byte, terminating the first part of a scramble
data = append(data, 0x00)
// capability_flags_1 The lower 2 bytes of the Capabilities Flags
data = append(data, WriteUint16(uint16(hk.CapabilityFlag))...)
data = append(data, WriteByte(CharsetUtf8mb4GeneralCiId)...)
serverStatus := ServerStatusAutocommit
data = append(data, WriteUint16(serverStatus)...)
// The upper 2 bytes of the Capabilities Flags
data = append(data, WriteUint16(uint16(hk.CapabilityFlag>>16))...)
// Length of auth plugin data.
// Always 21 (8 + 13).
data = append(data, WriteByte(21)...)
// reserved. All 0s.
fillByte := make([]byte, 10)
data = append(data, fillByte...)
// auth-plugin-data-part-2 Rest of the plugin provided data (scramble), $len=MAX(13, length of auth-plugin-data - 8)
data = append(data, hk.AuthPluginData[8:]...)
data = append(data, WriteStringNull(hk.AuthPluginMethod)...)
return WithHeaderPacket(data, 0)
}
type HandshakeResponse struct {
ClientFlag uint32
MaxPacketSize uint32
Charset uint8
Username string
Password []byte
Database string
AuthPluginMethod string
ClientAttrLen uint64
ClientAttrs map[string]string
MysqlPacketHeader
}
func (hp *HandshakeResponse) ToByte() []byte {
res := make([]byte, 0, hp.Length)
res = append(res, WriteUint32(hp.ClientFlag)...)
res = append(res, WriteUint32(hp.MaxPacketSize)...)
res = append(res, WriteByte(hp.Charset)...)
fill := make([]byte, 23)
res = append(res, fill...)
res = append(res, WriteStringNull(hp.Username)...)
if hp.ClientFlag&CapabilityClientPluginAuthLenencClientData > 0 {
res = append(res, WriteLengthEncodedString(hp.Password)...)
} else {
res = append(res, uint8(len(hp.Password)))
res = append(res, hp.Password...)
}
if hp.ClientFlag&CapabilityClientConnectWithDB > 0 {
res = append(res, WriteStringNull(hp.Database)...)
}
if hp.ClientFlag&CapabilityClientPluginAuth > 0 {
res = append(res, WriteStringNull(hp.AuthPluginMethod)...)
}
if hp.ClientFlag&CapabilityClientConnAttr > 0 {
//res = append(res, WriteLengthEncodedInt(resp.ClientAttrLen)...)
attrByte := ClientAttrsToByte(hp.ClientAttrs)
attrLen := len(attrByte)
if attrLen > 0 {
res = append(res, WriteLengthEncodedInt(uint64(attrLen))...)
res = append(res, attrByte...)
}
}
//fmt.Printf("%+v\n", buf.Bytes())
return WithHeaderPacket(res, hp.SequenceId)
}
func ClientAttrsToByte(clientAttrs map[string]string) []byte {
if len(clientAttrs) == 0 {
return nil
}
attrByte := make([]byte, 0)
for k, v := range clientAttrs {
keyByte := []byte(k)
valByte := []byte(v)
keyEncoded := WriteLengthEncodedInt(uint64(len(keyByte)))
valEncoded := WriteLengthEncodedInt(uint64(len(valByte)))
attrByte = append(attrByte, keyEncoded...)
attrByte = append(attrByte, keyByte...)
attrByte = append(attrByte, valEncoded...)
attrByte = append(attrByte, valByte...)
}
//fmt.Printf("len:%d attrByte:%+v\n", len(attrByte), attrByte)
return attrByte
}
func ParseClientAttrs(data []byte) (map[string]string, error) {
length := len(data)
buf := bytes.NewBuffer(data)
readLength := 0
attrs := make(map[string]string)
for readLength < length {
keyLen, keyPos, ok := ReadLengthEncodedInt(buf.Bytes())
if !ok {
return nil, fmt.Errorf("read attrs key err")
}
readLength += int(keyLen) + keyPos
buf.Next(keyPos)
key := ReadString(buf.Next(int(keyLen)))
valLen, valPos, ok := ReadLengthEncodedInt(buf.Bytes())
if !ok {
return nil, fmt.Errorf("read attrs value err")
}
readLength += int(valLen) + keyPos
buf.Next(valPos)
value := ReadString(buf.Next(int(valLen)))
attrs[key] = value
//fmt.Printf("%s - %s [%d]\n", key, value, readLength)
//break
}
//fmt.Printf("parse attrs ok: %+v\n", attrs)
return attrs, nil
}
func ReadHandshakeV10(conn io.Reader) (*HandshakeV10, error) {
pk, err := ReadMysqlPacket(conn)
if err != nil {
return nil, err
}
if pk.Payload[0] != protocolVersion {
return nil, fmt.Errorf("protocol version err: %d", pk.Payload[0])
}
buf := bytes.NewBuffer(pk.Payload)
res := &HandshakeV10{}
res.ProtocolVersion = ReadByte(buf.Next(1))
// version
versionByte, err := buf.ReadBytes(0x00)
if err != nil {
return nil, err
}
res.ServerVersion = ReadStringNull(versionByte)
//fmt.Println(string(versionByte), versionByte)
// connId
res.ConnectionId = ReadUint32(buf.Next(4))
//fmt.Println(binary.LittleEndian.Uint32(connId))
fullAuthPluginData := make([]byte, 21)
authPluginData1 := buf.Next(8)
copy(fullAuthPluginData, authPluginData1[:8])
//fmt.Println(string(salt1), salt1)
// filler 0x00
buf.Next(1)
cap1 := buf.Next(2)
//fmt.Println("cap1", cap1)
fullCap := make([]byte, 4)
copy(fullCap[:2], cap1)
charset := buf.Next(1)
res.CharsetCollation = ReadByte(charset)
serverStatus := buf.Next(2)
res.ServerStatus = ReadUint16(serverStatus)
cap2 := buf.Next(2)
copy(fullCap[2:], cap2)
//fmt.Printf("fullCap: %+v\n", fullCap)
res.CapabilityFlag = ReadUint32(fullCap)
//res.CapFlag2 = ReadUint16(cap2)
// auth_plugin_data_len, reserved. All 0s.
buf.Next(1 + 10)
authPluginData2, err := buf.ReadBytes(0x00)
if err != nil {
return nil, err
}
copy(fullAuthPluginData[8:], authPluginData2)
res.AuthPluginData = fullAuthPluginData
authPluginMethod, err := buf.ReadBytes(0x00)
if err != nil {
return nil, err
}
res.AuthPluginMethod = ReadStringNull(authPluginMethod)
return res, nil
}
func ReadHandshakeResponse(conn io.Reader) (*HandshakeResponse, error) {
pk, err := ReadMysqlPacket(conn)
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(pk.Payload)
clientFlag := ReadUint32(buf.Next(4))
if clientFlag&CapabilityClientProtocol41 == 0 {
return nil, fmt.Errorf("only support CLIENT_PROTOCOL_41")
}
if clientFlag&CapabilityClientPluginAuth == 0 {
return nil, fmt.Errorf("unsupport without ClientPluginAuth")
}
res := &HandshakeResponse{}
res.MysqlPacketHeader = pk.MysqlPacketHeader
res.ClientFlag = clientFlag
res.MaxPacketSize = ReadUint32(buf.Next(4))
res.Charset = ReadByte(buf.Next(1))
buf.Next(23)
username, err := buf.ReadBytes(0x00)
if err != nil {
return nil, err
}
res.Username = ReadStringNull(username)
//fmt.Println(buf.Bytes())
if clientFlag&CapabilityClientPluginAuthLenencClientData > 0 {
length, pos, ok := ReadLengthEncodedInt(buf.Bytes())
if !ok {
return nil, fmt.Errorf("ReadLengthEncodedInt err")
}
buf.Next(pos)
password := buf.Next(int(length))
res.Password = password
} else {
length := ReadByte(buf.Next(1))
password := buf.Next(int(length))
res.Password = password
}
if clientFlag&CapabilityClientConnectWithDB > 0 {
db, err := buf.ReadBytes(0x00)
if err != nil {
return nil, err
}
res.Database = ReadStringNull(db)
}
if clientFlag&CapabilityClientPluginAuth > 0 {
authPluginMethod, err := buf.ReadBytes(0x00)
if err != nil {
return nil, err
}
res.AuthPluginMethod = ReadStringNull(authPluginMethod)
}
if clientFlag&CapabilityClientConnAttr > 0 {
length, pos, ok := ReadLengthEncodedInt(buf.Bytes())
if ok {
res.ClientAttrLen = length
buf.Next(pos)
attrsByte := make([]byte, length)
copy(attrsByte, buf.Next(int(res.ClientAttrLen)))
attrs, err := ParseClientAttrs(attrsByte)
if err != nil {
zlog.Errorf("parse client attrs err: %s", err)
} else {
res.ClientAttrs = attrs
}
}
//fmt.Printf("%+v\n", buf.Bytes())
}
//fmt.Printf("%+v\n", buf.Bytes())
return res, nil
}