github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/cte.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "fmt" 19 "strings" 20 21 ast "github.com/dolthub/vitess/go/vt/sqlparser" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/plan" 25 ) 26 27 func (b *Builder) buildWith(inScope *scope, with *ast.With) (outScope *scope) { 28 // resolveCommonTableExpressions operates on With nodes. It replaces any matching UnresolvedTable references in the 29 // tree with the subqueries defined in the CTEs. 30 31 // CTE resolution: 32 // - pre-process, get the list of CTEs 33 // - find uses of those CTEs in the regular query body 34 // - replace references to the name with the subquery body 35 // - avoid infinite recursion of CTE referencing itself 36 37 // recursive CTE (more complicated) 38 // push recursive half right, minimize recursive side 39 // create *plan.RecursiveCte node 40 // replace recursive references of cte name with *plan.RecursiveTable 41 42 outScope = inScope.push() 43 44 for _, cte := range with.Ctes { 45 cte, ok := cte.(*ast.CommonTableExpr) 46 if !ok { 47 b.handleErr(sql.ErrUnsupportedFeature.New(fmt.Sprintf("Unsupported type of common table expression %T", cte))) 48 } 49 50 ate := cte.AliasedTableExpr 51 sq, ok := ate.Expr.(*ast.Subquery) 52 if !ok { 53 b.handleErr(sql.ErrUnsupportedFeature.New(fmt.Sprintf("Unsupported type of common table expression %T", ate.Expr))) 54 } 55 56 cteName := strings.ToLower(ate.As.String()) 57 var cteScope *scope 58 if with.Recursive { 59 switch n := sq.Select.(type) { 60 case *ast.SetOp: 61 switch n.Type { 62 case ast.UnionStr, ast.UnionAllStr, ast.UnionDistinctStr: 63 cteScope = b.buildRecursiveCte(outScope, n, cteName, columnsToStrings(cte.Columns)) 64 default: 65 b.handleErr(sql.ErrRecursiveCTEMissingUnion.New(cteName)) 66 } 67 default: 68 if hasRecursiveTable(cteName, n) { 69 b.handleErr(sql.ErrRecursiveCTEMissingUnion.New(cteName)) 70 } 71 cteScope = b.buildCte(outScope, ate, cteName, columnsToStrings(cte.Columns)) 72 } 73 } else { 74 cteScope = b.buildCte(outScope, ate, cteName, columnsToStrings(cte.Columns)) 75 } 76 inScope.addCte(cteName, cteScope) 77 } 78 return 79 } 80 81 func (b *Builder) buildCte(inScope *scope, e ast.TableExpr, name string, columns []string) *scope { 82 cteScope := b.buildDataSource(inScope, e) 83 b.renameSource(cteScope, name, columns) 84 switch n := cteScope.node.(type) { 85 case *plan.SubqueryAlias: 86 cteScope.node = n.WithColumnNames(columns) 87 } 88 return cteScope 89 } 90 91 func (b *Builder) buildRecursiveCte(inScope *scope, union *ast.SetOp, name string, columns []string) *scope { 92 l, r := splitRecursiveCteUnion(name, union) 93 if r == nil { 94 // not recursive 95 sqScope := inScope.pushSubquery() 96 cteScope := b.buildSelectStmt(sqScope, union) 97 b.renameSource(cteScope, name, columns) 98 switch n := cteScope.node.(type) { 99 case *plan.SetOp: 100 sq := plan.NewSubqueryAlias(name, "", n) 101 sq = sq.WithColumnNames(columns) 102 sq = sq.WithCorrelated(sqScope.correlated()) 103 sq = sq.WithVolatile(sqScope.volatile()) 104 105 tabId := cteScope.addTable(name) 106 var colset sql.ColSet 107 for i, c := range cteScope.cols { 108 c.tableId = tabId 109 cteScope.cols[i] = c 110 colset.Add(sql.ColumnId(c.id)) 111 } 112 113 cteScope.node = sq.WithId(tabId).WithColumns(colset) 114 } 115 return cteScope 116 } 117 118 switch union.Type { 119 case ast.UnionStr, ast.UnionAllStr, ast.UnionDistinctStr: 120 default: 121 b.handleErr(sql.ErrRecursiveCTENotUnion.New(union.Type)) 122 } 123 124 // resolve non-recursive portion 125 leftSqScope := inScope.pushSubquery() 126 leftScope := b.buildSelectStmt(leftSqScope, l) 127 128 // schema for non-recursive portion => recursive table 129 var rTable *plan.RecursiveTable 130 var rInit sql.Node 131 var recSch sql.Schema 132 cteScope := leftScope.replace() 133 tableId := cteScope.addTable(name) 134 var cols sql.ColSet 135 scopeMapping := make(map[sql.ColumnId]sql.Expression) 136 { 137 rInit = leftScope.node 138 recSch = make(sql.Schema, len(rInit.Schema())) 139 for i, c := range rInit.Schema() { 140 newC := c.Copy() 141 if len(columns) > 0 { 142 newC.Name = columns[i] 143 } 144 newC.Source = name 145 // the recursive part of the CTE may produce wider types than the left/non-recursive part 146 // we need to promote the type of the left part, so the final schema is the widest possible type 147 newC.Type = newC.Type.Promote() 148 recSch[i] = newC 149 } 150 151 for i, c := range leftScope.cols { 152 c.typ = recSch[i].Type 153 c.scalar = nil 154 c.table = name 155 toId := cteScope.newColumn(c) 156 scopeMapping[sql.ColumnId(toId)] = c.scalarGf() 157 cols.Add(sql.ColumnId(toId)) 158 } 159 b.renameSource(cteScope, name, columns) 160 161 rTable = plan.NewRecursiveTable(name, recSch) 162 cteScope.node = rTable.WithId(tableId).WithColumns(cols) 163 } 164 165 rightInScope := inScope.replaceSubquery() 166 rightInScope.addCte(name, cteScope) 167 rightScope := b.buildSelectStmt(rightInScope, r) 168 169 // all is not distinct 170 distinct := true 171 switch union.Type { 172 case ast.UnionAllStr, ast.IntersectAllStr, ast.ExceptAllStr: 173 distinct = false 174 } 175 limit := b.buildLimit(inScope, union.Limit) 176 177 orderByScope := b.analyzeOrderBy(cteScope, leftScope, union.OrderBy) 178 var sortFields sql.SortFields 179 for _, c := range orderByScope.cols { 180 so := sql.Ascending 181 if c.descending { 182 so = sql.Descending 183 } 184 scalar := c.scalar 185 if scalar == nil { 186 scalar = c.scalarGf() 187 } 188 sf := sql.SortField{ 189 Column: scalar, 190 Order: so, 191 } 192 sortFields = append(sortFields, sf) 193 } 194 195 rcte := plan.NewRecursiveCte(rInit, rightScope.node, name, columns, distinct, limit, sortFields) 196 rcte = rcte.WithSchema(recSch).WithWorking(rTable) 197 corr := leftSqScope.correlated().Union(rightInScope.correlated()) 198 vol := leftSqScope.activeSubquery.volatile || rightInScope.activeSubquery.volatile 199 200 rcteId := rcte.WithId(tableId).WithColumns(cols) 201 202 sq := plan.NewSubqueryAlias(name, "", rcteId) 203 sq = sq.WithColumnNames(columns) 204 sq = sq.WithCorrelated(corr) 205 sq = sq.WithVolatile(vol) 206 sq = sq.WithScopeMapping(scopeMapping) 207 cteScope.node = sq.WithId(tableId).WithColumns(cols) 208 b.renameSource(cteScope, name, columns) 209 return cteScope 210 } 211 212 // splitRecursiveCteUnion distinguishes between recursive and non-recursive 213 // portions of a recursive CTE. We walk a left deep tree of unions downwards 214 // as far as the right scope references the recursive binding. A subquery 215 // alias or a non-recursive right scope terminates the walk. We transpose all 216 // recursive right scopes into a new union tree, returning separate initial 217 // and recursive trees. If the node is not a recursive union, the returned 218 // right node will be nil. 219 // 220 // todo(max): better error messages to differentiate between syntax errors 221 // "should have one or more non-recursive query blocks followed by one or more recursive ones" 222 // "the recursive table must be referenced only once, and not in any subquery" 223 func splitRecursiveCteUnion(name string, n ast.SelectStatement) (ast.SelectStatement, ast.SelectStatement) { 224 union, ok := n.(*ast.SetOp) 225 if !ok { 226 return n, nil 227 } 228 229 if !hasRecursiveTable(name, union.Right) { 230 return n, nil 231 } 232 233 l, r := splitRecursiveCteUnion(name, union.Left) 234 if r == nil { 235 return union.Left, union.Right 236 } 237 238 return l, &ast.SetOp{ 239 Type: union.Type, 240 Left: r, 241 Right: union.Right, 242 OrderBy: union.OrderBy, 243 With: union.With, 244 Limit: union.Limit, 245 Lock: union.Lock, 246 } 247 } 248 249 // hasRecursiveTable returns true if the given scope references the 250 // table name. 251 func hasRecursiveTable(name string, s ast.SelectStatement) bool { 252 var found bool 253 ast.Walk(func(node ast.SQLNode) (bool, error) { 254 switch t := (node).(type) { 255 case *ast.AliasedTableExpr: 256 switch e := t.Expr.(type) { 257 case ast.TableName: 258 if strings.ToLower(e.Name.String()) == name { 259 found = true 260 } 261 } 262 } 263 return true, nil 264 }, s) 265 return found 266 }