github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/apply_indexes_from_outer_scope.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  	"github.com/dolthub/go-mysql-server/sql/plan"
    24  	"github.com/dolthub/go-mysql-server/sql/transform"
    25  )
    26  
    27  // applyIndexesFromOuterScope attempts to apply an indexed lookup to a subquery using variables from the outer scope.
    28  // It functions similarly to generateIndexScans, in that it applies an index to a table. But unlike that function, it must
    29  // apply, effectively, an indexed join between two tables, one of which is defined in the outer scope. This is similar
    30  // to the process in the join analyzer.
    31  func applyIndexesFromOuterScope(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    32  	if scope.IsEmpty() {
    33  		return n, transform.SameTree, nil
    34  	}
    35  
    36  	// this isn't good enough: we need to consider aliases defined in the outer scope as well for this analysis
    37  	tableAliases, err := getTableAliases(n, scope)
    38  	if err != nil {
    39  		return nil, transform.SameTree, err
    40  	}
    41  
    42  	indexLookups, err := getOuterScopeIndexes(ctx, a, n, scope, tableAliases)
    43  	if err != nil {
    44  		return nil, transform.SameTree, err
    45  	}
    46  
    47  	if len(indexLookups) == 0 {
    48  		return n, transform.SameTree, nil
    49  	}
    50  
    51  	childSelector := func(c transform.Context) bool {
    52  		switch c.Parent.(type) {
    53  		// We can't push any indexes down a branch that have already had an index pushed down it
    54  		case *plan.IndexedTableAccess:
    55  			return false
    56  		}
    57  		return true
    58  	}
    59  
    60  	// replace the tables with possible index lookups with indexed access
    61  	allSame := transform.SameTree
    62  	sameN := transform.SameTree
    63  	for _, idxLookup := range indexLookups {
    64  		n, sameN, err = transform.NodeWithCtx(n, childSelector, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
    65  			switch n := c.Node.(type) {
    66  			case *plan.IndexedTableAccess:
    67  				return n, transform.SameTree, nil
    68  			case *plan.TableAlias:
    69  				if strings.ToLower(n.Name()) == idxLookup.table {
    70  					return pushdownIndexToTable(ctx, a, n, idxLookup.index, idxLookup.keyExpr, idxLookup.nullmask)
    71  				}
    72  				return n, transform.SameTree, nil
    73  			case sql.TableNode:
    74  				if strings.ToLower(n.Name()) == idxLookup.table {
    75  					return pushdownIndexToTable(ctx, a, n, idxLookup.index, idxLookup.keyExpr, idxLookup.nullmask)
    76  				}
    77  				return n, transform.SameTree, nil
    78  			default:
    79  				return n, transform.SameTree, nil
    80  			}
    81  		})
    82  		allSame = allSame && sameN
    83  		if err != nil {
    84  			return nil, transform.SameTree, err
    85  		}
    86  	}
    87  
    88  	return n, allSame, nil
    89  }
    90  
    91  // pushdownIndexToTable attempts to push the index given down to the table given, if it implements
    92  // sql.IndexAddressableTable
    93  func pushdownIndexToTable(ctx *sql.Context, a *Analyzer, tableNode sql.NameableNode, index sql.Index, keyExpr []sql.Expression, nullmask []bool) (sql.Node, transform.TreeIdentity, error) {
    94  	return transform.Node(tableNode, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    95  		switch n := n.(type) {
    96  		case *plan.IndexedTableAccess:
    97  		case sql.TableNode:
    98  			table := getTable(tableNode)
    99  			if table == nil {
   100  				return n, transform.SameTree, nil
   101  			}
   102  			if _, ok := table.(sql.IndexAddressableTable); ok {
   103  				a.Log("table %q transformed with pushdown of index", tableNode.Name())
   104  				lb := plan.NewLookupBuilder(index, keyExpr, nullmask)
   105  
   106  				ret, err := plan.NewIndexedAccessForTableNode(n, lb)
   107  				if err != nil {
   108  					return nil, transform.SameTree, err
   109  				}
   110  
   111  				return ret, transform.NewTree, nil
   112  			}
   113  		}
   114  		return n, transform.SameTree, nil
   115  	})
   116  }
   117  
   118  type subqueryIndexLookup struct {
   119  	table    string
   120  	keyExpr  []sql.Expression
   121  	nullmask []bool
   122  	index    sql.Index
   123  }
   124  
   125  func getOuterScopeIndexes(
   126  	ctx *sql.Context,
   127  	a *Analyzer,
   128  	node sql.Node,
   129  	scope *plan.Scope,
   130  	tableAliases TableAliases,
   131  ) ([]subqueryIndexLookup, error) {
   132  	indexSpan, ctx := ctx.Span("getOuterScopeIndexes")
   133  	defer indexSpan.End()
   134  
   135  	var indexes map[string]sql.Index
   136  	var exprsByTable joinExpressionsByTable
   137  
   138  	var err error
   139  	transform.Inspect(node, func(node sql.Node) bool {
   140  		switch node := node.(type) {
   141  		case *plan.Filter:
   142  
   143  			var indexAnalyzer *indexAnalyzer
   144  			indexAnalyzer, err = newIndexAnalyzerForNode(ctx, node.Child)
   145  			if err != nil {
   146  				return false
   147  			}
   148  			defer indexAnalyzer.releaseUsedIndexes()
   149  
   150  			indexes, exprsByTable, err = getSubqueryIndexes(ctx, a, node.Expression, scope, indexAnalyzer, tableAliases)
   151  			if err != nil {
   152  				return false
   153  			}
   154  		}
   155  
   156  		return true
   157  	})
   158  
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	if len(indexes) == 0 {
   163  		return nil, nil
   164  	}
   165  
   166  	var lookups []subqueryIndexLookup
   167  
   168  	for table, idx := range indexes {
   169  		if exprsByTable[table] != nil {
   170  			// creating a key expression can fail in some cases, just skip this table
   171  			keyExpr, nullmask, err := createIndexKeyExpr(ctx, idx, exprsByTable[table], tableAliases)
   172  			if err != nil {
   173  				return nil, err
   174  			}
   175  			if keyExpr == nil {
   176  				continue
   177  			}
   178  
   179  			lookups = append(lookups, subqueryIndexLookup{
   180  				table:    table,
   181  				keyExpr:  keyExpr,
   182  				nullmask: nullmask,
   183  				index:    idx,
   184  			})
   185  		}
   186  	}
   187  
   188  	return lookups, nil
   189  }
   190  
   191  // createIndexKeyExpr returns a slice of expressions to be used when creating an index lookup key for the table given.
   192  func createIndexKeyExpr(ctx *sql.Context, idx sql.Index, joinExprs []*joinColExpr, tableAliases TableAliases) ([]sql.Expression, []bool, error) {
   193  	// To allow partial matching, we need to see if the expressions are a prefix of the index
   194  	idxExpressions := idx.Expressions()
   195  	normalizedJoinExprStrs := make([]string, len(joinExprs))
   196  	for i := range joinExprs {
   197  		normalizedJoinExprStrs[i] = normalizeExpression(tableAliases, joinExprs[i].colExpr).String()
   198  	}
   199  	if ok, prefixCount := exprsAreIndexSubset(normalizedJoinExprStrs, idxExpressions); !ok || prefixCount != len(normalizedJoinExprStrs) {
   200  		return nil, nil, nil
   201  	}
   202  	// Since the expressions are a prefix, we cut the index expressions we are using to just those involved
   203  	idxPrefixExpressions := idxExpressions[:len(normalizedJoinExprStrs)]
   204  
   205  	keyExprs := make([]sql.Expression, len(idxPrefixExpressions))
   206  	nullmask := make([]bool, len(idxPrefixExpressions))
   207  IndexExpressions:
   208  	for i, idxExpr := range idxPrefixExpressions {
   209  		for j := range joinExprs {
   210  			if strings.EqualFold(idxExpr, normalizedJoinExprStrs[j]) {
   211  				keyExprs[i] = joinExprs[j].comparand
   212  				nullmask[i] = joinExprs[j].matchnull
   213  				continue IndexExpressions
   214  			}
   215  		}
   216  
   217  		return nil, nil, fmt.Errorf("index `%s` reported having prefix of `%v` but has expressions `%v`",
   218  			idx.ID(), normalizedJoinExprStrs, idxExpressions)
   219  	}
   220  
   221  	return keyExprs, nullmask, nil
   222  }
   223  
   224  func getSubqueryIndexes(
   225  	ctx *sql.Context,
   226  	a *Analyzer,
   227  	e sql.Expression,
   228  	scope *plan.Scope,
   229  	ia *indexAnalyzer,
   230  	tableAliases TableAliases,
   231  ) (map[string]sql.Index, joinExpressionsByTable, error) {
   232  	// build a list of candidate predicate expressions, those that might be used for an index lookup
   233  	var candidatePredicates []sql.Expression
   234  
   235  	for _, e := range expression.SplitConjunction(e) {
   236  		// We are only interested in expressions that involve an outer scope variable (those whose index is less than the
   237  		// scope length)
   238  		isScopeExpr := false
   239  		sql.Inspect(e, func(e sql.Expression) bool {
   240  			if gf, ok := e.(*expression.GetField); ok {
   241  				if scope.Correlated().Contains(sql.ColumnId(gf.Id())) {
   242  					isScopeExpr = true
   243  					return false
   244  				}
   245  			}
   246  			return true
   247  		})
   248  
   249  		if isScopeExpr {
   250  			candidatePredicates = append(candidatePredicates, e)
   251  		}
   252  	}
   253  
   254  	tablesInScope := tablesInScope(scope)
   255  
   256  	// group them by the table they reference
   257  	// TODO: this only works for equality, make it work for other operands
   258  	exprsByTable := joinExprsByTable(candidatePredicates)
   259  
   260  	result := make(map[string]sql.Index)
   261  	// For every predicate involving a table in the outer scope, see if there's an index lookup possible on its comparands
   262  	// (the tables in this scope)
   263  	for _, scopeTable := range tablesInScope {
   264  		indexCols := exprsByTable[scopeTable]
   265  		if indexCols != nil {
   266  			col := indexCols[0].comparandCol
   267  			idx := ia.MatchingIndex(ctx, col.Table(), col.Database(), normalizeExpressions(tableAliases, extractComparands(indexCols)...)...)
   268  			if idx != nil {
   269  				result[indexCols[0].comparandCol.Table()] = idx
   270  			}
   271  		}
   272  	}
   273  
   274  	return result, exprsByTable, nil
   275  }
   276  
   277  func tablesInScope(scope *plan.Scope) []string {
   278  	tables := make(map[string]bool)
   279  	for _, node := range scope.InnerToOuter() {
   280  		for _, col := range Schemas(node.Children()) {
   281  			tables[col.Source] = true
   282  		}
   283  	}
   284  	var tableSlice []string
   285  	for table := range tables {
   286  		tableSlice = append(tableSlice, table)
   287  	}
   288  	return tableSlice
   289  }
   290  
   291  // Schemas returns the Schemas for the nodes given appended in to a single one
   292  func Schemas(nodes []sql.Node) sql.Schema {
   293  	var schema sql.Schema
   294  	for _, n := range nodes {
   295  		schema = append(schema, n.Schema()...)
   296  	}
   297  	return schema
   298  }
   299  
   300  // A joinColExpr  captures a GetField expression used in a comparison, as well as some additional contextual
   301  // information. Example, for the base expression col1 + 1 > col2 - 1:
   302  // col refers to `col1`
   303  // colExpr refers to `col1 + 1`
   304  // comparand refers to `col2 - 1`
   305  // comparandCol refers to `col2`
   306  // comparison refers to `col1 + 1 > col2 - 1`
   307  // indexes contains any indexes onto col1's table that can be used during the join
   308  // TODO: rename
   309  type joinColExpr struct {
   310  	// The field (column) being evaluated, which may not be the entire term in the comparison
   311  	col *expression.GetField
   312  	// The entire expression on this side of the comparison
   313  	colExpr sql.Expression
   314  	// The expression this field is being compared to (the other term in the comparison)
   315  	comparand sql.Expression
   316  	// The other field (column) this field is being compared to (the other term in the comparison)
   317  	comparandCol *expression.GetField
   318  	// The comparison expression in which this joinColExpr is one term
   319  	comparison sql.Expression
   320  	// Whether the comparison expression will match null or not.
   321  	matchnull bool
   322  }
   323  
   324  type joinColExprs []*joinColExpr
   325  type joinExpressionsByTable map[string]joinColExprs
   326  
   327  // extractComparands returns the comparand Expressions in the slice of joinColExpr given.
   328  func extractComparands(colExprs []*joinColExpr) []sql.Expression {
   329  	result := make([]sql.Expression, len(colExprs))
   330  	for i, expr := range colExprs {
   331  		result[i] = expr.comparand
   332  	}
   333  	return result
   334  }
   335  
   336  // joinExprsByTable returns a map of the expressions given keyed by their table name.
   337  func joinExprsByTable(exprs []sql.Expression) joinExpressionsByTable {
   338  	var result = make(joinExpressionsByTable)
   339  
   340  	for _, expr := range exprs {
   341  		leftExpr, rightExpr := extractJoinColumnExpr(expr)
   342  		if leftExpr != nil {
   343  			result[leftExpr.col.Table()] = append(result[leftExpr.col.Table()], leftExpr)
   344  		}
   345  
   346  		if rightExpr != nil {
   347  			result[rightExpr.col.Table()] = append(result[rightExpr.col.Table()], rightExpr)
   348  		}
   349  	}
   350  
   351  	return result
   352  }
   353  
   354  // extractJoinColumnExpr extracts a pair of joinColExprs from a join condition, one each for the left and right side of
   355  // the expression. Returns nils if either side of the expression doesn't reference a table column.
   356  // Both sides have to have getField (this is currently invalid: a.x + b.y = 1)
   357  func extractJoinColumnExpr(e sql.Expression) (leftCol *joinColExpr, rightCol *joinColExpr) {
   358  	switch e := e.(type) {
   359  	case *expression.Equals, *expression.NullSafeEquals:
   360  		cmp := e.(expression.Comparer)
   361  		left, right := cmp.Left(), cmp.Right()
   362  		if isEvaluable(left) || isEvaluable(right) {
   363  			return nil, nil
   364  		}
   365  
   366  		leftField, rightField := expression.ExtractGetField(left), expression.ExtractGetField(right)
   367  		if leftField == nil || rightField == nil {
   368  			return nil, nil
   369  		}
   370  
   371  		_, matchnull := e.(*expression.NullSafeEquals)
   372  
   373  		leftCol = &joinColExpr{
   374  			col:          leftField,
   375  			colExpr:      left,
   376  			comparand:    right,
   377  			comparandCol: rightField,
   378  			comparison:   cmp,
   379  			matchnull:    matchnull,
   380  		}
   381  		rightCol = &joinColExpr{
   382  			col:          rightField,
   383  			colExpr:      right,
   384  			comparand:    left,
   385  			comparandCol: leftField,
   386  			comparison:   cmp,
   387  			matchnull:    matchnull,
   388  		}
   389  		return leftCol, rightCol
   390  	default:
   391  		return nil, nil
   392  	}
   393  }
   394  
   395  func containsColumns(e sql.Expression) bool {
   396  	var result bool
   397  	sql.Inspect(e, func(e sql.Expression) bool {
   398  		_, ok1 := e.(*expression.GetField)
   399  		_, ok2 := e.(*expression.UnresolvedColumn)
   400  		if ok1 || ok2 {
   401  			result = true
   402  			return false
   403  		}
   404  		return true
   405  	})
   406  	return result
   407  }
   408  
   409  func containsSubquery(e sql.Expression) bool {
   410  	var result bool
   411  	sql.Inspect(e, func(e sql.Expression) bool {
   412  		if _, ok := e.(*plan.Subquery); ok {
   413  			result = true
   414  			return false
   415  		}
   416  		return true
   417  	})
   418  	return result
   419  }
   420  
   421  func isEvaluable(e sql.Expression) bool {
   422  	return !containsColumns(e) && !containsSubquery(e) && !containsBindvars(e) && !containsProcedureParam(e)
   423  }
   424  
   425  func containsBindvars(e sql.Expression) bool {
   426  	var result bool
   427  	sql.Inspect(e, func(e sql.Expression) bool {
   428  		if _, ok := e.(*expression.BindVar); ok {
   429  			result = true
   430  			return false
   431  		}
   432  		return true
   433  	})
   434  	return result
   435  }
   436  
   437  func containsProcedureParam(e sql.Expression) bool {
   438  	var result bool
   439  	sql.Inspect(e, func(e sql.Expression) bool {
   440  		_, result = e.(*expression.ProcedureParam)
   441  		return !result
   442  	})
   443  	return result
   444  }