package mymysql import ( "context" "git.makemake.in/kzkzzzz/mycommon" "git.makemake.in/kzkzzzz/mycommon/mylog" "github.com/goccy/go-json" "github.com/rs/xid" "gorm.io/gorm/clause" "log" "sync" "time" ) const ( defaultDataBuffer = 1e5 // channel缓冲区 defaultBatchSize = 200 // 多少条数据写一次 defaultIntervalTime = time.Second * 2 // 多久时间写一次 defaultJobNum = 2 // 写入db 任务数量 defaultAsyncWorkerNum = 20 // 异步执行写入事件的最大协程数量 ) type iWriterStop interface { StopWriter() } var ( writerJobMap = &sync.Map{} ) type ( batchData[T any] struct { jobIndex int dataList []T } ) type BatchWriterConfig struct { channelBuffer int batchSize int batchInterval time.Duration jobNum int asyncWorkerNum int duplicateUpdate *clause.OnConflict debug bool } type BatchWriter[T any] struct { db *MysqlDb tableName string jobName string uniqueId string config *BatchWriterConfig dataChan chan T ctx context.Context cancel context.CancelFunc stopChan chan struct{} asyncWorkerLimitChan chan struct{} asyncWorkerWg *sync.WaitGroup } type BatchWriterOpt func(c *BatchWriterConfig) func WithWriteJobNum(v int) BatchWriterOpt { return func(c *BatchWriterConfig) { c.jobNum = v } } func WithWriteChannelBuffer(v int) BatchWriterOpt { return func(c *BatchWriterConfig) { c.channelBuffer = v } } func WithWriteBatchSize(v int) BatchWriterOpt { return func(c *BatchWriterConfig) { c.batchSize = v } } func WithWriteIntervalTime(v time.Duration) BatchWriterOpt { return func(c *BatchWriterConfig) { c.batchInterval = v } } func WithAsyncWorkerNum(v int) BatchWriterOpt { return func(c *BatchWriterConfig) { c.asyncWorkerNum = v } } func WithDuplicateUpdate(v *clause.OnConflict) BatchWriterOpt { return func(c *BatchWriterConfig) { c.duplicateUpdate = v } } func WithDebug(v bool) BatchWriterOpt { return func(c *BatchWriterConfig) { c.debug = v } } func NewBatchWrite[T any](db *MysqlDb, tableName, jobName string, opts ...BatchWriterOpt) *BatchWriter[T] { config := &BatchWriterConfig{} for _, opt := range opts { opt(config) } if config.batchInterval <= 0 { config.batchInterval = defaultIntervalTime } if config.channelBuffer <= 0 { config.channelBuffer = defaultDataBuffer } if config.jobNum <= 0 { config.jobNum = defaultJobNum } if config.asyncWorkerNum <= 0 { config.asyncWorkerNum = defaultAsyncWorkerNum } if config.batchSize <= 0 { config.batchSize = defaultBatchSize } bw := &BatchWriter[T]{ db: db, tableName: tableName, jobName: jobName, uniqueId: xid.New().String(), config: config, dataChan: make(chan T), stopChan: make(chan struct{}, 1), asyncWorkerLimitChan: make(chan struct{}, config.asyncWorkerNum), asyncWorkerWg: &sync.WaitGroup{}, } bw.ctx, bw.cancel = context.WithCancel(context.Background()) // 记录实例, 便于退出程序的时候入库 writerJobMap.Store(bw.uniqueId, bw) go func() { bw.start() }() return bw } func (bw *BatchWriter[T]) Write(data ...T) { if len(data) == 0 { return } if bw.ctx.Err() != nil { b, _ := json.Marshal(data) mylog.Errorf("[%s] save to db err: job is close, data: (%s)", bw.tableName, b) return } for _, v := range data { bw.dataChan <- v } } func (bw *BatchWriter[T]) start() { wg := &sync.WaitGroup{} for i := 0; i < bw.config.jobNum; i++ { wg.Add(1) go func(i0 int) { defer wg.Done() bw.startJob(i) }(i) } wg.Wait() log.Printf("[table:%s - job:%s] batch write job stop", bw.tableName, bw.jobName) close(bw.stopChan) } func (bw *BatchWriter[T]) startJob(jobIndex int) { tkTime := bw.config.batchInterval // 定时器增加随机时间差 randN := float64(mycommon.RandRange(50, 350)) / float64(100) tkTime = tkTime + time.Duration(float64(time.Second)*randN) log.Printf("[table:%s - job:%s - %d] batch write job start, ticker time: %s", bw.tableName, bw.jobName, jobIndex, tkTime.String()) tk := time.NewTicker(tkTime) defer tk.Stop() bd := &batchData[T]{ jobIndex: jobIndex, dataList: make([]T, 0, bw.config.batchSize), } loop: for { select { case <-bw.ctx.Done(): break loop case <-tk.C: bw.writeToDb(bd) case data, ok := <-bw.dataChan: if !ok { break loop } bd.dataList = append(bd.dataList, data) if len(bd.dataList) >= bw.config.batchSize { bw.writeToDb(bd) } } } if len(bd.dataList) > 0 { bw.writeToDb(bd) } } func (bw *BatchWriter[T]) writeToDb(bd *batchData[T]) { if len(bd.dataList) == 0 { return } defer func() { // 清空切片 bd.dataList = bd.dataList[:0] }() bw.asyncWorkerLimitChan <- struct{}{} // 复制一份数据, 异步写入 copyDataList := make([]T, len(bd.dataList)) copy(copyDataList, bd.dataList) bw.asyncWorkerWg.Add(1) go func() { defer func() { <-bw.asyncWorkerLimitChan bw.asyncWorkerWg.Done() }() bw.asyncWriteToDb(bd.jobIndex, copyDataList) }() } func (bw *BatchWriter[T]) asyncWriteToDb(jobIndex int, copyDataList []T) { if len(copyDataList) == 0 { return } query := bw.db.Table(bw.tableName) if bw.config.duplicateUpdate != nil { query.Clauses(bw.config.duplicateUpdate) } err := query.Create(copyDataList).Error if err == nil { return } // 批量写入失败, 后续优化重试流程 b, _ := json.Marshal(copyDataList) mylog.Errorf("[%s - %s] save to db err: %s data: (%s)", bw.tableName, bw.jobName, err, b) } func (bw *BatchWriter[T]) StopWriter() { if bw.ctx.Err() != nil { return } bw.cancel() close(bw.dataChan) <-bw.stopChan bw.asyncWorkerWg.Wait() } func StopAllBatchWriter() { writerJobMap.Range(func(k, v interface{}) bool { q := v.(iWriterStop) q.StopWriter() return true }) } // Deprecated: 改成用 StopAllBatchWriter func StopWriter() { StopAllBatchWriter() }