github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/with.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  )
    23  
    24  // With is a node to wrap the top-level node in a query plan so that any common table expressions can be applied in
    25  // analysis. It is removed during analysis.
    26  type With struct {
    27  	UnaryNode
    28  	CTEs      []*CommonTableExpression
    29  	Recursive bool
    30  }
    31  
    32  var _ sql.Node = (*With)(nil)
    33  var _ sql.CollationCoercible = (*With)(nil)
    34  var _ DisjointedChildrenNode = (*With)(nil)
    35  
    36  func NewWith(child sql.Node, ctes []*CommonTableExpression, recursive bool) *With {
    37  	return &With{
    38  		UnaryNode: UnaryNode{child},
    39  		CTEs:      ctes,
    40  		Recursive: recursive,
    41  	}
    42  }
    43  
    44  func (w *With) IsReadOnly() bool {
    45  	return w.Child.IsReadOnly()
    46  }
    47  
    48  func (w *With) String() string {
    49  	cteStrings := make([]string, len(w.CTEs))
    50  	for i, e := range w.CTEs {
    51  		cteStrings[i] = e.String()
    52  	}
    53  
    54  	pr := sql.NewTreePrinter()
    55  	if w.Recursive {
    56  		_ = pr.WriteNode("with recursive (%s)", strings.Join(cteStrings, ", "))
    57  	} else {
    58  		_ = pr.WriteNode("with(%s)", strings.Join(cteStrings, ", "))
    59  	}
    60  	_ = pr.WriteChildren(w.Child.String())
    61  	return pr.String()
    62  }
    63  
    64  func (w *With) DebugString() string {
    65  	cteStrings := make([]string, len(w.CTEs))
    66  	for i, e := range w.CTEs {
    67  		cteStrings[i] = sql.DebugString(e)
    68  	}
    69  
    70  	pr := sql.NewTreePrinter()
    71  	_ = pr.WriteNode("With(%s)", strings.Join(cteStrings, ", "))
    72  	_ = pr.WriteChildren(sql.DebugString(w.Child))
    73  	return pr.String()
    74  }
    75  
    76  func (w *With) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
    77  	panic("Cannot call RowIter on With node")
    78  }
    79  
    80  func (w *With) WithChildren(children ...sql.Node) (sql.Node, error) {
    81  	if len(children) != 1 {
    82  		return nil, sql.ErrInvalidChildrenNumber.New(w, len(children), 1)
    83  	}
    84  
    85  	return NewWith(children[0], w.CTEs, w.Recursive), nil
    86  }
    87  
    88  // CheckPrivileges implements the interface sql.Node.
    89  func (w *With) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
    90  	return w.Child.CheckPrivileges(ctx, opChecker)
    91  }
    92  
    93  // CollationCoercibility implements the interface sql.CollationCoercible.
    94  func (w *With) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    95  	return sql.GetCoercibility(ctx, w.Child)
    96  }
    97  
    98  // DisjointedChildren implements the interface DisjointedChildrenNode.
    99  func (w *With) DisjointedChildren() [][]sql.Node {
   100  	cteAliases := make([]sql.Node, len(w.CTEs))
   101  	for i := range cteAliases {
   102  		cteAliases[i] = w.CTEs[i].Subquery
   103  	}
   104  	return [][]sql.Node{
   105  		{w.UnaryNode.Child},
   106  		cteAliases,
   107  	}
   108  }
   109  
   110  // WithDisjointedChildren implements the interface DisjointedChildrenNode.
   111  func (w *With) WithDisjointedChildren(children [][]sql.Node) (sql.Node, error) {
   112  	if len(children) != 2 || len(children[0]) != 1 || len(children[1]) != len(w.CTEs) {
   113  		return nil, sql.ErrInvalidChildrenNumber.New(w, len(children), 2)
   114  	}
   115  	nw := *w
   116  	nw.UnaryNode.Child = children[0][0]
   117  	newCTEs := make([]*CommonTableExpression, len(w.CTEs))
   118  	copy(newCTEs, w.CTEs)
   119  	for i, cteAliasChild := range children[1] {
   120  		subqueryAlias, ok := cteAliasChild.(*SubqueryAlias)
   121  		if !ok {
   122  			return nil, fmt.Errorf("%T: expected `%T`, got `%T`", w, nw.CTEs[i].Subquery, cteAliasChild)
   123  		}
   124  		newCTEs[i] = &CommonTableExpression{
   125  			Subquery: subqueryAlias,
   126  			Columns:  w.CTEs[i].Columns,
   127  		}
   128  	}
   129  	nw.CTEs = newCTEs
   130  	return &nw, nil
   131  }
   132  
   133  type CommonTableExpression struct {
   134  	Subquery *SubqueryAlias
   135  	Columns  []string
   136  }
   137  
   138  func NewCommonTableExpression(subquery *SubqueryAlias, columns []string) *CommonTableExpression {
   139  	return &CommonTableExpression{
   140  		Subquery: subquery,
   141  		Columns:  columns,
   142  	}
   143  }
   144  
   145  func (e *CommonTableExpression) String() string {
   146  	pr := sql.NewTreePrinter()
   147  	if len(e.Columns) > 0 {
   148  		_ = pr.WriteNode("%s (%s)", e.Subquery.name, strings.Join(e.Columns, ","))
   149  	} else {
   150  		_ = pr.WriteNode("%s", e.Subquery.name)
   151  	}
   152  	_ = pr.WriteChildren(sql.DebugString(e.Subquery.Child))
   153  	return pr.String()
   154  }
   155  
   156  func (e *CommonTableExpression) DebugString() string {
   157  	pr := sql.NewTreePrinter()
   158  	if len(e.Columns) > 0 {
   159  		_ = pr.WriteNode("%s (%s)", e.Subquery.name, strings.Join(e.Columns, ","))
   160  	} else {
   161  		_ = pr.WriteNode("%s", e.Subquery.name)
   162  	}
   163  	_ = pr.WriteChildren(e.Subquery.Child.String())
   164  	return pr.String()
   165  }