proxymysql/app/mysqlserver/record_query.go

276 lines
5.2 KiB
Go

package mysqlserver
import (
"bytes"
"github.com/huandu/go-sqlbuilder"
jsoniter "github.com/json-iterator/go"
"io"
"proxymysql/app/conf"
"proxymysql/app/db"
"proxymysql/app/zlog"
"regexp"
"strings"
"time"
)
var _ io.Writer = (*RecordQuery)(nil)
type RecordQuery struct {
pipeReader *io.PipeReader
pipeWriter *io.PipeWriter
stmtId uint32
stmtMap map[uint32]string
}
func NewRecordQuery() *RecordQuery {
r := &RecordQuery{}
r.pipeReader, r.pipeWriter = io.Pipe()
r.stmtMap = make(map[uint32]string)
go r.readQuery()
return r
}
func (r *RecordQuery) Write(p []byte) (n int, err error) {
r.pipeWriter.Write(p)
return len(p), err
}
func (r *RecordQuery) Close() error {
r.pipeWriter.Close()
r.pipeReader.Close()
return nil
}
func (r *RecordQuery) readQuery() {
for {
packet, err := ReadMysqlPacket(r.pipeReader)
if err != nil {
// errors.Is(err, io.ErrClosedPipe)
zlog.Errorf("read pipe pack err: %s", err)
return
}
if len(packet.Payload) < 2 {
continue
}
switch packet.Payload[0] {
case ComQuery:
query := string(packet.Payload[1:])
zlog.Debugf("query: %s\n", query)
r.saveToDb(query)
case ComPrepare:
query := string(packet.Payload[1:])
zlog.Debugf("prepare: %s\n", query)
r.stmtId++
r.stmtMap[r.stmtId] = query
r.saveToDb(query)
case ComStmtExecute:
query := r.stmtMap[r.stmtId]
_, args := r.parseStmtArgs(strings.Count(query, "?"), packet.Payload)
//fmt.Printf("ComStmtExecute: %s %+v\n", query, args)
fullSqlQuery, err := sqlbuilder.MySQL.Interpolate(query, args)
if err != nil {
zlog.Errorf("ComStmtExecute builder sql err: %s", err)
} else {
zlog.Debugf("stmt: %s\n", fullSqlQuery)
}
r.saveToDb(fullSqlQuery)
case ComStmtClose:
delete(r.stmtMap, r.stmtId)
}
}
}
type BindArg struct {
ArgType uint8
Unsigned uint8
ArgValue interface{}
}
func (r *RecordQuery) parseStmtArgs(argNum int, data []byte) ([]*BindArg, []any) {
if argNum == 0 {
return nil, nil
}
if len(data) == 0 {
return nil, nil
}
skipPos := 1 + 4 + 1 + 4
buf := bytes.NewBuffer(data)
//fmt.Printf("%+v\n", buf.Bytes())
buf.Next(skipPos)
nullBitMapLen := (argNum + 7) / 8
//fmt.Println(nullBitMapLen)
nullBitMap := buf.Next(nullBitMapLen)
//fmt.Println("nullBitMap", nullBitMap)
newParamsBindFlag := ReadByte(buf.Next(1))
//fmt.Println("newParamsBindFlag", ReadByte(newParamsBindFlag))
if newParamsBindFlag != 0x01 {
return nil, nil
}
bindArgs := make([]*BindArg, argNum)
args := make([]interface{}, argNum)
for i := 0; i < argNum; i++ {
filedType := ReadByte(buf.Next(1))
//fmt.Printf("filedType: %+v\n", filedType)
unsigned := ReadByte(buf.Next(1))
//fmt.Printf("unsigned: %+v\n", unsigned)
bindArgs[i] = &BindArg{
ArgType: filedType,
Unsigned: unsigned,
ArgValue: nil,
}
}
//fmt.Printf("val: %+v\n", buf.Bytes())
//fmt.Printf("%+v\n", nullBitMap)
for i := 0; i < argNum; i++ {
nullBytePos := i / 8
nullBitPos := i % 8
//fmt.Printf("nullBytePos: %08b\n", nullBitMap[nullBytePos])
//fmt.Printf("nullBitPos: %08b\n", 1<<nullBitPos)
if (nullBitMap[nullBytePos] & (1 << nullBitPos)) > 0 {
//buf.Next(1)
bindArgs[i].ArgValue = nil
args[i] = nil
//fmt.Printf("%+v\n", bindArgs[i])
continue
}
switch bindArgs[i].ArgType {
case FieldTypeTiny, FieldTypeBit:
val := ReadByte(buf.Next(1))
bindArgs[i].ArgValue = val
args[i] = val
case FieldTypeInt24, FieldTypeLong:
val := ReadUint32(buf.Next(4))
bindArgs[i].ArgValue = val
args[i] = val
case FieldTypeLongLong:
val := ReadUint64(buf.Next(8))
bindArgs[i].ArgValue = val
args[i] = val
default:
length, pos, ok := ReadLengthEncodedInt(buf.Bytes())
if !ok {
zlog.Errorf("read args err %+v", buf.Bytes())
continue
}
buf.Next(pos)
val := string(buf.Next(int(length)))
bindArgs[i].ArgValue = val
args[i] = val
//fmt.Printf("str: %s\n", val)
}
}
return bindArgs, args
}
type SqlComment struct {
AdminId int64
AdminName string
AdminRealName string
QueryGameId int32
HeaderGameId int32
Ip string
RequestPath string
RequestInfo string
UnixMilli int64
Query string
CallInfo string
CreateTime string
}
var adminCommentReg = regexp.MustCompile(`/\*\s+TzAdmin-([\s\S]+)-TzAdmin\s+\*/`)
func (r *RecordQuery) parseSqlComment(query string) (sc *SqlComment) {
sc = &SqlComment{}
sc.Query = query
if sc.UnixMilli == 0 {
sc.UnixMilli = time.Now().UnixMilli()
}
sc.CreateTime = time.Now().Format("2006-01-02 15:04:05.000")
if !strings.Contains(query, " TzAdmin-") {
return
}
subMatch := adminCommentReg.FindStringSubmatch(query)
if len(subMatch) >= 2 {
err := jsoniter.Unmarshal([]byte(subMatch[1]), sc)
if err != nil {
zlog.Warnf("解析sql admin信息失败 %s [%s]", err, subMatch[1])
return
}
_sql := strings.TrimSpace(adminCommentReg.ReplaceAllString(sc.Query, ""))
sc.Query = _sql
//zlog.Infof("%+v\n", sc)
}
return
}
func (r *RecordQuery) saveToDb(query string) {
if conf.App.SaveLog == false {
return
}
sc := r.parseSqlComment(query)
//sc.Query = base64.StdEncoding.EncodeToString([]byte(sc.Query))
err := db.GetDB().Table("sql_query_log").Create(sc).Error
if err != nil {
zlog.Errorf("save to db err: %s", err)
return
}
}