vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/semantic_state.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package semantics
    18  
    19  import (
    20  	"vitess.io/vitess/go/mysql/collations"
    21  	"vitess.io/vitess/go/sqltypes"
    22  	"vitess.io/vitess/go/vt/key"
    23  	querypb "vitess.io/vitess/go/vt/proto/query"
    24  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    25  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    26  	"vitess.io/vitess/go/vt/vterrors"
    27  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    28  	"vitess.io/vitess/go/vt/vtgate/vindexes"
    29  
    30  	"vitess.io/vitess/go/vt/sqlparser"
    31  )
    32  
    33  type (
    34  	// TableInfo contains information about tables
    35  	TableInfo interface {
    36  		// Name returns the table name
    37  		Name() (sqlparser.TableName, error)
    38  
    39  		// GetVindexTable returns the vschema version of this TableInfo
    40  		GetVindexTable() *vindexes.Table
    41  
    42  		// IsInfSchema returns true if this table is information_schema
    43  		IsInfSchema() bool
    44  
    45  		// matches returns true if the provided table name matches this TableInfo
    46  		matches(name sqlparser.TableName) bool
    47  
    48  		// authoritative is true if we have exhaustive column information
    49  		authoritative() bool
    50  
    51  		// getExpr returns the AST struct behind this table
    52  		getExpr() *sqlparser.AliasedTableExpr
    53  
    54  		// getColumns returns the known column information for this table
    55  		getColumns() []ColumnInfo
    56  
    57  		dependencies(colName string, org originable) (dependencies, error)
    58  		getExprFor(s string) (sqlparser.Expr, error)
    59  		getTableSet(org originable) TableSet
    60  	}
    61  
    62  	// ColumnInfo contains information about columns
    63  	ColumnInfo struct {
    64  		Name string
    65  		Type Type
    66  	}
    67  
    68  	// ExprDependencies stores the tables that an expression depends on as a map
    69  	ExprDependencies map[sqlparser.Expr]TableSet
    70  
    71  	// SemTable contains semantic analysis information about the query.
    72  	SemTable struct {
    73  		Tables []TableInfo
    74  
    75  		// NotSingleRouteErr stores any errors that have to be generated if the query cannot be planned as a single route.
    76  		NotSingleRouteErr error
    77  		// NotUnshardedErr stores any errors that have to be generated if the query is not unsharded.
    78  		NotUnshardedErr error
    79  
    80  		// Recursive contains the dependencies from the expression to the actual tables
    81  		// in the query (i.e. not including derived tables). If an expression is a column on a derived table,
    82  		// this map will contain the accumulated dependencies for the column expression inside the derived table
    83  		Recursive ExprDependencies
    84  
    85  		// Direct keeps information about the closest dependency for an expression.
    86  		// It does not recurse inside derived tables and the like to find the original dependencies
    87  		Direct ExprDependencies
    88  
    89  		ExprTypes   map[sqlparser.Expr]Type
    90  		selectScope map[*sqlparser.Select]*scope
    91  		Comments    *sqlparser.ParsedComments
    92  		SubqueryMap map[sqlparser.Statement][]*sqlparser.ExtractedSubquery
    93  		SubqueryRef map[*sqlparser.Subquery]*sqlparser.ExtractedSubquery
    94  
    95  		// ColumnEqualities is used to enable transitive closures
    96  		// if a == b and b == c then a == c
    97  		ColumnEqualities map[columnName][]sqlparser.Expr
    98  
    99  		// DefaultCollation is the default collation for this query, which is usually
   100  		// inherited from the connection's default collation.
   101  		Collation collations.ID
   102  
   103  		Warning string
   104  
   105  		// ExpandedColumns is a map of all the added columns for a given table.
   106  		ExpandedColumns map[sqlparser.TableName][]*sqlparser.ColName
   107  
   108  		comparator *sqlparser.Comparator
   109  	}
   110  
   111  	columnName struct {
   112  		Table      TableSet
   113  		ColumnName string
   114  	}
   115  
   116  	// SchemaInformation is used tp provide table information from Vschema.
   117  	SchemaInformation interface {
   118  		FindTableOrVindex(tablename sqlparser.TableName) (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error)
   119  		ConnCollation() collations.ID
   120  	}
   121  )
   122  
   123  var (
   124  	// ErrNotSingleTable refers to an error happening when something should be used only for single tables
   125  	ErrNotSingleTable = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] should only be used for single tables")
   126  )
   127  
   128  // CopyDependencies copies the dependencies from one expression into the other
   129  func (st *SemTable) CopyDependencies(from, to sqlparser.Expr) {
   130  	st.Recursive[to] = st.RecursiveDeps(from)
   131  	st.Direct[to] = st.DirectDeps(from)
   132  }
   133  
   134  // EmptySemTable creates a new empty SemTable
   135  func EmptySemTable() *SemTable {
   136  	return &SemTable{
   137  		Recursive:        map[sqlparser.Expr]TableSet{},
   138  		Direct:           map[sqlparser.Expr]TableSet{},
   139  		ColumnEqualities: map[columnName][]sqlparser.Expr{},
   140  	}
   141  }
   142  
   143  // TableSetFor returns the bitmask for this particular table
   144  func (st *SemTable) TableSetFor(t *sqlparser.AliasedTableExpr) TableSet {
   145  	for idx, t2 := range st.Tables {
   146  		if t == t2.getExpr() {
   147  			return SingleTableSet(idx)
   148  		}
   149  	}
   150  	return EmptyTableSet()
   151  }
   152  
   153  // ReplaceTableSetFor replaces the given single TabletSet with the new *sqlparser.AliasedTableExpr
   154  func (st *SemTable) ReplaceTableSetFor(id TableSet, t *sqlparser.AliasedTableExpr) {
   155  	if id.NumberOfTables() != 1 {
   156  		// This is probably a derived table
   157  		return
   158  	}
   159  	tblOffset := id.TableOffset()
   160  	if tblOffset > len(st.Tables) {
   161  		// This should not happen and is probably a bug, but the output query will still work fine
   162  		return
   163  	}
   164  	switch tbl := st.Tables[id.TableOffset()].(type) {
   165  	case *RealTable:
   166  		tbl.ASTNode = t
   167  	case *DerivedTable:
   168  		tbl.ASTNode = t
   169  	}
   170  }
   171  
   172  // TableInfoFor returns the table info for the table set. It should contains only single table.
   173  func (st *SemTable) TableInfoFor(id TableSet) (TableInfo, error) {
   174  	offset := id.TableOffset()
   175  	if offset < 0 {
   176  		return nil, ErrNotSingleTable
   177  	}
   178  	return st.Tables[offset], nil
   179  }
   180  
   181  // RecursiveDeps return the table dependencies of the expression.
   182  func (st *SemTable) RecursiveDeps(expr sqlparser.Expr) TableSet {
   183  	return st.Recursive.dependencies(expr)
   184  }
   185  
   186  // DirectDeps return the table dependencies of the expression.
   187  func (st *SemTable) DirectDeps(expr sqlparser.Expr) TableSet {
   188  	return st.Direct.dependencies(expr)
   189  }
   190  
   191  // AddColumnEquality adds a relation of the given colName to the ColumnEqualities map
   192  func (st *SemTable) AddColumnEquality(colName *sqlparser.ColName, expr sqlparser.Expr) {
   193  	ts := st.Direct.dependencies(colName)
   194  	columnName := columnName{
   195  		Table:      ts,
   196  		ColumnName: colName.Name.String(),
   197  	}
   198  	elem := st.ColumnEqualities[columnName]
   199  	elem = append(elem, expr)
   200  	st.ColumnEqualities[columnName] = elem
   201  }
   202  
   203  // GetExprAndEqualities returns a slice containing the given expression, and it's known equalities if any
   204  func (st *SemTable) GetExprAndEqualities(expr sqlparser.Expr) []sqlparser.Expr {
   205  	result := []sqlparser.Expr{expr}
   206  	switch expr := expr.(type) {
   207  	case *sqlparser.ColName:
   208  		table := st.DirectDeps(expr)
   209  		k := columnName{Table: table, ColumnName: expr.Name.String()}
   210  		result = append(result, st.ColumnEqualities[k]...)
   211  	}
   212  	return result
   213  }
   214  
   215  // TableInfoForExpr returns the table info of the table that this expression depends on.
   216  // Careful: this only works for expressions that have a single table dependency
   217  func (st *SemTable) TableInfoForExpr(expr sqlparser.Expr) (TableInfo, error) {
   218  	return st.TableInfoFor(st.Direct.dependencies(expr))
   219  }
   220  
   221  // GetSelectTables returns the table in the select.
   222  func (st *SemTable) GetSelectTables(node *sqlparser.Select) []TableInfo {
   223  	scope := st.selectScope[node]
   224  	return scope.tables
   225  }
   226  
   227  // AddExprs adds new select exprs to the SemTable.
   228  func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.SelectExprs) {
   229  	tableSet := st.TableSetFor(tbl)
   230  	for _, col := range cols {
   231  		st.Recursive[col.(*sqlparser.AliasedExpr).Expr] = tableSet
   232  	}
   233  }
   234  
   235  // TypeFor returns the type of expressions in the query
   236  func (st *SemTable) TypeFor(e sqlparser.Expr) *querypb.Type {
   237  	typ, found := st.ExprTypes[e]
   238  	if found {
   239  		return &typ.Type
   240  	}
   241  	return nil
   242  }
   243  
   244  // CollationForExpr returns the collation name of expressions in the query
   245  func (st *SemTable) CollationForExpr(e sqlparser.Expr) collations.ID {
   246  	typ, found := st.ExprTypes[e]
   247  	if found {
   248  		return typ.Collation
   249  	}
   250  	return collations.Unknown
   251  }
   252  
   253  // NeedsWeightString returns true if the given expression needs weight_string to do safe comparisons
   254  func (st *SemTable) NeedsWeightString(e sqlparser.Expr) bool {
   255  	typ, found := st.ExprTypes[e]
   256  	if !found {
   257  		return true
   258  	}
   259  	return typ.Collation == collations.Unknown && !sqltypes.IsNumber(typ.Type)
   260  }
   261  
   262  func (st *SemTable) DefaultCollation() collations.ID {
   263  	return st.Collation
   264  }
   265  
   266  // dependencies return the table dependencies of the expression. This method finds table dependencies recursively
   267  func (d ExprDependencies) dependencies(expr sqlparser.Expr) (deps TableSet) {
   268  	if ValidAsMapKey(expr) {
   269  		// we have something that could live in the cache
   270  		var found bool
   271  		deps, found = d[expr]
   272  		if found {
   273  			return deps
   274  		}
   275  		defer func() {
   276  			d[expr] = deps
   277  		}()
   278  	}
   279  
   280  	// During the original semantic analysis, all ColNames were found and bound to the corresponding tables
   281  	// Here, we'll walk the expression tree and look to see if we can find any sub-expressions
   282  	// that have already set dependencies.
   283  	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   284  		expr, ok := node.(sqlparser.Expr)
   285  		if !ok || !ValidAsMapKey(expr) {
   286  			// if this is not an expression, or it is an expression we can't use as a map-key,
   287  			// just carry on down the tree
   288  			return true, nil
   289  		}
   290  
   291  		if extracted, ok := expr.(*sqlparser.ExtractedSubquery); ok {
   292  			if extracted.OtherSide != nil {
   293  				set := d.dependencies(extracted.OtherSide)
   294  				deps = deps.Merge(set)
   295  			}
   296  			return false, nil
   297  		}
   298  		set, found := d[expr]
   299  		deps = deps.Merge(set)
   300  
   301  		// if we found a cached value, there is no need to continue down to visit children
   302  		return !found, nil
   303  	}, expr)
   304  
   305  	return deps
   306  }
   307  
   308  // RewriteDerivedTableExpression rewrites all the ColName instances in the supplied expression with
   309  // the expressions behind the column definition of the derived table
   310  // SELECT foo FROM (SELECT id+42 as foo FROM user) as t
   311  // We need `foo` to be translated to `id+42` on the inside of the derived table
   312  func RewriteDerivedTableExpression(expr sqlparser.Expr, vt TableInfo) sqlparser.Expr {
   313  	return sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
   314  		node, ok := cursor.Node().(*sqlparser.ColName)
   315  		if !ok {
   316  			return
   317  		}
   318  		exp, err := vt.getExprFor(node.Name.String())
   319  		if err == nil {
   320  			cursor.Replace(exp)
   321  			return
   322  		}
   323  
   324  		// cloning the expression and removing the qualifier
   325  		col := *node
   326  		col.Qualifier = sqlparser.TableName{}
   327  		cursor.Replace(&col)
   328  
   329  	}, nil).(sqlparser.Expr)
   330  }
   331  
   332  // FindSubqueryReference goes over the sub queries and searches for it by value equality instead of reference equality
   333  func (st *SemTable) FindSubqueryReference(subquery *sqlparser.Subquery) *sqlparser.ExtractedSubquery {
   334  	for foundSubq, extractedSubquery := range st.SubqueryRef {
   335  		if sqlparser.Equals.RefOfSubquery(subquery, foundSubq) {
   336  			return extractedSubquery
   337  		}
   338  	}
   339  	return nil
   340  }
   341  
   342  // GetSubqueryNeedingRewrite returns a list of sub-queries that need to be rewritten
   343  func (st *SemTable) GetSubqueryNeedingRewrite() []*sqlparser.ExtractedSubquery {
   344  	var res []*sqlparser.ExtractedSubquery
   345  	for _, extractedSubquery := range st.SubqueryRef {
   346  		if extractedSubquery.NeedsRewrite {
   347  			res = append(res, extractedSubquery)
   348  		}
   349  	}
   350  	return res
   351  }
   352  
   353  // CopyExprInfo lookups src in the ExprTypes map and, if a key is found, assign
   354  // the corresponding Type value of src to dest.
   355  func (st *SemTable) CopyExprInfo(src, dest sqlparser.Expr) {
   356  	srcType, found := st.ExprTypes[src]
   357  	if found {
   358  		st.ExprTypes[dest] = srcType
   359  	}
   360  }
   361  
   362  var _ evalengine.TranslationLookup = (*SemTable)(nil)
   363  
   364  var columnNotSupportedErr = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "column access not supported here")
   365  
   366  // ColumnLookup implements the TranslationLookup interface
   367  func (st *SemTable) ColumnLookup(*sqlparser.ColName) (int, error) {
   368  	return 0, columnNotSupportedErr
   369  }
   370  
   371  // SingleUnshardedKeyspace returns the single keyspace if all tables in the query are in the same, unsharded keyspace
   372  func (st *SemTable) SingleUnshardedKeyspace() (*vindexes.Keyspace, []*vindexes.Table) {
   373  	var ks *vindexes.Keyspace
   374  	var tables []*vindexes.Table
   375  	for _, table := range st.Tables {
   376  		vindexTable := table.GetVindexTable()
   377  
   378  		if vindexTable == nil {
   379  			_, isDT := table.getExpr().Expr.(*sqlparser.DerivedTable)
   380  			if isDT {
   381  				// derived tables are ok, as long as all real tables are from the same unsharded keyspace
   382  				// we check the real tables inside the derived table as well for same unsharded keyspace.
   383  				continue
   384  			}
   385  			return nil, nil
   386  		}
   387  		if vindexTable.Type != "" {
   388  			// A reference table is not an issue when seeing if a query is going to an unsharded keyspace
   389  			if vindexTable.Type == vindexes.TypeReference {
   390  				continue
   391  			}
   392  			return nil, nil
   393  		}
   394  		name, ok := table.getExpr().Expr.(sqlparser.TableName)
   395  		if !ok {
   396  			return nil, nil
   397  		}
   398  		if name.Name.String() != vindexTable.Name.String() {
   399  			// this points to a table alias. safer to not shortcut
   400  			return nil, nil
   401  		}
   402  		this := vindexTable.Keyspace
   403  		if this == nil || this.Sharded {
   404  			return nil, nil
   405  		}
   406  		if ks == nil {
   407  			ks = this
   408  		} else {
   409  			if ks != this {
   410  				return nil, nil
   411  			}
   412  		}
   413  		tables = append(tables, vindexTable)
   414  	}
   415  	return ks, tables
   416  }
   417  
   418  // EqualsExpr compares two expressions using the semantic analysis information.
   419  // This means that we use the binding info to recognize that two ColName's can point to the same
   420  // table column even though they are written differently. Example would be the `foobar` column in the following query:
   421  // `SELECT foobar FROM tbl ORDER BY tbl.foobar`
   422  // The expression in the select list is not equal to the one in the ORDER BY,
   423  // but they point to the same column and would be considered equal by this method
   424  func (st *SemTable) EqualsExpr(a, b sqlparser.Expr) bool {
   425  	return st.ASTEquals().Expr(a, b)
   426  }
   427  
   428  func (st *SemTable) ContainsExpr(e sqlparser.Expr, expres []sqlparser.Expr) bool {
   429  	for _, expre := range expres {
   430  		if st.EqualsExpr(e, expre) {
   431  			return true
   432  		}
   433  	}
   434  	return false
   435  }
   436  
   437  // AndExpressions ands together two or more expressions, minimising the expr when possible
   438  func (st *SemTable) AndExpressions(exprs ...sqlparser.Expr) sqlparser.Expr {
   439  	switch len(exprs) {
   440  	case 0:
   441  		return nil
   442  	case 1:
   443  		return exprs[0]
   444  	default:
   445  		result := (sqlparser.Expr)(nil)
   446  	outer:
   447  		// we'll loop and remove any duplicates
   448  		for i, expr := range exprs {
   449  			if expr == nil {
   450  				continue
   451  			}
   452  			if result == nil {
   453  				result = expr
   454  				continue outer
   455  			}
   456  
   457  			for j := 0; j < i; j++ {
   458  				if st.EqualsExpr(expr, exprs[j]) {
   459  					continue outer
   460  				}
   461  			}
   462  			result = &sqlparser.AndExpr{Left: result, Right: expr}
   463  		}
   464  		return result
   465  	}
   466  }
   467  
   468  // ASTEquals returns a sqlparser.Comparator that uses the semantic information in this SemTable to
   469  // explicitly compare column names for equality.
   470  func (st *SemTable) ASTEquals() *sqlparser.Comparator {
   471  	if st.comparator == nil {
   472  		st.comparator = &sqlparser.Comparator{
   473  			RefOfColName_: func(a, b *sqlparser.ColName) bool {
   474  				aDeps := st.RecursiveDeps(a)
   475  				bDeps := st.RecursiveDeps(b)
   476  				if aDeps != bDeps && (aDeps.IsEmpty() || bDeps.IsEmpty()) {
   477  					// if we don't know, we don't know
   478  					return sqlparser.Equals.RefOfColName(a, b)
   479  				}
   480  				return a.Name.Equal(b.Name) && aDeps == bDeps
   481  			},
   482  		}
   483  	}
   484  	return st.comparator
   485  }