diff --git a/mymysql/batchwriter.go b/mymysql/batchwriter.go index aaa774b..5b0c59b 100644 --- a/mymysql/batchwriter.go +++ b/mymysql/batchwriter.go @@ -36,13 +36,13 @@ type ( ) type BatchWriterConfig struct { - channelBuffer int - batchSize int - batchInterval time.Duration - jobNum int - asyncWorkerNum int - duplicateUpdate *clause.OnConflict - debug bool + channelBuffer int + batchSize int + batchInterval time.Duration + jobNum int + asyncWorkerNum int + clauseExpr []clause.Expression + debug bool } type BatchWriter[T any] struct { @@ -93,9 +93,9 @@ func WithAsyncWorkerNum(v int) BatchWriterOpt { } } -func WithDuplicateUpdate(v *clause.OnConflict) BatchWriterOpt { +func WithClause(v ...clause.Expression) BatchWriterOpt { return func(c *BatchWriterConfig) { - c.duplicateUpdate = v + c.clauseExpr = v } } @@ -269,8 +269,8 @@ func (bw *BatchWriter[T]) asyncWriteToDb(jobIndex int, copyDataList []T) { query := bw.db.Table(bw.tableName) - if bw.config.duplicateUpdate != nil { - query.Clauses(bw.config.duplicateUpdate) + if len(bw.config.clauseExpr) > 0 { + query.Clauses(bw.config.clauseExpr...) } err := query.Create(copyDataList).Error