github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/sqlmodel/multirow.go (about) 1 // Copyright 2022 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package sqlmodel 15 16 import ( 17 "strings" 18 19 "github.com/pingcap/log" 20 "github.com/pingcap/tiflow/pkg/quotes" 21 "go.uber.org/zap" 22 ) 23 24 const ( 25 // CommonIndexColumnsCount means common columns count of an index, index contains 1, 2, 26 // , 3 or 4 columns are common, but index contains 5 columns or more are not that common, 27 // so we use 4 as the common index column count. It will be used to pre-allocate slice space. 28 CommonIndexColumnsCount = 4 29 ) 30 31 // SameTypeTargetAndColumns check whether two row changes have same type, target 32 // and columns, so they can be merged to a multi-value DML. 33 func SameTypeTargetAndColumns(lhs *RowChange, rhs *RowChange) bool { 34 if lhs.tp != rhs.tp { 35 return false 36 } 37 if lhs.sourceTable.Schema == rhs.sourceTable.Schema && 38 lhs.sourceTable.Table == rhs.sourceTable.Table { 39 return true 40 } 41 if lhs.targetTable.Schema != rhs.targetTable.Schema || 42 lhs.targetTable.Table != rhs.targetTable.Table { 43 return false 44 } 45 46 // when the targets are the same and the sources are not the same (same 47 // group of shard tables), this piece of code is run. 48 var lhsCols, rhsCols []string 49 switch lhs.tp { 50 case RowChangeDelete: 51 lhsCols, _ = lhs.whereColumnsAndValues() 52 rhsCols, _ = rhs.whereColumnsAndValues() 53 case RowChangeUpdate: 54 // not supported yet 55 return false 56 case RowChangeInsert: 57 for _, col := range lhs.sourceTableInfo.Columns { 58 lhsCols = append(lhsCols, col.Name.L) 59 } 60 for _, col := range rhs.sourceTableInfo.Columns { 61 rhsCols = append(rhsCols, col.Name.L) 62 } 63 } 64 65 if len(lhsCols) != len(rhsCols) { 66 return false 67 } 68 for i := 0; i < len(lhsCols); i++ { 69 if lhsCols[i] != rhsCols[i] { 70 return false 71 } 72 } 73 return true 74 } 75 76 // GenDeleteSQL generates the DELETE SQL and its arguments. 77 // Input `changes` should have same target table and same columns for WHERE 78 // (typically same PK/NOT NULL UK), otherwise the behaviour is undefined. 79 func GenDeleteSQL(changes ...*RowChange) (string, []interface{}) { 80 if len(changes) == 0 { 81 log.L().DPanic("row changes is empty") 82 return "", nil 83 } 84 85 first := changes[0] 86 87 var buf strings.Builder 88 buf.Grow(1024) 89 buf.WriteString("DELETE FROM ") 90 buf.WriteString(first.targetTable.QuoteString()) 91 buf.WriteString(" WHERE (") 92 93 allArgs := make([]interface{}, 0, len(changes)*CommonIndexColumnsCount) 94 95 for i, c := range changes { 96 if i > 0 { 97 buf.WriteString(") OR (") 98 } 99 args := c.genWhere(&buf) 100 allArgs = append(allArgs, args...) 101 } 102 buf.WriteString(")") 103 return buf.String(), allArgs 104 } 105 106 // GenUpdateSQL generates the UPDATE SQL and its arguments. 107 // Input `changes` should have same target table and same columns for WHERE 108 // (typically same PK/NOT NULL UK), otherwise the behaviour is undefined. 109 func GenUpdateSQL(changes ...*RowChange) (string, []any) { 110 if len(changes) == 0 { 111 log.L().DPanic("row changes is empty") 112 return "", nil 113 } 114 var buf strings.Builder 115 buf.Grow(1024) 116 117 // Generate UPDATE `db`.`table` SET 118 first := changes[0] 119 buf.WriteString("UPDATE ") 120 buf.WriteString(first.targetTable.QuoteString()) 121 buf.WriteString(" SET ") 122 123 // Pre-generate essential sub statements used after WHEN, WHERE. 124 var ( 125 whenCaseStmts = make([]string, len(changes)) 126 whenCaseArgs = make([][]interface{}, len(changes)) 127 ) 128 whereColumns, _ := first.whereColumnsAndValues() 129 130 var whereBuf strings.Builder 131 for i, c := range changes { 132 whereBuf.Reset() 133 whereBuf.Grow(128) 134 whenCaseArgs[i] = c.genWhere(&whereBuf) 135 whenCaseStmts[i] = whereBuf.String() 136 } 137 138 // Build gegerated columns lower name set to accelerate the following check 139 targetGeneratedColSet := generatedColumnsNameSet(first.targetTableInfo.Columns) 140 141 // Generate `ColumnName`=CASE WHEN .. THEN .. END 142 // Use this value in order to identify which is the first CaseWhenThen line, 143 // because generated column can happen any where and it will be skipped. 144 isFirstCaseWhenThenLine := true 145 for _, column := range first.targetTableInfo.Columns { 146 // skip generated columns 147 if _, ok := targetGeneratedColSet[column.Name.L]; ok { 148 continue 149 } 150 if !isFirstCaseWhenThenLine { 151 // insert ", " after END of each lines except for the first line. 152 buf.WriteString(", ") 153 } 154 155 buf.WriteString(quotes.QuoteName(column.Name.String()) + "=CASE") 156 for i := range changes { 157 buf.WriteString(" WHEN ") 158 buf.WriteString(whenCaseStmts[i]) 159 buf.WriteString(" THEN ?") 160 } 161 buf.WriteString(" END") 162 isFirstCaseWhenThenLine = false 163 } 164 165 // Generate WHERE (...) OR (...) 166 buf.WriteString(" WHERE (") 167 for i, s := range whenCaseStmts { 168 if i > 0 { 169 buf.WriteString(") OR (") 170 } 171 buf.WriteString(s) 172 } 173 buf.WriteString(")") 174 175 // Build args of the UPDATE SQL 176 var assignValueColumnCount int 177 var skipColIdx []int 178 for i, col := range first.sourceTableInfo.Columns { 179 if _, ok := targetGeneratedColSet[col.Name.L]; ok { 180 skipColIdx = append(skipColIdx, i) 181 continue 182 } 183 assignValueColumnCount++ 184 } 185 whereValuesAtTheEnd := make([]any, 0, len(changes)*len(whereColumns)) 186 args := make([]any, 0, 187 assignValueColumnCount*len(changes)*(len(whereColumns)+1)+len(whereValuesAtTheEnd)) 188 argsPerCol := make([][]any, assignValueColumnCount) 189 for i := 0; i < assignValueColumnCount; i++ { 190 argsPerCol[i] = make([]any, 0, len(changes)*(len(whereColumns)+1)) 191 } 192 for i, change := range changes { 193 whereValues := whenCaseArgs[i] 194 // a simple check about different number of WHERE values, not trying to 195 // cover all cases 196 if len(whereValues) != len(whereColumns) { 197 log.Panic("len(whereValues) != len(whereColumns)", 198 zap.Int("len(whereValues)", len(whereValues)), 199 zap.Int("len(whereColumns)", len(whereColumns)), 200 zap.Any("whereValues", whereValues), 201 zap.Stringer("sourceTable", change.sourceTable)) 202 } 203 204 whereValuesAtTheEnd = append(whereValuesAtTheEnd, whereValues...) 205 206 i := 0 // used as index of skipColIdx 207 writeableCol := 0 208 for j, val := range change.postValues { 209 if i < len(skipColIdx) && skipColIdx[i] == j { 210 i++ 211 continue 212 } 213 argsPerCol[writeableCol] = append(argsPerCol[writeableCol], whereValues...) 214 argsPerCol[writeableCol] = append(argsPerCol[writeableCol], val) 215 writeableCol++ 216 } 217 } 218 for _, a := range argsPerCol { 219 args = append(args, a...) 220 } 221 args = append(args, whereValuesAtTheEnd...) 222 223 return buf.String(), args 224 } 225 226 // GenInsertSQL generates the INSERT SQL and its arguments. 227 // Input `changes` should have same target table and same modifiable columns, 228 // otherwise the behaviour is undefined. 229 func GenInsertSQL(tp DMLType, changes ...*RowChange) (string, []interface{}) { 230 if len(changes) == 0 { 231 log.L().DPanic("row changes is empty") 232 return "", nil 233 } 234 235 first := changes[0] 236 237 var buf strings.Builder 238 buf.Grow(1024) 239 if tp == DMLReplace { 240 buf.WriteString("REPLACE INTO ") 241 } else { 242 buf.WriteString("INSERT INTO ") 243 } 244 buf.WriteString(first.targetTable.QuoteString()) 245 buf.WriteString(" (") 246 columnNum := 0 247 var skipColIdx []int 248 249 // build gegerated columns lower name set to accelerate the following check 250 generatedColumns := generatedColumnsNameSet(first.targetTableInfo.Columns) 251 for i, col := range first.sourceTableInfo.Columns { 252 if _, ok := generatedColumns[col.Name.L]; ok { 253 skipColIdx = append(skipColIdx, i) 254 continue 255 } 256 257 if columnNum != 0 { 258 buf.WriteByte(',') 259 } 260 columnNum++ 261 buf.WriteString(quotes.QuoteName(col.Name.O)) 262 } 263 buf.WriteString(") VALUES ") 264 holder := valuesHolder(columnNum) 265 for i := range changes { 266 if i > 0 { 267 buf.WriteString(",") 268 } 269 buf.WriteString(holder) 270 } 271 if tp == DMLInsertOnDuplicateUpdate { 272 buf.WriteString(" ON DUPLICATE KEY UPDATE ") 273 i := 0 // used as index of skipColIdx 274 writtenFirstCol := false 275 276 for j, col := range first.sourceTableInfo.Columns { 277 if i < len(skipColIdx) && skipColIdx[i] == j { 278 i++ 279 continue 280 } 281 282 if writtenFirstCol { 283 buf.WriteByte(',') 284 } 285 writtenFirstCol = true 286 287 colName := quotes.QuoteName(col.Name.O) 288 buf.WriteString(colName + "=VALUES(" + colName + ")") 289 } 290 } 291 292 args := make([]interface{}, 0, len(changes)*(len(first.sourceTableInfo.Columns)-len(skipColIdx))) 293 for _, change := range changes { 294 i := 0 // used as index of skipColIdx 295 for j, val := range change.postValues { 296 if i >= len(skipColIdx) { 297 args = append(args, change.postValues[j:]...) 298 break 299 } 300 if skipColIdx[i] == j { 301 i++ 302 continue 303 } 304 args = append(args, val) 305 } 306 } 307 return buf.String(), args 308 }