github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/fix_exec_indexes.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql/planbuilder"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/transform"
    27  )
    28  
    29  // assignExecIndexes walks a query plan in-order and rewrites GetFields to use
    30  // execution appropriate indexing.
    31  func assignExecIndexes(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    32  	s := &idxScope{}
    33  	if !scope.IsEmpty() {
    34  		// triggers
    35  		s.triggerScope = true
    36  		s.addSchema(scope.Schema())
    37  		s = s.push()
    38  	}
    39  	ret, _, err := assignIndexesHelper(n, s)
    40  	if err != nil {
    41  		return n, transform.SameTree, err
    42  	}
    43  	return ret, transform.NewTree, nil
    44  }
    45  
    46  func assignIndexesHelper(n sql.Node, inScope *idxScope) (sql.Node, *idxScope, error) {
    47  	// copy scope, otherwise parent/lateral edits have non-local effects
    48  	outScope := inScope.copy()
    49  	err := outScope.visitChildren(n)
    50  	if err != nil {
    51  		return nil, nil, err
    52  	}
    53  	err = outScope.visitSelf(n)
    54  	if err != nil {
    55  		return nil, nil, err
    56  	}
    57  	ret, err := outScope.finalizeSelf(n)
    58  	return ret, outScope, err
    59  }
    60  
    61  // idxScope accumulates the information needed to rewrite node column
    62  // references for execution, including parent/child scopes, lateral
    63  // scopes (if in the middle of a join tree), and child nodes and expressions.
    64  // Collecting this info in one place makes it easier to compartmentalize
    65  // finalization into an after phase.
    66  type idxScope struct {
    67  	parentScopes  []*idxScope
    68  	lateralScopes []*idxScope
    69  	childScopes   []*idxScope
    70  	ids           []sql.ColumnId
    71  	columns       []string
    72  	children      []sql.Node
    73  	expressions   []sql.Expression
    74  	checks        sql.CheckConstraints
    75  	triggerScope  bool
    76  }
    77  
    78  func (s *idxScope) addSchema(sch sql.Schema) {
    79  	for _, c := range sch {
    80  		if c.Source == "" {
    81  			s.columns = append(s.columns, c.Name)
    82  		} else {
    83  			s.columns = append(s.columns, fmt.Sprintf("%s.%s", c.Source, c.Name))
    84  		}
    85  	}
    86  }
    87  
    88  func (s *idxScope) addScope(other *idxScope) {
    89  	s.columns = append(s.columns, other.columns...)
    90  	s.ids = append(s.ids, other.ids...)
    91  }
    92  
    93  func (s *idxScope) addLateral(other *idxScope) {
    94  	s.lateralScopes = append(s.lateralScopes, other)
    95  }
    96  
    97  func (s *idxScope) addParent(other *idxScope) {
    98  	s.parentScopes = append(s.parentScopes, other)
    99  }
   100  
   101  func isQualified(s string) bool {
   102  	return strings.Contains(s, ".")
   103  }
   104  
   105  // unqualify is a helper function to remove the table prefix from a column, if it's present.
   106  func unqualify(s string) string {
   107  	if isQualified(s) {
   108  		return strings.Split(s, ".")[1]
   109  	}
   110  	return s
   111  }
   112  
   113  func (s *idxScope) getIdxId(id sql.ColumnId, name string) (int, bool) {
   114  	if s.triggerScope || id == 0 {
   115  		// todo: add expr ids for trigger columns and procedure params
   116  		return s.getIdx(name)
   117  	}
   118  	for i, c := range s.ids {
   119  		if c == id {
   120  			return i, true
   121  		}
   122  	}
   123  	// todo: fix places where this is necessary
   124  	return s.getIdx(name)
   125  }
   126  
   127  func (s *idxScope) getIdx(n string) (int, bool) {
   128  	// We match the column closet to our current scope. We have already
   129  	// resolved columns, so there will be no in-scope collisions.
   130  	if isQualified(n) {
   131  		for i := len(s.columns) - 1; i >= 0; i-- {
   132  			if strings.EqualFold(n, s.columns[i]) {
   133  				return i, true
   134  			}
   135  		}
   136  		// TODO: we do not have a good way to match columns over set_ops where the column has the same name, but are
   137  		//  from different tables and have different types.
   138  		n = unqualify(n)
   139  		for i := len(s.columns) - 1; i >= 0; i-- {
   140  			if strings.EqualFold(n, s.columns[i]) {
   141  				return i, true
   142  			}
   143  		}
   144  	} else {
   145  		for i := len(s.columns) - 1; i >= 0; i-- {
   146  			if strings.EqualFold(n, unqualify(s.columns[i])) {
   147  				return i, true
   148  			}
   149  		}
   150  	}
   151  	return -1, false
   152  }
   153  
   154  func (s *idxScope) copy() *idxScope {
   155  	if s == nil {
   156  		return &idxScope{}
   157  	}
   158  	var varsCopy []string
   159  	if len(s.columns) > 0 {
   160  		varsCopy = make([]string, len(s.columns))
   161  		copy(varsCopy, s.columns)
   162  	}
   163  	var lateralCopy []*idxScope
   164  	if len(s.lateralScopes) > 0 {
   165  		lateralCopy = make([]*idxScope, len(s.lateralScopes))
   166  		copy(lateralCopy, s.lateralScopes)
   167  	}
   168  	var parentCopy []*idxScope
   169  	if len(s.parentScopes) > 0 {
   170  		parentCopy = make([]*idxScope, len(s.parentScopes))
   171  		copy(parentCopy, s.parentScopes)
   172  	}
   173  	if len(s.columns) > 0 {
   174  		varsCopy = make([]string, len(s.columns))
   175  		copy(varsCopy, s.columns)
   176  	}
   177  	var idsCopy []sql.ColumnId
   178  	if len(s.ids) > 0 {
   179  		idsCopy = make([]sql.ColumnId, len(s.ids))
   180  		copy(idsCopy, s.ids)
   181  	}
   182  	return &idxScope{
   183  		lateralScopes: lateralCopy,
   184  		parentScopes:  parentCopy,
   185  		columns:       varsCopy,
   186  		ids:           idsCopy,
   187  	}
   188  }
   189  
   190  func (s *idxScope) push() *idxScope {
   191  	return &idxScope{
   192  		parentScopes: []*idxScope{s},
   193  	}
   194  }
   195  
   196  // visitChildren walks children and gathers schema info for this node
   197  func (s *idxScope) visitChildren(n sql.Node) error {
   198  	switch n := n.(type) {
   199  	case *plan.JoinNode:
   200  		lateralScope := s.copy()
   201  		for _, c := range n.Children() {
   202  			newC, cScope, err := assignIndexesHelper(c, lateralScope)
   203  			if err != nil {
   204  				return err
   205  			}
   206  			// child scope is always a child to the current scope
   207  			s.childScopes = append(s.childScopes, cScope)
   208  			if n.Op.IsLateral() {
   209  				// lateral promotes the scope to parent relative to other join children
   210  				lateralScope.addParent(cScope)
   211  			} else {
   212  				// child scope is lateral scope to join children, hidden by default from
   213  				// most expressions
   214  				lateralScope.addLateral(cScope)
   215  			}
   216  			s.children = append(s.children, newC)
   217  		}
   218  	case *plan.SubqueryAlias:
   219  		sqScope := s.copy()
   220  		if !n.OuterScopeVisibility && !n.IsLateral {
   221  			// TODO: this should not apply to subqueries inside of lateral joins
   222  			// Subqueries with no visibility have no parent scopes. Lateral
   223  			// join subquery aliases continue to enjoy full visibility.
   224  			sqScope.parentScopes = sqScope.parentScopes[:0]
   225  			sqScope.lateralScopes = sqScope.lateralScopes[:0]
   226  		}
   227  		newC, cScope, err := assignIndexesHelper(n.Child, sqScope)
   228  		if err != nil {
   229  			return err
   230  		}
   231  		s.childScopes = append(s.childScopes, cScope)
   232  		s.children = append(s.children, newC)
   233  	case *plan.SetOp:
   234  		var keepScope *idxScope
   235  		for i, c := range n.Children() {
   236  			newC, cScope, err := assignIndexesHelper(c, s)
   237  			if err != nil {
   238  				return err
   239  			}
   240  			if i == 0 {
   241  				keepScope = cScope
   242  			}
   243  			s.children = append(s.children, newC)
   244  		}
   245  		// keep only the first union scope to avoid double counting
   246  		s.childScopes = append(s.childScopes, keepScope)
   247  	case *plan.InsertInto:
   248  		newSrc, _, err := assignIndexesHelper(n.Source, s)
   249  		if err != nil {
   250  			return err
   251  		}
   252  		newDst, dScope, err := assignIndexesHelper(n.Destination, s)
   253  		if err != nil {
   254  			return err
   255  		}
   256  		s.children = append(s.children, newSrc)
   257  		s.children = append(s.children, newDst)
   258  		s.childScopes = append(s.childScopes, dScope)
   259  	case *plan.Procedure, *plan.CreateTable:
   260  		// do nothing
   261  	default:
   262  		for _, c := range n.Children() {
   263  			newC, cScope, err := assignIndexesHelper(c, s)
   264  			if err != nil {
   265  				return err
   266  			}
   267  			s.childScopes = append(s.childScopes, cScope)
   268  			s.children = append(s.children, newC)
   269  		}
   270  	}
   271  	return nil
   272  }
   273  
   274  // visitSelf fixes expression indexes for this node. Assumes |s.childScopes|
   275  // is set, any partial |s.lateralScopes| are filled, and the self scope is
   276  // unset.
   277  func (s *idxScope) visitSelf(n sql.Node) error {
   278  	switch n := n.(type) {
   279  	case *plan.JoinNode:
   280  		// join on expressions see everything
   281  		scopes := append(append(s.parentScopes, s.lateralScopes...), s.childScopes...)
   282  		for _, e := range n.Expressions() {
   283  			s.expressions = append(s.expressions, fixExprToScope(e, scopes...))
   284  		}
   285  	case *plan.RangeHeap:
   286  		// value indexes other side of join
   287  		newValue := fixExprToScope(n.ValueColumnGf, s.lateralScopes...)
   288  		// min/are this child
   289  		newMin := fixExprToScope(n.MinColumnGf, s.childScopes...)
   290  		newMax := fixExprToScope(n.MaxColumnGf, s.childScopes...)
   291  		n.MaxColumnGf = newMax
   292  		n.MinColumnGf = newMin
   293  		n.ValueColumnGf = newValue
   294  		n.MaxColumnIndex = newMax.(*expression.GetField).Index()
   295  		n.MinColumnIndex = newMin.(*expression.GetField).Index()
   296  		n.ValueColumnIndex = newValue.(*expression.GetField).Index()
   297  	case *plan.HashLookup:
   298  		// right entry has parent and self visibility, no lateral join scope
   299  		rightScopes := append(s.parentScopes, s.childScopes...)
   300  		s.expressions = append(s.expressions, fixExprToScope(n.RightEntryKey, rightScopes...))
   301  		// left probe is the join context accumulation
   302  		leftScopes := append(s.parentScopes, s.lateralScopes...)
   303  		s.expressions = append(s.expressions, fixExprToScope(n.LeftProbeKey, leftScopes...))
   304  	case *plan.IndexedTableAccess:
   305  		var scope []*idxScope
   306  		switch n.Typ {
   307  		case plan.ItaTypeStatic:
   308  			// self-visibility
   309  			scope = append(s.parentScopes, s.childScopes...)
   310  		case plan.ItaTypeLookup:
   311  			// join siblings
   312  			scope = append(s.parentScopes, s.lateralScopes...)
   313  		}
   314  		for _, e := range n.Expressions() {
   315  			s.expressions = append(s.expressions, fixExprToScope(e, scope...))
   316  		}
   317  	case *plan.ShowVariables:
   318  		if n.Filter != nil {
   319  			selfScope := s.copy()
   320  			selfScope.addSchema(n.Schema())
   321  			scope := append(s.parentScopes, selfScope)
   322  			for _, e := range n.Expressions() {
   323  				s.expressions = append(s.expressions, fixExprToScope(e, scope...))
   324  			}
   325  		}
   326  	case *plan.JSONTable:
   327  		scopes := append(s.parentScopes, s.lateralScopes...)
   328  		for _, e := range n.Expressions() {
   329  			s.expressions = append(s.expressions, fixExprToScope(e, scopes...))
   330  		}
   331  	case *plan.InsertInto:
   332  		rightSchema := make(sql.Schema, len(n.Destination.Schema())*2)
   333  		// schema = [oldrow][newrow]
   334  		for oldRowIdx, c := range n.Destination.Schema() {
   335  			rightSchema[oldRowIdx] = c
   336  			newRowIdx := len(n.Destination.Schema()) + oldRowIdx
   337  			if _, ok := n.Source.(*plan.Values); !ok && len(n.Destination.Schema()) == len(n.Source.Schema()) {
   338  				// find source index that aligns with dest column
   339  				var matched bool
   340  				for j, sourceCol := range n.ColumnNames {
   341  					if strings.EqualFold(c.Name, sourceCol) {
   342  						rightSchema[newRowIdx] = n.Source.Schema()[j]
   343  						matched = true
   344  						break
   345  					}
   346  				}
   347  				if !matched {
   348  					// todo: this is only used for load data. load data errors
   349  					//  without a fallback, and fails to resolve defaults if I
   350  					//  define the columns upfront.
   351  					rightSchema[newRowIdx] = n.Source.Schema()[oldRowIdx]
   352  				}
   353  			} else {
   354  				newC := c.Copy()
   355  				newC.Source = planbuilder.OnDupValuesPrefix
   356  				rightSchema[newRowIdx] = newC
   357  			}
   358  		}
   359  		rightScope := &idxScope{}
   360  		rightScope.addSchema(rightSchema)
   361  		dstScope := s.childScopes[0]
   362  
   363  		for _, e := range n.OnDupExprs {
   364  			set, ok := e.(*expression.SetField)
   365  			if !ok {
   366  				return fmt.Errorf("on duplicate update expressions should be *expression.SetField; found %T", e)
   367  			}
   368  			// left uses destination schema
   369  			// right uses |rightSchema|
   370  			newLeft := fixExprToScope(set.LeftChild, dstScope)
   371  			newRight := fixExprToScope(set.RightChild, rightScope)
   372  			s.expressions = append(s.expressions, expression.NewSetField(newLeft, newRight))
   373  		}
   374  		for _, c := range n.Checks() {
   375  			newE := fixExprToScope(c.Expr, dstScope)
   376  			newCheck := *c
   377  			newCheck.Expr = newE
   378  			s.checks = append(s.checks, &newCheck)
   379  		}
   380  	case *plan.Update:
   381  		newScope := s.copy()
   382  		srcScope := s.childScopes[0]
   383  		// schema is |old_row|-|new_row|; checks only recieve half
   384  		newScope.columns = append(newScope.columns, srcScope.columns[:len(srcScope.columns)/2]...)
   385  		for _, c := range n.Checks() {
   386  			newE := fixExprToScope(c.Expr, newScope)
   387  			newCheck := *c
   388  			newCheck.Expr = newE
   389  			s.checks = append(s.checks, &newCheck)
   390  		}
   391  	default:
   392  		if ne, ok := n.(sql.Expressioner); ok {
   393  			scope := append(s.parentScopes, s.childScopes...)
   394  			for _, e := range ne.Expressions() {
   395  				// default nodes can't see lateral join nodes, unless we're in lateral
   396  				// join and lateral scopes are promoted to parent status
   397  				s.expressions = append(s.expressions, fixExprToScope(e, scope...))
   398  			}
   399  		}
   400  	}
   401  	return nil
   402  }
   403  
   404  // finalizeSelf builds the output node and fixes the return scope
   405  func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) {
   406  	// assumes children scopes have been set
   407  	switch n := n.(type) {
   408  	case *plan.InsertInto:
   409  		s.addSchema(n.Destination.Schema())
   410  		nn := *n
   411  		nn.Source = s.children[0]
   412  		nn.Destination = s.children[1]
   413  		nn.OnDupExprs = s.expressions
   414  		return nn.WithChecks(s.checks), nil
   415  	default:
   416  		s.ids = columnIdsForNode(n)
   417  
   418  		s.addSchema(n.Schema())
   419  		var err error
   420  		if s.children != nil {
   421  			n, err = n.WithChildren(s.children...)
   422  			if err != nil {
   423  				return nil, err
   424  			}
   425  		}
   426  		if ne, ok := n.(sql.Expressioner); ok && s.expressions != nil {
   427  			n, err = ne.WithExpressions(s.expressions...)
   428  			if err != nil {
   429  				return nil, err
   430  			}
   431  		}
   432  		if nc, ok := n.(sql.CheckConstraintNode); ok && s.checks != nil {
   433  			n = nc.WithChecks(s.checks)
   434  		}
   435  		if jn, ok := n.(*plan.JoinNode); ok {
   436  			if len(s.parentScopes) == 0 {
   437  				return n, nil
   438  			}
   439  			// TODO: combine scopes?
   440  			scopeLen := len(s.parentScopes[0].columns)
   441  			if scopeLen == 0 {
   442  				return n, nil
   443  			}
   444  			n = jn.WithScopeLen(scopeLen)
   445  			n, err = n.WithChildren(
   446  				plan.NewStripRowNode(jn.Left(), scopeLen),
   447  				plan.NewStripRowNode(jn.Right(), scopeLen),
   448  			)
   449  			if err != nil {
   450  				return nil, err
   451  			}
   452  		}
   453  		return n, nil
   454  	}
   455  }
   456  
   457  // columnIdsForNode collects the column ids of a node's return schema.
   458  // Projector nodes can return a subset of the full sql.PrimaryTableSchema.
   459  // todo: pruning projections should update plan.TableIdNode .Columns()
   460  // to avoid schema/column discontinuities.
   461  func columnIdsForNode(n sql.Node) []sql.ColumnId {
   462  	var ret []sql.ColumnId
   463  	switch n := n.(type) {
   464  	case sql.Projector:
   465  		for _, e := range n.ProjectedExprs() {
   466  			if ide, ok := e.(sql.IdExpression); ok {
   467  				ret = append(ret, ide.Id())
   468  			} else {
   469  				ret = append(ret, 0)
   470  			}
   471  		}
   472  	case *plan.TableCountLookup:
   473  		ret = append(ret, n.Id())
   474  	case *plan.TableAlias:
   475  		// Table alias's child either exposes 1) child ids or 2) is custom
   476  		// table function. We currently do not update table columns in response
   477  		// to table pruning, so we need to manually distinguish these cases.
   478  		// todo: prune columns should update column ids and table alias ids
   479  		switch n.Child.(type) {
   480  		case sql.TableFunction:
   481  			// todo: table functions that implement sql.Projector are not going
   482  			// to work. Need to fix prune.
   483  			n.Columns().ForEach(func(col sql.ColumnId) {
   484  				ret = append(ret, col)
   485  			})
   486  		default:
   487  			ret = append(ret, columnIdsForNode(n.Child)...)
   488  		}
   489  	case plan.TableIdNode:
   490  		if rt, ok := n.(*plan.ResolvedTable); ok && plan.IsDualTable(rt.Table) {
   491  			ret = append(ret, 0)
   492  			break
   493  		}
   494  
   495  		cols := n.(plan.TableIdNode).Columns()
   496  		if tn, ok := n.(sql.TableNode); ok {
   497  			if pkt, ok := tn.UnderlyingTable().(sql.PrimaryKeyTable); ok && len(pkt.PrimaryKeySchema().Schema) != len(n.Schema()) {
   498  				firstcol, _ := cols.Next(1)
   499  				for _, c := range n.Schema() {
   500  					ord := pkt.PrimaryKeySchema().IndexOfColName(c.Name)
   501  					colId := firstcol + sql.ColumnId(ord)
   502  					ret = append(ret, colId)
   503  				}
   504  				break
   505  			}
   506  		}
   507  		cols.ForEach(func(col sql.ColumnId) {
   508  			ret = append(ret, col)
   509  		})
   510  	case *plan.JoinNode:
   511  		if n.Op.IsPartial() {
   512  			ret = append(ret, columnIdsForNode(n.Left())...)
   513  		} else {
   514  			ret = append(ret, columnIdsForNode(n.Left())...)
   515  			ret = append(ret, columnIdsForNode(n.Right())...)
   516  		}
   517  	case *plan.ShowStatus:
   518  		for i := range n.Schema() {
   519  			ret = append(ret, sql.ColumnId(i+1))
   520  		}
   521  	case *plan.Concat:
   522  		ret = append(ret, columnIdsForNode(n.Left())...)
   523  	default:
   524  		for _, c := range n.Children() {
   525  			ret = append(ret, columnIdsForNode(c)...)
   526  		}
   527  	}
   528  	return ret
   529  }
   530  
   531  func fixExprToScope(e sql.Expression, scopes ...*idxScope) sql.Expression {
   532  	newScope := &idxScope{}
   533  	for _, s := range scopes {
   534  		newScope.triggerScope = newScope.triggerScope || s.triggerScope
   535  		newScope.addScope(s)
   536  	}
   537  	ret, _, _ := transform.Expr(e, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   538  		switch e := e.(type) {
   539  		case *expression.GetField:
   540  			// TODO: this is a swallowed error in some cases. It triggers falsely in queries involving the dual table, or
   541  			//  queries where the columns being selected are only found in subqueries. Conversely, we actually want to ignore
   542  			//  this error for the case of DEFAULT in a `plan.Values`, since we analyze the insert source in isolation (we
   543  			//  don't have the destination schema, and column references in default values are determined in the build phase)
   544  			idx, _ := newScope.getIdxId(e.Id(), e.String())
   545  			if idx >= 0 {
   546  				return e.WithIndex(idx), transform.NewTree, nil
   547  			}
   548  			return e, transform.SameTree, nil
   549  		case *plan.Subquery:
   550  			// this |outScope| prepends the subquery scope
   551  			newQ, _, err := assignIndexesHelper(e.Query, newScope.push())
   552  			if err != nil {
   553  				return nil, transform.SameTree, err
   554  			}
   555  			return e.WithQuery(newQ), transform.NewTree, nil
   556  		default:
   557  			return e, transform.SameTree, nil
   558  		}
   559  	})
   560  	return ret
   561  }