vitess.io/vitess@v0.16.2/go/vt/wrangler/vdiff.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package wrangler
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"fmt"
    23  	"reflect"
    24  	"sort"
    25  	"strings"
    26  	"sync"
    27  	"time"
    28  
    29  	"google.golang.org/protobuf/encoding/prototext"
    30  
    31  	"vitess.io/vitess/go/mysql"
    32  	"vitess.io/vitess/go/mysql/collations"
    33  	"vitess.io/vitess/go/sqltypes"
    34  	"vitess.io/vitess/go/vt/binlog/binlogplayer"
    35  	"vitess.io/vitess/go/vt/concurrency"
    36  	"vitess.io/vitess/go/vt/discovery"
    37  	"vitess.io/vitess/go/vt/grpcclient"
    38  	"vitess.io/vitess/go/vt/key"
    39  	"vitess.io/vitess/go/vt/log"
    40  	"vitess.io/vitess/go/vt/logutil"
    41  	"vitess.io/vitess/go/vt/schema"
    42  	"vitess.io/vitess/go/vt/sqlparser"
    43  	"vitess.io/vitess/go/vt/topo"
    44  	"vitess.io/vitess/go/vt/topo/topoproto"
    45  	"vitess.io/vitess/go/vt/vtctl/schematools"
    46  	"vitess.io/vitess/go/vt/vtctl/workflow"
    47  	"vitess.io/vitess/go/vt/vterrors"
    48  	"vitess.io/vitess/go/vt/vtgate/engine"
    49  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    50  	"vitess.io/vitess/go/vt/vttablet/tabletconn"
    51  	"vitess.io/vitess/go/vt/vttablet/tabletmanager/vreplication"
    52  
    53  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    54  	querypb "vitess.io/vitess/go/vt/proto/query"
    55  	tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
    56  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    57  )
    58  
    59  // At most how many samples we should show for row differences in the final report
    60  const maxVDiffReportSampleRows = 10
    61  
    62  // DiffReport is the summary of differences for one table.
    63  type DiffReport struct {
    64  	ProcessedRows        int
    65  	MatchingRows         int
    66  	MismatchedRows       int
    67  	ExtraRowsSource      int
    68  	ExtraRowsSourceDiffs []*RowDiff
    69  	ExtraRowsTarget      int
    70  	ExtraRowsTargetDiffs []*RowDiff
    71  	MismatchedRowsSample []*DiffMismatch
    72  	TableName            string
    73  }
    74  
    75  // DiffMismatch is a sample of row diffs between source and target.
    76  type DiffMismatch struct {
    77  	Source *RowDiff
    78  	Target *RowDiff
    79  }
    80  
    81  // RowDiff is a row that didn't match as part of the comparison.
    82  type RowDiff struct {
    83  	Row   map[string]sqltypes.Value
    84  	Query string
    85  }
    86  
    87  // vdiff contains the metadata for performing vdiff for one workflow.
    88  type vdiff struct {
    89  	ts             *trafficSwitcher
    90  	sourceCell     string
    91  	targetCell     string
    92  	tabletTypesStr string
    93  
    94  	// differs uses the target table name for its key.
    95  	differs map[string]*tableDiffer
    96  
    97  	// The key for sources and targets is the shard name.
    98  	// The source and target keyspaces are pulled from ts.
    99  	sources map[string]*shardStreamer
   100  	targets map[string]*shardStreamer
   101  
   102  	workflow       string
   103  	targetKeyspace string
   104  	tables         []string
   105  	sourceTimeZone string
   106  	targetTimeZone string
   107  }
   108  
   109  // compareColInfo contains the metadata for a column of the table being diffed
   110  type compareColInfo struct {
   111  	colIndex  int                  // index of the column in the filter's select
   112  	collation collations.Collation // is the collation of the column, if any
   113  	isPK      bool                 // is this column part of the primary key
   114  }
   115  
   116  // tableDiffer performs a diff for one table in the workflow.
   117  type tableDiffer struct {
   118  	targetTable string
   119  	// sourceExpression and targetExpression are select queries.
   120  	sourceExpression string
   121  	targetExpression string
   122  
   123  	// compareCols is the list of non-pk columns to compare.
   124  	// If the value is -1, it's a pk column and should not be
   125  	// compared.
   126  	compareCols []compareColInfo
   127  	// comparePKs is the list of pk columns to compare. The logic
   128  	// for comparing pk columns is different from compareCols
   129  	comparePKs []compareColInfo
   130  	// pkCols has the indices of PK cols in the select list
   131  	pkCols []int
   132  
   133  	// selectPks is the list of pk columns as they appear in the select clause for the diff.
   134  	selectPks []int
   135  
   136  	// source Primitive and targetPrimitive are used for streaming
   137  	sourcePrimitive engine.Primitive
   138  	targetPrimitive engine.Primitive
   139  }
   140  
   141  // shardStreamer streams rows from one shard. This works for
   142  // the source as well as the target.
   143  // shardStreamer satisfies engine.StreamExecutor, and can be
   144  // added to Primitives of engine.MergeSort.
   145  // shardStreamer is a member of vdiff, and gets reused by
   146  // every tableDiffer. A new result channel gets instantiated
   147  // for every tableDiffer iteration.
   148  type shardStreamer struct {
   149  	primary          *topo.TabletInfo
   150  	tablet           *topodatapb.Tablet
   151  	position         mysql.Position
   152  	snapshotPosition string
   153  	result           chan *sqltypes.Result
   154  	err              error
   155  }
   156  
   157  // VDiff reports differences between the sources and targets of a vreplication workflow.
   158  func (wr *Wrangler) VDiff(ctx context.Context, targetKeyspace, workflowName, sourceCell, targetCell, tabletTypesStr string,
   159  	filteredReplicationWaitTime time.Duration, format string, maxRows int64, tables string, debug, onlyPks bool,
   160  	maxExtraRowsToCompare int) (map[string]*DiffReport, error) {
   161  	log.Infof("Starting VDiff for %s.%s, sourceCell %s, targetCell %s, tabletTypes %s, timeout %s",
   162  		targetKeyspace, workflowName, sourceCell, targetCell, tabletTypesStr, filteredReplicationWaitTime.String())
   163  	// Assign defaults to sourceCell and targetCell if not specified.
   164  	if sourceCell == "" && targetCell == "" {
   165  		cells, err := wr.ts.GetCellInfoNames(ctx)
   166  		if err != nil {
   167  			return nil, err
   168  		}
   169  		if len(cells) == 0 {
   170  			// Unreachable
   171  			return nil, fmt.Errorf("there are no cells in the topo")
   172  		}
   173  		sourceCell = cells[0]
   174  		targetCell = sourceCell
   175  	}
   176  	if sourceCell == "" {
   177  		sourceCell = targetCell
   178  	}
   179  	if targetCell == "" {
   180  		targetCell = sourceCell
   181  	}
   182  
   183  	// Reuse migrater code to fetch and validate initial metadata about the workflow.
   184  	ts, err := wr.buildTrafficSwitcher(ctx, targetKeyspace, workflowName)
   185  	if err != nil {
   186  		wr.Logger().Errorf("buildTrafficSwitcher: %v", err)
   187  		return nil, err
   188  	}
   189  	if ts.frozen {
   190  		return nil, fmt.Errorf("invalid VDiff run: writes have been already been switched for workflow %s.%s",
   191  			targetKeyspace, workflowName)
   192  	}
   193  	if err := ts.validate(ctx); err != nil {
   194  		ts.Logger().Errorf("validate: %v", err)
   195  		return nil, err
   196  	}
   197  	tables = strings.TrimSpace(tables)
   198  	var includeTables []string
   199  	if tables != "" {
   200  		includeTables = strings.Split(tables, ",")
   201  	}
   202  	// Initialize vdiff
   203  	df := &vdiff{
   204  		ts:             ts,
   205  		sourceCell:     sourceCell,
   206  		targetCell:     targetCell,
   207  		tabletTypesStr: tabletTypesStr,
   208  		sources:        make(map[string]*shardStreamer),
   209  		targets:        make(map[string]*shardStreamer),
   210  		workflow:       workflowName,
   211  		targetKeyspace: targetKeyspace,
   212  		tables:         includeTables,
   213  		sourceTimeZone: ts.sourceTimeZone,
   214  		targetTimeZone: ts.targetTimeZone,
   215  	}
   216  	for shard, source := range ts.Sources() {
   217  		df.sources[shard] = &shardStreamer{
   218  			primary: source.GetPrimary(),
   219  		}
   220  	}
   221  	var oneTarget *workflow.MigrationTarget
   222  	for shard, target := range ts.Targets() {
   223  		df.targets[shard] = &shardStreamer{
   224  			primary: target.GetPrimary(),
   225  		}
   226  		oneTarget = target
   227  	}
   228  	var oneFilter *binlogdatapb.Filter
   229  	for _, bls := range oneTarget.Sources {
   230  		oneFilter = bls.Filter
   231  		break
   232  	}
   233  	req := &tabletmanagerdatapb.GetSchemaRequest{}
   234  	schm, err := schematools.GetSchema(ctx, wr.ts, wr.tmc, oneTarget.GetPrimary().Alias, req)
   235  	if err != nil {
   236  		return nil, vterrors.Wrap(err, "GetSchema")
   237  	}
   238  	if err = df.buildVDiffPlan(ctx, oneFilter, schm, df.tables); err != nil {
   239  		return nil, vterrors.Wrap(err, "buildVDiffPlan")
   240  	}
   241  
   242  	if err := df.selectTablets(ctx, ts); err != nil {
   243  		return nil, vterrors.Wrap(err, "selectTablets")
   244  	}
   245  	defer func() {
   246  		// We use a new context as we want to reset the state even
   247  		// when the parent context has timed out or been canceled.
   248  		log.Infof("Restarting the %q VReplication workflow on target tablets in keyspace %q", df.workflow, df.targetKeyspace)
   249  		restartCtx, restartCancel := context.WithTimeout(context.Background(), DefaultActionTimeout)
   250  		defer restartCancel()
   251  		if err := df.restartTargets(restartCtx); err != nil {
   252  			wr.Logger().Errorf("Could not restart workflow %q on target tablets in keyspace %q: %v, please restart it manually",
   253  				df.workflow, df.targetKeyspace, err)
   254  		}
   255  	}()
   256  
   257  	// Perform the diffs.
   258  	// We need a cancelable context to abort all running streams
   259  	// if one stream returns an error.
   260  	ctx, cancel := context.WithCancel(ctx)
   261  	defer cancel()
   262  
   263  	// TODO(sougou): parallelize
   264  	rowsToCompare := maxRows
   265  	diffReports := make(map[string]*DiffReport)
   266  	jsonOutput := ""
   267  	for table, td := range df.differs {
   268  		// Skip internal operation tables for vdiff
   269  		if schema.IsInternalOperationTableName(table) {
   270  			continue
   271  		}
   272  		if err := df.diffTable(ctx, wr, table, td, filteredReplicationWaitTime); err != nil {
   273  			return nil, err
   274  		}
   275  		// Perform the diff of source and target streams.
   276  		dr, err := td.diff(ctx, &rowsToCompare, debug, onlyPks, maxExtraRowsToCompare)
   277  		if err != nil {
   278  			return nil, vterrors.Wrap(err, "diff")
   279  		}
   280  		dr.TableName = table
   281  		// If the only difference is the order in which the rows were returned
   282  		// by MySQL on each side then we'll have the same number of extras on
   283  		// both sides. If that's the case, then let's see if the extra rows on
   284  		// both sides are actually different.
   285  		if (dr.ExtraRowsSource == dr.ExtraRowsTarget) && (dr.ExtraRowsSource <= maxExtraRowsToCompare) {
   286  			for i := range dr.ExtraRowsSourceDiffs {
   287  				foundMatch := false
   288  				for j := range dr.ExtraRowsTargetDiffs {
   289  					if reflect.DeepEqual(dr.ExtraRowsSourceDiffs[i], dr.ExtraRowsTargetDiffs[j]) {
   290  						dr.ExtraRowsSourceDiffs = append(dr.ExtraRowsSourceDiffs[:i], dr.ExtraRowsSourceDiffs[i+1:]...)
   291  						dr.ExtraRowsSource--
   292  						dr.ExtraRowsTargetDiffs = append(dr.ExtraRowsTargetDiffs[:j], dr.ExtraRowsTargetDiffs[j+1:]...)
   293  						dr.ExtraRowsTarget--
   294  						dr.ProcessedRows--
   295  						dr.MatchingRows++
   296  						foundMatch = true
   297  						break
   298  					}
   299  				}
   300  				// If we didn't find a match then the tables are in fact different and we can short circuit the second pass
   301  				if !foundMatch {
   302  					break
   303  				}
   304  			}
   305  		}
   306  		// We can now trim the extra rows diffs on both sides to the maxVDiffReportSampleRows value
   307  		if len(dr.ExtraRowsSourceDiffs) > maxVDiffReportSampleRows {
   308  			dr.ExtraRowsSourceDiffs = dr.ExtraRowsSourceDiffs[:maxVDiffReportSampleRows-1]
   309  		}
   310  		if len(dr.ExtraRowsTargetDiffs) > maxVDiffReportSampleRows {
   311  			dr.ExtraRowsTargetDiffs = dr.ExtraRowsTargetDiffs[:maxVDiffReportSampleRows-1]
   312  		}
   313  		diffReports[table] = dr
   314  	}
   315  	if format == "json" {
   316  		json, err := json.MarshalIndent(diffReports, "", "")
   317  		if err != nil {
   318  			wr.Logger().Printf("Error converting report to json: %v", err.Error())
   319  		}
   320  		jsonOutput += string(json)
   321  		wr.logger.Printf("%s", jsonOutput)
   322  	} else {
   323  		for table, dr := range diffReports {
   324  			wr.Logger().Printf("Summary for table %v:\n", table)
   325  			wr.Logger().Printf("\tProcessedRows: %v\n", dr.ProcessedRows)
   326  			wr.Logger().Printf("\tMatchingRows: %v\n", dr.MatchingRows)
   327  			wr.Logger().Printf("\tMismatchedRows: %v\n", dr.MismatchedRows)
   328  			wr.Logger().Printf("\tExtraRowsSource: %v\n", dr.ExtraRowsSource)
   329  			wr.Logger().Printf("\tExtraRowsTarget: %v\n", dr.ExtraRowsTarget)
   330  			for i, rs := range dr.ExtraRowsSourceDiffs {
   331  				wr.Logger().Printf("\tSample extra row in source %v:\n", i)
   332  				formatSampleRow(wr.Logger(), rs, debug)
   333  			}
   334  			for i, rs := range dr.ExtraRowsTargetDiffs {
   335  				wr.Logger().Printf("\tSample extra row in target %v:\n", i)
   336  				formatSampleRow(wr.Logger(), rs, debug)
   337  			}
   338  			for i, rs := range dr.MismatchedRowsSample {
   339  				wr.Logger().Printf("\tSample rows with mismatch %v:\n", i)
   340  				wr.Logger().Printf("\t\tSource row:\n")
   341  				formatSampleRow(wr.Logger(), rs.Source, debug)
   342  				wr.Logger().Printf("\t\tTarget row:\n")
   343  				formatSampleRow(wr.Logger(), rs.Target, debug)
   344  			}
   345  		}
   346  	}
   347  	return diffReports, nil
   348  }
   349  
   350  func (df *vdiff) diffTable(ctx context.Context, wr *Wrangler, table string, td *tableDiffer, filteredReplicationWaitTime time.Duration) error {
   351  	log.Infof("Starting vdiff for table %s", table)
   352  
   353  	log.Infof("Locking target keyspace %s", df.targetKeyspace)
   354  	ctx, unlock, lockErr := wr.ts.LockKeyspace(ctx, df.targetKeyspace, "vdiff")
   355  	if lockErr != nil {
   356  		log.Errorf("LockKeyspace failed: %v", lockErr)
   357  		wr.Logger().Errorf("LockKeyspace %s failed: %v", df.targetKeyspace)
   358  		return lockErr
   359  	}
   360  
   361  	var err error
   362  	defer func() {
   363  		unlock(&err)
   364  		if err != nil {
   365  			log.Errorf("UnlockKeyspace %s failed: %v", df.targetKeyspace, lockErr)
   366  		}
   367  	}()
   368  
   369  	// Stop the targets and record their source positions.
   370  	if err := df.stopTargets(ctx); err != nil {
   371  		return vterrors.Wrap(err, "stopTargets")
   372  	}
   373  	// Make sure all sources are past the target's positions and start a query stream that records the current source positions.
   374  	if err := df.startQueryStreams(ctx, df.ts.SourceKeyspaceName(), df.sources, td.sourceExpression, filteredReplicationWaitTime); err != nil {
   375  		return vterrors.Wrap(err, "startQueryStreams(sources)")
   376  	}
   377  	// Fast forward the targets to the newly recorded source positions.
   378  	if err := df.syncTargets(ctx, filteredReplicationWaitTime); err != nil {
   379  		return vterrors.Wrap(err, "syncTargets")
   380  	}
   381  	// Sources and targets are in sync. Start query streams on the targets.
   382  	if err := df.startQueryStreams(ctx, df.ts.TargetKeyspaceName(), df.targets, td.targetExpression, filteredReplicationWaitTime); err != nil {
   383  		return vterrors.Wrap(err, "startQueryStreams(targets)")
   384  	}
   385  	// Now that queries are running, target vreplication streams can be restarted.
   386  	return nil
   387  }
   388  
   389  // buildVDiffPlan builds all the differs.
   390  func (df *vdiff) buildVDiffPlan(ctx context.Context, filter *binlogdatapb.Filter, schm *tabletmanagerdatapb.SchemaDefinition, tablesToInclude []string) error {
   391  	df.differs = make(map[string]*tableDiffer)
   392  	for _, table := range schm.TableDefinitions {
   393  		rule, err := vreplication.MatchTable(table.Name, filter)
   394  		if err != nil {
   395  			return err
   396  		}
   397  		if rule == nil || rule.Filter == "exclude" {
   398  			continue
   399  		}
   400  		query := rule.Filter
   401  		if rule.Filter == "" || key.IsKeyRange(rule.Filter) {
   402  			buf := sqlparser.NewTrackedBuffer(nil)
   403  			buf.Myprintf("select * from %v", sqlparser.NewIdentifierCS(table.Name))
   404  			query = buf.String()
   405  		}
   406  		include := true
   407  		if len(tablesToInclude) > 0 {
   408  			include = false
   409  			for _, t := range tablesToInclude {
   410  				if t == table.Name {
   411  					include = true
   412  					break
   413  				}
   414  			}
   415  		}
   416  		if include {
   417  			df.differs[table.Name], err = df.buildTablePlan(table, query)
   418  			if err != nil {
   419  				return err
   420  			}
   421  		}
   422  	}
   423  	if len(tablesToInclude) > 0 && len(tablesToInclude) != len(df.differs) {
   424  		log.Errorf("one or more tables provided are not present in the workflow: %v, %+v", tablesToInclude, df.differs)
   425  		return fmt.Errorf("one or more tables provided are not present in the workflow: %v, %+v", tablesToInclude, df.differs)
   426  	}
   427  	return nil
   428  }
   429  
   430  // findPKs identifies PKs and removes them from the columns to do data comparison
   431  func findPKs(table *tabletmanagerdatapb.TableDefinition, targetSelect *sqlparser.Select, td *tableDiffer) (sqlparser.OrderBy, error) {
   432  	var orderby sqlparser.OrderBy
   433  	for _, pk := range table.PrimaryKeyColumns {
   434  		found := false
   435  		for i, selExpr := range targetSelect.SelectExprs {
   436  			expr := selExpr.(*sqlparser.AliasedExpr).Expr
   437  			colname := ""
   438  			switch ct := expr.(type) {
   439  			case *sqlparser.ColName:
   440  				colname = ct.Name.String()
   441  			case *sqlparser.FuncExpr: //eg. weight_string()
   442  				//no-op
   443  			default:
   444  				log.Warningf("Not considering column %v for PK, type %v not handled", selExpr, ct)
   445  			}
   446  			if strings.EqualFold(pk, colname) {
   447  				td.compareCols[i].isPK = true
   448  				td.comparePKs = append(td.comparePKs, td.compareCols[i])
   449  				td.selectPks = append(td.selectPks, i)
   450  				// We'll be comparing pks separately. So, remove them from compareCols.
   451  				td.pkCols = append(td.pkCols, i)
   452  				found = true
   453  				break
   454  			}
   455  		}
   456  		if !found {
   457  			// Unreachable.
   458  			return nil, fmt.Errorf("column %v not found in table %v", pk, table.Name)
   459  		}
   460  		orderby = append(orderby, &sqlparser.Order{
   461  			Expr:      &sqlparser.ColName{Name: sqlparser.NewIdentifierCI(pk)},
   462  			Direction: sqlparser.AscOrder,
   463  		})
   464  	}
   465  	return orderby, nil
   466  }
   467  
   468  // If SourceTimeZone is defined in the BinlogSource, the VReplication workflow would have converted the datetime
   469  // columns expecting the source to have been in the SourceTimeZone and target in TargetTimeZone. We need to do the reverse
   470  // conversion in VDiff before comparing to the source
   471  func (df *vdiff) adjustForSourceTimeZone(targetSelectExprs sqlparser.SelectExprs, fields map[string]querypb.Type) sqlparser.SelectExprs {
   472  	if df.sourceTimeZone == "" {
   473  		return targetSelectExprs
   474  	}
   475  	log.Infof("source time zone specified: %s", df.sourceTimeZone)
   476  	var newSelectExprs sqlparser.SelectExprs
   477  	var modified bool
   478  	for _, expr := range targetSelectExprs {
   479  		converted := false
   480  		switch selExpr := expr.(type) {
   481  		case *sqlparser.AliasedExpr:
   482  			if colAs, ok := selExpr.Expr.(*sqlparser.ColName); ok {
   483  				var convertTZFuncExpr *sqlparser.FuncExpr
   484  				colName := colAs.Name.Lowered()
   485  				fieldType := fields[colName]
   486  				if fieldType == querypb.Type_DATETIME {
   487  					convertTZFuncExpr = &sqlparser.FuncExpr{
   488  						Name: sqlparser.NewIdentifierCI("convert_tz"),
   489  						Exprs: sqlparser.SelectExprs{
   490  							expr,
   491  							&sqlparser.AliasedExpr{Expr: sqlparser.NewStrLiteral(df.targetTimeZone)},
   492  							&sqlparser.AliasedExpr{Expr: sqlparser.NewStrLiteral(df.sourceTimeZone)},
   493  						},
   494  					}
   495  					log.Infof("converting datetime column %s using convert_tz()", colName)
   496  					newSelectExprs = append(newSelectExprs, &sqlparser.AliasedExpr{Expr: convertTZFuncExpr, As: colAs.Name})
   497  					converted = true
   498  					modified = true
   499  				}
   500  			}
   501  		}
   502  		if !converted { // not datetime
   503  			newSelectExprs = append(newSelectExprs, expr)
   504  		}
   505  	}
   506  	if modified { // at least one datetime was found
   507  		log.Infof("Found datetime columns when SourceTimeZone was set, resetting target SelectExprs after convert_tz()")
   508  		return newSelectExprs
   509  	}
   510  	return targetSelectExprs
   511  }
   512  
   513  func getColumnNameForSelectExpr(selectExpression sqlparser.SelectExpr) (string, error) {
   514  	aliasedExpr := selectExpression.(*sqlparser.AliasedExpr)
   515  	expr := aliasedExpr.Expr
   516  	var colname string
   517  	switch t := expr.(type) {
   518  	case *sqlparser.ColName:
   519  		colname = t.Name.Lowered()
   520  	case *sqlparser.FuncExpr: // only in case datetime was converted using convert_tz()
   521  		colname = aliasedExpr.As.Lowered()
   522  	default:
   523  		return "", fmt.Errorf("found target SelectExpr which was neither ColName or FuncExpr: %+v", aliasedExpr)
   524  	}
   525  	return colname, nil
   526  }
   527  
   528  // buildTablePlan builds one tableDiffer.
   529  func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, query string) (*tableDiffer, error) {
   530  	statement, err := sqlparser.Parse(query)
   531  	if err != nil {
   532  		return nil, err
   533  	}
   534  	sel, ok := statement.(*sqlparser.Select)
   535  	if !ok {
   536  		return nil, fmt.Errorf("unexpected: %v", sqlparser.String(statement))
   537  	}
   538  	td := &tableDiffer{
   539  		targetTable: table.Name,
   540  	}
   541  	sourceSelect := &sqlparser.Select{}
   542  	targetSelect := &sqlparser.Select{}
   543  	// aggregates contains the list if Aggregate functions, if any.
   544  	var aggregates []*engine.AggregateParams
   545  	for _, selExpr := range sel.SelectExprs {
   546  		switch selExpr := selExpr.(type) {
   547  		case *sqlparser.StarExpr:
   548  			// If it's a '*' expression, expand column list from the schema.
   549  			for _, fld := range table.Fields {
   550  				aliased := &sqlparser.AliasedExpr{Expr: &sqlparser.ColName{Name: sqlparser.NewIdentifierCI(fld.Name)}}
   551  				sourceSelect.SelectExprs = append(sourceSelect.SelectExprs, aliased)
   552  				targetSelect.SelectExprs = append(targetSelect.SelectExprs, aliased)
   553  			}
   554  		case *sqlparser.AliasedExpr:
   555  			var targetCol *sqlparser.ColName
   556  			if !selExpr.As.IsEmpty() {
   557  				targetCol = &sqlparser.ColName{Name: selExpr.As}
   558  			} else {
   559  				if colAs, ok := selExpr.Expr.(*sqlparser.ColName); ok {
   560  					targetCol = colAs
   561  				} else {
   562  					return nil, fmt.Errorf("expression needs an alias: %v", sqlparser.String(selExpr))
   563  				}
   564  			}
   565  			// If the input was "select a as b", then source will use "a" and target will use "b".
   566  			sourceSelect.SelectExprs = append(sourceSelect.SelectExprs, selExpr)
   567  			targetSelect.SelectExprs = append(targetSelect.SelectExprs, &sqlparser.AliasedExpr{Expr: targetCol})
   568  
   569  			// Check if it's an aggregate expression
   570  			if expr, ok := selExpr.Expr.(sqlparser.AggrFunc); ok {
   571  				switch fname := strings.ToLower(expr.AggrName()); fname {
   572  				case "count", "sum":
   573  					// this will only work as long as aggregates can be pushed down to tablets
   574  					// this won't work: "select count(*) from (select id from t limit 1)"
   575  					// since vreplication only handles simple tables (no joins/derived tables) this is fine for now
   576  					// but will need to be revisited when we add such support to vreplication
   577  					aggregateFuncType := "sum"
   578  					aggregates = append(aggregates, &engine.AggregateParams{
   579  						Opcode: engine.SupportedAggregates[aggregateFuncType],
   580  						Col:    len(sourceSelect.SelectExprs) - 1,
   581  					})
   582  				}
   583  			}
   584  		default:
   585  			return nil, fmt.Errorf("unexpected: %v", sqlparser.String(statement))
   586  		}
   587  	}
   588  
   589  	fields := make(map[string]querypb.Type)
   590  	for _, field := range table.Fields {
   591  		fields[strings.ToLower(field.Name)] = field.Type
   592  	}
   593  
   594  	targetSelect.SelectExprs = df.adjustForSourceTimeZone(targetSelect.SelectExprs, fields)
   595  	// Start with adding all columns for comparison.
   596  	td.compareCols = make([]compareColInfo, len(sourceSelect.SelectExprs))
   597  	for i := range td.compareCols {
   598  		td.compareCols[i].colIndex = i
   599  		colname, err := getColumnNameForSelectExpr(targetSelect.SelectExprs[i])
   600  		if err != nil {
   601  			return nil, err
   602  		}
   603  		_, ok := fields[colname]
   604  		if !ok {
   605  			return nil, fmt.Errorf("column %v not found in table %v", colname, table.Name)
   606  		}
   607  	}
   608  
   609  	sourceSelect.From = sel.From
   610  	// The target table name should the one that matched the rule.
   611  	// It can be different from the source table.
   612  	targetSelect.From = sqlparser.TableExprs{
   613  		&sqlparser.AliasedTableExpr{
   614  			Expr: &sqlparser.TableName{
   615  				Name: sqlparser.NewIdentifierCS(table.Name),
   616  			},
   617  		},
   618  	}
   619  
   620  	orderby, err := findPKs(table, targetSelect, td)
   621  	if err != nil {
   622  		return nil, err
   623  	}
   624  	// Remove in_keyrange. It's not understood by mysql.
   625  	sourceSelect.Where = removeKeyrange(sel.Where)
   626  	// The source should also perform the group by.
   627  	sourceSelect.GroupBy = sel.GroupBy
   628  	sourceSelect.OrderBy = orderby
   629  
   630  	// The target should perform the order by, but not the group by.
   631  	targetSelect.OrderBy = orderby
   632  
   633  	td.sourceExpression = sqlparser.String(sourceSelect)
   634  	td.targetExpression = sqlparser.String(targetSelect)
   635  
   636  	td.sourcePrimitive = newMergeSorter(df.sources, td.comparePKs)
   637  	td.targetPrimitive = newMergeSorter(df.targets, td.comparePKs)
   638  	// If there were aggregate expressions, we have to re-aggregate
   639  	// the results, which engine.OrderedAggregate can do.
   640  	if len(aggregates) != 0 {
   641  		td.sourcePrimitive = &engine.OrderedAggregate{
   642  			Aggregates:  aggregates,
   643  			GroupByKeys: pkColsToGroupByParams(td.pkCols),
   644  			Input:       td.sourcePrimitive,
   645  		}
   646  	}
   647  
   648  	return td, nil
   649  }
   650  
   651  func pkColsToGroupByParams(pkCols []int) []*engine.GroupByParams {
   652  	var res []*engine.GroupByParams
   653  	for _, col := range pkCols {
   654  		res = append(res, &engine.GroupByParams{KeyCol: col, WeightStringCol: -1})
   655  	}
   656  	return res
   657  }
   658  
   659  // newMergeSorter creates an engine.MergeSort based on the shard streamers and pk columns.
   660  func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compareColInfo) *engine.MergeSort {
   661  	prims := make([]engine.StreamExecutor, 0, len(participants))
   662  	for _, participant := range participants {
   663  		prims = append(prims, participant)
   664  	}
   665  	ob := make([]engine.OrderByParams, 0, len(comparePKs))
   666  	for _, cpk := range comparePKs {
   667  		weightStringCol := -1
   668  		// if the collation is nil or unknown, use binary collation to compare as bytes
   669  		if cpk.collation == nil {
   670  			ob = append(ob, engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, CollationID: collations.CollationBinaryID})
   671  		} else {
   672  			ob = append(ob, engine.OrderByParams{Col: cpk.colIndex, WeightStringCol: weightStringCol, CollationID: cpk.collation.ID()})
   673  		}
   674  	}
   675  	return &engine.MergeSort{
   676  		Primitives: prims,
   677  		OrderBy:    ob,
   678  	}
   679  }
   680  
   681  // selectTablets selects the tablets that will be used for the diff.
   682  func (df *vdiff) selectTablets(ctx context.Context, ts *trafficSwitcher) error {
   683  	var wg sync.WaitGroup
   684  	var err1, err2 error
   685  
   686  	// Parallelize all discovery.
   687  	wg.Add(1)
   688  	go func() {
   689  		defer wg.Done()
   690  		err1 = df.forAll(df.sources, func(shard string, source *shardStreamer) error {
   691  			sourceTopo := df.ts.TopoServer()
   692  			if ts.ExternalTopo() != nil {
   693  				sourceTopo = ts.ExternalTopo()
   694  			}
   695  			tp, err := discovery.NewTabletPicker(sourceTopo, []string{df.sourceCell}, df.ts.SourceKeyspaceName(), shard, df.tabletTypesStr)
   696  			if err != nil {
   697  				return err
   698  			}
   699  
   700  			tablet, err := tp.PickForStreaming(ctx)
   701  			if err != nil {
   702  				return err
   703  			}
   704  			source.tablet = tablet
   705  			return nil
   706  		})
   707  	}()
   708  
   709  	wg.Add(1)
   710  	go func() {
   711  		defer wg.Done()
   712  		err2 = df.forAll(df.targets, func(shard string, target *shardStreamer) error {
   713  			tp, err := discovery.NewTabletPicker(df.ts.TopoServer(), []string{df.targetCell}, df.ts.TargetKeyspaceName(), shard, df.tabletTypesStr)
   714  			if err != nil {
   715  				return err
   716  			}
   717  
   718  			tablet, err := tp.PickForStreaming(ctx)
   719  			if err != nil {
   720  				return err
   721  			}
   722  			target.tablet = tablet
   723  			return nil
   724  		})
   725  	}()
   726  
   727  	wg.Wait()
   728  	if err1 != nil {
   729  		return err1
   730  	}
   731  	return err2
   732  }
   733  
   734  // stopTargets stops all the targets and records their source positions.
   735  func (df *vdiff) stopTargets(ctx context.Context) error {
   736  	var mu sync.Mutex
   737  
   738  	err := df.forAll(df.targets, func(shard string, target *shardStreamer) error {
   739  		query := fmt.Sprintf("update _vt.vreplication set state='Stopped', message='for vdiff' where db_name=%s and workflow=%s", encodeString(target.primary.DbName()), encodeString(df.ts.WorkflowName()))
   740  		_, err := df.ts.TabletManagerClient().VReplicationExec(ctx, target.primary.Tablet, query)
   741  		if err != nil {
   742  			return err
   743  		}
   744  		query = fmt.Sprintf("select source, pos from _vt.vreplication where db_name=%s and workflow=%s", encodeString(target.primary.DbName()), encodeString(df.ts.WorkflowName()))
   745  		p3qr, err := df.ts.TabletManagerClient().VReplicationExec(ctx, target.primary.Tablet, query)
   746  		if err != nil {
   747  			return err
   748  		}
   749  		qr := sqltypes.Proto3ToResult(p3qr)
   750  
   751  		for _, row := range qr.Rows {
   752  			var bls binlogdatapb.BinlogSource
   753  			rowBytes, err := row[0].ToBytes()
   754  			if err != nil {
   755  				return err
   756  			}
   757  			if err := prototext.Unmarshal(rowBytes, &bls); err != nil {
   758  				return err
   759  			}
   760  			pos, err := binlogplayer.DecodePosition(row[1].ToString())
   761  			if err != nil {
   762  				return err
   763  			}
   764  			func() {
   765  				mu.Lock()
   766  				defer mu.Unlock()
   767  
   768  				source, ok := df.sources[bls.Shard]
   769  				if !ok {
   770  					// Unreachable.
   771  					return
   772  				}
   773  				if !source.position.IsZero() && source.position.AtLeast(pos) {
   774  					return
   775  				}
   776  				source.position = pos
   777  			}()
   778  		}
   779  		return nil
   780  	})
   781  	if err != nil {
   782  		return err
   783  	}
   784  	return nil
   785  }
   786  
   787  // starQueryStreams makes sure the sources are past the target's positions, starts the query streams,
   788  // and records the snapshot position of the query. It creates a result channel which StreamExecute
   789  // will use to serve rows.
   790  func (df *vdiff) startQueryStreams(ctx context.Context, keyspace string, participants map[string]*shardStreamer, query string, filteredReplicationWaitTime time.Duration) error {
   791  	waitCtx, cancel := context.WithTimeout(ctx, filteredReplicationWaitTime)
   792  	defer cancel()
   793  	return df.forAll(participants, func(shard string, participant *shardStreamer) error {
   794  		// Iteration for each participant.
   795  		if participant.position.IsZero() {
   796  			return fmt.Errorf("workflow %s.%s: stream has not started on tablet %s", df.targetKeyspace, df.workflow, participant.primary.Alias.String())
   797  		}
   798  		log.Infof("WaitForPosition: tablet %s should reach position %s", participant.tablet.Alias.String(), mysql.EncodePosition(participant.position))
   799  		if err := df.ts.TabletManagerClient().WaitForPosition(waitCtx, participant.tablet, mysql.EncodePosition(participant.position)); err != nil {
   800  			log.Errorf("WaitForPosition error: %s", err)
   801  			return vterrors.Wrapf(err, "WaitForPosition for tablet %v", topoproto.TabletAliasString(participant.tablet.Alias))
   802  		}
   803  		participant.result = make(chan *sqltypes.Result, 1)
   804  		gtidch := make(chan string, 1)
   805  
   806  		// Start the stream in a separate goroutine.
   807  		go df.streamOne(ctx, keyspace, shard, participant, query, gtidch)
   808  
   809  		// Wait for the gtid to be sent. If it's not received, there was an error
   810  		// which would be stored in participant.err.
   811  		gtid, ok := <-gtidch
   812  		if !ok {
   813  			return participant.err
   814  		}
   815  		// Save the new position, as of when the query executed.
   816  		participant.snapshotPosition = gtid
   817  		return nil
   818  	})
   819  }
   820  
   821  // streamOne is called as a goroutine, and communicates its results through channels.
   822  // It first sends the snapshot gtid to gtidch.
   823  // Then it streams results to participant.result.
   824  // Before returning, it sets participant.err, and closes all channels.
   825  // If any channel is closed, then participant.err can be checked if there was an error.
   826  // The shardStreamer's StreamExecute consumes the result channel.
   827  func (df *vdiff) streamOne(ctx context.Context, keyspace, shard string, participant *shardStreamer, query string, gtidch chan string) {
   828  	defer close(participant.result)
   829  	defer close(gtidch)
   830  
   831  	// Wrap the streaming in a separate function so we can capture the error.
   832  	// This shows that the error will be set before the channels are closed.
   833  	participant.err = func() error {
   834  		conn, err := tabletconn.GetDialer()(participant.tablet, grpcclient.FailFast(false))
   835  		if err != nil {
   836  			return err
   837  		}
   838  		defer conn.Close(ctx)
   839  
   840  		target := &querypb.Target{
   841  			Keyspace:   keyspace,
   842  			Shard:      shard,
   843  			TabletType: participant.tablet.Type,
   844  		}
   845  		var fields []*querypb.Field
   846  		return conn.VStreamResults(ctx, target, query, func(vrs *binlogdatapb.VStreamResultsResponse) error {
   847  			if vrs.Fields != nil {
   848  				fields = vrs.Fields
   849  				gtidch <- vrs.Gtid
   850  			}
   851  			p3qr := &querypb.QueryResult{
   852  				Fields: fields,
   853  				Rows:   vrs.Rows,
   854  			}
   855  			result := sqltypes.Proto3ToResult(p3qr)
   856  			// Fields should be received only once, and sent only once.
   857  			if vrs.Fields == nil {
   858  				result.Fields = nil
   859  			}
   860  			select {
   861  			case participant.result <- result:
   862  			case <-ctx.Done():
   863  				return vterrors.Wrap(ctx.Err(), "VStreamResults")
   864  			}
   865  			return nil
   866  		})
   867  	}()
   868  }
   869  
   870  // syncTargets fast-forwards the vreplication to the source snapshot positons
   871  // and waits for the selected tablets to catch up to that point.
   872  func (df *vdiff) syncTargets(ctx context.Context, filteredReplicationWaitTime time.Duration) error {
   873  	waitCtx, cancel := context.WithTimeout(ctx, filteredReplicationWaitTime)
   874  	defer cancel()
   875  	err := df.ts.ForAllUIDs(func(target *workflow.MigrationTarget, uid uint32) error {
   876  		bls := target.Sources[uid]
   877  		pos := df.sources[bls.Shard].snapshotPosition
   878  		query := fmt.Sprintf("update _vt.vreplication set state='Running', stop_pos='%s', message='synchronizing for vdiff' where id=%d", pos, uid)
   879  		if _, err := df.ts.TabletManagerClient().VReplicationExec(ctx, target.GetPrimary().Tablet, query); err != nil {
   880  			return err
   881  		}
   882  		if err := df.ts.TabletManagerClient().VReplicationWaitForPos(waitCtx, target.GetPrimary().Tablet, int(uid), pos); err != nil {
   883  			return vterrors.Wrapf(err, "VReplicationWaitForPos for tablet %v", topoproto.TabletAliasString(target.GetPrimary().Tablet.Alias))
   884  		}
   885  		return nil
   886  	})
   887  	if err != nil {
   888  		return err
   889  	}
   890  
   891  	err = df.forAll(df.targets, func(shard string, target *shardStreamer) error {
   892  		pos, err := df.ts.TabletManagerClient().PrimaryPosition(ctx, target.primary.Tablet)
   893  		if err != nil {
   894  			return err
   895  		}
   896  		mpos, err := binlogplayer.DecodePosition(pos)
   897  		if err != nil {
   898  			return err
   899  		}
   900  		target.position = mpos
   901  		return nil
   902  	})
   903  	return err
   904  }
   905  
   906  // restartTargets restarts the stopped target vreplication streams.
   907  func (df *vdiff) restartTargets(ctx context.Context) error {
   908  	return df.forAll(df.targets, func(shard string, target *shardStreamer) error {
   909  		query := fmt.Sprintf("update _vt.vreplication set state='Running', message='', stop_pos='' where db_name=%s and workflow=%s",
   910  			encodeString(target.primary.DbName()), encodeString(df.ts.WorkflowName()))
   911  		log.Infof("Restarting the %q VReplication workflow on %q using %q", df.ts.WorkflowName(), target.primary.Alias, query)
   912  		var err error
   913  		// Let's retry a few times if we get a retryable error.
   914  		for i := 1; i <= 3; i++ {
   915  			_, err = df.ts.TabletManagerClient().VReplicationExec(ctx, target.primary.Tablet, query)
   916  			if err == nil || !mysql.IsEphemeralError(err) {
   917  				break
   918  			}
   919  			log.Warningf("Encountered the following error while restarting the %q VReplication workflow on %q, will retry (attempt #%d): %v",
   920  				df.ts.WorkflowName(), target.primary.Alias, i, err)
   921  		}
   922  		return err
   923  	})
   924  }
   925  
   926  func (df *vdiff) forAll(participants map[string]*shardStreamer, f func(string, *shardStreamer) error) error {
   927  	var wg sync.WaitGroup
   928  	allErrors := &concurrency.AllErrorRecorder{}
   929  	for shard, participant := range participants {
   930  		wg.Add(1)
   931  		go func(shard string, participant *shardStreamer) {
   932  			defer wg.Done()
   933  
   934  			if err := f(shard, participant); err != nil {
   935  				allErrors.RecordError(err)
   936  			}
   937  		}(shard, participant)
   938  	}
   939  	wg.Wait()
   940  	return allErrors.AggrError(vterrors.Aggregate)
   941  }
   942  
   943  //-----------------------------------------------------------------
   944  // primitiveExecutor
   945  
   946  // primitiveExecutor starts execution on the top level primitive
   947  // and provides convenience functions for row-by-row iteration.
   948  type primitiveExecutor struct {
   949  	prim     engine.Primitive
   950  	rows     [][]sqltypes.Value
   951  	resultch chan *sqltypes.Result
   952  	err      error
   953  }
   954  
   955  func newPrimitiveExecutor(ctx context.Context, prim engine.Primitive) *primitiveExecutor {
   956  	pe := &primitiveExecutor{
   957  		prim:     prim,
   958  		resultch: make(chan *sqltypes.Result, 1),
   959  	}
   960  	vcursor := &contextVCursor{}
   961  	go func() {
   962  		defer close(pe.resultch)
   963  		pe.err = vcursor.StreamExecutePrimitive(ctx, pe.prim, make(map[string]*querypb.BindVariable), true, func(qr *sqltypes.Result) error {
   964  			select {
   965  			case pe.resultch <- qr:
   966  			case <-ctx.Done():
   967  				return vterrors.Wrap(ctx.Err(), "Outer Stream")
   968  			}
   969  			return nil
   970  		})
   971  	}()
   972  	return pe
   973  }
   974  
   975  func (pe *primitiveExecutor) next() ([]sqltypes.Value, error) {
   976  	for len(pe.rows) == 0 {
   977  		qr, ok := <-pe.resultch
   978  		if !ok {
   979  			return nil, pe.err
   980  		}
   981  		pe.rows = qr.Rows
   982  	}
   983  
   984  	row := pe.rows[0]
   985  	pe.rows = pe.rows[1:]
   986  	return row, nil
   987  }
   988  
   989  func (pe *primitiveExecutor) drain(ctx context.Context) (int, error) {
   990  	count := 0
   991  	for {
   992  		row, err := pe.next()
   993  		if err != nil {
   994  			return 0, err
   995  		}
   996  		if row == nil {
   997  			return count, nil
   998  		}
   999  		count++
  1000  	}
  1001  }
  1002  
  1003  //-----------------------------------------------------------------
  1004  // shardStreamer
  1005  
  1006  func (sm *shardStreamer) StreamExecute(ctx context.Context, vcursor engine.VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
  1007  	for result := range sm.result {
  1008  		if err := callback(result); err != nil {
  1009  			return err
  1010  		}
  1011  	}
  1012  	return sm.err
  1013  }
  1014  
  1015  // humanInt formats large integers to a value easier to the eye: 100000=100k 1e12=1b 234000000=234m ...
  1016  func humanInt(n int64) string { // nolint
  1017  	var val float64
  1018  	var unit string
  1019  	switch true {
  1020  	case n < 1000:
  1021  		val = float64(n)
  1022  	case n < 1e6:
  1023  		val = float64(n) / 1000
  1024  		unit = "k"
  1025  	case n < 1e9:
  1026  		val = float64(n) / 1e6
  1027  		unit = "m"
  1028  	default:
  1029  		val = float64(n) / 1e9
  1030  		unit = "b"
  1031  	}
  1032  	s := fmt.Sprintf("%0.3f", val)
  1033  	s = strings.Replace(s, ".000", "", -1)
  1034  
  1035  	return fmt.Sprintf("%s%s", s, unit)
  1036  }
  1037  
  1038  //-----------------------------------------------------------------
  1039  // tableDiffer
  1040  
  1041  func (td *tableDiffer) diff(ctx context.Context, rowsToCompare *int64, debug, onlyPks bool, maxExtraRowsToCompare int) (*DiffReport, error) {
  1042  	sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive)
  1043  	targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive)
  1044  	dr := &DiffReport{}
  1045  	var sourceRow, targetRow []sqltypes.Value
  1046  	var err error
  1047  	advanceSource := true
  1048  	advanceTarget := true
  1049  	for {
  1050  		if dr.ProcessedRows%1e7 == 0 { // log progress every 10 million rows
  1051  			log.Infof("VDiff progress:: table %s: %s rows", td.targetTable, humanInt(int64(dr.ProcessedRows)))
  1052  		}
  1053  		*rowsToCompare--
  1054  		if *rowsToCompare < 0 {
  1055  			log.Infof("Stopping vdiff, specified limit reached")
  1056  			return dr, nil
  1057  		}
  1058  		if advanceSource {
  1059  			sourceRow, err = sourceExecutor.next()
  1060  			if err != nil {
  1061  				return nil, err
  1062  			}
  1063  		}
  1064  		if advanceTarget {
  1065  			targetRow, err = targetExecutor.next()
  1066  			if err != nil {
  1067  				return nil, err
  1068  			}
  1069  		}
  1070  
  1071  		if sourceRow == nil && targetRow == nil {
  1072  			return dr, nil
  1073  		}
  1074  
  1075  		advanceSource = true
  1076  		advanceTarget = true
  1077  
  1078  		if sourceRow == nil {
  1079  			diffRow, err := td.genRowDiff(td.sourceExpression, targetRow, debug, onlyPks)
  1080  			if err != nil {
  1081  				return nil, vterrors.Wrap(err, "unexpected error generating diff")
  1082  			}
  1083  			dr.ExtraRowsTargetDiffs = append(dr.ExtraRowsTargetDiffs, diffRow)
  1084  
  1085  			// drain target, update count
  1086  			count, err := targetExecutor.drain(ctx)
  1087  			if err != nil {
  1088  				return nil, err
  1089  			}
  1090  			dr.ExtraRowsTarget += 1 + count
  1091  			dr.ProcessedRows += 1 + count
  1092  			return dr, nil
  1093  		}
  1094  		if targetRow == nil {
  1095  			// no more rows from the target
  1096  			// we know we have rows from source, drain, update count
  1097  			diffRow, err := td.genRowDiff(td.sourceExpression, sourceRow, debug, onlyPks)
  1098  			if err != nil {
  1099  				return nil, vterrors.Wrap(err, "unexpected error generating diff")
  1100  			}
  1101  			dr.ExtraRowsSourceDiffs = append(dr.ExtraRowsSourceDiffs, diffRow)
  1102  
  1103  			count, err := sourceExecutor.drain(ctx)
  1104  			if err != nil {
  1105  				return nil, err
  1106  			}
  1107  			dr.ExtraRowsSource += 1 + count
  1108  			dr.ProcessedRows += 1 + count
  1109  			return dr, nil
  1110  		}
  1111  
  1112  		dr.ProcessedRows++
  1113  
  1114  		// Compare pk values.
  1115  		c, err := td.compare(sourceRow, targetRow, td.comparePKs, false)
  1116  		switch {
  1117  		case err != nil:
  1118  			return nil, err
  1119  		case c < 0:
  1120  			if dr.ExtraRowsSource < maxExtraRowsToCompare {
  1121  				diffRow, err := td.genRowDiff(td.sourceExpression, sourceRow, debug, onlyPks)
  1122  				if err != nil {
  1123  					return nil, vterrors.Wrap(err, "unexpected error generating diff")
  1124  				}
  1125  				dr.ExtraRowsSourceDiffs = append(dr.ExtraRowsSourceDiffs, diffRow)
  1126  			}
  1127  			dr.ExtraRowsSource++
  1128  			advanceTarget = false
  1129  			continue
  1130  		case c > 0:
  1131  			if dr.ExtraRowsTarget < maxExtraRowsToCompare {
  1132  				diffRow, err := td.genRowDiff(td.targetExpression, targetRow, debug, onlyPks)
  1133  				if err != nil {
  1134  					return nil, vterrors.Wrap(err, "unexpected error generating diff")
  1135  				}
  1136  				dr.ExtraRowsTargetDiffs = append(dr.ExtraRowsTargetDiffs, diffRow)
  1137  			}
  1138  			dr.ExtraRowsTarget++
  1139  			advanceSource = false
  1140  			continue
  1141  		}
  1142  
  1143  		// c == 0
  1144  		// Compare the non-pk values.
  1145  		c, err = td.compare(sourceRow, targetRow, td.compareCols, true)
  1146  		switch {
  1147  		case err != nil:
  1148  			return nil, err
  1149  		case c != 0:
  1150  			// We don't do a second pass to compare mismatched rows so we can cap the slice here
  1151  			if dr.MismatchedRows < maxVDiffReportSampleRows {
  1152  				sourceDiffRow, err := td.genRowDiff(td.targetExpression, sourceRow, debug, onlyPks)
  1153  				if err != nil {
  1154  					return nil, vterrors.Wrap(err, "unexpected error generating diff")
  1155  				}
  1156  				targetDiffRow, err := td.genRowDiff(td.targetExpression, targetRow, debug, onlyPks)
  1157  				if err != nil {
  1158  					return nil, vterrors.Wrap(err, "unexpected error generating diff")
  1159  				}
  1160  				dr.MismatchedRowsSample = append(dr.MismatchedRowsSample, &DiffMismatch{Source: sourceDiffRow, Target: targetDiffRow})
  1161  			}
  1162  			dr.MismatchedRows++
  1163  		default:
  1164  			dr.MatchingRows++
  1165  		}
  1166  	}
  1167  }
  1168  
  1169  func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []compareColInfo, compareOnlyNonPKs bool) (int, error) {
  1170  	for _, col := range cols {
  1171  		if col.isPK && compareOnlyNonPKs {
  1172  			continue
  1173  		}
  1174  		compareIndex := col.colIndex
  1175  		var c int
  1176  		var err error
  1177  		var collationID collations.ID
  1178  		// if the collation is nil or unknown, use binary collation to compare as bytes
  1179  		if col.collation == nil {
  1180  			collationID = collations.CollationBinaryID
  1181  		} else {
  1182  			collationID = col.collation.ID()
  1183  		}
  1184  		c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], collationID)
  1185  		if err != nil {
  1186  			return 0, err
  1187  		}
  1188  		if c != 0 {
  1189  			return c, nil
  1190  		}
  1191  	}
  1192  	return 0, nil
  1193  }
  1194  
  1195  func (td *tableDiffer) genRowDiff(queryStmt string, row []sqltypes.Value, debug, onlyPks bool) (*RowDiff, error) {
  1196  	drp := &RowDiff{}
  1197  	drp.Row = make(map[string]sqltypes.Value)
  1198  	statement, err := sqlparser.Parse(queryStmt)
  1199  	if err != nil {
  1200  		return nil, err
  1201  	}
  1202  	sel, ok := statement.(*sqlparser.Select)
  1203  	if !ok {
  1204  		return nil, fmt.Errorf("unexpected: %v", sqlparser.String(statement))
  1205  	}
  1206  
  1207  	if debug {
  1208  		drp.Query = td.genDebugQueryDiff(sel, row, onlyPks)
  1209  	}
  1210  
  1211  	if onlyPks {
  1212  		for _, pkI := range td.selectPks {
  1213  			buf := sqlparser.NewTrackedBuffer(nil)
  1214  			sel.SelectExprs[pkI].Format(buf)
  1215  			col := buf.String()
  1216  			drp.Row[col] = row[pkI]
  1217  		}
  1218  		return drp, nil
  1219  	}
  1220  
  1221  	for i := range sel.SelectExprs {
  1222  		buf := sqlparser.NewTrackedBuffer(nil)
  1223  		sel.SelectExprs[i].Format(buf)
  1224  		col := buf.String()
  1225  		drp.Row[col] = row[i]
  1226  	}
  1227  
  1228  	return drp, nil
  1229  }
  1230  
  1231  func (td *tableDiffer) genDebugQueryDiff(sel *sqlparser.Select, row []sqltypes.Value, onlyPks bool) string {
  1232  	buf := sqlparser.NewTrackedBuffer(nil)
  1233  	buf.Myprintf("select ")
  1234  
  1235  	if onlyPks {
  1236  		for i, pkI := range td.selectPks {
  1237  			pk := sel.SelectExprs[pkI]
  1238  			pk.Format(buf)
  1239  			if i != len(td.selectPks)-1 {
  1240  				buf.Myprintf(", ")
  1241  			}
  1242  		}
  1243  	} else {
  1244  		sel.SelectExprs.Format(buf)
  1245  	}
  1246  	buf.Myprintf(" from ")
  1247  	buf.Myprintf(sqlparser.ToString(sel.From))
  1248  	buf.Myprintf(" where ")
  1249  	for i, pkI := range td.selectPks {
  1250  		sel.SelectExprs[pkI].Format(buf)
  1251  		buf.Myprintf("=")
  1252  		row[pkI].EncodeSQL(buf)
  1253  		if i != len(td.selectPks)-1 {
  1254  			buf.Myprintf(" AND ")
  1255  		}
  1256  	}
  1257  	buf.Myprintf(";")
  1258  	return buf.String()
  1259  }
  1260  
  1261  //-----------------------------------------------------------------
  1262  // contextVCursor
  1263  
  1264  // contextVCursor satisfies VCursor interface
  1265  type contextVCursor struct {
  1266  	engine.VCursor
  1267  }
  1268  
  1269  func (vc *contextVCursor) ConnCollation() collations.ID {
  1270  	return collations.CollationBinaryID
  1271  }
  1272  
  1273  func (vc *contextVCursor) ExecutePrimitive(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
  1274  	return primitive.TryExecute(ctx, vc, bindVars, wantfields)
  1275  }
  1276  
  1277  func (vc *contextVCursor) StreamExecutePrimitive(ctx context.Context, primitive engine.Primitive, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
  1278  	return primitive.TryStreamExecute(ctx, vc, bindVars, wantfields, callback)
  1279  }
  1280  
  1281  //-----------------------------------------------------------------
  1282  // Utility functions
  1283  
  1284  func removeKeyrange(where *sqlparser.Where) *sqlparser.Where {
  1285  	if where == nil {
  1286  		return nil
  1287  	}
  1288  	if isFuncKeyrange(where.Expr) {
  1289  		return nil
  1290  	}
  1291  	where.Expr = removeExprKeyrange(where.Expr)
  1292  	return where
  1293  }
  1294  
  1295  func removeExprKeyrange(node sqlparser.Expr) sqlparser.Expr {
  1296  	switch node := node.(type) {
  1297  	case *sqlparser.AndExpr:
  1298  		if isFuncKeyrange(node.Left) {
  1299  			return removeExprKeyrange(node.Right)
  1300  		}
  1301  		if isFuncKeyrange(node.Right) {
  1302  			return removeExprKeyrange(node.Left)
  1303  		}
  1304  		return &sqlparser.AndExpr{
  1305  			Left:  removeExprKeyrange(node.Left),
  1306  			Right: removeExprKeyrange(node.Right),
  1307  		}
  1308  	}
  1309  	return node
  1310  }
  1311  
  1312  func isFuncKeyrange(expr sqlparser.Expr) bool {
  1313  	funcExpr, ok := expr.(*sqlparser.FuncExpr)
  1314  	return ok && funcExpr.Name.EqualString("in_keyrange")
  1315  }
  1316  
  1317  func formatSampleRow(logger logutil.Logger, rd *RowDiff, debug bool) {
  1318  	keys := make([]string, 0, len(rd.Row))
  1319  	for k := range rd.Row {
  1320  		keys = append(keys, k)
  1321  	}
  1322  
  1323  	sort.Strings(keys)
  1324  
  1325  	for _, k := range keys {
  1326  		logger.Printf("\t\t\t %s: %s\n", k, formatValue(rd.Row[k]))
  1327  	}
  1328  
  1329  	if debug {
  1330  		logger.Printf("\t\tDebugQuery: %v\n", rd.Query)
  1331  	}
  1332  }
  1333  
  1334  func formatValue(val sqltypes.Value) string {
  1335  	if val.Type() == sqltypes.Null {
  1336  		return "null (NULL_TYPE)"
  1337  	}
  1338  	if val.IsQuoted() || val.Type() == sqltypes.Bit {
  1339  		if len(val.Raw()) >= 20 {
  1340  			rawBytes := val.Raw()[:20]
  1341  			rawBytes = append(rawBytes, []byte("...[TRUNCATED]")...)
  1342  			return fmt.Sprintf("%q (%v)", rawBytes, val.Type())
  1343  		}
  1344  		return fmt.Sprintf("%q (%v)", val.Raw(), val.Type())
  1345  	}
  1346  	return fmt.Sprintf("%s (%v)", val.Raw(), val.Type())
  1347  }