github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/recursive_cte.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/types"
    22  )
    23  
    24  // RecursiveCte is defined by two subqueries
    25  // connected with a union:
    26  //
    27  //	ex => WITH RECURSIVE [name]([Columns]) as ([Init] UNION [Rec]) ...
    28  //
    29  // [Init] is a non-recursive select statement, and [Rec] selects from
    30  // the recursive relation [name] until exhaustion. Note that if [Rec] is
    31  // not recursive, the optimizer will fold the RecursiveCte into a
    32  // SubqueryAlias.
    33  //
    34  // The node is executed as follows:
    35  //  1. First, iterate the [Init] subquery.
    36  //  2. Collect the outputs of [Init] in a [temporary] buffer.
    37  //  3. When the iterator is exhausted, populate the recursive
    38  //     [working] table with the [temporary] buffer.
    39  //  4. Iterate [Rec], collecting outputs in the [temporary] buffer.
    40  //  5. Repeat steps (3) and (4) until [temporary] is empty.
    41  //
    42  // A RecursiveCte, its [Init], and its [Rec] have the same
    43  // projection count and types. [Init] will be resolved before
    44  // [Rec] or [RecursiveCte] to share schema types.
    45  type RecursiveCte struct {
    46  	union *SetOp
    47  	// ColumnNames used to name lazily-loaded schema fields
    48  	ColumnNames []string
    49  	// schema will match the types of [Init.Schema()], names of [Columns]
    50  	schema sql.Schema
    51  	// Working is a handle to our refreshable intermediate table
    52  	Working *RecursiveTable
    53  	name    string
    54  	id      sql.TableId
    55  	cols    sql.ColSet
    56  }
    57  
    58  var _ sql.Node = (*RecursiveCte)(nil)
    59  var _ sql.Nameable = (*RecursiveCte)(nil)
    60  var _ sql.RenameableNode = (*RecursiveCte)(nil)
    61  var _ sql.Expressioner = (*RecursiveCte)(nil)
    62  var _ sql.CollationCoercible = (*RecursiveCte)(nil)
    63  var _ TableIdNode = (*RecursiveCte)(nil)
    64  
    65  func NewRecursiveCte(initial, recursive sql.Node, name string, outputCols []string, deduplicate bool, l sql.Expression, sf sql.SortFields) *RecursiveCte {
    66  	return &RecursiveCte{
    67  		ColumnNames: outputCols,
    68  		union: &SetOp{
    69  			SetOpType:  UnionType,
    70  			BinaryNode: BinaryNode{left: initial, right: recursive},
    71  			Distinct:   deduplicate,
    72  			Limit:      l,
    73  			SortFields: sf,
    74  		},
    75  		name: name,
    76  	}
    77  }
    78  
    79  // WithId implements sql.TableIdNode
    80  func (r *RecursiveCte) WithId(id sql.TableId) TableIdNode {
    81  	ret := *r
    82  	ret.id = id
    83  	return &ret
    84  }
    85  
    86  // Id implements sql.TableIdNode
    87  func (r *RecursiveCte) Id() sql.TableId {
    88  	return r.id
    89  }
    90  
    91  // WithColumns implements sql.TableIdNode
    92  func (r *RecursiveCte) WithColumns(set sql.ColSet) TableIdNode {
    93  	ret := *r
    94  	ret.cols = set
    95  	return &ret
    96  }
    97  
    98  // ColumnNames implements sql.TableIdNode
    99  func (r *RecursiveCte) Columns() sql.ColSet {
   100  	return r.cols
   101  }
   102  
   103  func (r *RecursiveCte) WithName(s string) sql.Node {
   104  	ret := *r
   105  	ret.name = s
   106  	return &ret
   107  }
   108  
   109  // Name implements sql.Nameable
   110  func (r *RecursiveCte) Name() string {
   111  	return r.name
   112  }
   113  
   114  func (r *RecursiveCte) IsReadOnly() bool {
   115  	return r.union.BinaryNode.left.IsReadOnly() && r.union.BinaryNode.right.IsReadOnly()
   116  }
   117  
   118  // Left implements sql.BinaryNode
   119  func (r *RecursiveCte) Left() sql.Node {
   120  	return r.union.left
   121  }
   122  
   123  // Right implements sql.BinaryNode
   124  func (r *RecursiveCte) Right() sql.Node {
   125  	return r.union.right
   126  }
   127  
   128  func (r *RecursiveCte) Union() *SetOp {
   129  	return r.union
   130  }
   131  
   132  // WithSchema inherits [Init]'s schema at resolve time
   133  func (r *RecursiveCte) WithSchema(s sql.Schema) *RecursiveCte {
   134  	nr := *r
   135  	nr.schema = s
   136  	return &nr
   137  }
   138  
   139  // WithWorking populates the [working] table with a common schema
   140  func (r *RecursiveCte) WithWorking(t *RecursiveTable) *RecursiveCte {
   141  	nr := *r
   142  	nr.Working = t
   143  	return &nr
   144  }
   145  
   146  // Schema implements sql.Node
   147  func (r *RecursiveCte) Schema() sql.Schema {
   148  	return r.schema
   149  }
   150  
   151  // WithChildren implements sql.Node
   152  func (r *RecursiveCte) WithChildren(children ...sql.Node) (sql.Node, error) {
   153  	ret := *r
   154  	s, err := r.union.WithChildren(children...)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  	ret.union = s.(*SetOp)
   159  	return &ret, nil
   160  }
   161  
   162  func (r *RecursiveCte) Opaque() bool {
   163  	return true
   164  }
   165  
   166  func (r *RecursiveCte) Resolved() bool {
   167  	return r.union.Resolved()
   168  }
   169  
   170  func (r *RecursiveCte) Children() []sql.Node {
   171  	return r.union.Children()
   172  }
   173  
   174  func (r *RecursiveCte) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   175  	return r.union.CheckPrivileges(ctx, opChecker)
   176  }
   177  
   178  // CollationCoercibility implements the interface sql.CollationCoercible.
   179  func (*RecursiveCte) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   180  	return sql.Collation_binary, 7
   181  }
   182  
   183  func (r *RecursiveCte) Expressions() []sql.Expression {
   184  	return r.union.Expressions()
   185  }
   186  
   187  func (r *RecursiveCte) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   188  	ret := *r
   189  	s, err := r.union.WithExpressions(exprs...)
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  	ret.union = s.(*SetOp)
   194  	return &ret, nil
   195  }
   196  
   197  // String implements sql.Node
   198  func (r *RecursiveCte) String() string {
   199  	pr := sql.NewTreePrinter()
   200  	_ = pr.WriteNode("RecursiveCTE")
   201  	pr.WriteChildren(r.union.String())
   202  	return pr.String()
   203  }
   204  
   205  // DebugString implements sql.Node
   206  func (r *RecursiveCte) DebugString() string {
   207  	pr := sql.NewTreePrinter()
   208  	_ = pr.WriteNode("RecursiveCTE")
   209  	pr.WriteChildren(sql.DebugString(r.union))
   210  	return pr.String()
   211  }
   212  
   213  // Type implements sql.Node
   214  func (r *RecursiveCte) Type() sql.Type {
   215  	cols := r.schema
   216  	if len(cols) == 1 {
   217  		return cols[0].Type
   218  	}
   219  	ts := make([]sql.Type, len(cols))
   220  	for i, c := range cols {
   221  		ts[i] = c.Type
   222  	}
   223  	return types.CreateTuple(ts...)
   224  }
   225  
   226  // IsNullable implements sql.Node
   227  func (r *RecursiveCte) IsNullable() bool {
   228  	return true
   229  }
   230  
   231  func NewRecursiveTable(n string, s sql.Schema) *RecursiveTable {
   232  	return &RecursiveTable{
   233  		name:   n,
   234  		schema: s,
   235  	}
   236  }
   237  
   238  // RecursiveTable is a thin wrapper around an in memory
   239  // buffer for use with recursiveCteIter.
   240  type RecursiveTable struct {
   241  	name   string
   242  	schema sql.Schema
   243  	Buf    []sql.Row
   244  	id     sql.TableId
   245  	cols   sql.ColSet
   246  }
   247  
   248  var _ sql.Node = (*RecursiveTable)(nil)
   249  var _ sql.NameableNode = (*RecursiveTable)(nil)
   250  var _ sql.RenameableNode = (*RecursiveTable)(nil)
   251  var _ TableIdNode = (*RecursiveTable)(nil)
   252  var _ sql.CollationCoercible = (*RecursiveTable)(nil)
   253  
   254  // WithId implements sql.TableIdNode
   255  func (r *RecursiveTable) WithId(id sql.TableId) TableIdNode {
   256  	// currently recursive table pointers need to be stable at execution time
   257  	r.id = id
   258  	return r
   259  }
   260  
   261  // Id implements sql.TableIdNode
   262  func (r *RecursiveTable) Id() sql.TableId {
   263  	return r.id
   264  }
   265  
   266  // WithColumns implements sql.TableIdNode
   267  func (r *RecursiveTable) WithColumns(set sql.ColSet) TableIdNode {
   268  	// currently recursive table pointers need to be stable at execution time
   269  	r.cols = set
   270  	return r
   271  }
   272  
   273  // Columns implements sql.TableIdNode
   274  func (r *RecursiveTable) Columns() sql.ColSet {
   275  	return r.cols
   276  }
   277  
   278  func (r *RecursiveTable) WithName(s string) sql.Node {
   279  	ret := *r
   280  	r.name = s
   281  	return &ret
   282  }
   283  
   284  func (r *RecursiveTable) Resolved() bool {
   285  	return true
   286  }
   287  
   288  func (r *RecursiveTable) Name() string {
   289  	return r.name
   290  }
   291  
   292  func (r *RecursiveTable) IsReadOnly() bool {
   293  	return true
   294  }
   295  
   296  func (r *RecursiveTable) String() string {
   297  	return fmt.Sprintf("RecursiveTable(%s)", r.name)
   298  }
   299  
   300  func (r *RecursiveTable) Schema() sql.Schema {
   301  	return r.schema
   302  }
   303  
   304  func (r *RecursiveTable) Children() []sql.Node {
   305  	return nil
   306  }
   307  
   308  func (r *RecursiveTable) WithChildren(node ...sql.Node) (sql.Node, error) {
   309  	return r, nil
   310  }
   311  
   312  // CheckPrivileges implements the interface sql.Node.
   313  func (r *RecursiveTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   314  	return true
   315  }
   316  
   317  // CollationCoercibility implements the interface sql.CollationCoercible.
   318  func (*RecursiveTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   319  	return sql.Collation_binary, 7
   320  }