github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/cdc/sink/dmlsink/txn/mysql/mysql.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package mysql
    15  
    16  import (
    17  	"context"
    18  	"database/sql"
    19  	"database/sql/driver"
    20  	"fmt"
    21  	"math"
    22  	"net/url"
    23  	"strings"
    24  	"time"
    25  
    26  	dmysql "github.com/go-sql-driver/mysql"
    27  	lru "github.com/hashicorp/golang-lru"
    28  	"github.com/pingcap/errors"
    29  	"github.com/pingcap/failpoint"
    30  	"github.com/pingcap/log"
    31  	"github.com/pingcap/tidb/pkg/parser/charset"
    32  	"github.com/pingcap/tidb/pkg/parser/mysql"
    33  	"github.com/pingcap/tidb/pkg/sessionctx/variable"
    34  	"github.com/pingcap/tiflow/cdc/model"
    35  	"github.com/pingcap/tiflow/cdc/sink/dmlsink"
    36  	"github.com/pingcap/tiflow/cdc/sink/metrics"
    37  	"github.com/pingcap/tiflow/cdc/sink/metrics/txn"
    38  	"github.com/pingcap/tiflow/pkg/config"
    39  	cerror "github.com/pingcap/tiflow/pkg/errors"
    40  	"github.com/pingcap/tiflow/pkg/retry"
    41  	pmysql "github.com/pingcap/tiflow/pkg/sink/mysql"
    42  	"github.com/pingcap/tiflow/pkg/sqlmodel"
    43  	"github.com/pingcap/tiflow/pkg/util"
    44  	"github.com/prometheus/client_golang/prometheus"
    45  	"go.uber.org/zap"
    46  )
    47  
    48  const (
    49  	// Max interval for flushing transactions to the downstream.
    50  	maxFlushInterval = 10 * time.Millisecond
    51  
    52  	// networkDriftDuration is used to construct a context timeout for database operations.
    53  	networkDriftDuration = 5 * time.Second
    54  
    55  	defaultDMLMaxRetry uint64 = 8
    56  
    57  	// To limit memory usage for prepared statements.
    58  	prepStmtCacheSize int = 16 * 1024
    59  )
    60  
    61  type mysqlBackend struct {
    62  	workerID    int
    63  	changefeed  string
    64  	db          *sql.DB
    65  	cfg         *pmysql.Config
    66  	dmlMaxRetry uint64
    67  
    68  	events []*dmlsink.TxnCallbackableEvent
    69  	rows   int
    70  
    71  	statistics                      *metrics.Statistics
    72  	metricTxnSinkDMLBatchCommit     prometheus.Observer
    73  	metricTxnSinkDMLBatchCallback   prometheus.Observer
    74  	metricTxnPrepareStatementErrors prometheus.Counter
    75  
    76  	// implement stmtCache to improve performance, especially when the downstream is TiDB
    77  	stmtCache *lru.Cache
    78  	// Indicate if the CachePrepStmts should be enabled or not
    79  	cachePrepStmts   bool
    80  	maxAllowedPacket int64
    81  }
    82  
    83  // NewMySQLBackends creates a new MySQL sink using schema storage
    84  func NewMySQLBackends(
    85  	ctx context.Context,
    86  	changefeedID model.ChangeFeedID,
    87  	sinkURI *url.URL,
    88  	replicaConfig *config.ReplicaConfig,
    89  	dbConnFactory pmysql.Factory,
    90  	statistics *metrics.Statistics,
    91  ) ([]*mysqlBackend, error) {
    92  	changefeed := fmt.Sprintf("%s.%s", changefeedID.Namespace, changefeedID.ID)
    93  
    94  	cfg := pmysql.NewConfig()
    95  	err := cfg.Apply(config.GetGlobalServerConfig().TZ, changefeedID, sinkURI, replicaConfig)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, dbConnFactory)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	db, err := dbConnFactory(ctx, dsnStr)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	cfg.IsTiDB, err = pmysql.CheckIsTiDB(ctx, db)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	cfg.IsWriteSourceExisted, err = pmysql.CheckIfBDRModeIsSupported(ctx, db)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	// By default, cache-prep-stmts=true, an LRU cache is used for prepared statements,
   121  	// two connections are required to process a transaction.
   122  	// The first connection is held in the tx variable, which is used to manage the transaction.
   123  	// The second connection is requested through a call to s.db.Prepare
   124  	// in case of a cache miss for the statement query.
   125  	// The connection pool for CDC is configured with a static size, equal to the number of workers.
   126  	// CDC may hang at the "Get Connection" call is due to the limited size of the connection pool.
   127  	// When the connection pool is small,
   128  	// the chance of all connections being active at the same time increases,
   129  	// leading to exhaustion of available connections and a hang at the "Get Connection" call.
   130  	// This issue is less likely to occur when the connection pool is larger,
   131  	// as there are more connections available for use.
   132  	// Adding an extra connection to the connection pool solves the connection exhaustion issue.
   133  	db.SetMaxIdleConns(cfg.WorkerCount + 1)
   134  	db.SetMaxOpenConns(cfg.WorkerCount + 1)
   135  
   136  	// Inherit the default value of the prepared statement cache from the SinkURI Options
   137  	cachePrepStmts := cfg.CachePrepStmts
   138  	if cachePrepStmts {
   139  		// query the size of the prepared statement cache on serverside
   140  		maxPreparedStmtCount, err := pmysql.QueryMaxPreparedStmtCount(ctx, db)
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  		if maxPreparedStmtCount == -1 {
   145  			// NOTE: seems TiDB doesn't follow MySQL's specification.
   146  			maxPreparedStmtCount = math.MaxInt
   147  		}
   148  		// if maxPreparedStmtCount == 0,
   149  		// it means that the prepared statement cache is disabled on serverside.
   150  		// if maxPreparedStmtCount/(cfg.WorkerCount+1) == 0, for each single connection,
   151  		// it means that the prepared statement cache is disabled on clientsize.
   152  		// Because each connection can not hold at lease one prepared statement.
   153  		if maxPreparedStmtCount == 0 || maxPreparedStmtCount/(cfg.WorkerCount+1) == 0 {
   154  			cachePrepStmts = false
   155  		}
   156  	}
   157  
   158  	var stmtCache *lru.Cache
   159  	if cachePrepStmts {
   160  		stmtCache, err = lru.NewWithEvict(prepStmtCacheSize, func(key, value interface{}) {
   161  			stmt := value.(*sql.Stmt)
   162  			stmt.Close()
   163  		})
   164  		if err != nil {
   165  			return nil, err
   166  		}
   167  	}
   168  
   169  	var maxAllowedPacket int64
   170  	maxAllowedPacket, err = pmysql.QueryMaxAllowedPacket(ctx, db)
   171  	if err != nil {
   172  		log.Warn("failed to query max_allowed_packet, use default value",
   173  			zap.String("changefeed", changefeed),
   174  			zap.Error(err))
   175  		maxAllowedPacket = int64(variable.DefMaxAllowedPacket)
   176  	}
   177  
   178  	backends := make([]*mysqlBackend, 0, cfg.WorkerCount)
   179  	for i := 0; i < cfg.WorkerCount; i++ {
   180  		backends = append(backends, &mysqlBackend{
   181  			workerID:    i,
   182  			changefeed:  changefeed,
   183  			db:          db,
   184  			cfg:         cfg,
   185  			dmlMaxRetry: defaultDMLMaxRetry,
   186  			statistics:  statistics,
   187  
   188  			metricTxnSinkDMLBatchCommit:     txn.SinkDMLBatchCommit.WithLabelValues(changefeedID.Namespace, changefeedID.ID),
   189  			metricTxnSinkDMLBatchCallback:   txn.SinkDMLBatchCallback.WithLabelValues(changefeedID.Namespace, changefeedID.ID),
   190  			metricTxnPrepareStatementErrors: txn.PrepareStatementErrors.WithLabelValues(changefeedID.Namespace, changefeedID.ID),
   191  			stmtCache:                       stmtCache,
   192  			cachePrepStmts:                  cachePrepStmts,
   193  			maxAllowedPacket:                maxAllowedPacket,
   194  		})
   195  	}
   196  
   197  	log.Info("MySQL backends is created",
   198  		zap.String("changefeed", changefeed),
   199  		zap.Int("workerCount", cfg.WorkerCount),
   200  		zap.Bool("forceReplicate", cfg.ForceReplicate))
   201  	return backends, nil
   202  }
   203  
   204  // OnTxnEvent implements interface backend.
   205  // It adds the event to the buffer, and return true if it needs flush immediately.
   206  func (s *mysqlBackend) OnTxnEvent(event *dmlsink.TxnCallbackableEvent) (needFlush bool) {
   207  	s.events = append(s.events, event)
   208  	s.rows += len(event.Event.Rows)
   209  	return s.rows >= s.cfg.MaxTxnRow
   210  }
   211  
   212  // Flush implements interface backend.
   213  func (s *mysqlBackend) Flush(ctx context.Context) (err error) {
   214  	if s.rows == 0 {
   215  		return
   216  	}
   217  
   218  	failpoint.Inject("MySQLSinkExecDMLError", func() {
   219  		// Add a delay to ensure the sink worker with `MySQLSinkHangLongTime`
   220  		// failpoint injected is executed first.
   221  		time.Sleep(time.Second * 2)
   222  		failpoint.Return(errors.Trace(dmysql.ErrInvalidConn))
   223  	})
   224  
   225  	for _, event := range s.events {
   226  		s.statistics.ObserveRows(event.Event.Rows...)
   227  	}
   228  
   229  	dmls := s.prepareDMLs()
   230  	log.Debug("prepare DMLs", zap.String("changefeed", s.changefeed), zap.Any("rows", s.rows),
   231  		zap.Strings("sqls", dmls.sqls), zap.Any("values", dmls.values))
   232  
   233  	start := time.Now()
   234  	if err := s.execDMLWithMaxRetries(ctx, dmls); err != nil {
   235  		if errors.Cause(err) != context.Canceled {
   236  			log.Error("execute DMLs failed", zap.String("changefeed", s.changefeed), zap.Error(err))
   237  		}
   238  		return errors.Trace(err)
   239  	}
   240  	startCallback := time.Now()
   241  	for _, callback := range dmls.callbacks {
   242  		callback()
   243  	}
   244  	s.metricTxnSinkDMLBatchCommit.Observe(startCallback.Sub(start).Seconds())
   245  	s.metricTxnSinkDMLBatchCallback.Observe(time.Since(startCallback).Seconds())
   246  
   247  	// Be friently to GC.
   248  	for i := 0; i < len(s.events); i++ {
   249  		s.events[i] = nil
   250  	}
   251  	if cap(s.events) > 1024 {
   252  		s.events = make([]*dmlsink.TxnCallbackableEvent, 0)
   253  	}
   254  	s.events = s.events[:0]
   255  	s.rows = 0
   256  	return
   257  }
   258  
   259  // Close implements interface backend.
   260  func (s *mysqlBackend) Close() (err error) {
   261  	if s.stmtCache != nil {
   262  		s.stmtCache.Purge()
   263  	}
   264  	if s.db != nil {
   265  		err = s.db.Close()
   266  		s.db = nil
   267  	}
   268  	return
   269  }
   270  
   271  // MaxFlushInterval implements interface backend.
   272  func (s *mysqlBackend) MaxFlushInterval() time.Duration {
   273  	return maxFlushInterval
   274  }
   275  
   276  type preparedDMLs struct {
   277  	startTs         []model.Ts
   278  	sqls            []string
   279  	values          [][]interface{}
   280  	callbacks       []dmlsink.CallbackFunc
   281  	rowCount        int
   282  	approximateSize int64
   283  }
   284  
   285  // convert2RowChanges is a helper function that convert the row change representation
   286  // of CDC into a general one.
   287  func convert2RowChanges(
   288  	row *model.RowChangedEvent,
   289  	tableInfo *model.TableInfo,
   290  	changeType sqlmodel.RowChangeType,
   291  ) *sqlmodel.RowChange {
   292  	tidbTableInfo := tableInfo.TableInfo
   293  	// RowChangedEvent doesn't contain data for virtual columns,
   294  	// so we need to create a new table info without virtual columns before pass it to NewRowChange.
   295  	if tableInfo.HasVirtualColumns() {
   296  		tidbTableInfo = model.BuildTiDBTableInfoWithoutVirtualColumns(tidbTableInfo)
   297  	}
   298  
   299  	preValues := make([]interface{}, 0, len(row.PreColumns))
   300  	for _, col := range row.PreColumns {
   301  		preValues = append(preValues, col.Value)
   302  	}
   303  
   304  	postValues := make([]interface{}, 0, len(row.Columns))
   305  	for _, col := range row.Columns {
   306  		postValues = append(postValues, col.Value)
   307  	}
   308  
   309  	var res *sqlmodel.RowChange
   310  
   311  	switch changeType {
   312  	case sqlmodel.RowChangeInsert:
   313  		res = sqlmodel.NewRowChange(
   314  			&row.TableInfo.TableName,
   315  			nil,
   316  			nil,
   317  			postValues,
   318  			tidbTableInfo,
   319  			nil, nil)
   320  	case sqlmodel.RowChangeUpdate:
   321  		res = sqlmodel.NewRowChange(
   322  			&row.TableInfo.TableName,
   323  			nil,
   324  			preValues,
   325  			postValues,
   326  			tidbTableInfo,
   327  			nil, nil)
   328  	case sqlmodel.RowChangeDelete:
   329  		res = sqlmodel.NewRowChange(
   330  			&row.TableInfo.TableName,
   331  			nil,
   332  			preValues,
   333  			nil,
   334  			tidbTableInfo,
   335  			nil, nil)
   336  	}
   337  	res.SetApproximateDataSize(row.ApproximateDataSize)
   338  	return res
   339  }
   340  
   341  func convertBinaryToString(cols []*model.ColumnData, tableInfo *model.TableInfo) {
   342  	for i, col := range cols {
   343  		if col == nil {
   344  			continue
   345  		}
   346  		colInfo := tableInfo.ForceGetColumnInfo(col.ColumnID)
   347  		if colInfo.GetCharset() != "" && colInfo.GetCharset() != charset.CharsetBin {
   348  			colValBytes, ok := col.Value.([]byte)
   349  			if ok {
   350  				cols[i].Value = string(colValBytes)
   351  			}
   352  		}
   353  	}
   354  }
   355  
   356  func (s *mysqlBackend) groupRowsByType(
   357  	event *dmlsink.TxnCallbackableEvent,
   358  	tableInfo *model.TableInfo,
   359  ) (insertRows, updateRows, deleteRows [][]*sqlmodel.RowChange) {
   360  	preAllocateSize := len(event.Event.Rows)
   361  	if preAllocateSize > s.cfg.MaxTxnRow {
   362  		preAllocateSize = s.cfg.MaxTxnRow
   363  	}
   364  
   365  	insertRow := make([]*sqlmodel.RowChange, 0, preAllocateSize)
   366  	updateRow := make([]*sqlmodel.RowChange, 0, preAllocateSize)
   367  	deleteRow := make([]*sqlmodel.RowChange, 0, preAllocateSize)
   368  
   369  	for _, row := range event.Event.Rows {
   370  		convertBinaryToString(row.Columns, tableInfo)
   371  		convertBinaryToString(row.PreColumns, tableInfo)
   372  
   373  		if row.IsInsert() {
   374  			insertRow = append(
   375  				insertRow,
   376  				convert2RowChanges(row, tableInfo, sqlmodel.RowChangeInsert))
   377  			if len(insertRow) >= s.cfg.MaxTxnRow {
   378  				insertRows = append(insertRows, insertRow)
   379  				insertRow = make([]*sqlmodel.RowChange, 0, preAllocateSize)
   380  			}
   381  		}
   382  
   383  		if row.IsDelete() {
   384  			deleteRow = append(
   385  				deleteRow,
   386  				convert2RowChanges(row, tableInfo, sqlmodel.RowChangeDelete))
   387  			if len(deleteRow) >= s.cfg.MaxTxnRow {
   388  				deleteRows = append(deleteRows, deleteRow)
   389  				deleteRow = make([]*sqlmodel.RowChange, 0, preAllocateSize)
   390  			}
   391  		}
   392  
   393  		if row.IsUpdate() {
   394  			updateRow = append(
   395  				updateRow,
   396  				convert2RowChanges(row, tableInfo, sqlmodel.RowChangeUpdate))
   397  			if len(updateRow) >= s.cfg.MaxMultiUpdateRowCount {
   398  				updateRows = append(updateRows, updateRow)
   399  				updateRow = make([]*sqlmodel.RowChange, 0, preAllocateSize)
   400  			}
   401  		}
   402  	}
   403  
   404  	if len(insertRow) > 0 {
   405  		insertRows = append(insertRows, insertRow)
   406  	}
   407  	if len(updateRow) > 0 {
   408  		updateRows = append(updateRows, updateRow)
   409  	}
   410  	if len(deleteRow) > 0 {
   411  		deleteRows = append(deleteRows, deleteRow)
   412  	}
   413  
   414  	return
   415  }
   416  
   417  func (s *mysqlBackend) batchSingleTxnDmls(
   418  	event *dmlsink.TxnCallbackableEvent,
   419  	tableInfo *model.TableInfo,
   420  	translateToInsert bool,
   421  ) (sqls []string, values [][]interface{}) {
   422  	insertRows, updateRows, deleteRows := s.groupRowsByType(event, tableInfo)
   423  
   424  	// handle delete
   425  	if len(deleteRows) > 0 {
   426  		for _, rows := range deleteRows {
   427  			sql, value := sqlmodel.GenDeleteSQL(rows...)
   428  			sqls = append(sqls, sql)
   429  			values = append(values, value)
   430  		}
   431  	}
   432  
   433  	// handle update
   434  	if len(updateRows) > 0 {
   435  		if s.cfg.IsTiDB {
   436  			for _, rows := range updateRows {
   437  				s, v := s.genUpdateSQL(rows...)
   438  				sqls = append(sqls, s...)
   439  				values = append(values, v...)
   440  			}
   441  			// The behavior of update statement differs between TiDB and MySQL.
   442  			// So we don't use batch update statement when downstream is MySQL.
   443  			// Ref:https://docs.pingcap.com/tidb/stable/sql-statement-update#mysql-compatibility
   444  		} else {
   445  			for _, rows := range updateRows {
   446  				for _, row := range rows {
   447  					sql, value := row.GenSQL(sqlmodel.DMLUpdate)
   448  					sqls = append(sqls, sql)
   449  					values = append(values, value)
   450  				}
   451  			}
   452  		}
   453  	}
   454  
   455  	// handle insert
   456  	if len(insertRows) > 0 {
   457  		for _, rows := range insertRows {
   458  			if translateToInsert {
   459  				sql, value := sqlmodel.GenInsertSQL(sqlmodel.DMLInsert, rows...)
   460  				sqls = append(sqls, sql)
   461  				values = append(values, value)
   462  			} else {
   463  				sql, value := sqlmodel.GenInsertSQL(sqlmodel.DMLReplace, rows...)
   464  				sqls = append(sqls, sql)
   465  				values = append(values, value)
   466  			}
   467  		}
   468  	}
   469  
   470  	return
   471  }
   472  
   473  func (s *mysqlBackend) genUpdateSQL(rows ...*sqlmodel.RowChange) ([]string, [][]interface{}) {
   474  	size := 0
   475  	for _, r := range rows {
   476  		size += int(r.GetApproximateDataSize())
   477  	}
   478  	if size < s.cfg.MaxMultiUpdateRowSize*len(rows) {
   479  		// use multi update in one SQL
   480  		sql, value := sqlmodel.GenUpdateSQL(rows...)
   481  		return []string{sql}, [][]interface{}{value}
   482  	}
   483  	// each row has one independent update SQL.
   484  	sqls := make([]string, 0, len(rows))
   485  	values := make([][]interface{}, 0, len(rows))
   486  	for _, row := range rows {
   487  		sql, value := row.GenSQL(sqlmodel.DMLUpdate)
   488  		sqls = append(sqls, sql)
   489  		values = append(values, value)
   490  	}
   491  	return sqls, values
   492  }
   493  
   494  func hasHandleKey(cols []*model.ColumnData, tableInfo *model.TableInfo) bool {
   495  	for _, col := range cols {
   496  		if col == nil {
   497  			continue
   498  		}
   499  		if tableInfo.ForceGetColumnFlagType(col.ColumnID).IsHandleKey() {
   500  			return true
   501  		}
   502  	}
   503  	return false
   504  }
   505  
   506  // prepareDMLs converts model.RowChangedEvent list to query string list and args list
   507  func (s *mysqlBackend) prepareDMLs() *preparedDMLs {
   508  	// TODO: use a sync.Pool to reduce allocations.
   509  	startTs := make([]uint64, 0, s.rows)
   510  	sqls := make([]string, 0, s.rows)
   511  	values := make([][]interface{}, 0, s.rows)
   512  	callbacks := make([]dmlsink.CallbackFunc, 0, len(s.events))
   513  
   514  	// translateToInsert control the update and insert behavior.
   515  	translateToInsert := !s.cfg.SafeMode
   516  
   517  	rowCount := 0
   518  	approximateSize := int64(0)
   519  	for _, event := range s.events {
   520  		if len(event.Event.Rows) == 0 {
   521  			continue
   522  		}
   523  		rowCount += len(event.Event.Rows)
   524  
   525  		firstRow := event.Event.Rows[0]
   526  		if len(startTs) == 0 || startTs[len(startTs)-1] != firstRow.StartTs {
   527  			startTs = append(startTs, firstRow.StartTs)
   528  		}
   529  
   530  		// A row can be translated in to INSERT, when it was committed after
   531  		// the table it belongs to been replicating by TiCDC, which means it must not be
   532  		// replicated before, and there is no such row in downstream MySQL.
   533  		translateToInsert = translateToInsert && firstRow.CommitTs > firstRow.ReplicatingTs
   534  		log.Debug("translate to insert",
   535  			zap.String("changefeed", s.changefeed),
   536  			zap.Bool("translateToInsert", translateToInsert),
   537  			zap.Uint64("firstRowCommitTs", firstRow.CommitTs),
   538  			zap.Uint64("firstRowReplicatingTs", firstRow.ReplicatingTs),
   539  			zap.Bool("safeMode", s.cfg.SafeMode))
   540  
   541  		if event.Callback != nil {
   542  			callbacks = append(callbacks, event.Callback)
   543  		}
   544  
   545  		// TODO: find a better threshold
   546  		enableBatchModeThreshold := 1
   547  		// Determine whether to use batch dml feature here.
   548  		if s.cfg.BatchDMLEnable && len(event.Event.Rows) > enableBatchModeThreshold {
   549  			tableColumns := firstRow.Columns
   550  			if firstRow.IsDelete() {
   551  				tableColumns = firstRow.PreColumns
   552  			}
   553  			// only use batch dml when the table has a handle key
   554  			if hasHandleKey(tableColumns, firstRow.TableInfo) {
   555  				sql, value := s.batchSingleTxnDmls(event, firstRow.TableInfo, translateToInsert)
   556  				sqls = append(sqls, sql...)
   557  				values = append(values, value...)
   558  
   559  				for _, stmt := range sql {
   560  					approximateSize += int64(len(stmt))
   561  				}
   562  				for _, row := range event.Event.Rows {
   563  					approximateSize += row.ApproximateDataSize
   564  				}
   565  				continue
   566  			}
   567  		}
   568  
   569  		quoteTable := firstRow.TableInfo.TableName.QuoteString()
   570  		for _, row := range event.Event.Rows {
   571  			var query string
   572  			var args []interface{}
   573  			// Update Event
   574  			if len(row.PreColumns) != 0 && len(row.Columns) != 0 {
   575  				query, args = prepareUpdate(
   576  					quoteTable,
   577  					row.GetPreColumns(),
   578  					row.GetColumns(),
   579  					s.cfg.ForceReplicate)
   580  				if query != "" {
   581  					sqls = append(sqls, query)
   582  					values = append(values, args)
   583  				}
   584  				approximateSize += int64(len(query)) + row.ApproximateDataSize
   585  				continue
   586  			}
   587  
   588  			// Delete Event
   589  			if len(row.PreColumns) != 0 {
   590  				query, args = prepareDelete(quoteTable, row.GetPreColumns(), s.cfg.ForceReplicate)
   591  				if query != "" {
   592  					sqls = append(sqls, query)
   593  					values = append(values, args)
   594  				}
   595  			}
   596  
   597  			// Insert Event
   598  			// It will be translated directly into a
   599  			// INSERT(not in safe mode)
   600  			// or REPLACE(in safe mode) SQL.
   601  			if len(row.Columns) != 0 {
   602  				query, args = prepareReplace(
   603  					quoteTable,
   604  					row.GetColumns(),
   605  					true, /* appendPlaceHolder */
   606  					translateToInsert)
   607  				if query != "" {
   608  					sqls = append(sqls, query)
   609  					values = append(values, args)
   610  				}
   611  			}
   612  
   613  			approximateSize += int64(len(query)) + row.ApproximateDataSize
   614  		}
   615  	}
   616  
   617  	if len(callbacks) == 0 {
   618  		callbacks = nil
   619  	}
   620  
   621  	return &preparedDMLs{
   622  		startTs:         startTs,
   623  		sqls:            sqls,
   624  		values:          values,
   625  		callbacks:       callbacks,
   626  		rowCount:        rowCount,
   627  		approximateSize: approximateSize,
   628  	}
   629  }
   630  
   631  // execute SQLs in the multi statements way.
   632  func (s *mysqlBackend) multiStmtExecute(
   633  	ctx context.Context, dmls *preparedDMLs, tx *sql.Tx, writeTimeout time.Duration,
   634  ) error {
   635  	var multiStmtArgs []any
   636  	for _, value := range dmls.values {
   637  		multiStmtArgs = append(multiStmtArgs, value...)
   638  	}
   639  	multiStmtSQL := strings.Join(dmls.sqls, ";")
   640  
   641  	log.Debug("exec row", zap.String("changefeed", s.changefeed), zap.Int("workerID", s.workerID),
   642  		zap.String("sql", multiStmtSQL), zap.Any("args", multiStmtArgs))
   643  	ctx, cancel := context.WithTimeout(ctx, writeTimeout)
   644  	defer cancel()
   645  	start := time.Now()
   646  	_, execError := tx.ExecContext(ctx, multiStmtSQL, multiStmtArgs...)
   647  	if execError != nil {
   648  		err := logDMLTxnErr(
   649  			wrapMysqlTxnError(execError),
   650  			start, s.changefeed, multiStmtSQL, dmls.rowCount, dmls.startTs)
   651  		if rbErr := tx.Rollback(); rbErr != nil {
   652  			if errors.Cause(rbErr) != context.Canceled {
   653  				log.Warn("failed to rollback txn", zap.String("changefeed", s.changefeed), zap.Error(rbErr))
   654  			}
   655  		}
   656  		return err
   657  	}
   658  	return nil
   659  }
   660  
   661  // execute SQLs in each preparedDMLs one by one in the same transaction.
   662  func (s *mysqlBackend) sequenceExecute(
   663  	ctx context.Context, dmls *preparedDMLs, tx *sql.Tx, writeTimeout time.Duration,
   664  ) error {
   665  	start := time.Now()
   666  	for i, query := range dmls.sqls {
   667  		args := dmls.values[i]
   668  		log.Debug("exec row", zap.String("changefeed", s.changefeed), zap.Int("workerID", s.workerID),
   669  			zap.String("sql", query), zap.Any("args", args))
   670  		ctx, cancelFunc := context.WithTimeout(ctx, writeTimeout)
   671  
   672  		var prepStmt *sql.Stmt
   673  		if s.cachePrepStmts {
   674  			if stmt, ok := s.stmtCache.Get(query); ok {
   675  				prepStmt = stmt.(*sql.Stmt)
   676  			} else if stmt, err := s.db.Prepare(query); err == nil {
   677  				prepStmt = stmt
   678  				s.stmtCache.Add(query, stmt)
   679  			} else {
   680  				// Generally it means the downstream database doesn't allow
   681  				// too many preapred statements. So clean some of them.
   682  				s.stmtCache.RemoveOldest()
   683  				s.metricTxnPrepareStatementErrors.Inc()
   684  			}
   685  		}
   686  
   687  		var execError error
   688  		if prepStmt == nil {
   689  			_, execError = tx.ExecContext(ctx, query, args...)
   690  		} else {
   691  			//nolint:sqlclosecheck
   692  			_, execError = tx.Stmt(prepStmt).ExecContext(ctx, args...)
   693  		}
   694  		if execError != nil {
   695  			err := logDMLTxnErr(
   696  				wrapMysqlTxnError(execError),
   697  				start, s.changefeed, query, dmls.rowCount, dmls.startTs)
   698  			if rbErr := tx.Rollback(); rbErr != nil {
   699  				if errors.Cause(rbErr) != context.Canceled {
   700  					log.Warn("failed to rollback txn", zap.String("changefeed", s.changefeed), zap.Error(rbErr))
   701  				}
   702  			}
   703  			cancelFunc()
   704  			return err
   705  		}
   706  		cancelFunc()
   707  	}
   708  	return nil
   709  }
   710  
   711  func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *preparedDMLs) error {
   712  	if len(dmls.sqls) != len(dmls.values) {
   713  		log.Error("unexpected number of sqls and values",
   714  			zap.String("changefeed", s.changefeed),
   715  			zap.Strings("sqls", dmls.sqls),
   716  			zap.Any("values", dmls.values))
   717  		return cerror.ErrUnexpected.FastGenByArgs("unexpected number of sqls and values")
   718  	}
   719  
   720  	start := time.Now()
   721  	// approximateSize is multiplied by 2 because in extreme circustumas, every
   722  	// byte in dmls can be escaped and adds one byte.
   723  	fallbackToSeqWay := dmls.approximateSize*2 > s.maxAllowedPacket
   724  	return retry.Do(pctx, func() error {
   725  		writeTimeout, _ := time.ParseDuration(s.cfg.WriteTimeout)
   726  		writeTimeout += networkDriftDuration
   727  
   728  		failpoint.Inject("MySQLSinkTxnRandomError", func() {
   729  			log.Warn("inject MySQLSinkTxnRandomError")
   730  			err := logDMLTxnErr(errors.Trace(driver.ErrBadConn), start, s.changefeed, "failpoint", 0, nil)
   731  			failpoint.Return(err)
   732  		})
   733  		failpoint.Inject("MySQLSinkHangLongTime", func() { _ = util.Hang(pctx, time.Hour) })
   734  		failpoint.Inject("MySQLDuplicateEntryError", func() {
   735  			log.Warn("inject MySQLDuplicateEntryError")
   736  			err := logDMLTxnErr(cerror.WrapError(cerror.ErrMySQLDuplicateEntry, &dmysql.MySQLError{
   737  				Number:  uint16(mysql.ErrDupEntry),
   738  				Message: "Duplicate entry",
   739  			}), start, s.changefeed, "failpoint", 0, nil)
   740  			failpoint.Return(err)
   741  		})
   742  
   743  		err := s.statistics.RecordBatchExecution(func() (int, int64, error) {
   744  			tx, err := s.db.BeginTx(pctx, nil)
   745  			if err != nil {
   746  				return 0, 0, logDMLTxnErr(
   747  					wrapMysqlTxnError(err),
   748  					start, s.changefeed, "BEGIN", dmls.rowCount, dmls.startTs)
   749  			}
   750  
   751  			// Set session variables first and then execute the transaction.
   752  			// we try to set write source for each txn,
   753  			// so we can use it to trace the data source
   754  			if err = pmysql.SetWriteSource(pctx, s.cfg, tx); err != nil {
   755  				err := logDMLTxnErr(
   756  					wrapMysqlTxnError(err),
   757  					start, s.changefeed,
   758  					fmt.Sprintf("SET SESSION %s = %d", "tidb_cdc_write_source",
   759  						s.cfg.SourceID),
   760  					dmls.rowCount, dmls.startTs)
   761  				if rbErr := tx.Rollback(); rbErr != nil {
   762  					if errors.Cause(rbErr) != context.Canceled {
   763  						log.Warn("failed to rollback txn", zap.String("changefeed", s.changefeed), zap.Error(rbErr))
   764  					}
   765  				}
   766  				return 0, 0, err
   767  			}
   768  
   769  			// If interplated SQL size exceeds maxAllowedPacket, mysql driver will
   770  			// fall back to the sequantial way.
   771  			// error can be ErrPrepareMulti, ErrBadConn etc.
   772  			// TODO: add a quick path to check whether we should fallback to
   773  			// the sequence way.
   774  			if s.cfg.MultiStmtEnable && !fallbackToSeqWay {
   775  				err = s.multiStmtExecute(pctx, dmls, tx, writeTimeout)
   776  				if err != nil {
   777  					fallbackToSeqWay = true
   778  					return 0, 0, err
   779  				}
   780  			} else {
   781  				err = s.sequenceExecute(pctx, dmls, tx, writeTimeout)
   782  				if err != nil {
   783  					return 0, 0, err
   784  				}
   785  			}
   786  
   787  			if err = tx.Commit(); err != nil {
   788  				return 0, 0, logDMLTxnErr(
   789  					wrapMysqlTxnError(err),
   790  					start, s.changefeed, "COMMIT", dmls.rowCount, dmls.startTs)
   791  			}
   792  			return dmls.rowCount, dmls.approximateSize, nil
   793  		})
   794  		if err != nil {
   795  			return errors.Trace(err)
   796  		}
   797  		log.Debug("Exec Rows succeeded",
   798  			zap.String("changefeed", s.changefeed),
   799  			zap.Int("workerID", s.workerID),
   800  			zap.Int("numOfRows", dmls.rowCount))
   801  		return nil
   802  	}, retry.WithBackoffBaseDelay(pmysql.BackoffBaseDelay.Milliseconds()),
   803  		retry.WithBackoffMaxDelay(pmysql.BackoffMaxDelay.Milliseconds()),
   804  		retry.WithMaxTries(s.dmlMaxRetry),
   805  		retry.WithIsRetryableErr(isRetryableDMLError))
   806  }
   807  
   808  func wrapMysqlTxnError(err error) error {
   809  	errCode, ok := getSQLErrCode(err)
   810  	if !ok {
   811  		return cerror.WrapError(cerror.ErrMySQLTxnError, err)
   812  	}
   813  	switch errCode {
   814  	case mysql.ErrDupEntry:
   815  		return cerror.WrapError(cerror.ErrMySQLDuplicateEntry, err)
   816  	}
   817  	return cerror.WrapError(cerror.ErrMySQLTxnError, err)
   818  }
   819  
   820  func logDMLTxnErr(
   821  	err error, start time.Time, changefeed string,
   822  	query string, count int, startTs []model.Ts,
   823  ) error {
   824  	if len(query) > 1024 {
   825  		query = query[:1024]
   826  	}
   827  	if isRetryableDMLError(err) {
   828  		log.Warn("execute DMLs with error, retry later",
   829  			zap.Error(err), zap.Duration("duration", time.Since(start)),
   830  			zap.String("query", query), zap.Int("count", count),
   831  			zap.Uint64s("startTs", startTs),
   832  			zap.String("changefeed", changefeed))
   833  	} else {
   834  		log.Error("execute DMLs with error, can not retry",
   835  			zap.Error(err), zap.Duration("duration", time.Since(start)),
   836  			zap.String("query", query), zap.Int("count", count),
   837  			zap.String("changefeed", changefeed))
   838  	}
   839  	return errors.WithMessage(err, fmt.Sprintf("Failed query info: %s; ", query))
   840  }
   841  
   842  func isRetryableDMLError(err error) bool {
   843  	if !cerror.IsRetryableError(err) {
   844  		return false
   845  	}
   846  
   847  	errCode, ok := getSQLErrCode(err)
   848  	if !ok {
   849  		return true
   850  	}
   851  
   852  	switch errCode {
   853  	// when meet dup entry error, we don't retry and report the error directly to owner to restart the changefeed.
   854  	case mysql.ErrNoSuchTable, mysql.ErrBadDB, mysql.ErrDupEntry:
   855  		return false
   856  	}
   857  	return true
   858  }
   859  
   860  func getSQLErrCode(err error) (errors.ErrCode, bool) {
   861  	mysqlErr, ok := errors.Cause(err).(*dmysql.MySQLError)
   862  	if !ok {
   863  		return -1, false
   864  	}
   865  
   866  	return errors.ErrCode(mysqlErr.Number), true
   867  }
   868  
   869  // Only for testing.
   870  func (s *mysqlBackend) setDMLMaxRetry(maxRetry uint64) {
   871  	s.dmlMaxRetry = maxRetry
   872  }