vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/scoper.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package semantics
    18  
    19  import (
    20  	"reflect"
    21  
    22  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    23  	"vitess.io/vitess/go/vt/vterrors"
    24  	"vitess.io/vitess/go/vt/vtgate/engine"
    25  
    26  	"vitess.io/vitess/go/vt/sqlparser"
    27  )
    28  
    29  type (
    30  	// scoper is responsible for figuring out the scoping for the query,
    31  	// and keeps the current scope when walking the tree
    32  	scoper struct {
    33  		rScope map[*sqlparser.Select]*scope
    34  		wScope map[*sqlparser.Select]*scope
    35  		scopes []*scope
    36  		org    originable
    37  		binder *binder
    38  
    39  		// These scopes are only used for rewriting ORDER BY 1 and GROUP BY 1
    40  		specialExprScopes map[*sqlparser.Literal]*scope
    41  	}
    42  
    43  	scope struct {
    44  		parent    *scope
    45  		stmt      sqlparser.Statement
    46  		tables    []TableInfo
    47  		isUnion   bool
    48  		joinUsing map[string]TableSet
    49  		stmtScope bool
    50  	}
    51  )
    52  
    53  func newScoper() *scoper {
    54  	return &scoper{
    55  		rScope:            map[*sqlparser.Select]*scope{},
    56  		wScope:            map[*sqlparser.Select]*scope{},
    57  		specialExprScopes: map[*sqlparser.Literal]*scope{},
    58  	}
    59  }
    60  
    61  func (s *scoper) down(cursor *sqlparser.Cursor) error {
    62  	node := cursor.Node()
    63  	switch node := node.(type) {
    64  	case *sqlparser.Update, *sqlparser.Delete:
    65  		currScope := newScope(s.currentScope())
    66  		currScope.stmtScope = true
    67  		s.push(currScope)
    68  
    69  		currScope.stmt = node.(sqlparser.Statement)
    70  	case *sqlparser.Select:
    71  		currScope := newScope(s.currentScope())
    72  		currScope.stmtScope = true
    73  		s.push(currScope)
    74  
    75  		// Needed for order by with Literal to find the Expression.
    76  		currScope.stmt = node
    77  
    78  		s.rScope[node] = currScope
    79  		s.wScope[node] = newScope(nil)
    80  	case sqlparser.TableExpr:
    81  		if isParentSelect(cursor) {
    82  			// when checking the expressions used in JOIN conditions, special rules apply where the ON expression
    83  			// can only see the two tables involved in the JOIN, and no other tables of that select statement.
    84  			// They are allowed to see the tables of the outer select query.
    85  			// To create this special context, we will find the parent scope of the select statement involved.
    86  			nScope := newScope(s.currentScope().findParentScopeOfStatement())
    87  			nScope.stmt = cursor.Parent().(*sqlparser.Select)
    88  			s.push(nScope)
    89  		}
    90  	case sqlparser.SelectExprs:
    91  		sel, parentIsSelect := cursor.Parent().(*sqlparser.Select)
    92  		if !parentIsSelect {
    93  			break
    94  		}
    95  
    96  		// adding a vTableInfo for each SELECT, so it can be used by GROUP BY, HAVING, ORDER BY
    97  		// the vTableInfo we are creating here should not be confused with derived tables' vTableInfo
    98  		wScope, exists := s.wScope[sel]
    99  		if !exists {
   100  			break
   101  		}
   102  		wScope.tables = []TableInfo{createVTableInfoForExpressions(node, s.currentScope().tables, s.org)}
   103  	case sqlparser.OrderBy:
   104  		if isParentSelectStatement(cursor) {
   105  			err := s.createSpecialScopePostProjection(cursor.Parent())
   106  			if err != nil {
   107  				return err
   108  			}
   109  			for _, order := range node {
   110  				lit := keepIntLiteral(order.Expr)
   111  				if lit != nil {
   112  					s.specialExprScopes[lit] = s.currentScope()
   113  				}
   114  			}
   115  		}
   116  	case sqlparser.GroupBy:
   117  		err := s.createSpecialScopePostProjection(cursor.Parent())
   118  		if err != nil {
   119  			return err
   120  		}
   121  		for _, expr := range node {
   122  			lit := keepIntLiteral(expr)
   123  			if lit != nil {
   124  				s.specialExprScopes[lit] = s.currentScope()
   125  			}
   126  		}
   127  	case *sqlparser.Where:
   128  		if node.Type != sqlparser.HavingClause {
   129  			break
   130  		}
   131  		return s.createSpecialScopePostProjection(cursor.Parent())
   132  	case *sqlparser.DerivedTable:
   133  		if node.Lateral {
   134  			return vterrors.VT12001("lateral derived tables")
   135  		}
   136  	}
   137  	return nil
   138  }
   139  
   140  func keepIntLiteral(e sqlparser.Expr) *sqlparser.Literal {
   141  	coll, ok := e.(*sqlparser.CollateExpr)
   142  	if ok {
   143  		e = coll.Expr
   144  	}
   145  	l, ok := e.(*sqlparser.Literal)
   146  	if !ok {
   147  		return nil
   148  	}
   149  	if l.Type != sqlparser.IntVal {
   150  		return nil
   151  	}
   152  	return l
   153  }
   154  
   155  func (s *scoper) up(cursor *sqlparser.Cursor) error {
   156  	node := cursor.Node()
   157  	switch node := node.(type) {
   158  	case sqlparser.OrderBy:
   159  		if isParentSelectStatement(cursor) {
   160  			s.popScope()
   161  		}
   162  	case *sqlparser.Select, sqlparser.GroupBy, *sqlparser.Update:
   163  		s.popScope()
   164  	case *sqlparser.Where:
   165  		if node.Type != sqlparser.HavingClause {
   166  			break
   167  		}
   168  		s.popScope()
   169  	case sqlparser.TableExpr:
   170  		if isParentSelect(cursor) {
   171  			curScope := s.currentScope()
   172  			s.popScope()
   173  			earlierScope := s.currentScope()
   174  			// copy curScope into the earlierScope
   175  			for _, table := range curScope.tables {
   176  				err := earlierScope.addTable(table)
   177  				if err != nil {
   178  					return err
   179  				}
   180  			}
   181  		}
   182  	}
   183  	return nil
   184  }
   185  
   186  func ValidAsMapKey(s sqlparser.SQLNode) bool {
   187  	return reflect.TypeOf(s).Comparable()
   188  }
   189  
   190  // createSpecialScopePostProjection is used for the special projection in ORDER BY, GROUP BY and HAVING
   191  func (s *scoper) createSpecialScopePostProjection(parent sqlparser.SQLNode) error {
   192  	switch parent := parent.(type) {
   193  	case *sqlparser.Select:
   194  		// In ORDER BY, GROUP BY and HAVING, we can see both the scope in the FROM part of the query, and the SELECT columns created
   195  		// so before walking the rest of the tree, we change the scope to match this behaviour
   196  		incomingScope := s.currentScope()
   197  		nScope := newScope(incomingScope)
   198  		nScope.tables = s.wScope[parent].tables
   199  		nScope.stmt = incomingScope.stmt
   200  		s.push(nScope)
   201  
   202  		if s.rScope[parent] != incomingScope {
   203  			return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: scope counts did not match")
   204  		}
   205  	case *sqlparser.Union:
   206  		nScope := newScope(nil)
   207  		nScope.isUnion = true
   208  		var tableInfo *vTableInfo
   209  
   210  		for i, sel := range sqlparser.GetAllSelects(parent) {
   211  			if i == 0 {
   212  				nScope.stmt = sel
   213  				tableInfo = createVTableInfoForExpressions(sel.SelectExprs, nil /*needed for star expressions*/, s.org)
   214  				nScope.tables = append(nScope.tables, tableInfo)
   215  			}
   216  			thisTableInfo := createVTableInfoForExpressions(sel.SelectExprs, nil /*needed for star expressions*/, s.org)
   217  			if len(tableInfo.cols) != len(thisTableInfo.cols) {
   218  				return engine.ErrWrongNumberOfColumnsInSelect
   219  			}
   220  			for i, col := range tableInfo.cols {
   221  				// at this stage, we don't store the actual dependencies, we only store the expressions.
   222  				// only later will we walk the expression tree and figure out the deps. so, we need to create a
   223  				// composite expression that contains all the expressions in the SELECTs that this UNION consists of
   224  				tableInfo.cols[i] = sqlparser.AndExpressions(col, thisTableInfo.cols[i])
   225  			}
   226  		}
   227  
   228  		s.push(nScope)
   229  	}
   230  	return nil
   231  }
   232  
   233  func (s *scoper) currentScope() *scope {
   234  	size := len(s.scopes)
   235  	if size == 0 {
   236  		return nil
   237  	}
   238  	return s.scopes[size-1]
   239  }
   240  
   241  func (s *scoper) push(sc *scope) {
   242  	s.scopes = append(s.scopes, sc)
   243  }
   244  
   245  func (s *scoper) popScope() {
   246  	usingMap := s.currentScope().prepareUsingMap()
   247  	for ts, m := range usingMap {
   248  		s.binder.usingJoinInfo[ts] = m
   249  	}
   250  	l := len(s.scopes) - 1
   251  	s.scopes = s.scopes[:l]
   252  }
   253  
   254  func newScope(parent *scope) *scope {
   255  	return &scope{
   256  		parent:    parent,
   257  		joinUsing: map[string]TableSet{},
   258  	}
   259  }
   260  
   261  func (s *scope) addTable(info TableInfo) error {
   262  	name, err := info.Name()
   263  	if err != nil {
   264  		return err
   265  	}
   266  	tblName := name.Name.String()
   267  	for _, table := range s.tables {
   268  		name, err := table.Name()
   269  		if err != nil {
   270  			return err
   271  		}
   272  
   273  		if tblName == name.Name.String() {
   274  			return vterrors.VT03013(name.Name.String())
   275  		}
   276  	}
   277  	s.tables = append(s.tables, info)
   278  	return nil
   279  }
   280  
   281  func (s *scope) prepareUsingMap() (result map[TableSet]map[string]TableSet) {
   282  	result = map[TableSet]map[string]TableSet{}
   283  	for colName, tss := range s.joinUsing {
   284  		for _, ts := range tss.Constituents() {
   285  			m := result[ts]
   286  			if m == nil {
   287  				m = map[string]TableSet{}
   288  			}
   289  			m[colName] = tss
   290  			result[ts] = m
   291  		}
   292  	}
   293  	return
   294  }
   295  
   296  // findParentScopeOfStatement finds the scope that belongs to a statement.
   297  func (s *scope) findParentScopeOfStatement() *scope {
   298  	if s.stmtScope {
   299  		return s.parent
   300  	}
   301  	if s.parent == nil {
   302  		return nil
   303  	}
   304  	return s.parent.findParentScopeOfStatement()
   305  }