github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/with.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 ) 23 24 // With is a node to wrap the top-level node in a query plan so that any common table expressions can be applied in 25 // analysis. It is removed during analysis. 26 type With struct { 27 UnaryNode 28 CTEs []*CommonTableExpression 29 Recursive bool 30 } 31 32 var _ sql.Node = (*With)(nil) 33 var _ sql.CollationCoercible = (*With)(nil) 34 var _ DisjointedChildrenNode = (*With)(nil) 35 36 func NewWith(child sql.Node, ctes []*CommonTableExpression, recursive bool) *With { 37 return &With{ 38 UnaryNode: UnaryNode{child}, 39 CTEs: ctes, 40 Recursive: recursive, 41 } 42 } 43 44 func (w *With) IsReadOnly() bool { 45 return w.Child.IsReadOnly() 46 } 47 48 func (w *With) String() string { 49 cteStrings := make([]string, len(w.CTEs)) 50 for i, e := range w.CTEs { 51 cteStrings[i] = e.String() 52 } 53 54 pr := sql.NewTreePrinter() 55 if w.Recursive { 56 _ = pr.WriteNode("with recursive (%s)", strings.Join(cteStrings, ", ")) 57 } else { 58 _ = pr.WriteNode("with(%s)", strings.Join(cteStrings, ", ")) 59 } 60 _ = pr.WriteChildren(w.Child.String()) 61 return pr.String() 62 } 63 64 func (w *With) DebugString() string { 65 cteStrings := make([]string, len(w.CTEs)) 66 for i, e := range w.CTEs { 67 cteStrings[i] = sql.DebugString(e) 68 } 69 70 pr := sql.NewTreePrinter() 71 _ = pr.WriteNode("With(%s)", strings.Join(cteStrings, ", ")) 72 _ = pr.WriteChildren(sql.DebugString(w.Child)) 73 return pr.String() 74 } 75 76 func (w *With) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { 77 panic("Cannot call RowIter on With node") 78 } 79 80 func (w *With) WithChildren(children ...sql.Node) (sql.Node, error) { 81 if len(children) != 1 { 82 return nil, sql.ErrInvalidChildrenNumber.New(w, len(children), 1) 83 } 84 85 return NewWith(children[0], w.CTEs, w.Recursive), nil 86 } 87 88 // CheckPrivileges implements the interface sql.Node. 89 func (w *With) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 90 return w.Child.CheckPrivileges(ctx, opChecker) 91 } 92 93 // CollationCoercibility implements the interface sql.CollationCoercible. 94 func (w *With) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 95 return sql.GetCoercibility(ctx, w.Child) 96 } 97 98 // DisjointedChildren implements the interface DisjointedChildrenNode. 99 func (w *With) DisjointedChildren() [][]sql.Node { 100 cteAliases := make([]sql.Node, len(w.CTEs)) 101 for i := range cteAliases { 102 cteAliases[i] = w.CTEs[i].Subquery 103 } 104 return [][]sql.Node{ 105 {w.UnaryNode.Child}, 106 cteAliases, 107 } 108 } 109 110 // WithDisjointedChildren implements the interface DisjointedChildrenNode. 111 func (w *With) WithDisjointedChildren(children [][]sql.Node) (sql.Node, error) { 112 if len(children) != 2 || len(children[0]) != 1 || len(children[1]) != len(w.CTEs) { 113 return nil, sql.ErrInvalidChildrenNumber.New(w, len(children), 2) 114 } 115 nw := *w 116 nw.UnaryNode.Child = children[0][0] 117 newCTEs := make([]*CommonTableExpression, len(w.CTEs)) 118 copy(newCTEs, w.CTEs) 119 for i, cteAliasChild := range children[1] { 120 subqueryAlias, ok := cteAliasChild.(*SubqueryAlias) 121 if !ok { 122 return nil, fmt.Errorf("%T: expected `%T`, got `%T`", w, nw.CTEs[i].Subquery, cteAliasChild) 123 } 124 newCTEs[i] = &CommonTableExpression{ 125 Subquery: subqueryAlias, 126 Columns: w.CTEs[i].Columns, 127 } 128 } 129 nw.CTEs = newCTEs 130 return &nw, nil 131 } 132 133 type CommonTableExpression struct { 134 Subquery *SubqueryAlias 135 Columns []string 136 } 137 138 func NewCommonTableExpression(subquery *SubqueryAlias, columns []string) *CommonTableExpression { 139 return &CommonTableExpression{ 140 Subquery: subquery, 141 Columns: columns, 142 } 143 } 144 145 func (e *CommonTableExpression) String() string { 146 pr := sql.NewTreePrinter() 147 if len(e.Columns) > 0 { 148 _ = pr.WriteNode("%s (%s)", e.Subquery.name, strings.Join(e.Columns, ",")) 149 } else { 150 _ = pr.WriteNode("%s", e.Subquery.name) 151 } 152 _ = pr.WriteChildren(sql.DebugString(e.Subquery.Child)) 153 return pr.String() 154 } 155 156 func (e *CommonTableExpression) DebugString() string { 157 pr := sql.NewTreePrinter() 158 if len(e.Columns) > 0 { 159 _ = pr.WriteNode("%s (%s)", e.Subquery.name, strings.Join(e.Columns, ",")) 160 } else { 161 _ = pr.WriteNode("%s", e.Subquery.name) 162 } 163 _ = pr.WriteChildren(e.Subquery.Child.String()) 164 return pr.String() 165 }