vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletmanager/vreplication/table_plan_builder.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vreplication
    18  
    19  import (
    20  	"fmt"
    21  	"regexp"
    22  	"sort"
    23  	"strings"
    24  
    25  	"vitess.io/vitess/go/sqltypes"
    26  	"vitess.io/vitess/go/textutil"
    27  	"vitess.io/vitess/go/vt/binlog/binlogplayer"
    28  	"vitess.io/vitess/go/vt/key"
    29  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    30  	querypb "vitess.io/vitess/go/vt/proto/query"
    31  	"vitess.io/vitess/go/vt/schema"
    32  	"vitess.io/vitess/go/vt/sqlparser"
    33  )
    34  
    35  // This file contains just the builders for ReplicatorPlan and TablePlan.
    36  // ReplicatorPlan and TablePlan are in replicator_plan.go.
    37  // TODO(sougou): reorganize this in a better fashion.
    38  
    39  // ExcludeStr is the filter value for excluding tables that match a rule.
    40  // TODO(sougou): support this on vstreamer side also.
    41  const ExcludeStr = "exclude"
    42  
    43  // tablePlanBuilder contains the metadata needed for building a TablePlan.
    44  type tablePlanBuilder struct {
    45  	name       sqlparser.IdentifierCS
    46  	sendSelect *sqlparser.Select
    47  	// selColumns keeps track of the columns we want to pull from source.
    48  	// If Lastpk is set, we compare this list against the table's pk and
    49  	// add missing references.
    50  	colExprs          []*colExpr
    51  	onInsert          insertType
    52  	pkCols            []*colExpr
    53  	extraSourcePkCols []*colExpr
    54  	lastpk            *sqltypes.Result
    55  	colInfos          []*ColumnInfo
    56  	stats             *binlogplayer.Stats
    57  	source            *binlogdatapb.BinlogSource
    58  }
    59  
    60  // colExpr describes the processing to be performed to
    61  // compute the value of one column of the target table.
    62  type colExpr struct {
    63  	colName sqlparser.IdentifierCI
    64  	colType querypb.Type
    65  	// operation==opExpr: full expression is set
    66  	// operation==opCount: nothing is set.
    67  	// operation==opSum: for 'sum(a)', expr is set to 'a'.
    68  	operation operation
    69  	// expr stores the expected field name from vstreamer and dictates
    70  	// the generated bindvar names, like a_col or b_col.
    71  	expr sqlparser.Expr
    72  	// references contains all the column names referenced in the expression.
    73  	references map[string]bool
    74  
    75  	isGrouped  bool
    76  	isPK       bool
    77  	dataType   string
    78  	columnType string
    79  }
    80  
    81  // operation is the opcode for the colExpr.
    82  type operation int
    83  
    84  // The following values are the various colExpr opcodes.
    85  const (
    86  	opExpr = operation(iota)
    87  	opCount
    88  	opSum
    89  )
    90  
    91  // insertType describes the type of insert statement to generate.
    92  // Please refer to TestBuildPlayerPlan for examples.
    93  type insertType int
    94  
    95  // The following values are the various insert types.
    96  const (
    97  	// insertNormal is for normal selects without a group by, like
    98  	// "select a+b as c from t".
    99  	insertNormal = insertType(iota)
   100  	// insertOnDup is for the more traditional grouped expressions, like
   101  	// "select a, b, count(*) as c from t group by a". For statements
   102  	// like these, "insert.. on duplicate key" statements will be generated
   103  	// causing "b" to be updated to the latest value (last value wins).
   104  	insertOnDup
   105  	// insertIgnore is for special grouped expressions where all columns are
   106  	// in the group by, like "select a, b, c from t group by a, b, c".
   107  	// This generates "insert ignore" statements (first value wins).
   108  	insertIgnore
   109  )
   110  
   111  // buildReplicatorPlan builds a ReplicatorPlan for the tables that match the filter.
   112  // The filter is matched against the target schema. For every table matched,
   113  // a table-specific rule is built to be sent to the source. We don't send the
   114  // original rule to the source because it may not match the same tables as the
   115  // target.
   116  // colInfoMap specifies the list of primary key columns for each table.
   117  // copyState is a map of tables that have not been fully copied yet.
   118  // If a table is not present in copyState, then it has been fully copied. If so,
   119  // all replication events are applied. The table still has to match a Filter.Rule.
   120  // If it has a non-nil entry, then the value is the last primary key (lastpk)
   121  // that was copied.  If so, only replication events < lastpk are applied.
   122  // If the entry is nil, then copying of the table has not started yet. If so,
   123  // no events are applied.
   124  // The TablePlan built is a partial plan. The full plan for a table is built
   125  // when we receive field information from events or rows sent by the source.
   126  // buildExecutionPlan is the function that builds the full plan.
   127  func buildReplicatorPlan(source *binlogdatapb.BinlogSource, colInfoMap map[string][]*ColumnInfo, copyState map[string]*sqltypes.Result, stats *binlogplayer.Stats) (*ReplicatorPlan, error) {
   128  	filter := source.Filter
   129  	plan := &ReplicatorPlan{
   130  		VStreamFilter: &binlogdatapb.Filter{FieldEventMode: filter.FieldEventMode},
   131  		TargetTables:  make(map[string]*TablePlan),
   132  		TablePlans:    make(map[string]*TablePlan),
   133  		ColInfoMap:    colInfoMap,
   134  		stats:         stats,
   135  		Source:        source,
   136  	}
   137  	for tableName := range colInfoMap {
   138  		lastpk, ok := copyState[tableName]
   139  		if ok && lastpk == nil {
   140  			// Don't replicate uncopied tables.
   141  			continue
   142  		}
   143  		rule, err := MatchTable(tableName, filter)
   144  		if err != nil {
   145  			return nil, err
   146  		}
   147  		if rule == nil {
   148  			continue
   149  		}
   150  		colInfos, ok := colInfoMap[tableName]
   151  		if !ok {
   152  			return nil, fmt.Errorf("table %s not found in schema", tableName)
   153  		}
   154  		tablePlan, err := buildTablePlan(tableName, rule, colInfos, lastpk, stats, source)
   155  		if err != nil {
   156  			return nil, err
   157  		}
   158  		if tablePlan == nil {
   159  			// Table was excluded.
   160  			continue
   161  		}
   162  		if dup, ok := plan.TablePlans[tablePlan.SendRule.Match]; ok {
   163  			return nil, fmt.Errorf("more than one target for source table %s: %s and %s", tablePlan.SendRule.Match, dup.TargetName, tableName)
   164  		}
   165  		plan.VStreamFilter.Rules = append(plan.VStreamFilter.Rules, tablePlan.SendRule)
   166  		plan.TargetTables[tableName] = tablePlan
   167  		plan.TablePlans[tablePlan.SendRule.Match] = tablePlan
   168  	}
   169  	return plan, nil
   170  }
   171  
   172  // MatchTable is similar to tableMatches and buildPlan defined in vstreamer/planbuilder.go.
   173  func MatchTable(tableName string, filter *binlogdatapb.Filter) (*binlogdatapb.Rule, error) {
   174  	for _, rule := range filter.Rules {
   175  		switch {
   176  		case strings.HasPrefix(rule.Match, "/"):
   177  			expr := strings.Trim(rule.Match, "/")
   178  			result, err := regexp.MatchString(expr, tableName)
   179  			if err != nil {
   180  				return nil, err
   181  			}
   182  			if !result {
   183  				continue
   184  			}
   185  			return rule, nil
   186  		case tableName == rule.Match:
   187  			return rule, nil
   188  		}
   189  	}
   190  	return nil, nil
   191  }
   192  
   193  func buildTablePlan(tableName string, rule *binlogdatapb.Rule, colInfos []*ColumnInfo, lastpk *sqltypes.Result,
   194  	stats *binlogplayer.Stats, source *binlogdatapb.BinlogSource) (*TablePlan, error) {
   195  
   196  	filter := rule.Filter
   197  	query := filter
   198  	// generate equivalent select statement if filter is empty or a keyrange.
   199  	switch {
   200  	case filter == "":
   201  		buf := sqlparser.NewTrackedBuffer(nil)
   202  		buf.Myprintf("select * from %v", sqlparser.NewIdentifierCS(tableName))
   203  		query = buf.String()
   204  	case key.IsKeyRange(filter):
   205  		buf := sqlparser.NewTrackedBuffer(nil)
   206  		buf.Myprintf("select * from %v where in_keyrange(%v)", sqlparser.NewIdentifierCS(tableName), sqlparser.NewStrLiteral(filter))
   207  		query = buf.String()
   208  	case filter == ExcludeStr:
   209  		return nil, nil
   210  	}
   211  	sel, fromTable, err := analyzeSelectFrom(query)
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  	sendRule := &binlogdatapb.Rule{
   216  		Match: fromTable,
   217  	}
   218  
   219  	enumValuesMap := map[string](map[string]string){}
   220  	for k, v := range rule.ConvertEnumToText {
   221  		tokensMap := schema.ParseEnumOrSetTokensMap(v)
   222  		enumValuesMap[k] = tokensMap
   223  	}
   224  
   225  	if expr, ok := sel.SelectExprs[0].(*sqlparser.StarExpr); ok {
   226  		// If it's a "select *", we return a partial plan, and complete
   227  		// it when we get back field info from the stream.
   228  		if len(sel.SelectExprs) != 1 {
   229  			return nil, fmt.Errorf("unexpected: %v", sqlparser.String(sel))
   230  		}
   231  		if !expr.TableName.IsEmpty() {
   232  			return nil, fmt.Errorf("unsupported qualifier for '*' expression: %v", sqlparser.String(expr))
   233  		}
   234  		sendRule.Filter = query
   235  		tablePlan := &TablePlan{
   236  			TargetName:       tableName,
   237  			SendRule:         sendRule,
   238  			Lastpk:           lastpk,
   239  			Stats:            stats,
   240  			EnumValuesMap:    enumValuesMap,
   241  			ConvertCharset:   rule.ConvertCharset,
   242  			ConvertIntToEnum: rule.ConvertIntToEnum,
   243  		}
   244  
   245  		return tablePlan, nil
   246  	}
   247  
   248  	tpb := &tablePlanBuilder{
   249  		name: sqlparser.NewIdentifierCS(tableName),
   250  		sendSelect: &sqlparser.Select{
   251  			From:  sel.From,
   252  			Where: sel.Where,
   253  		},
   254  		lastpk:   lastpk,
   255  		colInfos: colInfos,
   256  		stats:    stats,
   257  		source:   source,
   258  	}
   259  
   260  	if err := tpb.analyzeExprs(sel.SelectExprs); err != nil {
   261  		return nil, err
   262  	}
   263  	// It's possible that the target table does not materialize all
   264  	// the primary keys of the source table. In such situations,
   265  	// we still have to be able to validate the incoming event
   266  	// against the current lastpk. For this, we have to request
   267  	// the missing columns so we can compare against those values.
   268  	// If there is no lastpk to validate against, then we don't
   269  	// care.
   270  	if tpb.lastpk != nil {
   271  		for _, f := range tpb.lastpk.Fields {
   272  			tpb.addCol(sqlparser.NewIdentifierCI(f.Name))
   273  		}
   274  	}
   275  	if err := tpb.analyzeGroupBy(sel.GroupBy); err != nil {
   276  		return nil, err
   277  	}
   278  	targetKeyColumnNames, err := textutil.SplitUnescape(rule.TargetUniqueKeyColumns, ",")
   279  	if err != nil {
   280  		return nil, err
   281  	}
   282  	pkColsInfo := tpb.getPKColsInfo(targetKeyColumnNames, colInfos)
   283  	if err := tpb.analyzePK(pkColsInfo); err != nil {
   284  		return nil, err
   285  	}
   286  
   287  	sourceKeyTargetColumnNames, err := textutil.SplitUnescape(rule.SourceUniqueKeyTargetColumns, ",")
   288  	if err != nil {
   289  		return nil, err
   290  	}
   291  	if err := tpb.analyzeExtraSourcePkCols(colInfos, sourceKeyTargetColumnNames); err != nil {
   292  		return nil, err
   293  	}
   294  
   295  	// if there are no columns being selected the select expression can be empty, so we "select 1" so we have a valid
   296  	// select to get a row back
   297  	if len(tpb.sendSelect.SelectExprs) == 0 {
   298  		tpb.sendSelect.SelectExprs = sqlparser.SelectExprs([]sqlparser.SelectExpr{
   299  			&sqlparser.AliasedExpr{
   300  				Expr: sqlparser.NewIntLiteral("1"),
   301  			},
   302  		})
   303  	}
   304  	commentsList := []string{}
   305  	if rule.SourceUniqueKeyColumns != "" {
   306  		commentsList = append(commentsList, fmt.Sprintf(`ukColumns="%s"`, rule.SourceUniqueKeyColumns))
   307  	}
   308  	if len(commentsList) > 0 {
   309  		comments := sqlparser.Comments{
   310  			fmt.Sprintf(`/*vt+ %s */`, strings.Join(commentsList, " ")),
   311  		}
   312  		tpb.sendSelect.Comments = comments.Parsed()
   313  	}
   314  	sendRule.Filter = sqlparser.String(tpb.sendSelect)
   315  
   316  	tablePlan := tpb.generate()
   317  	tablePlan.SendRule = sendRule
   318  	tablePlan.EnumValuesMap = enumValuesMap
   319  	tablePlan.ConvertCharset = rule.ConvertCharset
   320  	tablePlan.ConvertIntToEnum = rule.ConvertIntToEnum
   321  	return tablePlan, nil
   322  }
   323  
   324  func (tpb *tablePlanBuilder) generate() *TablePlan {
   325  	refmap := make(map[string]bool)
   326  	for _, cexpr := range tpb.pkCols {
   327  		for k := range cexpr.references {
   328  			refmap[k] = true
   329  		}
   330  	}
   331  	if tpb.lastpk != nil {
   332  		for _, f := range tpb.lastpk.Fields {
   333  			refmap[f.Name] = true
   334  		}
   335  	}
   336  	pkrefs := make([]string, 0, len(refmap))
   337  	for k := range refmap {
   338  		pkrefs = append(pkrefs, k)
   339  	}
   340  	sort.Strings(pkrefs)
   341  
   342  	bvf := &bindvarFormatter{}
   343  
   344  	fieldsToSkip := make(map[string]bool)
   345  	for _, colInfo := range tpb.colInfos {
   346  		if colInfo.IsGenerated {
   347  			fieldsToSkip[colInfo.Name] = true
   348  		}
   349  	}
   350  
   351  	return &TablePlan{
   352  		TargetName:              tpb.name.String(),
   353  		Lastpk:                  tpb.lastpk,
   354  		BulkInsertFront:         tpb.generateInsertPart(sqlparser.NewTrackedBuffer(bvf.formatter)),
   355  		BulkInsertValues:        tpb.generateValuesPart(sqlparser.NewTrackedBuffer(bvf.formatter), bvf),
   356  		BulkInsertOnDup:         tpb.generateOnDupPart(sqlparser.NewTrackedBuffer(bvf.formatter)),
   357  		Insert:                  tpb.generateInsertStatement(),
   358  		Update:                  tpb.generateUpdateStatement(),
   359  		Delete:                  tpb.generateDeleteStatement(),
   360  		PKReferences:            pkrefs,
   361  		Stats:                   tpb.stats,
   362  		FieldsToSkip:            fieldsToSkip,
   363  		HasExtraSourcePkColumns: (len(tpb.extraSourcePkCols) > 0),
   364  	}
   365  }
   366  
   367  func analyzeSelectFrom(query string) (sel *sqlparser.Select, from string, err error) {
   368  	statement, err := sqlparser.Parse(query)
   369  	if err != nil {
   370  		return nil, "", err
   371  	}
   372  	sel, ok := statement.(*sqlparser.Select)
   373  	if !ok {
   374  		return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(statement))
   375  	}
   376  	if sel.Distinct {
   377  		return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel))
   378  	}
   379  	if len(sel.From) > 1 {
   380  		return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel))
   381  	}
   382  	node, ok := sel.From[0].(*sqlparser.AliasedTableExpr)
   383  	if !ok {
   384  		return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel))
   385  	}
   386  	fromTable := sqlparser.GetTableName(node.Expr)
   387  	if fromTable.IsEmpty() {
   388  		return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel))
   389  	}
   390  	return sel, fromTable.String(), nil
   391  }
   392  
   393  func (tpb *tablePlanBuilder) analyzeExprs(selExprs sqlparser.SelectExprs) error {
   394  	for _, selExpr := range selExprs {
   395  		cexpr, err := tpb.analyzeExpr(selExpr)
   396  		if err != nil {
   397  			return err
   398  		}
   399  		tpb.colExprs = append(tpb.colExprs, cexpr)
   400  	}
   401  	return nil
   402  }
   403  
   404  func (tpb *tablePlanBuilder) analyzeExpr(selExpr sqlparser.SelectExpr) (*colExpr, error) {
   405  	aliased, ok := selExpr.(*sqlparser.AliasedExpr)
   406  	if !ok {
   407  		return nil, fmt.Errorf("unexpected: %v", sqlparser.String(selExpr))
   408  	}
   409  	as := aliased.As
   410  	if as.IsEmpty() {
   411  		// Require all non-trivial expressions to have an alias.
   412  		if colAs, ok := aliased.Expr.(*sqlparser.ColName); ok && colAs.Qualifier.IsEmpty() {
   413  			as = colAs.Name
   414  		} else {
   415  			return nil, fmt.Errorf("expression needs an alias: %v", sqlparser.String(aliased))
   416  		}
   417  	}
   418  	cexpr := &colExpr{
   419  		colName:    as,
   420  		references: make(map[string]bool),
   421  	}
   422  	if expr, ok := aliased.Expr.(*sqlparser.ConvertUsingExpr); ok {
   423  		selExpr := &sqlparser.ConvertUsingExpr{
   424  			Type: "utf8mb4",
   425  			Expr: &sqlparser.ColName{Name: as},
   426  		}
   427  		cexpr.expr = expr
   428  		cexpr.operation = opExpr
   429  		tpb.sendSelect.SelectExprs = append(tpb.sendSelect.SelectExprs, &sqlparser.AliasedExpr{Expr: selExpr, As: as})
   430  		cexpr.references[as.String()] = true
   431  		return cexpr, nil
   432  	}
   433  	if expr, ok := aliased.Expr.(*sqlparser.FuncExpr); ok {
   434  		switch fname := expr.Name.Lowered(); fname {
   435  		case "keyspace_id":
   436  			if len(expr.Exprs) != 0 {
   437  				return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr))
   438  			}
   439  			tpb.sendSelect.SelectExprs = append(tpb.sendSelect.SelectExprs, &sqlparser.AliasedExpr{Expr: aliased.Expr})
   440  			// The vstreamer responds with "keyspace_id" as the field name for this request.
   441  			cexpr.expr = &sqlparser.ColName{Name: sqlparser.NewIdentifierCI("keyspace_id")}
   442  			return cexpr, nil
   443  		}
   444  	}
   445  	if expr, ok := aliased.Expr.(sqlparser.AggrFunc); ok {
   446  		if expr.IsDistinct() {
   447  			return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr))
   448  		}
   449  		switch fname := strings.ToLower(expr.AggrName()); fname {
   450  		case "count":
   451  			if _, ok := expr.(*sqlparser.CountStar); !ok {
   452  				return nil, fmt.Errorf("only count(*) is supported: %v", sqlparser.String(expr))
   453  			}
   454  			cexpr.operation = opCount
   455  			return cexpr, nil
   456  		case "sum":
   457  			if len(expr.GetArgs()) != 1 {
   458  				return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr))
   459  			}
   460  			innerCol, ok := expr.GetArg().(*sqlparser.ColName)
   461  			if !ok {
   462  				return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr))
   463  			}
   464  			if !innerCol.Qualifier.IsEmpty() {
   465  				return nil, fmt.Errorf("unsupported qualifier for column: %v", sqlparser.String(innerCol))
   466  			}
   467  			cexpr.operation = opSum
   468  			cexpr.expr = innerCol
   469  			tpb.addCol(innerCol.Name)
   470  			cexpr.references[innerCol.Name.String()] = true
   471  			return cexpr, nil
   472  		}
   473  	}
   474  	err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   475  		switch node := node.(type) {
   476  		case *sqlparser.ColName:
   477  			if !node.Qualifier.IsEmpty() {
   478  				return false, fmt.Errorf("unsupported qualifier for column: %v", sqlparser.String(node))
   479  			}
   480  			tpb.addCol(node.Name)
   481  			cexpr.references[node.Name.String()] = true
   482  		case *sqlparser.Subquery:
   483  			return false, fmt.Errorf("unsupported subquery: %v", sqlparser.String(node))
   484  		case sqlparser.AggrFunc:
   485  			return false, fmt.Errorf("unexpected: %v", sqlparser.String(node))
   486  		}
   487  		return true, nil
   488  	}, aliased.Expr)
   489  	if err != nil {
   490  		return nil, err
   491  	}
   492  	cexpr.expr = aliased.Expr
   493  	return cexpr, nil
   494  }
   495  
   496  // addCol adds the specified column to the send query
   497  // if it's not already present.
   498  func (tpb *tablePlanBuilder) addCol(ident sqlparser.IdentifierCI) {
   499  	tpb.sendSelect.SelectExprs = append(tpb.sendSelect.SelectExprs, &sqlparser.AliasedExpr{
   500  		Expr: &sqlparser.ColName{Name: ident},
   501  	})
   502  }
   503  
   504  func (tpb *tablePlanBuilder) analyzeGroupBy(groupBy sqlparser.GroupBy) error {
   505  	if groupBy == nil {
   506  		// If there's no grouping, the it's an insertNormal.
   507  		return nil
   508  	}
   509  	for _, expr := range groupBy {
   510  		colname, ok := expr.(*sqlparser.ColName)
   511  		if !ok {
   512  			return fmt.Errorf("unexpected: %v", sqlparser.String(expr))
   513  		}
   514  		cexpr := tpb.findCol(colname.Name)
   515  		if cexpr == nil {
   516  			return fmt.Errorf("group by expression does not reference an alias in the select list: %v", sqlparser.String(expr))
   517  		}
   518  		if cexpr.operation != opExpr {
   519  			return fmt.Errorf("group by expression is not allowed to reference an aggregate expression: %v", sqlparser.String(expr))
   520  		}
   521  		cexpr.isGrouped = true
   522  	}
   523  	// If all colExprs are grouped, then it's an insertIgnore.
   524  	tpb.onInsert = insertIgnore
   525  	for _, cExpr := range tpb.colExprs {
   526  		if !cExpr.isGrouped {
   527  			// If some colExprs are not grouped, then it's an insertOnDup.
   528  			tpb.onInsert = insertOnDup
   529  			break
   530  		}
   531  	}
   532  	return nil
   533  }
   534  
   535  func (tpb *tablePlanBuilder) getPKColsInfo(uniqueKeyColumns []string, colInfos []*ColumnInfo) (pkColsInfo []*ColumnInfo) {
   536  	if len(uniqueKeyColumns) == 0 {
   537  		// No PK override
   538  		return colInfos
   539  	}
   540  	// A unique key is specified. We will re-assess colInfos based on the unique key
   541  	return recalculatePKColsInfoByColumnNames(uniqueKeyColumns, colInfos)
   542  }
   543  
   544  // analyzePK builds tpb.pkCols.
   545  // Input cols must include all columns which participate in the PRIMARY KEY or the chosen UniqueKey.
   546  // It's OK to also include columns not in the key.
   547  // Input cols should be ordered according to key ordinal.
   548  // e.g. if "UNIQUE KEY(c5,c2)" then we expect c5 to come before c2
   549  func (tpb *tablePlanBuilder) analyzePK(cols []*ColumnInfo) error {
   550  	for _, col := range cols {
   551  		if !col.IsPK {
   552  			continue
   553  		}
   554  		if col.IsGenerated {
   555  			// It's possible that a GENERATED column is part of the PRIMARY KEY. That's valid.
   556  			// But then, we also know that we don't actually SELECT a GENERATED column, we just skip
   557  			// it silently and let it re-materialize by MySQL itself on the target.
   558  			continue
   559  		}
   560  		cexpr := tpb.findCol(sqlparser.NewIdentifierCI(col.Name))
   561  		if cexpr == nil {
   562  			// TODO(shlomi): at some point in the futue we want to make this check stricter.
   563  			// We could be reading a generated column c1 which in turn selects some other column c2.
   564  			// We will want t oensure that `c2` is found in select list...
   565  			return fmt.Errorf("primary key column %v not found in select list", col)
   566  		}
   567  		if cexpr.operation != opExpr {
   568  			return fmt.Errorf("primary key column %v is not allowed to reference an aggregate expression", col)
   569  		}
   570  		cexpr.isPK = true
   571  		cexpr.dataType = col.DataType
   572  		cexpr.columnType = col.ColumnType
   573  		tpb.pkCols = append(tpb.pkCols, cexpr)
   574  	}
   575  	return nil
   576  }
   577  
   578  // analyzeExtraSourcePkCols builds tpb.extraSourcePkCols.
   579  // Vreplication allows source and target tables to use different unique keys. Normally, both will
   580  // use same PRIMARY KEY. Other times, same other UNIQUE KEY. Byut it's possible that cource and target
   581  // unique keys will only have partial (or empty) shared list of columns.
   582  // To be able to generate UPDATE/DELETE queries correctly, we need to know the identities of the
   583  // source unique key columns, that are not already part of the target unique key columns. We call
   584  // those columns "extra source pk columns". We will use them in the `WHERE` clause.
   585  func (tpb *tablePlanBuilder) analyzeExtraSourcePkCols(colInfos []*ColumnInfo, sourceKeyTargetColumnNames []string) error {
   586  	sourceKeyTargetColumnNamesMap := map[string]bool{}
   587  	for _, name := range sourceKeyTargetColumnNames {
   588  		sourceKeyTargetColumnNamesMap[name] = true
   589  	}
   590  
   591  	for _, col := range colInfos {
   592  		if !sourceKeyTargetColumnNamesMap[col.Name] {
   593  			// This column is not interesting
   594  			continue
   595  		}
   596  
   597  		if cexpr := findCol(sqlparser.NewIdentifierCI(col.Name), tpb.pkCols); cexpr != nil {
   598  			// Column is already found in pkCols. It's not an "extra" column
   599  			continue
   600  		}
   601  		if cexpr := findCol(sqlparser.NewIdentifierCI(col.Name), tpb.colExprs); cexpr != nil {
   602  			tpb.extraSourcePkCols = append(tpb.extraSourcePkCols, cexpr)
   603  		} else {
   604  			// Column not found
   605  			if !col.IsGenerated {
   606  				// We shouldn't get here in any normal scenario. If a column is part of colInfos,
   607  				// then it must also exist in tpb.colExprs.
   608  				return fmt.Errorf("column %s not found in table expressions", col.Name)
   609  			}
   610  		}
   611  	}
   612  	return nil
   613  }
   614  
   615  // findCol finds a column in a list of expressions
   616  func findCol(name sqlparser.IdentifierCI, exprs []*colExpr) *colExpr {
   617  	for _, cexpr := range exprs {
   618  		if cexpr.colName.Equal(name) {
   619  			return cexpr
   620  		}
   621  	}
   622  	return nil
   623  }
   624  
   625  func (tpb *tablePlanBuilder) findCol(name sqlparser.IdentifierCI) *colExpr {
   626  	return findCol(name, tpb.colExprs)
   627  }
   628  
   629  func (tpb *tablePlanBuilder) generateInsertStatement() *sqlparser.ParsedQuery {
   630  	bvf := &bindvarFormatter{}
   631  	buf := sqlparser.NewTrackedBuffer(bvf.formatter)
   632  
   633  	tpb.generateInsertPart(buf)
   634  	if tpb.lastpk == nil {
   635  		// If there's no lastpk, generate straight values.
   636  		buf.Myprintf(" values ", tpb.name)
   637  		tpb.generateValuesPart(buf, bvf)
   638  	} else {
   639  		// If there is a lastpk, generate values as a select from dual
   640  		// where the pks < lastpk
   641  		tpb.generateSelectPart(buf, bvf)
   642  	}
   643  	tpb.generateOnDupPart(buf)
   644  
   645  	return buf.ParsedQuery()
   646  }
   647  
   648  func (tpb *tablePlanBuilder) generateInsertPart(buf *sqlparser.TrackedBuffer) *sqlparser.ParsedQuery {
   649  	if tpb.onInsert == insertIgnore {
   650  		buf.Myprintf("insert ignore into %v(", tpb.name)
   651  	} else {
   652  		buf.Myprintf("insert into %v(", tpb.name)
   653  	}
   654  	separator := ""
   655  	for _, cexpr := range tpb.colExprs {
   656  		if tpb.isColumnGenerated(cexpr.colName) {
   657  			continue
   658  		}
   659  		buf.Myprintf("%s%v", separator, cexpr.colName)
   660  		separator = ","
   661  	}
   662  	buf.Myprintf(")", tpb.name)
   663  	return buf.ParsedQuery()
   664  }
   665  
   666  func (tpb *tablePlanBuilder) generateValuesPart(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) *sqlparser.ParsedQuery {
   667  	bvf.mode = bvAfter
   668  	separator := "("
   669  	for _, cexpr := range tpb.colExprs {
   670  		if tpb.isColumnGenerated(cexpr.colName) {
   671  			continue
   672  		}
   673  		buf.Myprintf("%s", separator)
   674  		separator = ","
   675  		switch cexpr.operation {
   676  		case opExpr:
   677  			switch cexpr.colType {
   678  			case querypb.Type_JSON:
   679  				buf.Myprintf("convert(%v using utf8mb4)", cexpr.expr)
   680  			case querypb.Type_DATETIME:
   681  				sourceTZ := tpb.source.SourceTimeZone
   682  				targetTZ := tpb.source.TargetTimeZone
   683  				if sourceTZ != "" && targetTZ != "" {
   684  					buf.Myprintf("convert_tz(%v, '%s', '%s')", cexpr.expr, sourceTZ, targetTZ)
   685  				} else {
   686  					buf.Myprintf("%v", cexpr.expr)
   687  				}
   688  			default:
   689  				buf.Myprintf("%v", cexpr.expr)
   690  			}
   691  		case opCount:
   692  			buf.WriteString("1")
   693  		case opSum:
   694  			// NULL values must be treated as 0 for SUM.
   695  			buf.Myprintf("ifnull(%v, 0)", cexpr.expr)
   696  		}
   697  	}
   698  	buf.Myprintf(")")
   699  	return buf.ParsedQuery()
   700  }
   701  
   702  func (tpb *tablePlanBuilder) generateSelectPart(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) *sqlparser.ParsedQuery {
   703  	bvf.mode = bvAfter
   704  	buf.WriteString(" select ")
   705  	separator := ""
   706  	for _, cexpr := range tpb.colExprs {
   707  		if tpb.isColumnGenerated(cexpr.colName) {
   708  			continue
   709  		}
   710  		buf.Myprintf("%s", separator)
   711  		separator = ", "
   712  		switch cexpr.operation {
   713  		case opExpr:
   714  			buf.Myprintf("%v", cexpr.expr)
   715  		case opCount:
   716  			buf.WriteString("1")
   717  		case opSum:
   718  			buf.Myprintf("ifnull(%v, 0)", cexpr.expr)
   719  		}
   720  	}
   721  	buf.WriteString(" from dual where ")
   722  	tpb.generatePKConstraint(buf, bvf)
   723  	return buf.ParsedQuery()
   724  }
   725  
   726  func (tpb *tablePlanBuilder) generateOnDupPart(buf *sqlparser.TrackedBuffer) *sqlparser.ParsedQuery {
   727  	if tpb.onInsert != insertOnDup {
   728  		return nil
   729  	}
   730  	buf.Myprintf(" on duplicate key update ")
   731  	separator := ""
   732  	for _, cexpr := range tpb.colExprs {
   733  		// We don't know of a use case where the group by columns
   734  		// don't match the pk of a table. But we'll allow this,
   735  		// and won't update the pk column with the new value if
   736  		// this does happen. This can be revisited if there's
   737  		// a legitimate use case in the future that demands
   738  		// a different behavior. This rule is applied uniformly
   739  		// for updates and deletes also.
   740  		if cexpr.isGrouped || cexpr.isPK {
   741  			continue
   742  		}
   743  		if tpb.isColumnGenerated(cexpr.colName) {
   744  			continue
   745  		}
   746  		buf.Myprintf("%s%v=", separator, cexpr.colName)
   747  		separator = ", "
   748  		switch cexpr.operation {
   749  		case opExpr:
   750  			buf.Myprintf("values(%v)", cexpr.colName)
   751  		case opCount:
   752  			buf.Myprintf("%v+1", cexpr.colName)
   753  		case opSum:
   754  			buf.Myprintf("%v", cexpr.colName)
   755  			buf.Myprintf("+ifnull(values(%v), 0)", cexpr.colName)
   756  		}
   757  	}
   758  	return buf.ParsedQuery()
   759  }
   760  
   761  func (tpb *tablePlanBuilder) generateUpdateStatement() *sqlparser.ParsedQuery {
   762  	if tpb.onInsert == insertIgnore {
   763  		return tpb.generateInsertStatement()
   764  	}
   765  	bvf := &bindvarFormatter{}
   766  	buf := sqlparser.NewTrackedBuffer(bvf.formatter)
   767  	buf.Myprintf("update %v set ", tpb.name)
   768  	separator := ""
   769  	for _, cexpr := range tpb.colExprs {
   770  		if cexpr.isGrouped || cexpr.isPK {
   771  			continue
   772  		}
   773  		if tpb.isColumnGenerated(cexpr.colName) {
   774  			continue
   775  		}
   776  		buf.Myprintf("%s%v=", separator, cexpr.colName)
   777  		separator = ", "
   778  		switch cexpr.operation {
   779  		case opExpr:
   780  			bvf.mode = bvAfter
   781  			switch cexpr.colType {
   782  			case querypb.Type_JSON:
   783  				buf.Myprintf("convert(%v using utf8mb4)", cexpr.expr)
   784  			case querypb.Type_DATETIME:
   785  				sourceTZ := tpb.source.SourceTimeZone
   786  				targetTZ := tpb.source.TargetTimeZone
   787  				if sourceTZ != "" && targetTZ != "" {
   788  					buf.Myprintf("convert_tz(%v, '%s', '%s')", cexpr.expr, sourceTZ, targetTZ)
   789  				} else {
   790  					buf.Myprintf("%v", cexpr.expr)
   791  				}
   792  			default:
   793  				buf.Myprintf("%v", cexpr.expr)
   794  			}
   795  		case opCount:
   796  			buf.Myprintf("%v", cexpr.colName)
   797  		case opSum:
   798  			buf.Myprintf("%v", cexpr.colName)
   799  			bvf.mode = bvBefore
   800  			buf.Myprintf("-ifnull(%v, 0)", cexpr.expr)
   801  			bvf.mode = bvAfter
   802  			buf.Myprintf("+ifnull(%v, 0)", cexpr.expr)
   803  		}
   804  	}
   805  	tpb.generateWhere(buf, bvf)
   806  	return buf.ParsedQuery()
   807  }
   808  
   809  func (tpb *tablePlanBuilder) generateDeleteStatement() *sqlparser.ParsedQuery {
   810  	bvf := &bindvarFormatter{}
   811  	buf := sqlparser.NewTrackedBuffer(bvf.formatter)
   812  	switch tpb.onInsert {
   813  	case insertNormal:
   814  		buf.Myprintf("delete from %v", tpb.name)
   815  		tpb.generateWhere(buf, bvf)
   816  	case insertOnDup:
   817  		bvf.mode = bvBefore
   818  		buf.Myprintf("update %v set ", tpb.name)
   819  		separator := ""
   820  		for _, cexpr := range tpb.colExprs {
   821  			if cexpr.isGrouped || cexpr.isPK {
   822  				continue
   823  			}
   824  			buf.Myprintf("%s%v=", separator, cexpr.colName)
   825  			separator = ", "
   826  			switch cexpr.operation {
   827  			case opExpr:
   828  				buf.WriteString("null")
   829  			case opCount:
   830  				buf.Myprintf("%v-1", cexpr.colName)
   831  			case opSum:
   832  				buf.Myprintf("%v-ifnull(%v, 0)", cexpr.colName, cexpr.expr)
   833  			}
   834  		}
   835  		tpb.generateWhere(buf, bvf)
   836  	case insertIgnore:
   837  		return nil
   838  	}
   839  	return buf.ParsedQuery()
   840  }
   841  
   842  func (tpb *tablePlanBuilder) generateWhere(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) {
   843  	buf.WriteString(" where ")
   844  	bvf.mode = bvBefore
   845  	separator := ""
   846  
   847  	addWhereColumns := func(colExprs []*colExpr) {
   848  		for _, cexpr := range colExprs {
   849  			if _, ok := cexpr.expr.(*sqlparser.ColName); ok {
   850  				buf.Myprintf("%s%v=", separator, cexpr.colName)
   851  				buf.Myprintf("%v", cexpr.expr)
   852  			} else {
   853  				// Parenthesize non-trivial expressions.
   854  				buf.Myprintf("%s%v=(", separator, cexpr.colName)
   855  				buf.Myprintf("%v", cexpr.expr)
   856  				buf.Myprintf(")")
   857  			}
   858  			separator = " and "
   859  		}
   860  	}
   861  	addWhereColumns(tpb.pkCols)
   862  	addWhereColumns(tpb.extraSourcePkCols)
   863  	if tpb.lastpk != nil {
   864  		buf.WriteString(" and ")
   865  		tpb.generatePKConstraint(buf, bvf)
   866  	}
   867  }
   868  
   869  func (tpb *tablePlanBuilder) getCharsetAndCollation(pkname string) (charSet string, collation string) {
   870  	for _, colInfo := range tpb.colInfos {
   871  		if colInfo.IsPK && strings.EqualFold(colInfo.Name, pkname) {
   872  			if colInfo.CharSet != "" {
   873  				charSet = fmt.Sprintf(" _%s ", colInfo.CharSet)
   874  			}
   875  			if colInfo.Collation != "" {
   876  				collation = fmt.Sprintf(" COLLATE %s ", colInfo.Collation)
   877  			}
   878  		}
   879  	}
   880  	return charSet, collation
   881  }
   882  
   883  func (tpb *tablePlanBuilder) generatePKConstraint(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) {
   884  	type charSetCollation struct {
   885  		charSet   string
   886  		collation string
   887  	}
   888  	var charSetCollations []*charSetCollation
   889  	separator := "("
   890  	for _, pkname := range tpb.lastpk.Fields {
   891  		charSet, collation := tpb.getCharsetAndCollation(pkname.Name)
   892  		charSetCollations = append(charSetCollations, &charSetCollation{charSet: charSet, collation: collation})
   893  		buf.Myprintf("%s%s%v%s", separator, charSet, &sqlparser.ColName{Name: sqlparser.NewIdentifierCI(pkname.Name)}, collation)
   894  		separator = ","
   895  	}
   896  	separator = ") <= ("
   897  	for i, val := range tpb.lastpk.Rows[0] {
   898  		buf.WriteString(separator)
   899  		buf.WriteString(charSetCollations[i].charSet)
   900  		separator = ","
   901  		val.EncodeSQL(buf)
   902  		buf.WriteString(charSetCollations[i].collation)
   903  	}
   904  	buf.WriteString(")")
   905  }
   906  
   907  func (tpb *tablePlanBuilder) isColumnGenerated(col sqlparser.IdentifierCI) bool {
   908  	for _, colInfo := range tpb.colInfos {
   909  		if col.EqualString(colInfo.Name) && colInfo.IsGenerated {
   910  			return true
   911  		}
   912  	}
   913  	return false
   914  }
   915  
   916  // bindvarFormatter is a dual mode formatter. Its behavior
   917  // can be changed dynamically changed to generate bind vars
   918  // for the 'before' row or 'after' row by setting its mode
   919  // to 'bvBefore' or 'bvAfter'. For example, inserts will always
   920  // use bvAfter, whereas deletes will always use bvBefore.
   921  // For updates, values being set will use bvAfter, whereas
   922  // the where clause will use bvBefore.
   923  type bindvarFormatter struct {
   924  	mode bindvarMode
   925  }
   926  
   927  type bindvarMode int
   928  
   929  const (
   930  	bvBefore = bindvarMode(iota)
   931  	bvAfter
   932  )
   933  
   934  func (bvf *bindvarFormatter) formatter(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) {
   935  	if node, ok := node.(*sqlparser.ColName); ok {
   936  		switch bvf.mode {
   937  		case bvBefore:
   938  			buf.WriteArg(":", "b_"+node.Name.String())
   939  			return
   940  		case bvAfter:
   941  			buf.WriteArg(":", "a_"+node.Name.String())
   942  			return
   943  		}
   944  	}
   945  	node.Format(buf)
   946  }