393 lines
8.7 KiB
Go
393 lines
8.7 KiB
Go
|
package mysqlserver
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"proxymysql/app/zlog"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
errAuthPluginMethod = "err_auth_plugin_method"
|
||
|
nativePasswordAuthPluginMethod = "mysql_native_password"
|
||
|
)
|
||
|
|
||
|
var DefaultHandshakeCapability uint32 = CapabilityClientLongPassword |
|
||
|
CapabilityClientFoundRows |
|
||
|
CapabilityClientLongFlag |
|
||
|
CapabilityClientConnectWithDB |
|
||
|
CapabilityClientProtocol41 |
|
||
|
CapabilityClientTransactions |
|
||
|
CapabilityClientSecureConnection |
|
||
|
CapabilityClientMultiStatements |
|
||
|
CapabilityClientMultiResults |
|
||
|
CapabilityClientPluginAuth |
|
||
|
CapabilityClientPluginAuthLenencClientData |
|
||
|
CapabilityClientDeprecateEOF |
|
||
|
CapabilityClientConnAttr |
|
||
|
CapabilityClientQueryAttributes
|
||
|
|
||
|
type HandshakeV10 struct {
|
||
|
ProtocolVersion uint8
|
||
|
ServerVersion string
|
||
|
AuthPluginMethod string
|
||
|
ConnectionId uint32
|
||
|
CharsetCollation uint8
|
||
|
|
||
|
ServerStatus uint16
|
||
|
|
||
|
CapabilityFlag uint32
|
||
|
|
||
|
AuthPluginData []byte
|
||
|
}
|
||
|
|
||
|
func (hk *HandshakeV10) ToByte() []byte {
|
||
|
data := make([]byte, 0)
|
||
|
|
||
|
// protocol version
|
||
|
data = append(data, WriteByte(protocolVersion)...)
|
||
|
|
||
|
// server version
|
||
|
data = append(data, WriteStringNull(hk.ServerVersion)...)
|
||
|
|
||
|
// thread id
|
||
|
data = append(data, WriteUint32(hk.ConnectionId)...)
|
||
|
|
||
|
// auth-plugin-data-part-1 first 8 bytes of the plugin provided data (scramble)
|
||
|
data = append(data, hk.AuthPluginData[:8]...)
|
||
|
|
||
|
// filler 0x00 byte, terminating the first part of a scramble
|
||
|
data = append(data, 0x00)
|
||
|
|
||
|
// capability_flags_1 The lower 2 bytes of the Capabilities Flags
|
||
|
data = append(data, WriteUint16(uint16(hk.CapabilityFlag))...)
|
||
|
|
||
|
data = append(data, WriteByte(CharsetUtf8mb4GeneralCiId)...)
|
||
|
|
||
|
serverStatus := ServerStatusAutocommit
|
||
|
data = append(data, WriteUint16(serverStatus)...)
|
||
|
|
||
|
// The upper 2 bytes of the Capabilities Flags
|
||
|
data = append(data, WriteUint16(uint16(hk.CapabilityFlag>>16))...)
|
||
|
|
||
|
// Length of auth plugin data.
|
||
|
// Always 21 (8 + 13).
|
||
|
data = append(data, WriteByte(21)...)
|
||
|
|
||
|
// reserved. All 0s.
|
||
|
fillByte := make([]byte, 10)
|
||
|
data = append(data, fillByte...)
|
||
|
|
||
|
// auth-plugin-data-part-2 Rest of the plugin provided data (scramble), $len=MAX(13, length of auth-plugin-data - 8)
|
||
|
data = append(data, hk.AuthPluginData[8:]...)
|
||
|
|
||
|
data = append(data, WriteStringNull(hk.AuthPluginMethod)...)
|
||
|
|
||
|
return WithHeaderPacket(data, 0)
|
||
|
}
|
||
|
|
||
|
type HandshakeResponse struct {
|
||
|
ClientFlag uint32
|
||
|
MaxPacketSize uint32
|
||
|
Charset uint8
|
||
|
Username string
|
||
|
Password []byte
|
||
|
Database string
|
||
|
AuthPluginMethod string
|
||
|
ClientAttrLen uint64
|
||
|
ClientAttrs map[string]string
|
||
|
MysqlPacketHeader
|
||
|
}
|
||
|
|
||
|
func (hp *HandshakeResponse) ToByte() []byte {
|
||
|
res := make([]byte, 0, hp.Length)
|
||
|
|
||
|
res = append(res, WriteUint32(hp.ClientFlag)...)
|
||
|
res = append(res, WriteUint32(hp.MaxPacketSize)...)
|
||
|
|
||
|
res = append(res, WriteByte(hp.Charset)...)
|
||
|
|
||
|
fill := make([]byte, 23)
|
||
|
res = append(res, fill...)
|
||
|
|
||
|
res = append(res, WriteStringNull(hp.Username)...)
|
||
|
|
||
|
if hp.ClientFlag&CapabilityClientPluginAuthLenencClientData > 0 {
|
||
|
res = append(res, WriteLengthEncodedString(hp.Password)...)
|
||
|
} else {
|
||
|
res = append(res, uint8(len(hp.Password)))
|
||
|
res = append(res, hp.Password...)
|
||
|
}
|
||
|
|
||
|
if hp.ClientFlag&CapabilityClientConnectWithDB > 0 {
|
||
|
res = append(res, WriteStringNull(hp.Database)...)
|
||
|
}
|
||
|
|
||
|
if hp.ClientFlag&CapabilityClientPluginAuth > 0 {
|
||
|
res = append(res, WriteStringNull(hp.AuthPluginMethod)...)
|
||
|
}
|
||
|
|
||
|
if hp.ClientFlag&CapabilityClientConnAttr > 0 {
|
||
|
//res = append(res, WriteLengthEncodedInt(resp.ClientAttrLen)...)
|
||
|
attrByte := ClientAttrsToByte(hp.ClientAttrs)
|
||
|
attrLen := len(attrByte)
|
||
|
if attrLen > 0 {
|
||
|
res = append(res, WriteLengthEncodedInt(uint64(attrLen))...)
|
||
|
res = append(res, attrByte...)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//fmt.Printf("%+v\n", buf.Bytes())
|
||
|
|
||
|
return WithHeaderPacket(res, hp.SequenceId)
|
||
|
}
|
||
|
|
||
|
func ClientAttrsToByte(clientAttrs map[string]string) []byte {
|
||
|
if len(clientAttrs) == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
attrByte := make([]byte, 0)
|
||
|
|
||
|
for k, v := range clientAttrs {
|
||
|
keyByte := []byte(k)
|
||
|
valByte := []byte(v)
|
||
|
|
||
|
keyEncoded := WriteLengthEncodedInt(uint64(len(keyByte)))
|
||
|
|
||
|
valEncoded := WriteLengthEncodedInt(uint64(len(valByte)))
|
||
|
|
||
|
attrByte = append(attrByte, keyEncoded...)
|
||
|
attrByte = append(attrByte, keyByte...)
|
||
|
|
||
|
attrByte = append(attrByte, valEncoded...)
|
||
|
attrByte = append(attrByte, valByte...)
|
||
|
}
|
||
|
|
||
|
//fmt.Printf("len:%d attrByte:%+v\n", len(attrByte), attrByte)
|
||
|
|
||
|
return attrByte
|
||
|
}
|
||
|
|
||
|
func ParseClientAttrs(data []byte) (map[string]string, error) {
|
||
|
length := len(data)
|
||
|
|
||
|
buf := bytes.NewBuffer(data)
|
||
|
|
||
|
readLength := 0
|
||
|
|
||
|
attrs := make(map[string]string)
|
||
|
|
||
|
for readLength < length {
|
||
|
keyLen, keyPos, ok := ReadLengthEncodedInt(buf.Bytes())
|
||
|
if !ok {
|
||
|
return nil, fmt.Errorf("read attrs key err")
|
||
|
}
|
||
|
|
||
|
readLength += int(keyLen) + keyPos
|
||
|
buf.Next(keyPos)
|
||
|
|
||
|
key := ReadString(buf.Next(int(keyLen)))
|
||
|
|
||
|
valLen, valPos, ok := ReadLengthEncodedInt(buf.Bytes())
|
||
|
if !ok {
|
||
|
return nil, fmt.Errorf("read attrs value err")
|
||
|
}
|
||
|
|
||
|
readLength += int(valLen) + keyPos
|
||
|
buf.Next(valPos)
|
||
|
|
||
|
value := ReadString(buf.Next(int(valLen)))
|
||
|
|
||
|
attrs[key] = value
|
||
|
|
||
|
//fmt.Printf("%s - %s [%d]\n", key, value, readLength)
|
||
|
//break
|
||
|
}
|
||
|
|
||
|
//fmt.Printf("parse attrs ok: %+v\n", attrs)
|
||
|
|
||
|
return attrs, nil
|
||
|
}
|
||
|
|
||
|
func ReadHandshakeV10(conn io.Reader) (*HandshakeV10, error) {
|
||
|
pk, err := ReadMysqlPacket(conn)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if pk.Payload[0] != protocolVersion {
|
||
|
return nil, fmt.Errorf("protocol version err: %d", pk.Payload[0])
|
||
|
}
|
||
|
|
||
|
buf := bytes.NewBuffer(pk.Payload)
|
||
|
|
||
|
res := &HandshakeV10{}
|
||
|
|
||
|
res.ProtocolVersion = ReadByte(buf.Next(1))
|
||
|
|
||
|
// version
|
||
|
versionByte, err := buf.ReadBytes(0x00)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
res.ServerVersion = ReadStringNull(versionByte)
|
||
|
|
||
|
//fmt.Println(string(versionByte), versionByte)
|
||
|
|
||
|
// connId
|
||
|
res.ConnectionId = ReadUint32(buf.Next(4))
|
||
|
//fmt.Println(binary.LittleEndian.Uint32(connId))
|
||
|
|
||
|
fullAuthPluginData := make([]byte, 21)
|
||
|
|
||
|
authPluginData1 := buf.Next(8)
|
||
|
|
||
|
copy(fullAuthPluginData, authPluginData1[:8])
|
||
|
//fmt.Println(string(salt1), salt1)
|
||
|
|
||
|
// filler 0x00
|
||
|
buf.Next(1)
|
||
|
|
||
|
cap1 := buf.Next(2)
|
||
|
//fmt.Println("cap1", cap1)
|
||
|
|
||
|
fullCap := make([]byte, 4)
|
||
|
|
||
|
copy(fullCap[:2], cap1)
|
||
|
|
||
|
charset := buf.Next(1)
|
||
|
res.CharsetCollation = ReadByte(charset)
|
||
|
|
||
|
serverStatus := buf.Next(2)
|
||
|
res.ServerStatus = ReadUint16(serverStatus)
|
||
|
|
||
|
cap2 := buf.Next(2)
|
||
|
copy(fullCap[2:], cap2)
|
||
|
|
||
|
//fmt.Printf("fullCap: %+v\n", fullCap)
|
||
|
res.CapabilityFlag = ReadUint32(fullCap)
|
||
|
|
||
|
//res.CapFlag2 = ReadUint16(cap2)
|
||
|
|
||
|
// auth_plugin_data_len, reserved. All 0s.
|
||
|
buf.Next(1 + 10)
|
||
|
|
||
|
authPluginData2, err := buf.ReadBytes(0x00)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
copy(fullAuthPluginData[8:], authPluginData2)
|
||
|
res.AuthPluginData = fullAuthPluginData
|
||
|
|
||
|
authPluginMethod, err := buf.ReadBytes(0x00)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
res.AuthPluginMethod = ReadStringNull(authPluginMethod)
|
||
|
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
func ReadHandshakeResponse(conn io.Reader) (*HandshakeResponse, error) {
|
||
|
pk, err := ReadMysqlPacket(conn)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
buf := bytes.NewBuffer(pk.Payload)
|
||
|
|
||
|
clientFlag := ReadUint32(buf.Next(4))
|
||
|
if clientFlag&CapabilityClientProtocol41 == 0 {
|
||
|
return nil, fmt.Errorf("only support CLIENT_PROTOCOL_41")
|
||
|
}
|
||
|
|
||
|
if clientFlag&CapabilityClientPluginAuth == 0 {
|
||
|
return nil, fmt.Errorf("unsupport without ClientPluginAuth")
|
||
|
}
|
||
|
|
||
|
res := &HandshakeResponse{}
|
||
|
res.MysqlPacketHeader = pk.MysqlPacketHeader
|
||
|
|
||
|
res.ClientFlag = clientFlag
|
||
|
res.MaxPacketSize = ReadUint32(buf.Next(4))
|
||
|
|
||
|
res.Charset = ReadByte(buf.Next(1))
|
||
|
|
||
|
buf.Next(23)
|
||
|
|
||
|
username, err := buf.ReadBytes(0x00)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
res.Username = ReadStringNull(username)
|
||
|
|
||
|
//fmt.Println(buf.Bytes())
|
||
|
|
||
|
if clientFlag&CapabilityClientPluginAuthLenencClientData > 0 {
|
||
|
length, pos, ok := ReadLengthEncodedInt(buf.Bytes())
|
||
|
if !ok {
|
||
|
return nil, fmt.Errorf("ReadLengthEncodedInt err")
|
||
|
}
|
||
|
|
||
|
buf.Next(pos)
|
||
|
|
||
|
password := buf.Next(int(length))
|
||
|
|
||
|
res.Password = password
|
||
|
|
||
|
} else {
|
||
|
length := ReadByte(buf.Next(1))
|
||
|
password := buf.Next(int(length))
|
||
|
|
||
|
res.Password = password
|
||
|
}
|
||
|
|
||
|
if clientFlag&CapabilityClientConnectWithDB > 0 {
|
||
|
db, err := buf.ReadBytes(0x00)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
res.Database = ReadStringNull(db)
|
||
|
}
|
||
|
|
||
|
if clientFlag&CapabilityClientPluginAuth > 0 {
|
||
|
authPluginMethod, err := buf.ReadBytes(0x00)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
res.AuthPluginMethod = ReadStringNull(authPluginMethod)
|
||
|
}
|
||
|
|
||
|
if clientFlag&CapabilityClientConnAttr > 0 {
|
||
|
length, pos, ok := ReadLengthEncodedInt(buf.Bytes())
|
||
|
if ok {
|
||
|
res.ClientAttrLen = length
|
||
|
buf.Next(pos)
|
||
|
|
||
|
attrsByte := make([]byte, length)
|
||
|
copy(attrsByte, buf.Next(int(res.ClientAttrLen)))
|
||
|
|
||
|
attrs, err := ParseClientAttrs(attrsByte)
|
||
|
if err != nil {
|
||
|
zlog.Errorf("parse client attrs err: %s", err)
|
||
|
} else {
|
||
|
res.ClientAttrs = attrs
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
//fmt.Printf("%+v\n", buf.Bytes())
|
||
|
}
|
||
|
|
||
|
//fmt.Printf("%+v\n", buf.Bytes())
|
||
|
|
||
|
return res, nil
|
||
|
}
|