vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/early_rewriter.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package semantics
    18  
    19  import (
    20  	"strconv"
    21  	"strings"
    22  
    23  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    24  
    25  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    26  	"vitess.io/vitess/go/vt/sqlparser"
    27  	"vitess.io/vitess/go/vt/vterrors"
    28  )
    29  
    30  type earlyRewriter struct {
    31  	binder          *binder
    32  	scoper          *scoper
    33  	clause          string
    34  	warning         string
    35  	expandedColumns map[sqlparser.TableName][]*sqlparser.ColName
    36  }
    37  
    38  func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error {
    39  	switch node := cursor.Node().(type) {
    40  	case *sqlparser.Where:
    41  		if node.Type != sqlparser.HavingClause {
    42  			return nil
    43  		}
    44  		rewriteHavingAndOrderBy(node, cursor.Parent())
    45  	case sqlparser.SelectExprs:
    46  		_, isSel := cursor.Parent().(*sqlparser.Select)
    47  		if !isSel {
    48  			return nil
    49  		}
    50  		err := r.expandStar(cursor, node)
    51  		if err != nil {
    52  			return err
    53  		}
    54  	case *sqlparser.JoinTableExpr:
    55  		if node.Join == sqlparser.StraightJoinType {
    56  			node.Join = sqlparser.NormalJoinType
    57  			r.warning = "straight join is converted to normal join"
    58  		}
    59  	case sqlparser.OrderBy:
    60  		r.clause = "order clause"
    61  		rewriteHavingAndOrderBy(node, cursor.Parent())
    62  	case *sqlparser.OrExpr:
    63  		newNode := rewriteOrFalse(*node)
    64  		if newNode != nil {
    65  			cursor.Replace(newNode)
    66  		}
    67  	case sqlparser.GroupBy:
    68  		r.clause = "group statement"
    69  
    70  	case *sqlparser.Literal:
    71  		newNode, err := r.rewriteOrderByExpr(node)
    72  		if err != nil {
    73  			return err
    74  		}
    75  		if newNode != nil {
    76  			cursor.Replace(newNode)
    77  		}
    78  	case *sqlparser.CollateExpr:
    79  		lit, ok := node.Expr.(*sqlparser.Literal)
    80  		if !ok {
    81  			return nil
    82  		}
    83  		newNode, err := r.rewriteOrderByExpr(lit)
    84  		if err != nil {
    85  			return err
    86  		}
    87  		if newNode != nil {
    88  			node.Expr = newNode
    89  		}
    90  	case *sqlparser.ComparisonExpr:
    91  		lft, lftOK := node.Left.(sqlparser.ValTuple)
    92  		rgt, rgtOK := node.Right.(sqlparser.ValTuple)
    93  		if !lftOK || !rgtOK || len(lft) != len(rgt) || node.Operator != sqlparser.EqualOp {
    94  			return nil
    95  		}
    96  		var predicates []sqlparser.Expr
    97  		for i, l := range lft {
    98  			r := rgt[i]
    99  			predicates = append(predicates, &sqlparser.ComparisonExpr{
   100  				Operator: sqlparser.EqualOp,
   101  				Left:     l,
   102  				Right:    r,
   103  				Escape:   node.Escape,
   104  			})
   105  		}
   106  		cursor.Replace(sqlparser.AndExpressions(predicates...))
   107  	}
   108  	return nil
   109  }
   110  
   111  func (r *earlyRewriter) expandStar(cursor *sqlparser.Cursor, node sqlparser.SelectExprs) error {
   112  	currentScope := r.scoper.currentScope()
   113  	var selExprs sqlparser.SelectExprs
   114  	changed := false
   115  	for _, selectExpr := range node {
   116  		starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr)
   117  		if !isStarExpr {
   118  			selExprs = append(selExprs, selectExpr)
   119  			continue
   120  		}
   121  		starExpanded, colNames, err := r.expandTableColumns(starExpr, currentScope.tables, r.binder.usingJoinInfo, r.scoper.org)
   122  		if err != nil {
   123  			return err
   124  		}
   125  		if !starExpanded || colNames == nil {
   126  			selExprs = append(selExprs, selectExpr)
   127  			continue
   128  		}
   129  		selExprs = append(selExprs, colNames...)
   130  		changed = true
   131  	}
   132  	if changed {
   133  		cursor.ReplaceAndRevisit(selExprs)
   134  	}
   135  	return nil
   136  }
   137  
   138  // rewriteHavingAndOrderBy rewrites columns on the ORDER BY/HAVING
   139  // clauses to use aliases from the SELECT expressions when available.
   140  // The scoping rules are:
   141  //   - A column identifier with no table qualifier that matches an alias introduced
   142  //     in SELECT points to that expression, and not at any table column
   143  //   - Except when expression aliased is an aggregation, and the column identifier in the
   144  //     HAVING/ORDER BY clause is inside an aggregation function
   145  //
   146  // This is a fucking weird scoping rule, but it's what MySQL seems to do... ¯\_(ツ)_/¯
   147  func rewriteHavingAndOrderBy(node, parent sqlparser.SQLNode) {
   148  	// TODO - clean up and comment this mess
   149  	sel, isSel := parent.(*sqlparser.Select)
   150  	if !isSel {
   151  		return
   152  	}
   153  
   154  	sqlparser.SafeRewrite(node, func(node, _ sqlparser.SQLNode) bool {
   155  		_, isSubQ := node.(*sqlparser.Subquery)
   156  		return !isSubQ
   157  	}, func(cursor *sqlparser.Cursor) bool {
   158  		col, ok := cursor.Node().(*sqlparser.ColName)
   159  		if !ok {
   160  			return true
   161  		}
   162  		if !col.Qualifier.IsEmpty() {
   163  			return true
   164  		}
   165  		_, parentIsAggr := cursor.Parent().(sqlparser.AggrFunc)
   166  		for _, e := range sel.SelectExprs {
   167  			ae, ok := e.(*sqlparser.AliasedExpr)
   168  			if !ok || !ae.As.Equal(col.Name) {
   169  				continue
   170  			}
   171  			_, aliasPointsToAggr := ae.Expr.(sqlparser.AggrFunc)
   172  			if parentIsAggr && aliasPointsToAggr {
   173  				return false
   174  			}
   175  
   176  			safeToRewrite := true
   177  			_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   178  				switch node.(type) {
   179  				case *sqlparser.ColName:
   180  					safeToRewrite = false
   181  					return false, nil
   182  				case sqlparser.AggrFunc:
   183  					return false, nil
   184  				}
   185  				return true, nil
   186  			}, ae.Expr)
   187  			if safeToRewrite {
   188  				cursor.Replace(ae.Expr)
   189  			}
   190  		}
   191  		return true
   192  	})
   193  }
   194  
   195  func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) {
   196  	currScope, found := r.scoper.specialExprScopes[node]
   197  	if !found {
   198  		return nil, nil
   199  	}
   200  	num, err := strconv.Atoi(node.Val)
   201  	if err != nil {
   202  		return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error parsing column number: %s", node.Val)
   203  	}
   204  	stmt, isSel := currScope.stmt.(*sqlparser.Select)
   205  	if !isSel {
   206  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error invalid statement type, expect Select, got: %T", currScope.stmt)
   207  	}
   208  
   209  	if num < 1 || num > len(stmt.SelectExprs) {
   210  		return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause)
   211  	}
   212  
   213  	for i := 0; i < num; i++ {
   214  		expr := stmt.SelectExprs[i]
   215  		_, ok := expr.(*sqlparser.AliasedExpr)
   216  		if !ok {
   217  			return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot use column offsets in %s when using `%s`", r.clause, sqlparser.String(expr))
   218  		}
   219  	}
   220  
   221  	aliasedExpr, ok := stmt.SelectExprs[num-1].(*sqlparser.AliasedExpr)
   222  	if !ok {
   223  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to handle %s", sqlparser.String(node))
   224  	}
   225  
   226  	if !aliasedExpr.As.IsEmpty() {
   227  		return sqlparser.NewColName(aliasedExpr.As.String()), nil
   228  	}
   229  
   230  	expr := realCloneOfColNames(aliasedExpr.Expr, currScope.isUnion)
   231  	return expr, nil
   232  }
   233  
   234  // realCloneOfColNames clones all the expressions including ColName.
   235  // Since sqlparser.CloneRefOfColName does not clone col names, this method is needed.
   236  func realCloneOfColNames(expr sqlparser.Expr, union bool) sqlparser.Expr {
   237  	return sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
   238  		exp, ok := cursor.Node().(*sqlparser.ColName)
   239  		if !ok {
   240  			return
   241  		}
   242  
   243  		newColName := *exp
   244  		if union {
   245  			newColName.Qualifier = sqlparser.TableName{}
   246  		}
   247  		cursor.Replace(&newColName)
   248  	}, nil).(sqlparser.Expr)
   249  }
   250  
   251  func rewriteOrFalse(orExpr sqlparser.OrExpr) sqlparser.Expr {
   252  	// we are looking for the pattern `WHERE c = 1 OR 1 = 0`
   253  	isFalse := func(subExpr sqlparser.Expr) bool {
   254  		evalEnginePred, err := evalengine.Translate(subExpr, nil)
   255  		if err != nil {
   256  			return false
   257  		}
   258  
   259  		env := evalengine.EmptyExpressionEnv()
   260  		res, err := env.Evaluate(evalEnginePred)
   261  		if err != nil {
   262  			return false
   263  		}
   264  
   265  		boolValue, err := res.Value().ToBool()
   266  		if err != nil {
   267  			return false
   268  		}
   269  
   270  		return !boolValue
   271  	}
   272  
   273  	if isFalse(orExpr.Left) {
   274  		return orExpr.Right
   275  	} else if isFalse(orExpr.Right) {
   276  		return orExpr.Left
   277  	}
   278  
   279  	return nil
   280  }
   281  
   282  func rewriteJoinUsing(
   283  	current *scope,
   284  	using sqlparser.Columns,
   285  	org originable,
   286  ) error {
   287  	joinUsing := current.prepareUsingMap()
   288  	predicates := make([]sqlparser.Expr, 0, len(using))
   289  	for _, column := range using {
   290  		var foundTables []sqlparser.TableName
   291  		for _, tbl := range current.tables {
   292  			if !tbl.authoritative() {
   293  				return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "can't handle JOIN USING without authoritative tables")
   294  			}
   295  
   296  			currTable := tbl.getTableSet(org)
   297  			usingCols := joinUsing[currTable]
   298  			if usingCols == nil {
   299  				usingCols = map[string]TableSet{}
   300  			}
   301  			for _, col := range tbl.getColumns() {
   302  				_, found := usingCols[strings.ToLower(col.Name)]
   303  				if found {
   304  					tblName, err := tbl.Name()
   305  					if err != nil {
   306  						return err
   307  					}
   308  
   309  					foundTables = append(foundTables, tblName)
   310  					break // no need to look at other columns in this table
   311  				}
   312  			}
   313  		}
   314  		for i, lft := range foundTables {
   315  			for j := i + 1; j < len(foundTables); j++ {
   316  				rgt := foundTables[j]
   317  				predicates = append(predicates, &sqlparser.ComparisonExpr{
   318  					Operator: sqlparser.EqualOp,
   319  					Left:     sqlparser.NewColNameWithQualifier(column.String(), lft),
   320  					Right:    sqlparser.NewColNameWithQualifier(column.String(), rgt),
   321  				})
   322  			}
   323  		}
   324  	}
   325  
   326  	// now, we go up the scope until we find a SELECT with a where clause we can add this predicate to
   327  	for current != nil {
   328  		sel, found := current.stmt.(*sqlparser.Select)
   329  		if found {
   330  			if sel.Where == nil {
   331  				sel.Where = &sqlparser.Where{
   332  					Type: sqlparser.WhereClause,
   333  					Expr: sqlparser.AndExpressions(predicates...),
   334  				}
   335  			} else {
   336  				sel.Where.Expr = sqlparser.AndExpressions(append(predicates, sel.Where.Expr)...)
   337  			}
   338  			return nil
   339  		}
   340  		current = current.parent
   341  	}
   342  	return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "did not find WHERE clause")
   343  }
   344  
   345  func (r *earlyRewriter) expandTableColumns(
   346  	starExpr *sqlparser.StarExpr,
   347  	tables []TableInfo,
   348  	joinUsing map[TableSet]map[string]TableSet,
   349  	org originable,
   350  ) (bool, sqlparser.SelectExprs, error) {
   351  	unknownTbl := true
   352  	var colNames sqlparser.SelectExprs
   353  	starExpanded := true
   354  	expandedColumns := map[sqlparser.TableName][]*sqlparser.ColName{}
   355  	for _, tbl := range tables {
   356  		if !starExpr.TableName.IsEmpty() && !tbl.matches(starExpr.TableName) {
   357  			continue
   358  		}
   359  		unknownTbl = false
   360  		if !tbl.authoritative() {
   361  			starExpanded = false
   362  			break
   363  		}
   364  		tblName, err := tbl.Name()
   365  		if err != nil {
   366  			return false, nil, err
   367  		}
   368  
   369  		needsQualifier := len(tables) > 1
   370  		tableAliased := !tbl.getExpr().As.IsEmpty()
   371  		withQualifier := needsQualifier || tableAliased
   372  		currTable := tbl.getTableSet(org)
   373  		usingCols := joinUsing[currTable]
   374  		if usingCols == nil {
   375  			usingCols = map[string]TableSet{}
   376  		}
   377  
   378  		addColName := func(col ColumnInfo) {
   379  			var colName *sqlparser.ColName
   380  			var alias sqlparser.IdentifierCI
   381  			if withQualifier {
   382  				colName = sqlparser.NewColNameWithQualifier(col.Name, tblName)
   383  			} else {
   384  				colName = sqlparser.NewColName(col.Name)
   385  			}
   386  			if needsQualifier {
   387  				alias = sqlparser.NewIdentifierCI(col.Name)
   388  			}
   389  			colNames = append(colNames, &sqlparser.AliasedExpr{Expr: colName, As: alias})
   390  			vt := tbl.GetVindexTable()
   391  			if vt != nil {
   392  				keyspace := vt.Keyspace
   393  				var ks sqlparser.IdentifierCS
   394  				if keyspace != nil {
   395  					ks = sqlparser.NewIdentifierCS(keyspace.Name)
   396  				}
   397  				tblName := sqlparser.TableName{
   398  					Name:      tblName.Name,
   399  					Qualifier: ks,
   400  				}
   401  				expandedColumns[tblName] = append(expandedColumns[tblName], colName)
   402  			}
   403  		}
   404  
   405  		/*
   406  			Redundant column elimination and column ordering occurs according to standard SQL, producing this display order:
   407  			  *	First, coalesced common columns of the two joined tables, in the order in which they occur in the first table
   408  			  *	Second, columns unique to the first table, in order in which they occur in that table
   409  			  *	Third, columns unique to the second table, in order in which they occur in that table
   410  
   411  			From: https://dev.mysql.com/doc/refman/8.0/en/join.html
   412  		*/
   413  	outer:
   414  		// in this first loop we just find columns used in any JOIN USING used on this table
   415  		for _, col := range tbl.getColumns() {
   416  			ts, found := usingCols[col.Name]
   417  			if found {
   418  				for i, ts := range ts.Constituents() {
   419  					if ts == currTable {
   420  						if i == 0 {
   421  							addColName(col)
   422  						} else {
   423  							continue outer
   424  						}
   425  					}
   426  				}
   427  			}
   428  		}
   429  
   430  		// and this time around we are printing any columns not involved in any JOIN USING
   431  		for _, col := range tbl.getColumns() {
   432  			if ts, found := usingCols[col.Name]; found && currTable.IsSolvedBy(ts) {
   433  				continue
   434  			}
   435  
   436  			addColName(col)
   437  		}
   438  	}
   439  
   440  	if unknownTbl {
   441  		// This will only happen for case when starExpr has qualifier.
   442  		return false, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName))
   443  	}
   444  	if starExpanded {
   445  		r.expandedColumns = expandedColumns
   446  	}
   447  	return starExpanded, colNames, nil
   448  }