vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletmanager/vreplication/table_plan_builder.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package vreplication 18 19 import ( 20 "fmt" 21 "regexp" 22 "sort" 23 "strings" 24 25 "vitess.io/vitess/go/sqltypes" 26 "vitess.io/vitess/go/textutil" 27 "vitess.io/vitess/go/vt/binlog/binlogplayer" 28 "vitess.io/vitess/go/vt/key" 29 binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" 30 querypb "vitess.io/vitess/go/vt/proto/query" 31 "vitess.io/vitess/go/vt/schema" 32 "vitess.io/vitess/go/vt/sqlparser" 33 ) 34 35 // This file contains just the builders for ReplicatorPlan and TablePlan. 36 // ReplicatorPlan and TablePlan are in replicator_plan.go. 37 // TODO(sougou): reorganize this in a better fashion. 38 39 // ExcludeStr is the filter value for excluding tables that match a rule. 40 // TODO(sougou): support this on vstreamer side also. 41 const ExcludeStr = "exclude" 42 43 // tablePlanBuilder contains the metadata needed for building a TablePlan. 44 type tablePlanBuilder struct { 45 name sqlparser.IdentifierCS 46 sendSelect *sqlparser.Select 47 // selColumns keeps track of the columns we want to pull from source. 48 // If Lastpk is set, we compare this list against the table's pk and 49 // add missing references. 50 colExprs []*colExpr 51 onInsert insertType 52 pkCols []*colExpr 53 extraSourcePkCols []*colExpr 54 lastpk *sqltypes.Result 55 colInfos []*ColumnInfo 56 stats *binlogplayer.Stats 57 source *binlogdatapb.BinlogSource 58 } 59 60 // colExpr describes the processing to be performed to 61 // compute the value of one column of the target table. 62 type colExpr struct { 63 colName sqlparser.IdentifierCI 64 colType querypb.Type 65 // operation==opExpr: full expression is set 66 // operation==opCount: nothing is set. 67 // operation==opSum: for 'sum(a)', expr is set to 'a'. 68 operation operation 69 // expr stores the expected field name from vstreamer and dictates 70 // the generated bindvar names, like a_col or b_col. 71 expr sqlparser.Expr 72 // references contains all the column names referenced in the expression. 73 references map[string]bool 74 75 isGrouped bool 76 isPK bool 77 dataType string 78 columnType string 79 } 80 81 // operation is the opcode for the colExpr. 82 type operation int 83 84 // The following values are the various colExpr opcodes. 85 const ( 86 opExpr = operation(iota) 87 opCount 88 opSum 89 ) 90 91 // insertType describes the type of insert statement to generate. 92 // Please refer to TestBuildPlayerPlan for examples. 93 type insertType int 94 95 // The following values are the various insert types. 96 const ( 97 // insertNormal is for normal selects without a group by, like 98 // "select a+b as c from t". 99 insertNormal = insertType(iota) 100 // insertOnDup is for the more traditional grouped expressions, like 101 // "select a, b, count(*) as c from t group by a". For statements 102 // like these, "insert.. on duplicate key" statements will be generated 103 // causing "b" to be updated to the latest value (last value wins). 104 insertOnDup 105 // insertIgnore is for special grouped expressions where all columns are 106 // in the group by, like "select a, b, c from t group by a, b, c". 107 // This generates "insert ignore" statements (first value wins). 108 insertIgnore 109 ) 110 111 // buildReplicatorPlan builds a ReplicatorPlan for the tables that match the filter. 112 // The filter is matched against the target schema. For every table matched, 113 // a table-specific rule is built to be sent to the source. We don't send the 114 // original rule to the source because it may not match the same tables as the 115 // target. 116 // colInfoMap specifies the list of primary key columns for each table. 117 // copyState is a map of tables that have not been fully copied yet. 118 // If a table is not present in copyState, then it has been fully copied. If so, 119 // all replication events are applied. The table still has to match a Filter.Rule. 120 // If it has a non-nil entry, then the value is the last primary key (lastpk) 121 // that was copied. If so, only replication events < lastpk are applied. 122 // If the entry is nil, then copying of the table has not started yet. If so, 123 // no events are applied. 124 // The TablePlan built is a partial plan. The full plan for a table is built 125 // when we receive field information from events or rows sent by the source. 126 // buildExecutionPlan is the function that builds the full plan. 127 func buildReplicatorPlan(source *binlogdatapb.BinlogSource, colInfoMap map[string][]*ColumnInfo, copyState map[string]*sqltypes.Result, stats *binlogplayer.Stats) (*ReplicatorPlan, error) { 128 filter := source.Filter 129 plan := &ReplicatorPlan{ 130 VStreamFilter: &binlogdatapb.Filter{FieldEventMode: filter.FieldEventMode}, 131 TargetTables: make(map[string]*TablePlan), 132 TablePlans: make(map[string]*TablePlan), 133 ColInfoMap: colInfoMap, 134 stats: stats, 135 Source: source, 136 } 137 for tableName := range colInfoMap { 138 lastpk, ok := copyState[tableName] 139 if ok && lastpk == nil { 140 // Don't replicate uncopied tables. 141 continue 142 } 143 rule, err := MatchTable(tableName, filter) 144 if err != nil { 145 return nil, err 146 } 147 if rule == nil { 148 continue 149 } 150 colInfos, ok := colInfoMap[tableName] 151 if !ok { 152 return nil, fmt.Errorf("table %s not found in schema", tableName) 153 } 154 tablePlan, err := buildTablePlan(tableName, rule, colInfos, lastpk, stats, source) 155 if err != nil { 156 return nil, err 157 } 158 if tablePlan == nil { 159 // Table was excluded. 160 continue 161 } 162 if dup, ok := plan.TablePlans[tablePlan.SendRule.Match]; ok { 163 return nil, fmt.Errorf("more than one target for source table %s: %s and %s", tablePlan.SendRule.Match, dup.TargetName, tableName) 164 } 165 plan.VStreamFilter.Rules = append(plan.VStreamFilter.Rules, tablePlan.SendRule) 166 plan.TargetTables[tableName] = tablePlan 167 plan.TablePlans[tablePlan.SendRule.Match] = tablePlan 168 } 169 return plan, nil 170 } 171 172 // MatchTable is similar to tableMatches and buildPlan defined in vstreamer/planbuilder.go. 173 func MatchTable(tableName string, filter *binlogdatapb.Filter) (*binlogdatapb.Rule, error) { 174 for _, rule := range filter.Rules { 175 switch { 176 case strings.HasPrefix(rule.Match, "/"): 177 expr := strings.Trim(rule.Match, "/") 178 result, err := regexp.MatchString(expr, tableName) 179 if err != nil { 180 return nil, err 181 } 182 if !result { 183 continue 184 } 185 return rule, nil 186 case tableName == rule.Match: 187 return rule, nil 188 } 189 } 190 return nil, nil 191 } 192 193 func buildTablePlan(tableName string, rule *binlogdatapb.Rule, colInfos []*ColumnInfo, lastpk *sqltypes.Result, 194 stats *binlogplayer.Stats, source *binlogdatapb.BinlogSource) (*TablePlan, error) { 195 196 filter := rule.Filter 197 query := filter 198 // generate equivalent select statement if filter is empty or a keyrange. 199 switch { 200 case filter == "": 201 buf := sqlparser.NewTrackedBuffer(nil) 202 buf.Myprintf("select * from %v", sqlparser.NewIdentifierCS(tableName)) 203 query = buf.String() 204 case key.IsKeyRange(filter): 205 buf := sqlparser.NewTrackedBuffer(nil) 206 buf.Myprintf("select * from %v where in_keyrange(%v)", sqlparser.NewIdentifierCS(tableName), sqlparser.NewStrLiteral(filter)) 207 query = buf.String() 208 case filter == ExcludeStr: 209 return nil, nil 210 } 211 sel, fromTable, err := analyzeSelectFrom(query) 212 if err != nil { 213 return nil, err 214 } 215 sendRule := &binlogdatapb.Rule{ 216 Match: fromTable, 217 } 218 219 enumValuesMap := map[string](map[string]string){} 220 for k, v := range rule.ConvertEnumToText { 221 tokensMap := schema.ParseEnumOrSetTokensMap(v) 222 enumValuesMap[k] = tokensMap 223 } 224 225 if expr, ok := sel.SelectExprs[0].(*sqlparser.StarExpr); ok { 226 // If it's a "select *", we return a partial plan, and complete 227 // it when we get back field info from the stream. 228 if len(sel.SelectExprs) != 1 { 229 return nil, fmt.Errorf("unexpected: %v", sqlparser.String(sel)) 230 } 231 if !expr.TableName.IsEmpty() { 232 return nil, fmt.Errorf("unsupported qualifier for '*' expression: %v", sqlparser.String(expr)) 233 } 234 sendRule.Filter = query 235 tablePlan := &TablePlan{ 236 TargetName: tableName, 237 SendRule: sendRule, 238 Lastpk: lastpk, 239 Stats: stats, 240 EnumValuesMap: enumValuesMap, 241 ConvertCharset: rule.ConvertCharset, 242 ConvertIntToEnum: rule.ConvertIntToEnum, 243 } 244 245 return tablePlan, nil 246 } 247 248 tpb := &tablePlanBuilder{ 249 name: sqlparser.NewIdentifierCS(tableName), 250 sendSelect: &sqlparser.Select{ 251 From: sel.From, 252 Where: sel.Where, 253 }, 254 lastpk: lastpk, 255 colInfos: colInfos, 256 stats: stats, 257 source: source, 258 } 259 260 if err := tpb.analyzeExprs(sel.SelectExprs); err != nil { 261 return nil, err 262 } 263 // It's possible that the target table does not materialize all 264 // the primary keys of the source table. In such situations, 265 // we still have to be able to validate the incoming event 266 // against the current lastpk. For this, we have to request 267 // the missing columns so we can compare against those values. 268 // If there is no lastpk to validate against, then we don't 269 // care. 270 if tpb.lastpk != nil { 271 for _, f := range tpb.lastpk.Fields { 272 tpb.addCol(sqlparser.NewIdentifierCI(f.Name)) 273 } 274 } 275 if err := tpb.analyzeGroupBy(sel.GroupBy); err != nil { 276 return nil, err 277 } 278 targetKeyColumnNames, err := textutil.SplitUnescape(rule.TargetUniqueKeyColumns, ",") 279 if err != nil { 280 return nil, err 281 } 282 pkColsInfo := tpb.getPKColsInfo(targetKeyColumnNames, colInfos) 283 if err := tpb.analyzePK(pkColsInfo); err != nil { 284 return nil, err 285 } 286 287 sourceKeyTargetColumnNames, err := textutil.SplitUnescape(rule.SourceUniqueKeyTargetColumns, ",") 288 if err != nil { 289 return nil, err 290 } 291 if err := tpb.analyzeExtraSourcePkCols(colInfos, sourceKeyTargetColumnNames); err != nil { 292 return nil, err 293 } 294 295 // if there are no columns being selected the select expression can be empty, so we "select 1" so we have a valid 296 // select to get a row back 297 if len(tpb.sendSelect.SelectExprs) == 0 { 298 tpb.sendSelect.SelectExprs = sqlparser.SelectExprs([]sqlparser.SelectExpr{ 299 &sqlparser.AliasedExpr{ 300 Expr: sqlparser.NewIntLiteral("1"), 301 }, 302 }) 303 } 304 commentsList := []string{} 305 if rule.SourceUniqueKeyColumns != "" { 306 commentsList = append(commentsList, fmt.Sprintf(`ukColumns="%s"`, rule.SourceUniqueKeyColumns)) 307 } 308 if len(commentsList) > 0 { 309 comments := sqlparser.Comments{ 310 fmt.Sprintf(`/*vt+ %s */`, strings.Join(commentsList, " ")), 311 } 312 tpb.sendSelect.Comments = comments.Parsed() 313 } 314 sendRule.Filter = sqlparser.String(tpb.sendSelect) 315 316 tablePlan := tpb.generate() 317 tablePlan.SendRule = sendRule 318 tablePlan.EnumValuesMap = enumValuesMap 319 tablePlan.ConvertCharset = rule.ConvertCharset 320 tablePlan.ConvertIntToEnum = rule.ConvertIntToEnum 321 return tablePlan, nil 322 } 323 324 func (tpb *tablePlanBuilder) generate() *TablePlan { 325 refmap := make(map[string]bool) 326 for _, cexpr := range tpb.pkCols { 327 for k := range cexpr.references { 328 refmap[k] = true 329 } 330 } 331 if tpb.lastpk != nil { 332 for _, f := range tpb.lastpk.Fields { 333 refmap[f.Name] = true 334 } 335 } 336 pkrefs := make([]string, 0, len(refmap)) 337 for k := range refmap { 338 pkrefs = append(pkrefs, k) 339 } 340 sort.Strings(pkrefs) 341 342 bvf := &bindvarFormatter{} 343 344 fieldsToSkip := make(map[string]bool) 345 for _, colInfo := range tpb.colInfos { 346 if colInfo.IsGenerated { 347 fieldsToSkip[colInfo.Name] = true 348 } 349 } 350 351 return &TablePlan{ 352 TargetName: tpb.name.String(), 353 Lastpk: tpb.lastpk, 354 BulkInsertFront: tpb.generateInsertPart(sqlparser.NewTrackedBuffer(bvf.formatter)), 355 BulkInsertValues: tpb.generateValuesPart(sqlparser.NewTrackedBuffer(bvf.formatter), bvf), 356 BulkInsertOnDup: tpb.generateOnDupPart(sqlparser.NewTrackedBuffer(bvf.formatter)), 357 Insert: tpb.generateInsertStatement(), 358 Update: tpb.generateUpdateStatement(), 359 Delete: tpb.generateDeleteStatement(), 360 PKReferences: pkrefs, 361 Stats: tpb.stats, 362 FieldsToSkip: fieldsToSkip, 363 HasExtraSourcePkColumns: (len(tpb.extraSourcePkCols) > 0), 364 } 365 } 366 367 func analyzeSelectFrom(query string) (sel *sqlparser.Select, from string, err error) { 368 statement, err := sqlparser.Parse(query) 369 if err != nil { 370 return nil, "", err 371 } 372 sel, ok := statement.(*sqlparser.Select) 373 if !ok { 374 return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(statement)) 375 } 376 if sel.Distinct { 377 return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel)) 378 } 379 if len(sel.From) > 1 { 380 return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel)) 381 } 382 node, ok := sel.From[0].(*sqlparser.AliasedTableExpr) 383 if !ok { 384 return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel)) 385 } 386 fromTable := sqlparser.GetTableName(node.Expr) 387 if fromTable.IsEmpty() { 388 return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel)) 389 } 390 return sel, fromTable.String(), nil 391 } 392 393 func (tpb *tablePlanBuilder) analyzeExprs(selExprs sqlparser.SelectExprs) error { 394 for _, selExpr := range selExprs { 395 cexpr, err := tpb.analyzeExpr(selExpr) 396 if err != nil { 397 return err 398 } 399 tpb.colExprs = append(tpb.colExprs, cexpr) 400 } 401 return nil 402 } 403 404 func (tpb *tablePlanBuilder) analyzeExpr(selExpr sqlparser.SelectExpr) (*colExpr, error) { 405 aliased, ok := selExpr.(*sqlparser.AliasedExpr) 406 if !ok { 407 return nil, fmt.Errorf("unexpected: %v", sqlparser.String(selExpr)) 408 } 409 as := aliased.As 410 if as.IsEmpty() { 411 // Require all non-trivial expressions to have an alias. 412 if colAs, ok := aliased.Expr.(*sqlparser.ColName); ok && colAs.Qualifier.IsEmpty() { 413 as = colAs.Name 414 } else { 415 return nil, fmt.Errorf("expression needs an alias: %v", sqlparser.String(aliased)) 416 } 417 } 418 cexpr := &colExpr{ 419 colName: as, 420 references: make(map[string]bool), 421 } 422 if expr, ok := aliased.Expr.(*sqlparser.ConvertUsingExpr); ok { 423 selExpr := &sqlparser.ConvertUsingExpr{ 424 Type: "utf8mb4", 425 Expr: &sqlparser.ColName{Name: as}, 426 } 427 cexpr.expr = expr 428 cexpr.operation = opExpr 429 tpb.sendSelect.SelectExprs = append(tpb.sendSelect.SelectExprs, &sqlparser.AliasedExpr{Expr: selExpr, As: as}) 430 cexpr.references[as.String()] = true 431 return cexpr, nil 432 } 433 if expr, ok := aliased.Expr.(*sqlparser.FuncExpr); ok { 434 switch fname := expr.Name.Lowered(); fname { 435 case "keyspace_id": 436 if len(expr.Exprs) != 0 { 437 return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) 438 } 439 tpb.sendSelect.SelectExprs = append(tpb.sendSelect.SelectExprs, &sqlparser.AliasedExpr{Expr: aliased.Expr}) 440 // The vstreamer responds with "keyspace_id" as the field name for this request. 441 cexpr.expr = &sqlparser.ColName{Name: sqlparser.NewIdentifierCI("keyspace_id")} 442 return cexpr, nil 443 } 444 } 445 if expr, ok := aliased.Expr.(sqlparser.AggrFunc); ok { 446 if expr.IsDistinct() { 447 return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) 448 } 449 switch fname := strings.ToLower(expr.AggrName()); fname { 450 case "count": 451 if _, ok := expr.(*sqlparser.CountStar); !ok { 452 return nil, fmt.Errorf("only count(*) is supported: %v", sqlparser.String(expr)) 453 } 454 cexpr.operation = opCount 455 return cexpr, nil 456 case "sum": 457 if len(expr.GetArgs()) != 1 { 458 return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) 459 } 460 innerCol, ok := expr.GetArg().(*sqlparser.ColName) 461 if !ok { 462 return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) 463 } 464 if !innerCol.Qualifier.IsEmpty() { 465 return nil, fmt.Errorf("unsupported qualifier for column: %v", sqlparser.String(innerCol)) 466 } 467 cexpr.operation = opSum 468 cexpr.expr = innerCol 469 tpb.addCol(innerCol.Name) 470 cexpr.references[innerCol.Name.String()] = true 471 return cexpr, nil 472 } 473 } 474 err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { 475 switch node := node.(type) { 476 case *sqlparser.ColName: 477 if !node.Qualifier.IsEmpty() { 478 return false, fmt.Errorf("unsupported qualifier for column: %v", sqlparser.String(node)) 479 } 480 tpb.addCol(node.Name) 481 cexpr.references[node.Name.String()] = true 482 case *sqlparser.Subquery: 483 return false, fmt.Errorf("unsupported subquery: %v", sqlparser.String(node)) 484 case sqlparser.AggrFunc: 485 return false, fmt.Errorf("unexpected: %v", sqlparser.String(node)) 486 } 487 return true, nil 488 }, aliased.Expr) 489 if err != nil { 490 return nil, err 491 } 492 cexpr.expr = aliased.Expr 493 return cexpr, nil 494 } 495 496 // addCol adds the specified column to the send query 497 // if it's not already present. 498 func (tpb *tablePlanBuilder) addCol(ident sqlparser.IdentifierCI) { 499 tpb.sendSelect.SelectExprs = append(tpb.sendSelect.SelectExprs, &sqlparser.AliasedExpr{ 500 Expr: &sqlparser.ColName{Name: ident}, 501 }) 502 } 503 504 func (tpb *tablePlanBuilder) analyzeGroupBy(groupBy sqlparser.GroupBy) error { 505 if groupBy == nil { 506 // If there's no grouping, the it's an insertNormal. 507 return nil 508 } 509 for _, expr := range groupBy { 510 colname, ok := expr.(*sqlparser.ColName) 511 if !ok { 512 return fmt.Errorf("unexpected: %v", sqlparser.String(expr)) 513 } 514 cexpr := tpb.findCol(colname.Name) 515 if cexpr == nil { 516 return fmt.Errorf("group by expression does not reference an alias in the select list: %v", sqlparser.String(expr)) 517 } 518 if cexpr.operation != opExpr { 519 return fmt.Errorf("group by expression is not allowed to reference an aggregate expression: %v", sqlparser.String(expr)) 520 } 521 cexpr.isGrouped = true 522 } 523 // If all colExprs are grouped, then it's an insertIgnore. 524 tpb.onInsert = insertIgnore 525 for _, cExpr := range tpb.colExprs { 526 if !cExpr.isGrouped { 527 // If some colExprs are not grouped, then it's an insertOnDup. 528 tpb.onInsert = insertOnDup 529 break 530 } 531 } 532 return nil 533 } 534 535 func (tpb *tablePlanBuilder) getPKColsInfo(uniqueKeyColumns []string, colInfos []*ColumnInfo) (pkColsInfo []*ColumnInfo) { 536 if len(uniqueKeyColumns) == 0 { 537 // No PK override 538 return colInfos 539 } 540 // A unique key is specified. We will re-assess colInfos based on the unique key 541 return recalculatePKColsInfoByColumnNames(uniqueKeyColumns, colInfos) 542 } 543 544 // analyzePK builds tpb.pkCols. 545 // Input cols must include all columns which participate in the PRIMARY KEY or the chosen UniqueKey. 546 // It's OK to also include columns not in the key. 547 // Input cols should be ordered according to key ordinal. 548 // e.g. if "UNIQUE KEY(c5,c2)" then we expect c5 to come before c2 549 func (tpb *tablePlanBuilder) analyzePK(cols []*ColumnInfo) error { 550 for _, col := range cols { 551 if !col.IsPK { 552 continue 553 } 554 if col.IsGenerated { 555 // It's possible that a GENERATED column is part of the PRIMARY KEY. That's valid. 556 // But then, we also know that we don't actually SELECT a GENERATED column, we just skip 557 // it silently and let it re-materialize by MySQL itself on the target. 558 continue 559 } 560 cexpr := tpb.findCol(sqlparser.NewIdentifierCI(col.Name)) 561 if cexpr == nil { 562 // TODO(shlomi): at some point in the futue we want to make this check stricter. 563 // We could be reading a generated column c1 which in turn selects some other column c2. 564 // We will want t oensure that `c2` is found in select list... 565 return fmt.Errorf("primary key column %v not found in select list", col) 566 } 567 if cexpr.operation != opExpr { 568 return fmt.Errorf("primary key column %v is not allowed to reference an aggregate expression", col) 569 } 570 cexpr.isPK = true 571 cexpr.dataType = col.DataType 572 cexpr.columnType = col.ColumnType 573 tpb.pkCols = append(tpb.pkCols, cexpr) 574 } 575 return nil 576 } 577 578 // analyzeExtraSourcePkCols builds tpb.extraSourcePkCols. 579 // Vreplication allows source and target tables to use different unique keys. Normally, both will 580 // use same PRIMARY KEY. Other times, same other UNIQUE KEY. Byut it's possible that cource and target 581 // unique keys will only have partial (or empty) shared list of columns. 582 // To be able to generate UPDATE/DELETE queries correctly, we need to know the identities of the 583 // source unique key columns, that are not already part of the target unique key columns. We call 584 // those columns "extra source pk columns". We will use them in the `WHERE` clause. 585 func (tpb *tablePlanBuilder) analyzeExtraSourcePkCols(colInfos []*ColumnInfo, sourceKeyTargetColumnNames []string) error { 586 sourceKeyTargetColumnNamesMap := map[string]bool{} 587 for _, name := range sourceKeyTargetColumnNames { 588 sourceKeyTargetColumnNamesMap[name] = true 589 } 590 591 for _, col := range colInfos { 592 if !sourceKeyTargetColumnNamesMap[col.Name] { 593 // This column is not interesting 594 continue 595 } 596 597 if cexpr := findCol(sqlparser.NewIdentifierCI(col.Name), tpb.pkCols); cexpr != nil { 598 // Column is already found in pkCols. It's not an "extra" column 599 continue 600 } 601 if cexpr := findCol(sqlparser.NewIdentifierCI(col.Name), tpb.colExprs); cexpr != nil { 602 tpb.extraSourcePkCols = append(tpb.extraSourcePkCols, cexpr) 603 } else { 604 // Column not found 605 if !col.IsGenerated { 606 // We shouldn't get here in any normal scenario. If a column is part of colInfos, 607 // then it must also exist in tpb.colExprs. 608 return fmt.Errorf("column %s not found in table expressions", col.Name) 609 } 610 } 611 } 612 return nil 613 } 614 615 // findCol finds a column in a list of expressions 616 func findCol(name sqlparser.IdentifierCI, exprs []*colExpr) *colExpr { 617 for _, cexpr := range exprs { 618 if cexpr.colName.Equal(name) { 619 return cexpr 620 } 621 } 622 return nil 623 } 624 625 func (tpb *tablePlanBuilder) findCol(name sqlparser.IdentifierCI) *colExpr { 626 return findCol(name, tpb.colExprs) 627 } 628 629 func (tpb *tablePlanBuilder) generateInsertStatement() *sqlparser.ParsedQuery { 630 bvf := &bindvarFormatter{} 631 buf := sqlparser.NewTrackedBuffer(bvf.formatter) 632 633 tpb.generateInsertPart(buf) 634 if tpb.lastpk == nil { 635 // If there's no lastpk, generate straight values. 636 buf.Myprintf(" values ", tpb.name) 637 tpb.generateValuesPart(buf, bvf) 638 } else { 639 // If there is a lastpk, generate values as a select from dual 640 // where the pks < lastpk 641 tpb.generateSelectPart(buf, bvf) 642 } 643 tpb.generateOnDupPart(buf) 644 645 return buf.ParsedQuery() 646 } 647 648 func (tpb *tablePlanBuilder) generateInsertPart(buf *sqlparser.TrackedBuffer) *sqlparser.ParsedQuery { 649 if tpb.onInsert == insertIgnore { 650 buf.Myprintf("insert ignore into %v(", tpb.name) 651 } else { 652 buf.Myprintf("insert into %v(", tpb.name) 653 } 654 separator := "" 655 for _, cexpr := range tpb.colExprs { 656 if tpb.isColumnGenerated(cexpr.colName) { 657 continue 658 } 659 buf.Myprintf("%s%v", separator, cexpr.colName) 660 separator = "," 661 } 662 buf.Myprintf(")", tpb.name) 663 return buf.ParsedQuery() 664 } 665 666 func (tpb *tablePlanBuilder) generateValuesPart(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) *sqlparser.ParsedQuery { 667 bvf.mode = bvAfter 668 separator := "(" 669 for _, cexpr := range tpb.colExprs { 670 if tpb.isColumnGenerated(cexpr.colName) { 671 continue 672 } 673 buf.Myprintf("%s", separator) 674 separator = "," 675 switch cexpr.operation { 676 case opExpr: 677 switch cexpr.colType { 678 case querypb.Type_JSON: 679 buf.Myprintf("convert(%v using utf8mb4)", cexpr.expr) 680 case querypb.Type_DATETIME: 681 sourceTZ := tpb.source.SourceTimeZone 682 targetTZ := tpb.source.TargetTimeZone 683 if sourceTZ != "" && targetTZ != "" { 684 buf.Myprintf("convert_tz(%v, '%s', '%s')", cexpr.expr, sourceTZ, targetTZ) 685 } else { 686 buf.Myprintf("%v", cexpr.expr) 687 } 688 default: 689 buf.Myprintf("%v", cexpr.expr) 690 } 691 case opCount: 692 buf.WriteString("1") 693 case opSum: 694 // NULL values must be treated as 0 for SUM. 695 buf.Myprintf("ifnull(%v, 0)", cexpr.expr) 696 } 697 } 698 buf.Myprintf(")") 699 return buf.ParsedQuery() 700 } 701 702 func (tpb *tablePlanBuilder) generateSelectPart(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) *sqlparser.ParsedQuery { 703 bvf.mode = bvAfter 704 buf.WriteString(" select ") 705 separator := "" 706 for _, cexpr := range tpb.colExprs { 707 if tpb.isColumnGenerated(cexpr.colName) { 708 continue 709 } 710 buf.Myprintf("%s", separator) 711 separator = ", " 712 switch cexpr.operation { 713 case opExpr: 714 buf.Myprintf("%v", cexpr.expr) 715 case opCount: 716 buf.WriteString("1") 717 case opSum: 718 buf.Myprintf("ifnull(%v, 0)", cexpr.expr) 719 } 720 } 721 buf.WriteString(" from dual where ") 722 tpb.generatePKConstraint(buf, bvf) 723 return buf.ParsedQuery() 724 } 725 726 func (tpb *tablePlanBuilder) generateOnDupPart(buf *sqlparser.TrackedBuffer) *sqlparser.ParsedQuery { 727 if tpb.onInsert != insertOnDup { 728 return nil 729 } 730 buf.Myprintf(" on duplicate key update ") 731 separator := "" 732 for _, cexpr := range tpb.colExprs { 733 // We don't know of a use case where the group by columns 734 // don't match the pk of a table. But we'll allow this, 735 // and won't update the pk column with the new value if 736 // this does happen. This can be revisited if there's 737 // a legitimate use case in the future that demands 738 // a different behavior. This rule is applied uniformly 739 // for updates and deletes also. 740 if cexpr.isGrouped || cexpr.isPK { 741 continue 742 } 743 if tpb.isColumnGenerated(cexpr.colName) { 744 continue 745 } 746 buf.Myprintf("%s%v=", separator, cexpr.colName) 747 separator = ", " 748 switch cexpr.operation { 749 case opExpr: 750 buf.Myprintf("values(%v)", cexpr.colName) 751 case opCount: 752 buf.Myprintf("%v+1", cexpr.colName) 753 case opSum: 754 buf.Myprintf("%v", cexpr.colName) 755 buf.Myprintf("+ifnull(values(%v), 0)", cexpr.colName) 756 } 757 } 758 return buf.ParsedQuery() 759 } 760 761 func (tpb *tablePlanBuilder) generateUpdateStatement() *sqlparser.ParsedQuery { 762 if tpb.onInsert == insertIgnore { 763 return tpb.generateInsertStatement() 764 } 765 bvf := &bindvarFormatter{} 766 buf := sqlparser.NewTrackedBuffer(bvf.formatter) 767 buf.Myprintf("update %v set ", tpb.name) 768 separator := "" 769 for _, cexpr := range tpb.colExprs { 770 if cexpr.isGrouped || cexpr.isPK { 771 continue 772 } 773 if tpb.isColumnGenerated(cexpr.colName) { 774 continue 775 } 776 buf.Myprintf("%s%v=", separator, cexpr.colName) 777 separator = ", " 778 switch cexpr.operation { 779 case opExpr: 780 bvf.mode = bvAfter 781 switch cexpr.colType { 782 case querypb.Type_JSON: 783 buf.Myprintf("convert(%v using utf8mb4)", cexpr.expr) 784 case querypb.Type_DATETIME: 785 sourceTZ := tpb.source.SourceTimeZone 786 targetTZ := tpb.source.TargetTimeZone 787 if sourceTZ != "" && targetTZ != "" { 788 buf.Myprintf("convert_tz(%v, '%s', '%s')", cexpr.expr, sourceTZ, targetTZ) 789 } else { 790 buf.Myprintf("%v", cexpr.expr) 791 } 792 default: 793 buf.Myprintf("%v", cexpr.expr) 794 } 795 case opCount: 796 buf.Myprintf("%v", cexpr.colName) 797 case opSum: 798 buf.Myprintf("%v", cexpr.colName) 799 bvf.mode = bvBefore 800 buf.Myprintf("-ifnull(%v, 0)", cexpr.expr) 801 bvf.mode = bvAfter 802 buf.Myprintf("+ifnull(%v, 0)", cexpr.expr) 803 } 804 } 805 tpb.generateWhere(buf, bvf) 806 return buf.ParsedQuery() 807 } 808 809 func (tpb *tablePlanBuilder) generateDeleteStatement() *sqlparser.ParsedQuery { 810 bvf := &bindvarFormatter{} 811 buf := sqlparser.NewTrackedBuffer(bvf.formatter) 812 switch tpb.onInsert { 813 case insertNormal: 814 buf.Myprintf("delete from %v", tpb.name) 815 tpb.generateWhere(buf, bvf) 816 case insertOnDup: 817 bvf.mode = bvBefore 818 buf.Myprintf("update %v set ", tpb.name) 819 separator := "" 820 for _, cexpr := range tpb.colExprs { 821 if cexpr.isGrouped || cexpr.isPK { 822 continue 823 } 824 buf.Myprintf("%s%v=", separator, cexpr.colName) 825 separator = ", " 826 switch cexpr.operation { 827 case opExpr: 828 buf.WriteString("null") 829 case opCount: 830 buf.Myprintf("%v-1", cexpr.colName) 831 case opSum: 832 buf.Myprintf("%v-ifnull(%v, 0)", cexpr.colName, cexpr.expr) 833 } 834 } 835 tpb.generateWhere(buf, bvf) 836 case insertIgnore: 837 return nil 838 } 839 return buf.ParsedQuery() 840 } 841 842 func (tpb *tablePlanBuilder) generateWhere(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) { 843 buf.WriteString(" where ") 844 bvf.mode = bvBefore 845 separator := "" 846 847 addWhereColumns := func(colExprs []*colExpr) { 848 for _, cexpr := range colExprs { 849 if _, ok := cexpr.expr.(*sqlparser.ColName); ok { 850 buf.Myprintf("%s%v=", separator, cexpr.colName) 851 buf.Myprintf("%v", cexpr.expr) 852 } else { 853 // Parenthesize non-trivial expressions. 854 buf.Myprintf("%s%v=(", separator, cexpr.colName) 855 buf.Myprintf("%v", cexpr.expr) 856 buf.Myprintf(")") 857 } 858 separator = " and " 859 } 860 } 861 addWhereColumns(tpb.pkCols) 862 addWhereColumns(tpb.extraSourcePkCols) 863 if tpb.lastpk != nil { 864 buf.WriteString(" and ") 865 tpb.generatePKConstraint(buf, bvf) 866 } 867 } 868 869 func (tpb *tablePlanBuilder) getCharsetAndCollation(pkname string) (charSet string, collation string) { 870 for _, colInfo := range tpb.colInfos { 871 if colInfo.IsPK && strings.EqualFold(colInfo.Name, pkname) { 872 if colInfo.CharSet != "" { 873 charSet = fmt.Sprintf(" _%s ", colInfo.CharSet) 874 } 875 if colInfo.Collation != "" { 876 collation = fmt.Sprintf(" COLLATE %s ", colInfo.Collation) 877 } 878 } 879 } 880 return charSet, collation 881 } 882 883 func (tpb *tablePlanBuilder) generatePKConstraint(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) { 884 type charSetCollation struct { 885 charSet string 886 collation string 887 } 888 var charSetCollations []*charSetCollation 889 separator := "(" 890 for _, pkname := range tpb.lastpk.Fields { 891 charSet, collation := tpb.getCharsetAndCollation(pkname.Name) 892 charSetCollations = append(charSetCollations, &charSetCollation{charSet: charSet, collation: collation}) 893 buf.Myprintf("%s%s%v%s", separator, charSet, &sqlparser.ColName{Name: sqlparser.NewIdentifierCI(pkname.Name)}, collation) 894 separator = "," 895 } 896 separator = ") <= (" 897 for i, val := range tpb.lastpk.Rows[0] { 898 buf.WriteString(separator) 899 buf.WriteString(charSetCollations[i].charSet) 900 separator = "," 901 val.EncodeSQL(buf) 902 buf.WriteString(charSetCollations[i].collation) 903 } 904 buf.WriteString(")") 905 } 906 907 func (tpb *tablePlanBuilder) isColumnGenerated(col sqlparser.IdentifierCI) bool { 908 for _, colInfo := range tpb.colInfos { 909 if col.EqualString(colInfo.Name) && colInfo.IsGenerated { 910 return true 911 } 912 } 913 return false 914 } 915 916 // bindvarFormatter is a dual mode formatter. Its behavior 917 // can be changed dynamically changed to generate bind vars 918 // for the 'before' row or 'after' row by setting its mode 919 // to 'bvBefore' or 'bvAfter'. For example, inserts will always 920 // use bvAfter, whereas deletes will always use bvBefore. 921 // For updates, values being set will use bvAfter, whereas 922 // the where clause will use bvBefore. 923 type bindvarFormatter struct { 924 mode bindvarMode 925 } 926 927 type bindvarMode int 928 929 const ( 930 bvBefore = bindvarMode(iota) 931 bvAfter 932 ) 933 934 func (bvf *bindvarFormatter) formatter(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) { 935 if node, ok := node.(*sqlparser.ColName); ok { 936 switch bvf.mode { 937 case bvBefore: 938 buf.WriteArg(":", "b_"+node.Name.String()) 939 return 940 case bvAfter: 941 buf.WriteArg(":", "a_"+node.Name.String()) 942 return 943 } 944 } 945 node.Format(buf) 946 }