github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/sqlmodel/row_change.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package sqlmodel
    15  
    16  import (
    17  	"fmt"
    18  	"strings"
    19  
    20  	"github.com/pingcap/failpoint"
    21  	timodel "github.com/pingcap/tidb/pkg/parser/model"
    22  	"github.com/pingcap/tidb/pkg/sessionctx"
    23  	cdcmodel "github.com/pingcap/tiflow/cdc/model"
    24  	"github.com/pingcap/tiflow/dm/pkg/log"
    25  	"github.com/pingcap/tiflow/dm/pkg/utils"
    26  	"github.com/pingcap/tiflow/pkg/quotes"
    27  	"go.uber.org/zap"
    28  )
    29  
    30  // RowChangeType is the type of row change.
    31  type RowChangeType int
    32  
    33  // these constants represent types of row change.
    34  const (
    35  	RowChangeNull RowChangeType = iota
    36  	RowChangeInsert
    37  	RowChangeUpdate
    38  	RowChangeDelete
    39  )
    40  
    41  // String implements fmt.Stringer interface.
    42  func (t RowChangeType) String() string {
    43  	switch t {
    44  	case RowChangeInsert:
    45  		return "ChangeInsert"
    46  	case RowChangeUpdate:
    47  		return "ChangeUpdate"
    48  	case RowChangeDelete:
    49  		return "ChangeDelete"
    50  	}
    51  
    52  	return ""
    53  }
    54  
    55  // RowChange represents a row change, it can be further converted into DML SQL.
    56  // It also provides some utility functions about calculating causality of two
    57  // row changes, merging successive row changes into one row change, etc.
    58  type RowChange struct {
    59  	sourceTable *cdcmodel.TableName
    60  	targetTable *cdcmodel.TableName
    61  
    62  	preValues  []interface{}
    63  	postValues []interface{}
    64  
    65  	sourceTableInfo *timodel.TableInfo
    66  	targetTableInfo *timodel.TableInfo
    67  
    68  	tiSessionCtx sessionctx.Context
    69  
    70  	tp          RowChangeType
    71  	whereHandle *WhereHandle
    72  
    73  	approximateDataSize int64
    74  }
    75  
    76  // NewRowChange creates a new RowChange.
    77  // preValues stands for values exists before this change, postValues stands for
    78  // values exists after this change.
    79  // These parameters can be nil:
    80  // - targetTable: when same as sourceTable or not applicable
    81  // - preValues: when INSERT
    82  // - postValues: when DELETE
    83  // - targetTableInfo: when same as sourceTableInfo or not applicable
    84  // - tiSessionCtx: will use default sessionCtx which is UTC timezone
    85  // All arguments must not be changed after assigned to RowChange, any
    86  // modification (like convert []byte to string) should be done before
    87  // NewRowChange.
    88  func NewRowChange(
    89  	sourceTable *cdcmodel.TableName,
    90  	targetTable *cdcmodel.TableName,
    91  	preValues []interface{},
    92  	postValues []interface{},
    93  	sourceTableInfo *timodel.TableInfo,
    94  	downstreamTableInfo *timodel.TableInfo,
    95  	tiCtx sessionctx.Context,
    96  ) *RowChange {
    97  	if sourceTable == nil {
    98  		log.L().DPanic("sourceTable is nil")
    99  	}
   100  	if sourceTableInfo == nil {
   101  		log.L().DPanic("sourceTableInfo is nil")
   102  	}
   103  
   104  	ret := &RowChange{
   105  		sourceTable:     sourceTable,
   106  		preValues:       preValues,
   107  		postValues:      postValues,
   108  		sourceTableInfo: sourceTableInfo,
   109  	}
   110  
   111  	colCount := ret.ColumnCount()
   112  	if preValues != nil && len(preValues) != colCount {
   113  		log.L().DPanic("preValues length not equal to sourceTableInfo columns",
   114  			zap.Int("preValues", len(preValues)),
   115  			zap.Int("sourceTableInfo", colCount),
   116  			zap.Stringer("sourceTable", sourceTable))
   117  	}
   118  	if postValues != nil && len(postValues) != colCount {
   119  		log.L().DPanic("postValues length not equal to sourceTableInfo columns",
   120  			zap.Int("postValues", len(postValues)),
   121  			zap.Int("sourceTableInfo", colCount),
   122  			zap.Stringer("sourceTable", sourceTable))
   123  	}
   124  
   125  	if targetTable != nil {
   126  		ret.targetTable = targetTable
   127  	} else {
   128  		ret.targetTable = sourceTable
   129  	}
   130  
   131  	if downstreamTableInfo != nil {
   132  		ret.targetTableInfo = downstreamTableInfo
   133  	} else {
   134  		ret.targetTableInfo = sourceTableInfo
   135  	}
   136  
   137  	if tiCtx != nil {
   138  		ret.tiSessionCtx = tiCtx
   139  	} else {
   140  		ret.tiSessionCtx = utils.ZeroSessionCtx
   141  	}
   142  
   143  	ret.calculateType()
   144  
   145  	return ret
   146  }
   147  
   148  func (r *RowChange) calculateType() {
   149  	switch {
   150  	case r.preValues == nil && r.postValues != nil:
   151  		r.tp = RowChangeInsert
   152  	case r.preValues != nil && r.postValues != nil:
   153  		r.tp = RowChangeUpdate
   154  	case r.preValues != nil && r.postValues == nil:
   155  		r.tp = RowChangeDelete
   156  	default:
   157  		log.L().DPanic("preValues and postValues can't both be nil",
   158  			zap.Stringer("sourceTable", r.sourceTable))
   159  	}
   160  }
   161  
   162  // Type returns the RowChangeType of this RowChange. Caller can future decide
   163  // the DMLType when generate DML from it.
   164  func (r *RowChange) Type() RowChangeType {
   165  	return r.tp
   166  }
   167  
   168  // String implements Stringer interface.
   169  func (r *RowChange) String() string {
   170  	return fmt.Sprintf("type: %s, source table: %s, target table: %s, preValues: %v, postValues: %v",
   171  		r.tp, r.sourceTable, r.targetTable, r.preValues, r.postValues)
   172  }
   173  
   174  // TargetTableID returns a ID string for target table.
   175  func (r *RowChange) TargetTableID() string {
   176  	return r.targetTable.QuoteString()
   177  }
   178  
   179  // ColumnCount returns the number of columns of this RowChange.
   180  // TiDB TableInfo contains some internal columns like expression index, they
   181  // are not included in this count.
   182  func (r *RowChange) ColumnCount() int {
   183  	c := 0
   184  	for _, col := range r.sourceTableInfo.Columns {
   185  		if !col.Hidden {
   186  			c++
   187  		}
   188  	}
   189  	return c
   190  }
   191  
   192  // SourceTableInfo returns the TableInfo of source table.
   193  func (r *RowChange) SourceTableInfo() *timodel.TableInfo {
   194  	return r.sourceTableInfo
   195  }
   196  
   197  // UniqueNotNullIdx returns the unique and not null index.
   198  func (r *RowChange) UniqueNotNullIdx() *timodel.IndexInfo {
   199  	r.lazyInitWhereHandle()
   200  	return r.whereHandle.UniqueNotNullIdx
   201  }
   202  
   203  // SetWhereHandle can be used when caller has cached whereHandle, to avoid every
   204  // RowChange lazily initialize it.
   205  func (r *RowChange) SetWhereHandle(whereHandle *WhereHandle) {
   206  	r.whereHandle = whereHandle
   207  }
   208  
   209  // GetApproximateDataSize returns internal approximateDataSize, it could be zero
   210  // if this value is not set.
   211  func (r *RowChange) GetApproximateDataSize() int64 {
   212  	return r.approximateDataSize
   213  }
   214  
   215  // SetApproximateDataSize sets the approximate size of row change.
   216  func (r *RowChange) SetApproximateDataSize(approximateDataSize int64) {
   217  	r.approximateDataSize = approximateDataSize
   218  }
   219  
   220  func (r *RowChange) lazyInitWhereHandle() {
   221  	if r.whereHandle != nil {
   222  		return
   223  	}
   224  
   225  	r.whereHandle = GetWhereHandle(r.sourceTableInfo, r.targetTableInfo)
   226  }
   227  
   228  // whereColumnsAndValues returns columns and values to identify the row, to form
   229  // the WHERE clause.
   230  func (r *RowChange) whereColumnsAndValues() ([]string, []interface{}) {
   231  	r.lazyInitWhereHandle()
   232  
   233  	columns, values := r.sourceTableInfo.Columns, r.preValues
   234  
   235  	uniqueIndex := r.whereHandle.getWhereIdxByData(r.preValues)
   236  	if uniqueIndex != nil {
   237  		columns, values = getColsAndValuesOfIdx(r.sourceTableInfo.Columns, uniqueIndex, values)
   238  	}
   239  
   240  	columnNames := make([]string, 0, len(columns))
   241  	for _, column := range columns {
   242  		columnNames = append(columnNames, column.Name.O)
   243  	}
   244  
   245  	failpoint.Inject("DownstreamTrackerWhereCheck", func() {
   246  		if r.tp == RowChangeUpdate {
   247  			log.L().Info("UpdateWhereColumnsCheck",
   248  				zap.String("Columns", fmt.Sprintf("%v", columnNames)))
   249  		} else if r.tp == RowChangeDelete {
   250  			log.L().Info("DeleteWhereColumnsCheck",
   251  				zap.String("Columns", fmt.Sprintf("%v", columnNames)))
   252  		}
   253  	})
   254  
   255  	return columnNames, values
   256  }
   257  
   258  // genWhere generates WHERE clause for UPDATE and DELETE to identify the row.
   259  // the SQL part is written to `buf` and the args part is returned.
   260  func (r *RowChange) genWhere(buf *strings.Builder) []interface{} {
   261  	whereColumns, whereValues := r.whereColumnsAndValues()
   262  
   263  	for i, col := range whereColumns {
   264  		if i != 0 {
   265  			buf.WriteString(" AND ")
   266  		}
   267  		buf.WriteString(quotes.QuoteName(col))
   268  		if whereValues[i] == nil {
   269  			buf.WriteString(" IS ?")
   270  		} else {
   271  			buf.WriteString(" = ?")
   272  		}
   273  	}
   274  	return whereValues
   275  }
   276  
   277  func (r *RowChange) genDeleteSQL() (string, []interface{}) {
   278  	if r.tp != RowChangeDelete && r.tp != RowChangeUpdate {
   279  		log.L().DPanic("illegal type for genDeleteSQL",
   280  			zap.String("sourceTable", r.sourceTable.String()),
   281  			zap.Stringer("changeType", r.tp))
   282  		return "", nil
   283  	}
   284  
   285  	var buf strings.Builder
   286  	buf.Grow(1024)
   287  	buf.WriteString("DELETE FROM ")
   288  	buf.WriteString(r.targetTable.QuoteString())
   289  	buf.WriteString(" WHERE ")
   290  	whereArgs := r.genWhere(&buf)
   291  	buf.WriteString(" LIMIT 1")
   292  
   293  	return buf.String(), whereArgs
   294  }
   295  
   296  func (r *RowChange) genUpdateSQL() (string, []interface{}) {
   297  	if r.tp != RowChangeUpdate {
   298  		log.L().DPanic("illegal type for genUpdateSQL",
   299  			zap.String("sourceTable", r.sourceTable.String()),
   300  			zap.Stringer("changeType", r.tp))
   301  		return "", nil
   302  	}
   303  
   304  	var buf strings.Builder
   305  	buf.Grow(2048)
   306  	buf.WriteString("UPDATE ")
   307  	buf.WriteString(r.targetTable.QuoteString())
   308  	buf.WriteString(" SET ")
   309  
   310  	// Build target generated columns lower names set to accelerate following check
   311  	generatedColumns := generatedColumnsNameSet(r.targetTableInfo.Columns)
   312  	args := make([]interface{}, 0, len(r.preValues)+len(r.postValues))
   313  	writtenFirstCol := false
   314  	for i, col := range r.sourceTableInfo.Columns {
   315  		if _, ok := generatedColumns[col.Name.L]; ok {
   316  			continue
   317  		}
   318  
   319  		if writtenFirstCol {
   320  			buf.WriteString(", ")
   321  		}
   322  		writtenFirstCol = true
   323  		fmt.Fprintf(&buf, "%s = ?", quotes.QuoteName(col.Name.O))
   324  		args = append(args, r.postValues[i])
   325  	}
   326  
   327  	buf.WriteString(" WHERE ")
   328  	whereArgs := r.genWhere(&buf)
   329  	buf.WriteString(" LIMIT 1")
   330  
   331  	args = append(args, whereArgs...)
   332  	return buf.String(), args
   333  }
   334  
   335  func (r *RowChange) genInsertSQL(tp DMLType) (string, []interface{}) {
   336  	return GenInsertSQL(tp, r)
   337  }
   338  
   339  // DMLType indicates the type of DML.
   340  type DMLType int
   341  
   342  // these constants represent types of row change.
   343  const (
   344  	DMLNull DMLType = iota
   345  	DMLInsert
   346  	DMLReplace
   347  	DMLInsertOnDuplicateUpdate
   348  	DMLUpdate
   349  	DMLDelete
   350  )
   351  
   352  // String implements fmt.Stringer interface.
   353  func (t DMLType) String() string {
   354  	switch t {
   355  	case DMLInsert:
   356  		return "DMLInsert"
   357  	case DMLReplace:
   358  		return "DMLReplace"
   359  	case DMLUpdate:
   360  		return "DMLUpdate"
   361  	case DMLInsertOnDuplicateUpdate:
   362  		return "DMLInsertOnDuplicateUpdate"
   363  	case DMLDelete:
   364  		return "DMLDelete"
   365  	}
   366  
   367  	return ""
   368  }
   369  
   370  // GenSQL generated a DML SQL for this RowChange.
   371  func (r *RowChange) GenSQL(tp DMLType) (string, []interface{}) {
   372  	switch tp {
   373  	case DMLInsert, DMLReplace, DMLInsertOnDuplicateUpdate:
   374  		return r.genInsertSQL(tp)
   375  	case DMLUpdate:
   376  		return r.genUpdateSQL()
   377  	case DMLDelete:
   378  		return r.genDeleteSQL()
   379  	}
   380  	log.L().DPanic("illegal type for GenSQL",
   381  		zap.String("sourceTable", r.sourceTable.String()),
   382  		zap.Stringer("DMLType", tp))
   383  	return "", nil
   384  }
   385  
   386  // GetPreValues is only used in tests.
   387  func (r *RowChange) GetPreValues() []interface{} {
   388  	return r.preValues
   389  }
   390  
   391  // GetPostValues is only used in tests.
   392  func (r *RowChange) GetPostValues() []interface{} {
   393  	return r.postValues
   394  }
   395  
   396  // RowValues returns the values of this row change
   397  // for INSERT and UPDATE, it is the post values.
   398  // for DELETE, it is the pre values.
   399  func (r *RowChange) RowValues() []interface{} {
   400  	switch r.tp {
   401  	case RowChangeInsert, RowChangeUpdate:
   402  		return r.postValues
   403  	default:
   404  		return r.preValues
   405  	}
   406  }
   407  
   408  // GetSourceTable returns TableName of the source table.
   409  func (r *RowChange) GetSourceTable() *cdcmodel.TableName {
   410  	return r.sourceTable
   411  }
   412  
   413  // GetTargetTable returns TableName of the target table.
   414  func (r *RowChange) GetTargetTable() *cdcmodel.TableName {
   415  	return r.targetTable
   416  }