github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/subquery.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"sync"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql/transform"
    23  	"github.com/dolthub/go-mysql-server/sql/types"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  )
    27  
    28  // Subquery is as an expression whose value is derived by executing a subquery. It must be executed for every row in
    29  // the outer result set. It's in the plan package instead of the expression package because it functions more like a
    30  // plan Node than an expression.
    31  type Subquery struct {
    32  	// The subquery to execute for each row in the outer result set
    33  	Query sql.Node
    34  	// The original verbatim select statement for this subquery
    35  	QueryString string
    36  	// correlated is a set of the field references in this subquery from out-of-scope
    37  	correlated sql.ColSet
    38  	// volatile indicates that the expression contains a non-deterministic function
    39  	volatile bool
    40  	// Whether results have been cached
    41  	resultsCached bool
    42  	// Cached results, if any
    43  	cache []interface{}
    44  	// Cached hash results, if any
    45  	hashCache sql.KeyValueCache
    46  	// Dispose function for the cache, if any. This would appear to violate the rule that nodes must be comparable by
    47  	// reflect.DeepEquals, but it's safe in practice because the function is always nil until execution.
    48  	disposeFunc sql.DisposeFunc
    49  	// Mutex to guard the caches
    50  	cacheMu sync.Mutex
    51  	// TODO convert subquery expressions into apply joins
    52  	// TODO move expression.Eval into an execution package
    53  	b sql.NodeExecBuilder
    54  	// TODO analyzer rule to connect builder access
    55  }
    56  
    57  // NewSubquery returns a new subquery expression.
    58  func NewSubquery(node sql.Node, queryString string) *Subquery {
    59  	return &Subquery{Query: node, QueryString: queryString}
    60  }
    61  
    62  var _ sql.NonDeterministicExpression = (*Subquery)(nil)
    63  var _ sql.ExpressionWithNodes = (*Subquery)(nil)
    64  var _ sql.CollationCoercible = (*Subquery)(nil)
    65  
    66  type StripRowNode struct {
    67  	UnaryNode
    68  	NumCols int
    69  }
    70  
    71  var _ sql.Node = (*StripRowNode)(nil)
    72  var _ sql.CollationCoercible = (*StripRowNode)(nil)
    73  
    74  func NewStripRowNode(child sql.Node, numCols int) sql.Node {
    75  	return &StripRowNode{UnaryNode: UnaryNode{child}, NumCols: numCols}
    76  }
    77  
    78  // Describe implements the sql.Describable interface
    79  func (srn *StripRowNode) Describe(options sql.DescribeOptions) string {
    80  	return sql.Describe(srn.Child, options)
    81  }
    82  
    83  // String implements the fmt.Stringer interface
    84  func (srn *StripRowNode) String() string {
    85  	return srn.Child.String()
    86  }
    87  
    88  func (srn *StripRowNode) IsReadOnly() bool {
    89  	return srn.Child.IsReadOnly()
    90  }
    91  
    92  // DebugString implements the sql.DebugStringer interface
    93  func (srn *StripRowNode) DebugString() string {
    94  	return sql.DebugString(srn.Child)
    95  }
    96  
    97  func (srn *StripRowNode) WithChildren(children ...sql.Node) (sql.Node, error) {
    98  	if len(children) != 1 {
    99  		return nil, sql.ErrInvalidChildrenNumber.New(srn, len(children), 1)
   100  	}
   101  	return NewStripRowNode(children[0], srn.NumCols), nil
   102  }
   103  
   104  // CheckPrivileges implements the interface sql.Node.
   105  func (srn *StripRowNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   106  	return srn.Child.CheckPrivileges(ctx, opChecker)
   107  }
   108  
   109  // CollationCoercibility implements the interface sql.CollationCoercible.
   110  func (srn *StripRowNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   111  	return sql.GetCoercibility(ctx, srn.Child)
   112  }
   113  
   114  // PrependNode wraps its child by prepending column values onto any result rows
   115  type PrependNode struct {
   116  	UnaryNode
   117  	Row sql.Row
   118  }
   119  
   120  var _ sql.Node = (*PrependNode)(nil)
   121  var _ sql.CollationCoercible = (*PrependNode)(nil)
   122  
   123  func NewPrependNode(child sql.Node, row sql.Row) sql.Node {
   124  	return &PrependNode{
   125  		UnaryNode: UnaryNode{Child: child},
   126  		Row:       row,
   127  	}
   128  }
   129  
   130  func (p *PrependNode) String() string {
   131  	return p.Child.String()
   132  }
   133  
   134  func (p *PrependNode) IsReadOnly() bool {
   135  	return p.Child.IsReadOnly()
   136  }
   137  
   138  func (p *PrependNode) DebugString() string {
   139  	tp := sql.NewTreePrinter()
   140  	_ = tp.WriteNode("Prepend(%s)", sql.FormatRow(p.Row))
   141  	_ = tp.WriteChildren(sql.DebugString(p.Child))
   142  	return tp.String()
   143  }
   144  
   145  func (p *PrependNode) WithChildren(children ...sql.Node) (sql.Node, error) {
   146  	if len(children) != 1 {
   147  		return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1)
   148  	}
   149  	return NewPrependNode(children[0], p.Row), nil
   150  }
   151  
   152  // CheckPrivileges implements the interface sql.Node.
   153  func (p *PrependNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   154  	return p.Child.CheckPrivileges(ctx, opChecker)
   155  }
   156  
   157  // CollationCoercibility implements the interface sql.CollationCoercible.
   158  func (p *PrependNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   159  	return sql.GetCoercibility(ctx, p.Child)
   160  }
   161  
   162  // Eval implements the Expression interface.
   163  func (s *Subquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   164  	s.cacheMu.Lock()
   165  	cached := s.resultsCached
   166  	s.cacheMu.Unlock()
   167  
   168  	if cached {
   169  		if len(s.cache) == 0 {
   170  			return nil, nil
   171  		}
   172  		return s.cache[0], nil
   173  	}
   174  
   175  	rows, err := s.evalMultiple(ctx, row)
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  
   180  	if len(rows) > 1 {
   181  		return nil, sql.ErrExpectedSingleRow.New()
   182  	}
   183  
   184  	if s.canCacheResults() {
   185  		s.cacheMu.Lock()
   186  		if !s.resultsCached {
   187  			s.cache, s.resultsCached = rows, true
   188  		}
   189  		s.cacheMu.Unlock()
   190  	}
   191  
   192  	if len(rows) == 0 {
   193  		return nil, nil
   194  	}
   195  	return rows[0], nil
   196  }
   197  
   198  // PrependRowInPlan returns a transformation function that prepends the row given to any row source in a query
   199  // plan. Any source of rows, as well as any node that alters the schema of its children, will be wrapped so that its
   200  // result rows are prepended with the row given.
   201  func PrependRowInPlan(row sql.Row, lateral bool) func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
   202  	return func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
   203  		switch n := n.(type) {
   204  		case sql.Table, sql.Projector, *ValueDerivedTable, *TableCountLookup, sql.TableFunction:
   205  			return NewPrependNode(n, row), transform.NewTree, nil
   206  		case *SetOp:
   207  			newSetOp := *n
   208  			newRight, _, err := transform.Node(n.Right(), PrependRowInPlan(row, lateral))
   209  			if err != nil {
   210  				return n, transform.SameTree, err
   211  			}
   212  			newLeft, _, err := transform.Node(n.Left(), PrependRowInPlan(row, lateral))
   213  			if err != nil {
   214  				return n, transform.SameTree, err
   215  			}
   216  			newSetOp.left = newLeft
   217  			newSetOp.right = newRight
   218  			return &newSetOp, transform.NewTree, nil
   219  		case *RecursiveCte:
   220  			newRecursiveCte := *n
   221  			newUnion, _, err := transform.Node(n.union, PrependRowInPlan(row, lateral))
   222  			newRecursiveCte.union = newUnion.(*SetOp)
   223  			return &newRecursiveCte, transform.NewTree, err
   224  		case *SubqueryAlias:
   225  			// For SubqueryAliases (i.e. DerivedTables), since they may have visibility to outer scopes, we need to
   226  			// transform their inner nodes to prepend the outer scope row data. Ideally, we would only do this when
   227  			// the subquery alias references those outer fields. That will also require updating subquery expression
   228  			// scope handling to also make the same optimization.
   229  			if n.OuterScopeVisibility || lateral {
   230  				newSubqueryAlias := *n
   231  				newChildNode, _, err := transform.Node(n.Child, PrependRowInPlan(row, lateral))
   232  				newSubqueryAlias.Child = newChildNode
   233  				return &newSubqueryAlias, transform.NewTree, err
   234  			} else {
   235  				return NewPrependNode(n, row), transform.NewTree, nil
   236  			}
   237  		}
   238  
   239  		return n, transform.SameTree, nil
   240  	}
   241  }
   242  
   243  func NewMax1Row(n sql.Node, name string) *Max1Row {
   244  	return &Max1Row{Child: n, name: name, Mu: &sync.Mutex{}}
   245  }
   246  
   247  // Max1Row throws a runtime error if its child (usually subquery) tries
   248  // to return more than one row.
   249  type Max1Row struct {
   250  	Child       sql.Node
   251  	name        string
   252  	Result      sql.Row
   253  	Mu          *sync.Mutex
   254  	EmptyResult bool
   255  }
   256  
   257  var _ sql.Node = (*Max1Row)(nil)
   258  var _ sql.CollationCoercible = (*Max1Row)(nil)
   259  var _ sql.NameableNode = (*Max1Row)(nil)
   260  var _ sql.RenameableNode = (*Max1Row)(nil)
   261  
   262  func (m *Max1Row) WithName(s string) sql.Node {
   263  	ret := *m
   264  	ret.name = s
   265  	return &ret
   266  }
   267  
   268  func (m *Max1Row) Name() string {
   269  	return m.name
   270  }
   271  
   272  func (m *Max1Row) IsReadOnly() bool {
   273  	return m.Child.IsReadOnly()
   274  }
   275  
   276  func (m *Max1Row) Resolved() bool {
   277  	return m.Child.Resolved()
   278  }
   279  
   280  func (m *Max1Row) Schema() sql.Schema {
   281  	return m.Child.Schema()
   282  }
   283  
   284  func (m *Max1Row) Children() []sql.Node {
   285  	return []sql.Node{m.Child}
   286  }
   287  
   288  func (m *Max1Row) String() string {
   289  	pr := sql.NewTreePrinter()
   290  	_ = pr.WriteNode("Max1Row")
   291  	children := []string{m.Child.String()}
   292  	_ = pr.WriteChildren(children...)
   293  	return pr.String()
   294  }
   295  
   296  func (m *Max1Row) DebugString() string {
   297  	pr := sql.NewTreePrinter()
   298  	_ = pr.WriteNode("Max1Row")
   299  	children := []string{sql.DebugString(m.Child)}
   300  	_ = pr.WriteChildren(children...)
   301  	return pr.String()
   302  }
   303  
   304  // HasResults returns true after a successful call to PopulateResults()
   305  func (m *Max1Row) HasResults() bool {
   306  	return m.Result != nil || m.EmptyResult
   307  }
   308  
   309  func (m *Max1Row) WithChildren(children ...sql.Node) (sql.Node, error) {
   310  	if len(children) != 1 {
   311  		return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1)
   312  	}
   313  	ret := *m
   314  
   315  	ret.Child = children[0]
   316  
   317  	return &ret, nil
   318  }
   319  
   320  func (m *Max1Row) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   321  	return m.Child.CheckPrivileges(ctx, opChecker)
   322  }
   323  
   324  // CollationCoercibility implements the interface sql.CollationCoercible.
   325  func (m *Max1Row) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   326  	return sql.GetCoercibility(ctx, m.Child)
   327  }
   328  
   329  // EvalMultiple returns all rows returned by a subquery.
   330  func (s *Subquery) EvalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, error) {
   331  	s.cacheMu.Lock()
   332  	cached := s.resultsCached
   333  	s.cacheMu.Unlock()
   334  	if cached {
   335  		return s.cache, nil
   336  	}
   337  
   338  	result, err := s.evalMultiple(ctx, row)
   339  	if err != nil {
   340  		return nil, err
   341  	}
   342  
   343  	if s.canCacheResults() {
   344  		s.cacheMu.Lock()
   345  		if s.resultsCached == false {
   346  			s.cache, s.resultsCached = result, true
   347  		}
   348  		s.cacheMu.Unlock()
   349  	}
   350  
   351  	return result, nil
   352  }
   353  
   354  func (s *Subquery) canCacheResults() bool {
   355  	return s.correlated.Empty() && !s.volatile
   356  }
   357  
   358  func (s *Subquery) evalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, error) {
   359  	// Any source of rows, as well as any node that alters the schema of its children, needs to be wrapped so that its
   360  	// result rows are prepended with the scope row.
   361  	q, _, err := transform.Node(s.Query, PrependRowInPlan(row, false))
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  
   366  	iter, err := s.b.Build(ctx, q, row)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  
   371  	returnsTuple := len(s.Query.Schema()) > 1
   372  
   373  	// Reduce the result row to the size of the expected schema. This means chopping off the first len(row) columns.
   374  	col := len(row)
   375  	var result []interface{}
   376  	for {
   377  		row, err := iter.Next(ctx)
   378  		if err == io.EOF {
   379  			break
   380  		}
   381  
   382  		if err != nil {
   383  			return nil, err
   384  		}
   385  
   386  		if returnsTuple {
   387  			result = append(result, append([]interface{}{}, row[col:]...))
   388  		} else {
   389  			result = append(result, row[col])
   390  		}
   391  	}
   392  
   393  	if err := iter.Close(ctx); err != nil {
   394  		return nil, err
   395  	}
   396  
   397  	return result, nil
   398  }
   399  
   400  // HashMultiple returns all rows returned by a subquery, backed by a sql.KeyValueCache. Keys are constructed using the
   401  // 64-bit hash of the values stored.
   402  func (s *Subquery) HashMultiple(ctx *sql.Context, row sql.Row) (sql.KeyValueCache, error) {
   403  	s.cacheMu.Lock()
   404  	cached := s.resultsCached && s.hashCache != nil
   405  	s.cacheMu.Unlock()
   406  	if cached {
   407  		return s.hashCache, nil
   408  	}
   409  
   410  	result, err := s.evalMultiple(ctx, row)
   411  	if err != nil {
   412  		return nil, err
   413  	}
   414  
   415  	if s.canCacheResults() {
   416  		s.cacheMu.Lock()
   417  		defer s.cacheMu.Unlock()
   418  		if !s.resultsCached || s.hashCache == nil {
   419  			hashCache, disposeFn := ctx.Memory.NewHistoryCache()
   420  			err = putAllRows(hashCache, result)
   421  			if err != nil {
   422  				return nil, err
   423  			}
   424  			s.cache, s.hashCache, s.disposeFunc, s.resultsCached = result, hashCache, disposeFn, true
   425  		}
   426  		return s.hashCache, nil
   427  	}
   428  
   429  	cache := sql.NewMapCache()
   430  	return cache, putAllRows(cache, result)
   431  }
   432  
   433  // HasResultRow returns whether the subquery has a result set > 0.
   434  func (s *Subquery) HasResultRow(ctx *sql.Context, row sql.Row) (bool, error) {
   435  	// First check if the query was cached.
   436  	s.cacheMu.Lock()
   437  	cached := s.resultsCached
   438  	s.cacheMu.Unlock()
   439  
   440  	if cached {
   441  		return len(s.cache) > 0, nil
   442  	}
   443  
   444  	// Any source of rows, as well as any node that alters the schema of its children, needs to be wrapped so that its
   445  	// result rows are prepended with the scope row.
   446  	q, _, err := transform.Node(s.Query, PrependRowInPlan(row, false))
   447  	if err != nil {
   448  		return false, err
   449  	}
   450  
   451  	iter, err := s.b.Build(ctx, q, row)
   452  	if err != nil {
   453  		return false, err
   454  	}
   455  
   456  	// Call the iterator once and see if it has a row. If io.EOF is received return false.
   457  	_, err = iter.Next(ctx)
   458  	if err == io.EOF {
   459  		err = iter.Close(ctx)
   460  		return false, err
   461  	} else if err != nil {
   462  		return false, err
   463  	}
   464  
   465  	err = iter.Close(ctx)
   466  	if err != nil {
   467  		return false, err
   468  	}
   469  
   470  	return true, nil
   471  }
   472  
   473  func putAllRows(cache sql.KeyValueCache, vals []interface{}) error {
   474  	for _, val := range vals {
   475  		rowKey, err := sql.HashOf(sql.NewRow(val))
   476  		if err != nil {
   477  			return err
   478  		}
   479  		err = cache.Put(rowKey, val)
   480  		if err != nil {
   481  			return err
   482  		}
   483  	}
   484  	return nil
   485  }
   486  
   487  // IsNullable implements the Expression interface.
   488  func (s *Subquery) IsNullable() bool {
   489  	return true
   490  }
   491  
   492  func (s *Subquery) String() string {
   493  	pr := sql.NewTreePrinter()
   494  	_ = pr.WriteNode("Subquery")
   495  	children := []string{fmt.Sprintf("cacheable: %t", s.canCacheResults()), s.Query.String()}
   496  	_ = pr.WriteChildren(children...)
   497  	return pr.String()
   498  }
   499  
   500  func (s *Subquery) DebugString() string {
   501  	pr := sql.NewTreePrinter()
   502  	_ = pr.WriteNode("Subquery")
   503  	children := []string{
   504  		fmt.Sprintf("cacheable: %t", s.canCacheResults()),
   505  		fmt.Sprintf("alias-string: %s", s.QueryString),
   506  		sql.DebugString(s.Query),
   507  	}
   508  	_ = pr.WriteChildren(children...)
   509  	return pr.String()
   510  }
   511  
   512  // Resolved implements the Expression interface.
   513  func (s *Subquery) Resolved() bool {
   514  	return s.Query.Resolved()
   515  }
   516  
   517  // Type implements the Expression interface.
   518  func (s *Subquery) Type() sql.Type {
   519  	qs := s.Query.Schema()
   520  	if len(qs) == 1 {
   521  		return s.Query.Schema()[0].Type
   522  	}
   523  	ts := make([]sql.Type, len(qs))
   524  	for i, c := range qs {
   525  		ts[i] = c.Type
   526  	}
   527  	return types.CreateTuple(ts...)
   528  }
   529  
   530  // WithChildren implements the Expression interface.
   531  func (s *Subquery) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   532  	if len(children) != 0 {
   533  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0)
   534  	}
   535  	return s, nil
   536  }
   537  
   538  // Children implements the Expression interface.
   539  func (s *Subquery) Children() []sql.Expression {
   540  	return nil
   541  }
   542  
   543  // NodeChildren implements the sql.ExpressionWithNodes interface.
   544  func (s *Subquery) NodeChildren() []sql.Node {
   545  	return []sql.Node{s.Query}
   546  }
   547  
   548  // WithNodeChildren implements the sql.ExpressionWithNodes interface.
   549  func (s *Subquery) WithNodeChildren(children ...sql.Node) (sql.ExpressionWithNodes, error) {
   550  	if len(children) != 1 {
   551  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1)
   552  	}
   553  	return s.WithQuery(children[0]), nil
   554  }
   555  
   556  // WithQuery returns the subquery with the query node changed.
   557  func (s *Subquery) WithQuery(node sql.Node) *Subquery {
   558  	ns := *s
   559  	ns.Query = node
   560  	return &ns
   561  }
   562  
   563  // WithExecBuilder returns the subquery with a recursive execution builder.
   564  func (s *Subquery) WithExecBuilder(b sql.NodeExecBuilder) *Subquery {
   565  	ns := *s
   566  	ns.b = b
   567  	return &ns
   568  }
   569  
   570  func (s *Subquery) IsNonDeterministic() bool {
   571  	return !s.canCacheResults()
   572  }
   573  
   574  func (s *Subquery) Volatile() bool {
   575  	return s.volatile
   576  }
   577  
   578  func (s *Subquery) WithVolatile() *Subquery {
   579  	ret := *s
   580  	ret.volatile = true
   581  	return &ret
   582  }
   583  
   584  func (s *Subquery) WithCorrelated(cols sql.ColSet) *Subquery {
   585  	ret := *s
   586  	ret.correlated = cols
   587  	return &ret
   588  }
   589  
   590  func (s *Subquery) Correlated() sql.ColSet {
   591  	return s.correlated
   592  }
   593  
   594  func (s *Subquery) CanCacheResults() bool {
   595  	return s.canCacheResults()
   596  }
   597  
   598  // Dispose implements sql.Disposable
   599  func (s *Subquery) Dispose() {
   600  	if s.disposeFunc != nil {
   601  		s.disposeFunc()
   602  		s.disposeFunc = nil
   603  	}
   604  	disposeNode(s.Query)
   605  }
   606  
   607  // CollationCoercibility implements the interface sql.CollationCoercible.
   608  func (s *Subquery) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   609  	return sql.GetCoercibility(ctx, s.Query)
   610  }