vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletmanager/vdiff/table_plan.go (about)

     1  /*
     2  Copyright 2022 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vdiff
    18  
    19  import (
    20  	"fmt"
    21  	"strings"
    22  
    23  	tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
    24  
    25  	"vitess.io/vitess/go/vt/log"
    26  	querypb "vitess.io/vitess/go/vt/proto/query"
    27  	"vitess.io/vitess/go/vt/sqlparser"
    28  	"vitess.io/vitess/go/vt/vtgate/engine"
    29  )
    30  
    31  type tablePlan struct {
    32  	// sourceQuery and targetQuery are select queries.
    33  	sourceQuery string
    34  	targetQuery string
    35  
    36  	// compareCols is the list of non-pk columns to compare.
    37  	// If the value is -1, it's a pk column and should not be
    38  	// compared.
    39  	compareCols []compareColInfo
    40  	// comparePKs is the list of pk columns to compare. The logic
    41  	// for comparing pk columns is different from compareCols
    42  	comparePKs []compareColInfo
    43  	// pkCols has the indices of PK cols in the select list
    44  	pkCols []int
    45  
    46  	// selectPks is the list of pk columns as they appear in the select clause for the diff.
    47  	selectPks  []int
    48  	table      *tabletmanagerdatapb.TableDefinition
    49  	orderBy    sqlparser.OrderBy
    50  	aggregates []*engine.AggregateParams
    51  }
    52  
    53  func (td *tableDiffer) buildTablePlan() (*tablePlan, error) {
    54  	tp := &tablePlan{table: td.table}
    55  	statement, err := sqlparser.Parse(td.sourceQuery)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	sel, ok := statement.(*sqlparser.Select)
    60  	if !ok {
    61  		return nil, fmt.Errorf("unexpected: %v", sqlparser.String(statement))
    62  	}
    63  
    64  	sourceSelect := &sqlparser.Select{}
    65  	targetSelect := &sqlparser.Select{}
    66  	// aggregates is the list of Aggregate functions, if any.
    67  	var aggregates []*engine.AggregateParams
    68  	for _, selExpr := range sel.SelectExprs {
    69  		switch selExpr := selExpr.(type) {
    70  		case *sqlparser.StarExpr:
    71  			// If it's a '*' expression, expand column list from the schema.
    72  			for _, fld := range tp.table.Fields {
    73  				aliased := &sqlparser.AliasedExpr{Expr: &sqlparser.ColName{Name: sqlparser.NewIdentifierCI(fld.Name)}}
    74  				sourceSelect.SelectExprs = append(sourceSelect.SelectExprs, aliased)
    75  				targetSelect.SelectExprs = append(targetSelect.SelectExprs, aliased)
    76  			}
    77  		case *sqlparser.AliasedExpr:
    78  			var targetCol *sqlparser.ColName
    79  			if !selExpr.As.IsEmpty() {
    80  				targetCol = &sqlparser.ColName{Name: selExpr.As}
    81  			} else {
    82  				if colAs, ok := selExpr.Expr.(*sqlparser.ColName); ok {
    83  					targetCol = colAs
    84  				} else {
    85  					return nil, fmt.Errorf("expression needs an alias: %v", sqlparser.String(selExpr))
    86  				}
    87  			}
    88  			// If the input was "select a as b", then source will use "a" and target will use "b".
    89  			sourceSelect.SelectExprs = append(sourceSelect.SelectExprs, selExpr)
    90  			targetSelect.SelectExprs = append(targetSelect.SelectExprs, &sqlparser.AliasedExpr{Expr: targetCol})
    91  
    92  			// Check if it's an aggregate expression
    93  			if expr, ok := selExpr.Expr.(sqlparser.AggrFunc); ok {
    94  				switch fname := strings.ToLower(expr.AggrName()); fname {
    95  				case "count", "sum":
    96  					// this will only work as long as aggregates can be pushed down to tablets
    97  					// this won't work: "select count(*) from (select id from t limit 1)"
    98  					// since vreplication only handles simple tables (no joins/derived tables) this is fine for now
    99  					// but will need to be revisited when we add such support to vreplication
   100  					aggregateFuncType := "sum"
   101  					aggregates = append(aggregates, &engine.AggregateParams{
   102  						Opcode: engine.SupportedAggregates[aggregateFuncType],
   103  						Col:    len(sourceSelect.SelectExprs) - 1,
   104  					})
   105  				}
   106  			}
   107  		default:
   108  			return nil, fmt.Errorf("unexpected: %v", sqlparser.String(statement))
   109  		}
   110  	}
   111  	fields := make(map[string]querypb.Type)
   112  	for _, field := range tp.table.Fields {
   113  		fields[strings.ToLower(field.Name)] = field.Type
   114  	}
   115  
   116  	targetSelect.SelectExprs = td.adjustForSourceTimeZone(targetSelect.SelectExprs, fields)
   117  	// Start with adding all columns for comparison.
   118  	tp.compareCols = make([]compareColInfo, len(sourceSelect.SelectExprs))
   119  	for i := range tp.compareCols {
   120  		tp.compareCols[i].colIndex = i
   121  		colname, err := getColumnNameForSelectExpr(targetSelect.SelectExprs[i])
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  		_, ok := fields[colname]
   126  		if !ok {
   127  			return nil, fmt.Errorf("column %v not found in table %v on tablet %v",
   128  				colname, tp.table.Name, td.wd.ct.vde.thisTablet.Alias)
   129  		}
   130  		tp.compareCols[i].colName = colname
   131  	}
   132  
   133  	sourceSelect.From = sel.From
   134  	// The target table name should the one that matched the rule.
   135  	// It can be different from the source table.
   136  	targetSelect.From = sqlparser.TableExprs{
   137  		&sqlparser.AliasedTableExpr{
   138  			Expr: &sqlparser.TableName{
   139  				Name: sqlparser.NewIdentifierCS(tp.table.Name),
   140  			},
   141  		},
   142  	}
   143  
   144  	err = tp.findPKs(targetSelect)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	// Remove in_keyrange. It's not understood by mysql.
   149  	sourceSelect.Where = sel.Where //removeKeyrange(sel.Where)
   150  	// The source should also perform the group by.
   151  	sourceSelect.GroupBy = sel.GroupBy
   152  	sourceSelect.OrderBy = tp.orderBy
   153  
   154  	// The target should perform the order by, but not the group by.
   155  	targetSelect.OrderBy = tp.orderBy
   156  
   157  	tp.sourceQuery = sqlparser.String(sourceSelect)
   158  	tp.targetQuery = sqlparser.String(targetSelect)
   159  	log.Info("VDiff query on source: %v", tp.sourceQuery)
   160  	log.Info("VDiff query on target: %v", tp.targetQuery)
   161  
   162  	tp.aggregates = aggregates
   163  	td.tablePlan = tp
   164  	return tp, err
   165  }
   166  
   167  // findPKs identifies PKs and removes them from the columns to do data comparison.
   168  func (tp *tablePlan) findPKs(targetSelect *sqlparser.Select) error {
   169  	var orderby sqlparser.OrderBy
   170  	for _, pk := range tp.table.PrimaryKeyColumns {
   171  		found := false
   172  		for i, selExpr := range targetSelect.SelectExprs {
   173  			expr := selExpr.(*sqlparser.AliasedExpr).Expr
   174  			colname := ""
   175  			switch ct := expr.(type) {
   176  			case *sqlparser.ColName:
   177  				colname = ct.Name.String()
   178  			case *sqlparser.FuncExpr: //eg. weight_string()
   179  				//no-op
   180  			default:
   181  				log.Warningf("Not considering column %v for PK, type %v not handled", selExpr, ct)
   182  			}
   183  			if strings.EqualFold(pk, colname) {
   184  				tp.compareCols[i].isPK = true
   185  				tp.comparePKs = append(tp.comparePKs, tp.compareCols[i])
   186  				tp.selectPks = append(tp.selectPks, i)
   187  				// We'll be comparing pks separately. So, remove them from compareCols.
   188  				tp.pkCols = append(tp.pkCols, i)
   189  				found = true
   190  				break
   191  			}
   192  		}
   193  		if !found {
   194  			// Unreachable.
   195  			return fmt.Errorf("column %v not found in table %v", pk, tp.table.Name)
   196  		}
   197  		orderby = append(orderby, &sqlparser.Order{
   198  			Expr:      &sqlparser.ColName{Name: sqlparser.NewIdentifierCI(pk)},
   199  			Direction: sqlparser.AscOrder,
   200  		})
   201  	}
   202  	tp.orderBy = orderby
   203  	return nil
   204  }