github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/scope.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/transform"
    27  )
    28  
    29  // scope tracks relational dependencies necessary to type check expressions,
    30  // resolve name definitions, and build relational nodes.
    31  type scope struct {
    32  	b      *Builder
    33  	parent *scope
    34  	ast    ast.SQLNode
    35  	node   sql.Node
    36  
    37  	activeSubquery *subquery
    38  	refsSubquery   bool
    39  
    40  	// cols are definitions provided by this scope
    41  	cols   []scopeColumn
    42  	colset sql.ColSet
    43  	// extraCols are auxillary output columns required
    44  	// for sorting or grouping
    45  	extraCols []scopeColumn
    46  	// redirectCol is used for using and natural joins right-table
    47  	// attributes that redirect to the left table intersection
    48  	redirectCol map[string]scopeColumn
    49  	// tables are the list of table definitions in this scope
    50  	tables map[string]sql.TableId
    51  	// ctes are common table expressions defined in this scope
    52  	// TODO these should be case-sensitive
    53  	ctes map[string]*scope
    54  	// groupBy collects aggregation functions and inputs
    55  	groupBy *groupBy
    56  	// windowFuncs is a list of window functions in the current scope
    57  	windowFuncs []scopeColumn
    58  	windowDefs  map[string]*sql.WindowDefinition
    59  	// exprs collects unique expression ids for reference
    60  	exprs map[string]columnId
    61  	proc  *procCtx
    62  }
    63  
    64  // resolveColumn matches a variable use to a column definition with a unique
    65  // expression id. |chooseFirst| is indicated for accepting ambiguous having and
    66  // group by columns.
    67  func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bool) (scopeColumn, bool) {
    68  	// procedure params take precedence
    69  	if table == "" && checkParent && s.procActive() {
    70  		col, ok := s.proc.GetVar(col)
    71  		if ok {
    72  			return col, true
    73  		}
    74  	}
    75  
    76  	// Unqualified columns that have been redirected should return early to avoid ambiguous column errors.
    77  	if table == "" && s.redirectCol != nil {
    78  		if rCol, ok := s.redirectCol[col]; ok {
    79  			return rCol, true
    80  		}
    81  	}
    82  
    83  	var found scopeColumn
    84  	var foundCand bool
    85  	for _, c := range s.cols {
    86  		if strings.EqualFold(c.col, col) && (strings.EqualFold(c.table, table) || table == "") && (strings.EqualFold(c.db, db) || db == "") {
    87  			if foundCand {
    88  				if found.equals(c) {
    89  					continue
    90  				}
    91  
    92  				if !s.b.TriggerCtx().Call && len(s.b.TriggerCtx().UnresolvedTables) > 0 {
    93  					c, ok := s.triggerCol(table, col)
    94  					if ok {
    95  						return c, true
    96  					}
    97  				}
    98  				if c.table == OnDupValuesPrefix {
    99  					return found, true
   100  				} else if found.table == OnDupValuesPrefix {
   101  					return c, true
   102  				}
   103  				err := sql.ErrAmbiguousColumnName.New(col, []string{c.table, found.table})
   104  				if c.table == "" {
   105  					err = sql.ErrAmbiguousColumnOrAliasName.New(c.col)
   106  				}
   107  				s.handleErr(err)
   108  			}
   109  			if chooseFirst || s.groupBy != nil {
   110  				return c, true
   111  			}
   112  			found = c
   113  			foundCand = true
   114  		}
   115  	}
   116  	if foundCand {
   117  		return found, true
   118  	}
   119  
   120  	if s.groupBy != nil {
   121  		if c, ok := s.groupBy.outScope.resolveColumn(db, table, col, false, false); ok {
   122  			return c, true
   123  		}
   124  	}
   125  
   126  	if !s.b.TriggerCtx().Call && len(s.b.TriggerCtx().UnresolvedTables) > 0 {
   127  		c, ok := s.triggerCol(table, col)
   128  		if ok {
   129  			return c, true
   130  		}
   131  	}
   132  
   133  	if !checkParent || s.parent == nil {
   134  		return scopeColumn{}, false
   135  	}
   136  
   137  	c, foundCand := s.parent.resolveColumn(db, table, col, true, false)
   138  	if !foundCand {
   139  		return scopeColumn{}, false
   140  	}
   141  
   142  	if s.parent.activeSubquery != nil {
   143  		s.parent.activeSubquery.addOutOfScope(c.id)
   144  	}
   145  	return c, true
   146  }
   147  
   148  func (s *scope) hasTable(table string) bool {
   149  	_, ok := s.tables[strings.ToLower(table)]
   150  	if ok {
   151  		return true
   152  	}
   153  	if s.parent != nil {
   154  		return s.parent.hasTable(table)
   155  	}
   156  	return false
   157  }
   158  
   159  // triggerCol is used to hallucinate a new column during trigger DDL
   160  // when we fail a resolveColumn.
   161  func (s *scope) triggerCol(table, col string) (scopeColumn, bool) {
   162  	// hallucinate tablecol
   163  	dbName := ""
   164  	if s.b.currentDatabase != nil {
   165  		dbName = s.b.currentDatabase.Name()
   166  	}
   167  	for _, t := range s.b.TriggerCtx().UnresolvedTables {
   168  		if strings.EqualFold(t, table) {
   169  			col := scopeColumn{db: dbName, table: table, col: col}
   170  			id := s.newColumn(col)
   171  			col.id = id
   172  			return col, true
   173  		}
   174  	}
   175  	if table == "" {
   176  		col := scopeColumn{db: dbName, table: table, col: col}
   177  		id := s.newColumn(col)
   178  		col.id = id
   179  		return col, true
   180  	}
   181  	return scopeColumn{}, false
   182  }
   183  
   184  // getExpr returns a columnId if the given expression has
   185  // been built.
   186  func (s *scope) getExpr(name string, checkCte bool) (columnId, bool) {
   187  	n := strings.ToLower(name)
   188  	id, ok := s.exprs[n]
   189  	if !ok && s.groupBy != nil {
   190  		id, ok = s.groupBy.outScope.getExpr(n, checkCte)
   191  	}
   192  	if !ok && checkCte && s.ctes != nil {
   193  		for _, cte := range s.ctes {
   194  			id, ok = cte.getExpr(n, false)
   195  			if ok {
   196  				break
   197  			}
   198  		}
   199  	}
   200  	// TODO: possibly want to look in parent scopes
   201  	if !ok && s.parent != nil {
   202  		return s.parent.getExpr(name, checkCte)
   203  	}
   204  	return id, ok
   205  }
   206  
   207  func (s *scope) procActive() bool {
   208  	return s.proc != nil
   209  }
   210  
   211  func (s *scope) initProc() {
   212  	s.proc = &procCtx{
   213  		s:          s,
   214  		conditions: make(map[string]*plan.DeclareCondition),
   215  		cursors:    make(map[string]struct{}),
   216  		vars:       make(map[string]scopeColumn),
   217  		labels:     make(map[string]bool),
   218  		lastState:  dsVariable,
   219  	}
   220  }
   221  
   222  // initGroupBy creates a container scope for aggregation
   223  // functions and function inputs.
   224  func (s *scope) initGroupBy() {
   225  	s.groupBy = &groupBy{outScope: s.replace()}
   226  }
   227  
   228  // pushSubquery creates a new scope with the subquery already initialized.
   229  func (s *scope) pushSubquery() *scope {
   230  	newScope := s.push()
   231  	newScope.activeSubquery = &subquery{parent: s.nearestSubquery()}
   232  	return newScope
   233  }
   234  
   235  // replaceSubquery creates a new scope with the subquery already initialized.
   236  func (s *scope) replaceSubquery() *scope {
   237  	newScope := s.replace()
   238  	newScope.activeSubquery = &subquery{parent: s.nearestSubquery()}
   239  	return newScope
   240  }
   241  
   242  // initSubquery creates a container for tracking out of scope
   243  // column references and volatile functions.
   244  func (s *scope) initSubquery() {
   245  	s.activeSubquery = &subquery{}
   246  }
   247  
   248  func (s *scope) correlated() sql.ColSet {
   249  	if s.activeSubquery == nil {
   250  		return sql.ColSet{}
   251  	}
   252  	return s.activeSubquery.correlated
   253  }
   254  
   255  func (s *scope) volatile() bool {
   256  	if s.activeSubquery == nil {
   257  		return false
   258  	}
   259  	return s.activeSubquery.volatile
   260  }
   261  
   262  func (s *scope) nearestSubquery() *subquery {
   263  	n := s
   264  	for n != nil {
   265  		if n.activeSubquery != nil {
   266  			return n.activeSubquery
   267  		}
   268  		n = n.parent
   269  	}
   270  	return nil
   271  }
   272  
   273  // setTableAlias updates column definitions in this scope to
   274  // appear sourced from a new table name.
   275  func (s *scope) setTableAlias(t string) {
   276  	t = strings.ToLower(t)
   277  	var oldTable string
   278  	for i := range s.cols {
   279  		beforeColStr := s.cols[i].String()
   280  		if oldTable == "" {
   281  			oldTable = s.cols[i].table
   282  		}
   283  		s.cols[i].table = t
   284  		id, ok := s.getExpr(beforeColStr, true)
   285  		if ok {
   286  			// todo better way to do projections
   287  			delete(s.exprs, beforeColStr)
   288  		}
   289  		s.exprs[strings.ToLower(s.cols[i].String())] = id
   290  	}
   291  	id, ok := s.tables[oldTable]
   292  	if !ok {
   293  		return
   294  	}
   295  	delete(s.tables, oldTable)
   296  	if s.tables == nil {
   297  		s.tables = make(map[string]sql.TableId)
   298  	}
   299  	s.tables[t] = id
   300  }
   301  
   302  // setColAlias updates the column name definitions for this scope
   303  // to the names in the input list.
   304  func (s *scope) setColAlias(cols []string) {
   305  	if len(cols) != len(s.cols) {
   306  		err := sql.ErrColumnCountMismatch.New()
   307  		s.b.handleErr(err)
   308  	}
   309  	ids := make([]columnId, len(cols))
   310  	for i := range s.cols {
   311  		beforeColStr := s.cols[i].String()
   312  		id, ok := s.getExpr(beforeColStr, true)
   313  		if ok {
   314  			// todo better way to do projections
   315  			delete(s.exprs, beforeColStr)
   316  		}
   317  		ids[i] = id
   318  		delete(s.exprs, beforeColStr)
   319  	}
   320  	for i := range s.cols {
   321  		name := strings.ToLower(cols[i])
   322  		s.cols[i].col = name
   323  		s.exprs[s.cols[i].String()] = ids[i]
   324  	}
   325  }
   326  
   327  // push creates a new scope referencing the current scope as a
   328  // parent. Variables in the new scope will have name visibility
   329  // into this scope.
   330  func (s *scope) push() *scope {
   331  	new := &scope{
   332  		b:      s.b,
   333  		parent: s,
   334  	}
   335  	if s.procActive() {
   336  		new.initProc()
   337  	}
   338  	return new
   339  }
   340  
   341  // replace creates a new scope with the same parent definition
   342  // visibility as the current scope. Useful for groupby and subqueries
   343  // that have more complex naming hierarchy.
   344  func (s *scope) replace() *scope {
   345  	if s == nil {
   346  		return &scope{}
   347  	}
   348  	return &scope{
   349  		b:      s.b,
   350  		parent: s.parent,
   351  	}
   352  }
   353  
   354  // aliasCte copies a scope, but increments the column and table ids
   355  // for the new relation.
   356  func (s *scope) aliasCte(alias string) *scope {
   357  	if s == nil {
   358  		return nil
   359  	}
   360  	outScope := s.copy()
   361  	if _, ok := s.tables[alias]; ok || alias == "" {
   362  		return outScope
   363  	}
   364  
   365  	sq, _ := outScope.node.(*plan.SubqueryAlias)
   366  
   367  	tabId := outScope.addTable(alias)
   368  	outScope.cols = nil
   369  	var colSet sql.ColSet
   370  	scopeMapping := make(map[sql.ColumnId]sql.Expression)
   371  	for _, c := range s.cols {
   372  		id := outScope.newColumn(scopeColumn{
   373  			tableId:     tabId,
   374  			db:          c.db,
   375  			table:       alias,
   376  			col:         c.col,
   377  			originalCol: c.originalCol,
   378  			id:          0,
   379  			typ:         c.typ,
   380  			nullable:    c.nullable,
   381  		})
   382  		colSet.Add(sql.ColumnId(id))
   383  		// todo double scope mapping
   384  		if sq != nil {
   385  			scopeMapping[sql.ColumnId(id)] = sq.ScopeMapping[sql.ColumnId(c.id)]
   386  		}
   387  	}
   388  
   389  	if sq != nil {
   390  		outScope.node = sq.WithScopeMapping(scopeMapping).WithColumns(colSet).WithId(tabId)
   391  	}
   392  	return outScope
   393  }
   394  
   395  // copy produces an identical scope with copied references.
   396  func (s *scope) copy() *scope {
   397  	if s == nil {
   398  		return nil
   399  	}
   400  
   401  	ret := *s
   402  	if ret.node != nil {
   403  		ret.node, _ = DeepCopyNode(s.node)
   404  	}
   405  	if s.tables != nil {
   406  		ret.tables = make(map[string]sql.TableId, len(s.tables))
   407  		for k, v := range s.tables {
   408  			ret.tables[k] = v
   409  		}
   410  	}
   411  	if s.ctes != nil {
   412  		ret.ctes = make(map[string]*scope, len(s.ctes))
   413  		for k, v := range s.ctes {
   414  			ret.ctes[k] = v
   415  		}
   416  	}
   417  	if s.exprs != nil {
   418  		ret.exprs = make(map[string]columnId, len(s.exprs))
   419  		for k, v := range s.exprs {
   420  			ret.exprs[k] = v
   421  		}
   422  	}
   423  	if s.groupBy != nil {
   424  		gbCopy := *s.groupBy
   425  		ret.groupBy = &gbCopy
   426  	}
   427  	if s.cols != nil {
   428  		ret.cols = make([]scopeColumn, len(s.cols))
   429  		copy(ret.cols, s.cols)
   430  	}
   431  	if !s.colset.Empty() {
   432  		ret.colset = s.colset.Copy()
   433  	}
   434  
   435  	return &ret
   436  }
   437  
   438  // DeepCopyNode copies a sql.Node.
   439  func DeepCopyNode(node sql.Node) (sql.Node, error) {
   440  	n, _, err := transform.NodeExprs(node, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   441  		e, err := transform.Clone(e)
   442  		return e, transform.NewTree, err
   443  	})
   444  	return n, err
   445  }
   446  
   447  // addCte adds a cte definition to this scope for table resolution.
   448  func (s *scope) addCte(name string, cteScope *scope) {
   449  	if s.ctes == nil {
   450  		s.ctes = make(map[string]*scope)
   451  	}
   452  	s.ctes[name] = cteScope
   453  	s.addTable(name)
   454  }
   455  
   456  // getCte attempts to resolve a table name as a cte definition.
   457  func (s *scope) getCte(name string) *scope {
   458  	checkScope := s
   459  	for checkScope != nil {
   460  		if checkScope.ctes != nil {
   461  			cte, ok := checkScope.ctes[strings.ToLower(name)]
   462  			if ok {
   463  				return cte
   464  			}
   465  		}
   466  		checkScope = checkScope.parent
   467  	}
   468  	return nil
   469  }
   470  
   471  // redirect overwrites a definition with an alias rewrite,
   472  // without preventing us from resolving the original column.
   473  // This is used for resolving natural join projections.
   474  func (s *scope) redirect(from, to scopeColumn) {
   475  	if s.redirectCol == nil {
   476  		s.redirectCol = make(map[string]scopeColumn)
   477  	}
   478  	s.redirectCol[from.String()] = to
   479  }
   480  
   481  // addColumn interns and saves the given column to this scope.
   482  // todo: new IR should absorb interning and use bitmaps for
   483  // column identity
   484  func (s *scope) addColumn(col scopeColumn) {
   485  	s.cols = append(s.cols, col)
   486  	s.colset.Add(sql.ColumnId(col.id))
   487  	if s.exprs == nil {
   488  		s.exprs = make(map[string]columnId)
   489  	}
   490  	s.exprs[strings.ToLower(col.String())] = col.id
   491  	return
   492  }
   493  
   494  // newColumn adds the column to the current scope and assigns a
   495  // new columnId for referencing. newColumn builds a new expression
   496  // reference, whereas addColumn only adds a preexisting expression
   497  // definition to a given scope.
   498  func (s *scope) newColumn(col scopeColumn) columnId {
   499  	s.b.colId++
   500  	col.id = s.b.colId
   501  	if col.table != "" {
   502  		tabId := s.addTable(col.table)
   503  		col.tableId = tabId
   504  	}
   505  	s.addColumn(col)
   506  	return col.id
   507  }
   508  
   509  // addTable records adds a table name defined in this scope
   510  func (s *scope) addTable(name string) sql.TableId {
   511  	if name == "" {
   512  		return 0
   513  	}
   514  	name = strings.ToLower(name)
   515  	if s.tables == nil {
   516  		s.tables = make(map[string]sql.TableId)
   517  	}
   518  	if _, ok := s.tables[name]; !ok {
   519  		s.b.tabId++
   520  		s.tables[name] = s.b.tabId
   521  	}
   522  	return s.tables[name]
   523  }
   524  
   525  // addExtraColumn marks an auxiliary column used in an
   526  // aggregation, sorting, or having clause.
   527  func (s *scope) addExtraColumn(col scopeColumn) {
   528  	s.extraCols = append(s.extraCols, col)
   529  }
   530  
   531  func (s *scope) addColumns(cols []scopeColumn) {
   532  	s.cols = append(s.cols, cols...)
   533  }
   534  
   535  // appendColumnsFromScope merges column definitions for
   536  // multi-relational expressions.
   537  func (s *scope) appendColumnsFromScope(src *scope) {
   538  	s.cols = append(s.cols, src.cols...)
   539  	if len(src.exprs) > 0 && s.exprs == nil {
   540  		s.exprs = make(map[string]columnId)
   541  	}
   542  	for k, v := range src.exprs {
   543  		s.exprs[k] = v
   544  	}
   545  	if len(src.redirectCol) > 0 && s.redirectCol == nil {
   546  		s.redirectCol = make(map[string]scopeColumn)
   547  	}
   548  	for k, v := range src.redirectCol {
   549  		s.redirectCol[k] = v
   550  	}
   551  	if len(src.tables) > 0 && s.tables == nil {
   552  		s.tables = make(map[string]sql.TableId)
   553  	}
   554  	for k, v := range src.tables {
   555  		s.tables[k] = v
   556  	}
   557  	// these become pass-through columns in the new scope.
   558  	for i := len(src.cols); i < len(s.cols); i++ {
   559  		s.cols[i].scalar = nil
   560  	}
   561  }
   562  
   563  func (s *scope) handleErr(err error) {
   564  	panic(parseErr{err})
   565  }
   566  
   567  // tableId and columnId are temporary ways to track expression
   568  // and name uniqueness.
   569  // todo: the plan format should track these
   570  type tableId uint16
   571  type columnId uint16
   572  
   573  type scopeColumn struct {
   574  	nullable    bool
   575  	descending  bool
   576  	outOfScope  bool
   577  	id          columnId
   578  	typ         sql.Type
   579  	scalar      sql.Expression
   580  	tableId     sql.TableId
   581  	db          string
   582  	table       string
   583  	col         string
   584  	originalCol string
   585  }
   586  
   587  // empty returns true if a scopeColumn is the null value
   588  func (c scopeColumn) empty() bool {
   589  	return c.id == 0
   590  }
   591  
   592  func (c scopeColumn) equals(other scopeColumn) bool {
   593  	if c.id == other.id {
   594  		return true
   595  	}
   596  	if c.unwrapGetFieldAliasId() == other.unwrapGetFieldAliasId() {
   597  		return true
   598  	}
   599  	return false
   600  }
   601  
   602  func (c scopeColumn) unwrapGetFieldAliasId() columnId {
   603  	if c.scalar != nil {
   604  		if a, ok := c.scalar.(*expression.Alias); ok {
   605  			if gf, ok := a.Child.(*expression.GetField); ok {
   606  				return columnId(gf.Id())
   607  			}
   608  		}
   609  	}
   610  	return c.id
   611  }
   612  
   613  func (c scopeColumn) withOriginal(col string) scopeColumn {
   614  	if !strings.EqualFold(c.db, sql.InformationSchemaDatabaseName) {
   615  		// info schema columns always presented as uppercase
   616  		c.originalCol = col
   617  	}
   618  	return c
   619  }
   620  
   621  // scalarGf returns a getField reference to this column's expression.
   622  func (c scopeColumn) scalarGf() sql.Expression {
   623  	if c.scalar != nil {
   624  		if p, ok := c.scalar.(*expression.ProcedureParam); ok {
   625  			return p
   626  		}
   627  	}
   628  	if c.originalCol != "" {
   629  		return expression.NewGetFieldWithTable(int(c.id), int(c.tableId), c.typ, c.db, c.table, c.originalCol, c.nullable)
   630  	}
   631  	return expression.NewGetFieldWithTable(int(c.id), int(c.tableId), c.typ, c.db, c.table, c.col, c.nullable)
   632  }
   633  
   634  func (c scopeColumn) String() string {
   635  	if c.table == "" {
   636  		return c.col
   637  	} else {
   638  		return fmt.Sprintf("%s.%s", c.table, c.col)
   639  	}
   640  }