github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/cte.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/plan"
    25  )
    26  
    27  func (b *Builder) buildWith(inScope *scope, with *ast.With) (outScope *scope) {
    28  	// resolveCommonTableExpressions operates on With nodes. It replaces any matching UnresolvedTable references in the
    29  	// tree with the subqueries defined in the CTEs.
    30  
    31  	// CTE resolution:
    32  	// - pre-process, get the list of CTEs
    33  	// - find uses of those CTEs in the regular query body
    34  	// - replace references to the name with the subquery body
    35  	// - avoid infinite recursion of CTE referencing itself
    36  
    37  	// recursive CTE (more complicated)
    38  	// push recursive half right, minimize recursive side
    39  	// create *plan.RecursiveCte node
    40  	// replace recursive references of cte name with *plan.RecursiveTable
    41  
    42  	outScope = inScope.push()
    43  
    44  	for _, cte := range with.Ctes {
    45  		cte, ok := cte.(*ast.CommonTableExpr)
    46  		if !ok {
    47  			b.handleErr(sql.ErrUnsupportedFeature.New(fmt.Sprintf("Unsupported type of common table expression %T", cte)))
    48  		}
    49  
    50  		ate := cte.AliasedTableExpr
    51  		sq, ok := ate.Expr.(*ast.Subquery)
    52  		if !ok {
    53  			b.handleErr(sql.ErrUnsupportedFeature.New(fmt.Sprintf("Unsupported type of common table expression %T", ate.Expr)))
    54  		}
    55  
    56  		cteName := strings.ToLower(ate.As.String())
    57  		var cteScope *scope
    58  		if with.Recursive {
    59  			switch n := sq.Select.(type) {
    60  			case *ast.SetOp:
    61  				switch n.Type {
    62  				case ast.UnionStr, ast.UnionAllStr, ast.UnionDistinctStr:
    63  					cteScope = b.buildRecursiveCte(outScope, n, cteName, columnsToStrings(cte.Columns))
    64  				default:
    65  					b.handleErr(sql.ErrRecursiveCTEMissingUnion.New(cteName))
    66  				}
    67  			default:
    68  				if hasRecursiveTable(cteName, n) {
    69  					b.handleErr(sql.ErrRecursiveCTEMissingUnion.New(cteName))
    70  				}
    71  				cteScope = b.buildCte(outScope, ate, cteName, columnsToStrings(cte.Columns))
    72  			}
    73  		} else {
    74  			cteScope = b.buildCte(outScope, ate, cteName, columnsToStrings(cte.Columns))
    75  		}
    76  		inScope.addCte(cteName, cteScope)
    77  	}
    78  	return
    79  }
    80  
    81  func (b *Builder) buildCte(inScope *scope, e ast.TableExpr, name string, columns []string) *scope {
    82  	cteScope := b.buildDataSource(inScope, e)
    83  	b.renameSource(cteScope, name, columns)
    84  	switch n := cteScope.node.(type) {
    85  	case *plan.SubqueryAlias:
    86  		cteScope.node = n.WithColumnNames(columns)
    87  	}
    88  	return cteScope
    89  }
    90  
    91  func (b *Builder) buildRecursiveCte(inScope *scope, union *ast.SetOp, name string, columns []string) *scope {
    92  	l, r := splitRecursiveCteUnion(name, union)
    93  	if r == nil {
    94  		// not recursive
    95  		sqScope := inScope.pushSubquery()
    96  		cteScope := b.buildSelectStmt(sqScope, union)
    97  		b.renameSource(cteScope, name, columns)
    98  		switch n := cteScope.node.(type) {
    99  		case *plan.SetOp:
   100  			sq := plan.NewSubqueryAlias(name, "", n)
   101  			sq = sq.WithColumnNames(columns)
   102  			sq = sq.WithCorrelated(sqScope.correlated())
   103  			sq = sq.WithVolatile(sqScope.volatile())
   104  
   105  			tabId := cteScope.addTable(name)
   106  			var colset sql.ColSet
   107  			for i, c := range cteScope.cols {
   108  				c.tableId = tabId
   109  				cteScope.cols[i] = c
   110  				colset.Add(sql.ColumnId(c.id))
   111  			}
   112  
   113  			cteScope.node = sq.WithId(tabId).WithColumns(colset)
   114  		}
   115  		return cteScope
   116  	}
   117  
   118  	switch union.Type {
   119  	case ast.UnionStr, ast.UnionAllStr, ast.UnionDistinctStr:
   120  	default:
   121  		b.handleErr(sql.ErrRecursiveCTENotUnion.New(union.Type))
   122  	}
   123  
   124  	// resolve non-recursive portion
   125  	leftSqScope := inScope.pushSubquery()
   126  	leftScope := b.buildSelectStmt(leftSqScope, l)
   127  
   128  	// schema for non-recursive portion => recursive table
   129  	var rTable *plan.RecursiveTable
   130  	var rInit sql.Node
   131  	var recSch sql.Schema
   132  	cteScope := leftScope.replace()
   133  	tableId := cteScope.addTable(name)
   134  	var cols sql.ColSet
   135  	scopeMapping := make(map[sql.ColumnId]sql.Expression)
   136  	{
   137  		rInit = leftScope.node
   138  		recSch = make(sql.Schema, len(rInit.Schema()))
   139  		for i, c := range rInit.Schema() {
   140  			newC := c.Copy()
   141  			if len(columns) > 0 {
   142  				newC.Name = columns[i]
   143  			}
   144  			newC.Source = name
   145  			// the recursive part of the CTE may produce wider types than the left/non-recursive part
   146  			// we need to promote the type of the left part, so the final schema is the widest possible type
   147  			newC.Type = newC.Type.Promote()
   148  			recSch[i] = newC
   149  		}
   150  
   151  		for i, c := range leftScope.cols {
   152  			c.typ = recSch[i].Type
   153  			c.scalar = nil
   154  			c.table = name
   155  			toId := cteScope.newColumn(c)
   156  			scopeMapping[sql.ColumnId(toId)] = c.scalarGf()
   157  			cols.Add(sql.ColumnId(toId))
   158  		}
   159  		b.renameSource(cteScope, name, columns)
   160  
   161  		rTable = plan.NewRecursiveTable(name, recSch)
   162  		cteScope.node = rTable.WithId(tableId).WithColumns(cols)
   163  	}
   164  
   165  	rightInScope := inScope.replaceSubquery()
   166  	rightInScope.addCte(name, cteScope)
   167  	rightScope := b.buildSelectStmt(rightInScope, r)
   168  
   169  	// all is not distinct
   170  	distinct := true
   171  	switch union.Type {
   172  	case ast.UnionAllStr, ast.IntersectAllStr, ast.ExceptAllStr:
   173  		distinct = false
   174  	}
   175  	limit := b.buildLimit(inScope, union.Limit)
   176  
   177  	orderByScope := b.analyzeOrderBy(cteScope, leftScope, union.OrderBy)
   178  	var sortFields sql.SortFields
   179  	for _, c := range orderByScope.cols {
   180  		so := sql.Ascending
   181  		if c.descending {
   182  			so = sql.Descending
   183  		}
   184  		scalar := c.scalar
   185  		if scalar == nil {
   186  			scalar = c.scalarGf()
   187  		}
   188  		sf := sql.SortField{
   189  			Column: scalar,
   190  			Order:  so,
   191  		}
   192  		sortFields = append(sortFields, sf)
   193  	}
   194  
   195  	rcte := plan.NewRecursiveCte(rInit, rightScope.node, name, columns, distinct, limit, sortFields)
   196  	rcte = rcte.WithSchema(recSch).WithWorking(rTable)
   197  	corr := leftSqScope.correlated().Union(rightInScope.correlated())
   198  	vol := leftSqScope.activeSubquery.volatile || rightInScope.activeSubquery.volatile
   199  
   200  	rcteId := rcte.WithId(tableId).WithColumns(cols)
   201  
   202  	sq := plan.NewSubqueryAlias(name, "", rcteId)
   203  	sq = sq.WithColumnNames(columns)
   204  	sq = sq.WithCorrelated(corr)
   205  	sq = sq.WithVolatile(vol)
   206  	sq = sq.WithScopeMapping(scopeMapping)
   207  	cteScope.node = sq.WithId(tableId).WithColumns(cols)
   208  	b.renameSource(cteScope, name, columns)
   209  	return cteScope
   210  }
   211  
   212  // splitRecursiveCteUnion distinguishes between recursive and non-recursive
   213  // portions of a recursive CTE. We walk a left deep tree of unions downwards
   214  // as far as the right scope references the recursive binding. A subquery
   215  // alias or a non-recursive right scope terminates the walk. We transpose all
   216  // recursive right scopes into a new union tree, returning separate initial
   217  // and recursive trees. If the node is not a recursive union, the returned
   218  // right node will be nil.
   219  //
   220  // todo(max): better error messages to differentiate between syntax errors
   221  // "should have one or more non-recursive query blocks followed by one or more recursive ones"
   222  // "the recursive table must be referenced only once, and not in any subquery"
   223  func splitRecursiveCteUnion(name string, n ast.SelectStatement) (ast.SelectStatement, ast.SelectStatement) {
   224  	union, ok := n.(*ast.SetOp)
   225  	if !ok {
   226  		return n, nil
   227  	}
   228  
   229  	if !hasRecursiveTable(name, union.Right) {
   230  		return n, nil
   231  	}
   232  
   233  	l, r := splitRecursiveCteUnion(name, union.Left)
   234  	if r == nil {
   235  		return union.Left, union.Right
   236  	}
   237  
   238  	return l, &ast.SetOp{
   239  		Type:    union.Type,
   240  		Left:    r,
   241  		Right:   union.Right,
   242  		OrderBy: union.OrderBy,
   243  		With:    union.With,
   244  		Limit:   union.Limit,
   245  		Lock:    union.Lock,
   246  	}
   247  }
   248  
   249  // hasRecursiveTable returns true if the given scope references the
   250  // table name.
   251  func hasRecursiveTable(name string, s ast.SelectStatement) bool {
   252  	var found bool
   253  	ast.Walk(func(node ast.SQLNode) (bool, error) {
   254  		switch t := (node).(type) {
   255  		case *ast.AliasedTableExpr:
   256  			switch e := t.Expr.(type) {
   257  			case ast.TableName:
   258  				if strings.ToLower(e.Name.String()) == name {
   259  					found = true
   260  				}
   261  			}
   262  		}
   263  		return true, nil
   264  	}, s)
   265  	return found
   266  }