vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/binder.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package semantics
    18  
    19  import (
    20  	"strings"
    21  
    22  	"vitess.io/vitess/go/vt/vtgate/engine"
    23  
    24  	"vitess.io/vitess/go/vt/sqlparser"
    25  )
    26  
    27  // binder is responsible for finding all the column references in
    28  // the query and bind them to the table that they belong to.
    29  // While doing this, it will also find the types for columns and
    30  // store these in the typer:s expression map
    31  type binder struct {
    32  	recursive   ExprDependencies
    33  	direct      ExprDependencies
    34  	scoper      *scoper
    35  	tc          *tableCollector
    36  	org         originable
    37  	typer       *typer
    38  	subqueryMap map[sqlparser.Statement][]*sqlparser.ExtractedSubquery
    39  	subqueryRef map[*sqlparser.Subquery]*sqlparser.ExtractedSubquery
    40  
    41  	// every table will have an entry in the outer map. it will point to a map with all the columns
    42  	// that this map is joined with using USING.
    43  	// This information is used to expand `*` correctly, and is not available post-analysis
    44  	usingJoinInfo map[TableSet]map[string]TableSet
    45  }
    46  
    47  func newBinder(scoper *scoper, org originable, tc *tableCollector, typer *typer) *binder {
    48  	return &binder{
    49  		recursive:     map[sqlparser.Expr]TableSet{},
    50  		direct:        map[sqlparser.Expr]TableSet{},
    51  		scoper:        scoper,
    52  		org:           org,
    53  		tc:            tc,
    54  		typer:         typer,
    55  		subqueryMap:   map[sqlparser.Statement][]*sqlparser.ExtractedSubquery{},
    56  		subqueryRef:   map[*sqlparser.Subquery]*sqlparser.ExtractedSubquery{},
    57  		usingJoinInfo: map[TableSet]map[string]TableSet{},
    58  	}
    59  }
    60  
    61  func (b *binder) up(cursor *sqlparser.Cursor) error {
    62  	switch node := cursor.Node().(type) {
    63  	case *sqlparser.Subquery:
    64  		currScope := b.scoper.currentScope()
    65  		sq, err := b.createExtractedSubquery(cursor, currScope, node)
    66  		if err != nil {
    67  			return err
    68  		}
    69  
    70  		b.subqueryMap[currScope.stmt] = append(b.subqueryMap[currScope.stmt], sq)
    71  		b.subqueryRef[node] = sq
    72  
    73  		b.setSubQueryDependencies(node, currScope)
    74  	case *sqlparser.JoinCondition:
    75  		currScope := b.scoper.currentScope()
    76  		for _, ident := range node.Using {
    77  			name := sqlparser.NewColName(ident.String())
    78  			deps, err := b.resolveColumn(name, currScope, true)
    79  			if err != nil {
    80  				return err
    81  			}
    82  			currScope.joinUsing[ident.Lowered()] = deps.direct
    83  		}
    84  		if len(node.Using) > 0 {
    85  			err := rewriteJoinUsing(currScope, node.Using, b.org)
    86  			if err != nil {
    87  				return err
    88  			}
    89  			node.Using = nil
    90  		}
    91  	case *sqlparser.ColName:
    92  		currentScope := b.scoper.currentScope()
    93  		deps, err := b.resolveColumn(node, currentScope, false)
    94  		if err != nil {
    95  			if deps.direct.IsEmpty() ||
    96  				!strings.HasSuffix(err.Error(), "is ambiguous") ||
    97  				!b.canRewriteUsingJoin(deps, node) {
    98  				return err
    99  			}
   100  
   101  			// if we got here it means we are dealing with a ColName that is involved in a JOIN USING.
   102  			// we do the rewriting of these ColName structs here because it would be difficult to copy all the
   103  			// needed state over to the earlyRewriter
   104  			deps, err = b.rewriteJoinUsingColName(deps, node, currentScope)
   105  			if err != nil {
   106  				return err
   107  			}
   108  		}
   109  		b.recursive[node] = deps.recursive
   110  		b.direct[node] = deps.direct
   111  		if deps.typ != nil {
   112  			b.typer.setTypeFor(node, *deps.typ)
   113  		}
   114  	case *sqlparser.CountStar:
   115  		b.bindCountStar(node)
   116  	}
   117  	return nil
   118  }
   119  
   120  func (b *binder) bindCountStar(node *sqlparser.CountStar) {
   121  	scope := b.scoper.currentScope()
   122  	var ts TableSet
   123  	for _, tbl := range scope.tables {
   124  		switch tbl := tbl.(type) {
   125  		case *vTableInfo:
   126  			for _, col := range tbl.cols {
   127  				if sqlparser.Equals.Expr(node, col) {
   128  					ts = ts.Merge(b.recursive[col])
   129  				}
   130  			}
   131  		default:
   132  			expr := tbl.getExpr()
   133  			if expr != nil {
   134  				setFor := b.tc.tableSetFor(expr)
   135  				ts = ts.Merge(setFor)
   136  			}
   137  		}
   138  	}
   139  	b.recursive[node] = ts
   140  	b.direct[node] = ts
   141  }
   142  
   143  func (b *binder) rewriteJoinUsingColName(deps dependency, node *sqlparser.ColName, currentScope *scope) (dependency, error) {
   144  	constituents := deps.recursive.Constituents()
   145  	if len(constituents) < 1 {
   146  		return dependency{}, NewError(Buggy, "we should not have a *ColName that depends on nothing")
   147  	}
   148  	newTbl := constituents[0]
   149  	infoFor, err := b.tc.tableInfoFor(newTbl)
   150  	if err != nil {
   151  		return dependency{}, err
   152  	}
   153  	alias := infoFor.getExpr().As
   154  	if alias.IsEmpty() {
   155  		name, err := infoFor.Name()
   156  		if err != nil {
   157  			return dependency{}, err
   158  		}
   159  		node.Qualifier = name
   160  	} else {
   161  		node.Qualifier = sqlparser.TableName{
   162  			Name: sqlparser.NewIdentifierCS(alias.String()),
   163  		}
   164  	}
   165  	deps, err = b.resolveColumn(node, currentScope, false)
   166  	if err != nil {
   167  		return dependency{}, err
   168  	}
   169  	return deps, nil
   170  }
   171  
   172  // canRewriteUsingJoin will return true when this ColName is safe to rewrite since it can only belong to a USING JOIN
   173  func (b *binder) canRewriteUsingJoin(deps dependency, node *sqlparser.ColName) bool {
   174  	tbls := deps.direct.Constituents()
   175  	colName := node.Name.Lowered()
   176  	for _, tbl := range tbls {
   177  		m := b.usingJoinInfo[tbl]
   178  		if _, found := m[colName]; !found {
   179  			return false
   180  		}
   181  	}
   182  	return true
   183  }
   184  
   185  // setSubQueryDependencies sets the correct dependencies for the subquery
   186  // the binder usually only sets the dependencies of ColNames, but we need to
   187  // handle the subquery dependencies differently, so they are set manually here
   188  // this method will only keep dependencies to tables outside the subquery
   189  func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery, currScope *scope) {
   190  	subqRecursiveDeps := b.recursive.dependencies(subq)
   191  	subqDirectDeps := b.direct.dependencies(subq)
   192  
   193  	tablesToKeep := EmptyTableSet()
   194  	sco := currScope
   195  	for sco != nil {
   196  		for _, table := range sco.tables {
   197  			tablesToKeep = tablesToKeep.Merge(table.getTableSet(b.org))
   198  		}
   199  		sco = sco.parent
   200  	}
   201  
   202  	b.recursive[subq] = subqRecursiveDeps.KeepOnly(tablesToKeep)
   203  	b.direct[subq] = subqDirectDeps.KeepOnly(tablesToKeep)
   204  }
   205  
   206  func (b *binder) createExtractedSubquery(cursor *sqlparser.Cursor, currScope *scope, subq *sqlparser.Subquery) (*sqlparser.ExtractedSubquery, error) {
   207  	if currScope.stmt == nil {
   208  		return nil, NewError(Buggy, "unable to bind subquery to select statement")
   209  	}
   210  
   211  	sq := &sqlparser.ExtractedSubquery{
   212  		Subquery: subq,
   213  		Original: subq,
   214  		OpCode:   int(engine.PulloutValue),
   215  	}
   216  
   217  	switch par := cursor.Parent().(type) {
   218  	case *sqlparser.ComparisonExpr:
   219  		switch par.Operator {
   220  		case sqlparser.InOp:
   221  			sq.OpCode = int(engine.PulloutIn)
   222  		case sqlparser.NotInOp:
   223  			sq.OpCode = int(engine.PulloutNotIn)
   224  		}
   225  		subq, exp := GetSubqueryAndOtherSide(par)
   226  		sq.Original = &sqlparser.ComparisonExpr{
   227  			Left:     exp,
   228  			Operator: par.Operator,
   229  			Right:    subq,
   230  		}
   231  		sq.OtherSide = exp
   232  	case *sqlparser.ExistsExpr:
   233  		sq.OpCode = int(engine.PulloutExists)
   234  		sq.Original = par
   235  	}
   236  	return sq, nil
   237  }
   238  
   239  func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allowMulti bool) (dependency, error) {
   240  	var thisDeps dependencies
   241  	for current != nil {
   242  		var err error
   243  		thisDeps, err = b.resolveColumnInScope(current, colName, allowMulti)
   244  		if err != nil {
   245  			err = makeAmbiguousError(colName, err)
   246  			if thisDeps == nil {
   247  				return dependency{}, err
   248  			}
   249  		}
   250  		if !thisDeps.empty() {
   251  			deps, thisErr := thisDeps.get()
   252  			if thisErr != nil {
   253  				err = makeAmbiguousError(colName, thisErr)
   254  			}
   255  			return deps, err
   256  		} else if err != nil {
   257  			return dependency{}, err
   258  		}
   259  		current = current.parent
   260  	}
   261  	return dependency{}, ShardedError{Inner: NewError(ColumnNotFound, colName)}
   262  }
   263  
   264  func (b *binder) resolveColumnInScope(current *scope, expr *sqlparser.ColName, allowMulti bool) (dependencies, error) {
   265  	var deps dependencies = &nothing{}
   266  	for _, table := range current.tables {
   267  		if !expr.Qualifier.IsEmpty() && !table.matches(expr.Qualifier) {
   268  			continue
   269  		}
   270  		thisDeps, err := table.dependencies(expr.Name.String(), b.org)
   271  		if err != nil {
   272  			return nil, err
   273  		}
   274  		deps = thisDeps.merge(deps, allowMulti)
   275  	}
   276  	if deps, isUncertain := deps.(*uncertain); isUncertain && deps.fail {
   277  		// if we have a failure from uncertain, we matched the column to multiple non-authoritative tables
   278  		return nil, ProjError{Inner: NewError(AmbiguousColumn, expr)}
   279  	}
   280  	return deps, nil
   281  }
   282  
   283  func makeAmbiguousError(colName *sqlparser.ColName, err error) error {
   284  	if err == ambigousErr {
   285  		err = NewError(AmbiguousColumn, colName)
   286  	}
   287  	return err
   288  }
   289  
   290  // GetSubqueryAndOtherSide returns the subquery and other side of a comparison, iff one of the sides is a SubQuery
   291  func GetSubqueryAndOtherSide(node *sqlparser.ComparisonExpr) (*sqlparser.Subquery, sqlparser.Expr) {
   292  	var subq *sqlparser.Subquery
   293  	var exp sqlparser.Expr
   294  	if lSubq, lIsSubq := node.Left.(*sqlparser.Subquery); lIsSubq {
   295  		subq = lSubq
   296  		exp = node.Right
   297  	} else if rSubq, rIsSubq := node.Right.(*sqlparser.Subquery); rIsSubq {
   298  		subq = rSubq
   299  		exp = node.Left
   300  	}
   301  	return subq, exp
   302  }