
     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    14  package sqlmodel
    16  import (
    17  	"strings"
    19  	""
    20  	""
    21  	""
    22  )
    24  const (
    25  	// CommonIndexColumnsCount means common columns count of an index, index contains 1, 2,
    26  	// , 3 or 4 columns are common, but index contains 5 columns or more are not that common,
    27  	// so we use 4 as the common index column count. It will be used to pre-allocate slice space.
    28  	CommonIndexColumnsCount = 4
    29  )
    31  // SameTypeTargetAndColumns check whether two row changes have same type, target
    32  // and columns, so they can be merged to a multi-value DML.
    33  func SameTypeTargetAndColumns(lhs *RowChange, rhs *RowChange) bool {
    34  	if != {
    35  		return false
    36  	}
    37  	if lhs.sourceTable.Schema == rhs.sourceTable.Schema &&
    38  		lhs.sourceTable.Table == rhs.sourceTable.Table {
    39  		return true
    40  	}
    41  	if lhs.targetTable.Schema != rhs.targetTable.Schema ||
    42  		lhs.targetTable.Table != rhs.targetTable.Table {
    43  		return false
    44  	}
    46  	// when the targets are the same and the sources are not the same (same
    47  	// group of shard tables), this piece of code is run.
    48  	var lhsCols, rhsCols []string
    49  	switch {
    50  	case RowChangeDelete:
    51  		lhsCols, _ = lhs.whereColumnsAndValues()
    52  		rhsCols, _ = rhs.whereColumnsAndValues()
    53  	case RowChangeUpdate:
    54  		// not supported yet
    55  		return false
    56  	case RowChangeInsert:
    57  		for _, col := range lhs.sourceTableInfo.Columns {
    58  			lhsCols = append(lhsCols, col.Name.L)
    59  		}
    60  		for _, col := range rhs.sourceTableInfo.Columns {
    61  			rhsCols = append(rhsCols, col.Name.L)
    62  		}
    63  	}
    65  	if len(lhsCols) != len(rhsCols) {
    66  		return false
    67  	}
    68  	for i := 0; i < len(lhsCols); i++ {
    69  		if lhsCols[i] != rhsCols[i] {
    70  			return false
    71  		}
    72  	}
    73  	return true
    74  }
    76  // GenDeleteSQL generates the DELETE SQL and its arguments.
    77  // Input `changes` should have same target table and same columns for WHERE
    78  // (typically same PK/NOT NULL UK), otherwise the behaviour is undefined.
    79  func GenDeleteSQL(changes ...*RowChange) (string, []interface{}) {
    80  	if len(changes) == 0 {
    81  		log.L().DPanic("row changes is empty")
    82  		return "", nil
    83  	}
    85  	first := changes[0]
    87  	var buf strings.Builder
    88  	buf.Grow(1024)
    89  	buf.WriteString("DELETE FROM ")
    90  	buf.WriteString(first.targetTable.QuoteString())
    91  	buf.WriteString(" WHERE (")
    93  	allArgs := make([]interface{}, 0, len(changes)*CommonIndexColumnsCount)
    95  	for i, c := range changes {
    96  		if i > 0 {
    97  			buf.WriteString(") OR (")
    98  		}
    99  		args := c.genWhere(&buf)
   100  		allArgs = append(allArgs, args...)
   101  	}
   102  	buf.WriteString(")")
   103  	return buf.String(), allArgs
   104  }
   106  // GenUpdateSQL generates the UPDATE SQL and its arguments.
   107  // Input `changes` should have same target table and same columns for WHERE
   108  // (typically same PK/NOT NULL UK), otherwise the behaviour is undefined.
   109  func GenUpdateSQL(changes ...*RowChange) (string, []any) {
   110  	if len(changes) == 0 {
   111  		log.L().DPanic("row changes is empty")
   112  		return "", nil
   113  	}
   114  	var buf strings.Builder
   115  	buf.Grow(1024)
   117  	// Generate UPDATE `db`.`table` SET
   118  	first := changes[0]
   119  	buf.WriteString("UPDATE ")
   120  	buf.WriteString(first.targetTable.QuoteString())
   121  	buf.WriteString(" SET ")
   123  	// Pre-generate essential sub statements used after WHEN, WHERE.
   124  	var (
   125  		whenCaseStmts = make([]string, len(changes))
   126  		whenCaseArgs  = make([][]interface{}, len(changes))
   127  	)
   128  	whereColumns, _ := first.whereColumnsAndValues()
   130  	var whereBuf strings.Builder
   131  	for i, c := range changes {
   132  		whereBuf.Reset()
   133  		whereBuf.Grow(128)
   134  		whenCaseArgs[i] = c.genWhere(&whereBuf)
   135  		whenCaseStmts[i] = whereBuf.String()
   136  	}
   138  	// Build gegerated columns lower name set to accelerate the following check
   139  	targetGeneratedColSet := generatedColumnsNameSet(first.targetTableInfo.Columns)
   141  	// Generate `ColumnName`=CASE WHEN .. THEN .. END
   142  	// Use this value in order to identify which is the first CaseWhenThen line,
   143  	// because generated column can happen any where and it will be skipped.
   144  	isFirstCaseWhenThenLine := true
   145  	for _, column := range first.targetTableInfo.Columns {
   146  		// skip generated columns
   147  		if _, ok := targetGeneratedColSet[column.Name.L]; ok {
   148  			continue
   149  		}
   150  		if !isFirstCaseWhenThenLine {
   151  			// insert ", " after END of each lines except for the first line.
   152  			buf.WriteString(", ")
   153  		}
   155  		buf.WriteString(quotes.QuoteName(column.Name.String()) + "=CASE")
   156  		for i := range changes {
   157  			buf.WriteString(" WHEN ")
   158  			buf.WriteString(whenCaseStmts[i])
   159  			buf.WriteString(" THEN ?")
   160  		}
   161  		buf.WriteString(" END")
   162  		isFirstCaseWhenThenLine = false
   163  	}
   165  	// Generate WHERE (...) OR (...)
   166  	buf.WriteString(" WHERE (")
   167  	for i, s := range whenCaseStmts {
   168  		if i > 0 {
   169  			buf.WriteString(") OR (")
   170  		}
   171  		buf.WriteString(s)
   172  	}
   173  	buf.WriteString(")")
   175  	// Build args of the UPDATE SQL
   176  	var assignValueColumnCount int
   177  	var skipColIdx []int
   178  	for i, col := range first.sourceTableInfo.Columns {
   179  		if _, ok := targetGeneratedColSet[col.Name.L]; ok {
   180  			skipColIdx = append(skipColIdx, i)
   181  			continue
   182  		}
   183  		assignValueColumnCount++
   184  	}
   185  	whereValuesAtTheEnd := make([]any, 0, len(changes)*len(whereColumns))
   186  	args := make([]any, 0,
   187  		assignValueColumnCount*len(changes)*(len(whereColumns)+1)+len(whereValuesAtTheEnd))
   188  	argsPerCol := make([][]any, assignValueColumnCount)
   189  	for i := 0; i < assignValueColumnCount; i++ {
   190  		argsPerCol[i] = make([]any, 0, len(changes)*(len(whereColumns)+1))
   191  	}
   192  	for i, change := range changes {
   193  		whereValues := whenCaseArgs[i]
   194  		// a simple check about different number of WHERE values, not trying to
   195  		// cover all cases
   196  		if len(whereValues) != len(whereColumns) {
   197  			log.Panic("len(whereValues) != len(whereColumns)",
   198  				zap.Int("len(whereValues)", len(whereValues)),
   199  				zap.Int("len(whereColumns)", len(whereColumns)),
   200  				zap.Any("whereValues", whereValues),
   201  				zap.Stringer("sourceTable", change.sourceTable))
   202  		}
   204  		whereValuesAtTheEnd = append(whereValuesAtTheEnd, whereValues...)
   206  		i := 0 // used as index of skipColIdx
   207  		writeableCol := 0
   208  		for j, val := range change.postValues {
   209  			if i < len(skipColIdx) && skipColIdx[i] == j {
   210  				i++
   211  				continue
   212  			}
   213  			argsPerCol[writeableCol] = append(argsPerCol[writeableCol], whereValues...)
   214  			argsPerCol[writeableCol] = append(argsPerCol[writeableCol], val)
   215  			writeableCol++
   216  		}
   217  	}
   218  	for _, a := range argsPerCol {
   219  		args = append(args, a...)
   220  	}
   221  	args = append(args, whereValuesAtTheEnd...)
   223  	return buf.String(), args
   224  }
   226  // GenInsertSQL generates the INSERT SQL and its arguments.
   227  // Input `changes` should have same target table and same modifiable columns,
   228  // otherwise the behaviour is undefined.
   229  func GenInsertSQL(tp DMLType, changes ...*RowChange) (string, []interface{}) {
   230  	if len(changes) == 0 {
   231  		log.L().DPanic("row changes is empty")
   232  		return "", nil
   233  	}
   235  	first := changes[0]
   237  	var buf strings.Builder
   238  	buf.Grow(1024)
   239  	if tp == DMLReplace {
   240  		buf.WriteString("REPLACE INTO ")
   241  	} else {
   242  		buf.WriteString("INSERT INTO ")
   243  	}
   244  	buf.WriteString(first.targetTable.QuoteString())
   245  	buf.WriteString(" (")
   246  	columnNum := 0
   247  	var skipColIdx []int
   249  	// build gegerated columns lower name set to accelerate the following check
   250  	generatedColumns := generatedColumnsNameSet(first.targetTableInfo.Columns)
   251  	for i, col := range first.sourceTableInfo.Columns {
   252  		if _, ok := generatedColumns[col.Name.L]; ok {
   253  			skipColIdx = append(skipColIdx, i)
   254  			continue
   255  		}
   257  		if columnNum != 0 {
   258  			buf.WriteByte(',')
   259  		}
   260  		columnNum++
   261  		buf.WriteString(quotes.QuoteName(col.Name.O))
   262  	}
   263  	buf.WriteString(") VALUES ")
   264  	holder := valuesHolder(columnNum)
   265  	for i := range changes {
   266  		if i > 0 {
   267  			buf.WriteString(",")
   268  		}
   269  		buf.WriteString(holder)
   270  	}
   271  	if tp == DMLInsertOnDuplicateUpdate {
   272  		buf.WriteString(" ON DUPLICATE KEY UPDATE ")
   273  		i := 0 // used as index of skipColIdx
   274  		writtenFirstCol := false
   276  		for j, col := range first.sourceTableInfo.Columns {
   277  			if i < len(skipColIdx) && skipColIdx[i] == j {
   278  				i++
   279  				continue
   280  			}
   282  			if writtenFirstCol {
   283  				buf.WriteByte(',')
   284  			}
   285  			writtenFirstCol = true
   287  			colName := quotes.QuoteName(col.Name.O)
   288  			buf.WriteString(colName + "=VALUES(" + colName + ")")
   289  		}
   290  	}
   292  	args := make([]interface{}, 0, len(changes)*(len(first.sourceTableInfo.Columns)-len(skipColIdx)))
   293  	for _, change := range changes {
   294  		i := 0 // used as index of skipColIdx
   295  		for j, val := range change.postValues {
   296  			if i >= len(skipColIdx) {
   297  				args = append(args, change.postValues[j:]...)
   298  				break
   299  			}
   300  			if skipColIdx[i] == j {
   301  				i++
   302  				continue
   303  			}
   304  			args = append(args, val)
   305  		}
   306  	}
   307  	return buf.String(), args
   308  }