vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/rewrite.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package planbuilder
    18  
    19  import (
    20  	"vitess.io/vitess/go/vt/sqlparser"
    21  	"vitess.io/vitess/go/vt/vterrors"
    22  	"vitess.io/vitess/go/vt/vtgate/engine"
    23  	"vitess.io/vitess/go/vt/vtgate/semantics"
    24  )
    25  
    26  type rewriter struct {
    27  	semTable     *semantics.SemTable
    28  	reservedVars *sqlparser.ReservedVars
    29  	inSubquery   int
    30  	err          error
    31  }
    32  
    33  func queryRewrite(semTable *semantics.SemTable, reservedVars *sqlparser.ReservedVars, statement sqlparser.Statement) error {
    34  	r := rewriter{
    35  		semTable:     semTable,
    36  		reservedVars: reservedVars,
    37  	}
    38  	sqlparser.Rewrite(statement, r.rewriteDown, r.rewriteUp)
    39  	return nil
    40  }
    41  
    42  func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool {
    43  	switch node := cursor.Node().(type) {
    44  	case *sqlparser.Select:
    45  		rewriteHavingClause(node)
    46  	case *sqlparser.ComparisonExpr:
    47  		err := rewriteInSubquery(cursor, r, node)
    48  		if err != nil {
    49  			r.err = err
    50  		}
    51  	case *sqlparser.ExistsExpr:
    52  		err := r.rewriteExistsSubquery(cursor, node)
    53  		if err != nil {
    54  			r.err = err
    55  		}
    56  		return false
    57  	case *sqlparser.AliasedTableExpr:
    58  		// rewrite names of the routed tables for the subquery
    59  		// We only need to do this for non-derived tables and if they are in a subquery
    60  		if _, isDerived := node.Expr.(*sqlparser.DerivedTable); isDerived || r.inSubquery == 0 {
    61  			break
    62  		}
    63  		// find the tableSet and tableInfo that this table points to
    64  		// tableInfo should contain the information for the original table that the routed table points to
    65  		tableSet := r.semTable.TableSetFor(node)
    66  		tableInfo, err := r.semTable.TableInfoFor(tableSet)
    67  		if err != nil {
    68  			// Fail-safe code, should never happen
    69  			break
    70  		}
    71  		// vindexTable is the original table
    72  		vindexTable := tableInfo.GetVindexTable()
    73  		if vindexTable == nil {
    74  			break
    75  		}
    76  		tableName := node.Expr.(sqlparser.TableName)
    77  		// if the table name matches what the original is, then we do not need to rewrite
    78  		if sqlparser.Equals.IdentifierCS(vindexTable.Name, tableName.Name) {
    79  			break
    80  		}
    81  		// if there is no as clause, then move the routed table to the as clause.
    82  		// i.e
    83  		// routed as x -> original as x
    84  		// routed -> original as routed
    85  		if node.As.IsEmpty() {
    86  			node.As = tableName.Name
    87  		}
    88  		// replace the table name with the original table
    89  		tableName.Name = vindexTable.Name
    90  		node.Expr = tableName
    91  	case *sqlparser.Subquery:
    92  		err := rewriteSubquery(cursor, r, node)
    93  		if err != nil {
    94  			r.err = err
    95  		}
    96  	}
    97  	return true
    98  }
    99  
   100  func (r *rewriter) rewriteUp(cursor *sqlparser.Cursor) bool {
   101  	switch cursor.Node().(type) {
   102  	case *sqlparser.Subquery:
   103  		r.inSubquery--
   104  	}
   105  	return r.err == nil
   106  }
   107  
   108  func rewriteInSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.ComparisonExpr) error {
   109  	subq, exp := semantics.GetSubqueryAndOtherSide(node)
   110  	if subq == nil || exp == nil {
   111  		return nil
   112  	}
   113  
   114  	semTableSQ, found := r.semTable.SubqueryRef[subq]
   115  	if !found {
   116  		return vterrors.VT13001("got subquery that was not in the subq map")
   117  	}
   118  
   119  	r.inSubquery++
   120  	argName, hasValuesArg := r.reservedVars.ReserveSubQueryWithHasValues()
   121  	semTableSQ.SetArgName(argName)
   122  	semTableSQ.SetHasValuesArg(hasValuesArg)
   123  	cursor.Replace(semTableSQ)
   124  	return nil
   125  }
   126  
   127  func rewriteSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Subquery) error {
   128  	semTableSQ, found := r.semTable.SubqueryRef[node]
   129  	if !found {
   130  		return vterrors.VT13001("got subquery that was not in the subq map")
   131  	}
   132  	if semTableSQ.GetArgName() != "" || engine.PulloutOpcode(semTableSQ.OpCode) != engine.PulloutValue {
   133  		return nil
   134  	}
   135  	r.inSubquery++
   136  	argName := r.reservedVars.ReserveSubQuery()
   137  	semTableSQ.SetArgName(argName)
   138  	cursor.Replace(semTableSQ)
   139  	return nil
   140  }
   141  
   142  func (r *rewriter) rewriteExistsSubquery(cursor *sqlparser.Cursor, node *sqlparser.ExistsExpr) error {
   143  	semTableSQ, found := r.semTable.SubqueryRef[node.Subquery]
   144  	if !found {
   145  		return vterrors.VT13001("got subquery that was not in the subq map")
   146  	}
   147  
   148  	r.inSubquery++
   149  	hasValuesArg := r.reservedVars.ReserveHasValuesSubQuery()
   150  	semTableSQ.SetHasValuesArg(hasValuesArg)
   151  	cursor.Replace(semTableSQ)
   152  	return nil
   153  }
   154  
   155  func rewriteHavingClause(node *sqlparser.Select) {
   156  	if node.Having == nil {
   157  		return
   158  	}
   159  
   160  	selectExprMap := map[string]sqlparser.Expr{}
   161  	for _, selectExpr := range node.SelectExprs {
   162  		aliasedExpr, isAliased := selectExpr.(*sqlparser.AliasedExpr)
   163  		if !isAliased || aliasedExpr.As.IsEmpty() {
   164  			continue
   165  		}
   166  		selectExprMap[aliasedExpr.As.Lowered()] = aliasedExpr.Expr
   167  	}
   168  
   169  	// for each expression in the having clause, we check if it contains aggregation.
   170  	// if it does, we keep the expression in the having clause ; and if it does not
   171  	// and the expression is in the select list, we replace the expression by the one
   172  	// used in the select list and add it to the where clause instead of the having clause.
   173  	exprs := sqlparser.SplitAndExpression(nil, node.Having.Expr)
   174  	node.Having = nil
   175  	for _, expr := range exprs {
   176  		hasAggr := sqlparser.ContainsAggregation(expr)
   177  		if !hasAggr {
   178  			sqlparser.Rewrite(expr, func(cursor *sqlparser.Cursor) bool {
   179  				visitColName(cursor.Node(), selectExprMap, func(original sqlparser.Expr) {
   180  					if sqlparser.ContainsAggregation(original) {
   181  						hasAggr = true
   182  					}
   183  				})
   184  				return true
   185  			}, nil)
   186  		}
   187  		if hasAggr {
   188  			node.AddHaving(expr)
   189  		} else {
   190  			sqlparser.Rewrite(expr, func(cursor *sqlparser.Cursor) bool {
   191  				visitColName(cursor.Node(), selectExprMap, func(original sqlparser.Expr) {
   192  					cursor.Replace(original)
   193  				})
   194  				return true
   195  			}, nil)
   196  			node.AddWhere(expr)
   197  		}
   198  	}
   199  }
   200  func visitColName(cursor sqlparser.SQLNode, selectExprMap map[string]sqlparser.Expr, f func(original sqlparser.Expr)) {
   201  	switch x := cursor.(type) {
   202  	case *sqlparser.ColName:
   203  		if !x.Qualifier.IsEmpty() {
   204  			return
   205  		}
   206  		originalExpr, isInMap := selectExprMap[x.Name.Lowered()]
   207  		if isInMap {
   208  			f(originalExpr)
   209  		}
   210  		return
   211  	}
   212  }