github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/set_op.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  
    21  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/transform"
    27  )
    28  
    29  func hasRecursiveCte(node sql.Node) bool {
    30  	hasRCTE := false
    31  	transform.Inspect(node, func(n sql.Node) bool {
    32  		if _, ok := n.(*plan.RecursiveTable); ok {
    33  			hasRCTE = true
    34  			return false
    35  		}
    36  		return true
    37  	})
    38  	return hasRCTE
    39  }
    40  
    41  func (b *Builder) buildSetOp(inScope *scope, u *ast.SetOp) (outScope *scope) {
    42  	leftScope := b.buildSelectStmt(inScope, u.Left)
    43  	rightScope := b.buildSelectStmt(inScope, u.Right)
    44  
    45  	var setOpType int
    46  	switch u.Type {
    47  	case ast.UnionStr, ast.UnionAllStr, ast.UnionDistinctStr:
    48  		setOpType = plan.UnionType
    49  	case ast.IntersectStr, ast.IntersectAllStr, ast.IntersectDistinctStr:
    50  		setOpType = plan.IntersectType
    51  	case ast.ExceptStr, ast.ExceptAllStr, ast.ExceptDistinctStr:
    52  		setOpType = plan.ExceptType
    53  	default:
    54  		b.handleErr(fmt.Errorf("unknown union type %s", u.Type))
    55  	}
    56  
    57  	if setOpType != plan.UnionType {
    58  		if hasRecursiveCte(leftScope.node) {
    59  			b.handleErr(sql.ErrRecursiveCTENotUnion.New())
    60  		}
    61  		if hasRecursiveCte(rightScope.node) {
    62  			b.handleErr(sql.ErrRecursiveCTENotUnion.New())
    63  		}
    64  	}
    65  
    66  	// all is not distinct
    67  	distinct := true
    68  	switch u.Type {
    69  	case ast.UnionAllStr, ast.IntersectAllStr, ast.ExceptAllStr:
    70  		distinct = false
    71  	}
    72  
    73  	limit := b.buildLimit(inScope, u.Limit)
    74  	offset := b.buildOffset(inScope, u.Limit)
    75  
    76  	for _, o := range u.OrderBy {
    77  		if expr, ok := o.Expr.(*ast.ColName); ok && len(expr.Qualifier.Name.String()) != 0 {
    78  			b.handleErr(ErrQualifiedOrderBy.New(expr.Qualifier.Name.String()))
    79  		}
    80  	}
    81  
    82  	// mysql errors for order by right projection
    83  	orderByScope := b.analyzeOrderBy(leftScope, leftScope, u.OrderBy)
    84  
    85  	var sortFields sql.SortFields
    86  	for _, c := range orderByScope.cols {
    87  		so := sql.Ascending
    88  		if c.descending {
    89  			so = sql.Descending
    90  		}
    91  		scalar := c.scalar
    92  		if scalar == nil {
    93  			scalar = c.scalarGf()
    94  		}
    95  		// Unions pass order bys to the top scope, where the original
    96  		// order by get field may no longer be accessible. Here it is
    97  		// safe to assume the alias has already been computed.
    98  		scalar, _, _ = transform.Expr(scalar, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
    99  			switch e := e.(type) {
   100  			case *expression.Alias:
   101  				return expression.NewGetField(int(c.id), e.Type(), e.Name(), e.IsNullable()), transform.NewTree, nil
   102  			default:
   103  				return e, transform.SameTree, nil
   104  			}
   105  		})
   106  		sf := sql.SortField{
   107  			Column: scalar,
   108  			Order:  so,
   109  		}
   110  		sortFields = append(sortFields, sf)
   111  	}
   112  
   113  	n, ok := leftScope.node.(*plan.SetOp)
   114  	if ok {
   115  		if len(n.SortFields) > 0 {
   116  			if len(sortFields) > 0 {
   117  				err := sql.ErrConflictingExternalQuery.New()
   118  				b.handleErr(err)
   119  			}
   120  			sortFields = n.SortFields
   121  		}
   122  		if n.Limit != nil {
   123  			if limit != nil {
   124  				err := fmt.Errorf("conflicing external LIMIT")
   125  				b.handleErr(err)
   126  			}
   127  			limit = n.Limit
   128  		}
   129  		if n.Offset != nil {
   130  			if offset != nil {
   131  				err := fmt.Errorf("conflicing external OFFSET")
   132  				b.handleErr(err)
   133  			}
   134  			offset = n.Offset
   135  		}
   136  		leftScope.node = plan.NewSetOp(n.SetOpType, n.Left(), n.Right(), n.Distinct, nil, nil, nil).WithColumns(n.Columns()).WithId(n.Id())
   137  	}
   138  
   139  	var cols sql.ColSet
   140  	for _, c := range leftScope.cols {
   141  		cols.Add(sql.ColumnId(c.id))
   142  	}
   143  	b.tabId++
   144  	tabId := b.tabId
   145  	ret := plan.NewSetOp(setOpType, leftScope.node, rightScope.node, distinct, limit, offset, sortFields).WithId(tabId).WithColumns(cols)
   146  	outScope = leftScope
   147  	outScope.node = b.mergeSetOpSchemas(ret.(*plan.SetOp))
   148  	return
   149  }
   150  
   151  func (b *Builder) mergeSetOpSchemas(u *plan.SetOp) sql.Node {
   152  	ls, rs := u.Left().Schema(), u.Right().Schema()
   153  	if len(ls) != len(rs) {
   154  		err := ErrUnionSchemasDifferentLength.New(len(ls), len(rs))
   155  		b.handleErr(err)
   156  	}
   157  
   158  	leftIds := colIdsForRel(u.Left())
   159  	rightIds := colIdsForRel(u.Right())
   160  
   161  	les, res := make([]sql.Expression, len(ls)), make([]sql.Expression, len(rs))
   162  	hasdiff := false
   163  	var err error
   164  	for i := range ls {
   165  		// todo: proj col ids should align with input column ids
   166  		les[i] = expression.NewGetFieldWithTable(int(leftIds[i]), 0, ls[i].Type, ls[i].DatabaseSource, ls[i].Source, ls[i].Name, ls[i].Nullable)
   167  		res[i] = expression.NewGetFieldWithTable(int(rightIds[i]), 0, rs[i].Type, rs[i].DatabaseSource, rs[i].Source, rs[i].Name, rs[i].Nullable)
   168  		if reflect.DeepEqual(ls[i].Type, rs[i].Type) {
   169  			continue
   170  		}
   171  		hasdiff = true
   172  
   173  		// try to get optimal type to convert both into
   174  		convertTo := expression.GetConvertToType(ls[i].Type, rs[i].Type)
   175  
   176  		// TODO: Principled type coercion...
   177  		les[i], err = b.f.buildConvert(les[i], convertTo, 0, 0)
   178  		res[i], err = b.f.buildConvert(res[i], convertTo, 0, 0)
   179  
   180  		// Preserve schema names across the conversion.
   181  		les[i] = expression.NewAlias(ls[i].Name, les[i])
   182  		res[i] = expression.NewAlias(rs[i].Name, res[i])
   183  	}
   184  	var ret sql.Node = u
   185  	if hasdiff {
   186  		ret, err = u.WithChildren(
   187  			plan.NewProject(les, u.Left()),
   188  			plan.NewProject(res, u.Right()),
   189  		)
   190  		if err != nil {
   191  			b.handleErr(err)
   192  		}
   193  	}
   194  	return ret
   195  }
   196  
   197  // colIdsForRel returns the padded column set returned by a node,
   198  // with 0's filled in for non-aliasable columns
   199  func colIdsForRel(n sql.Node) []sql.ColumnId {
   200  	var ids []sql.ColumnId
   201  	switch n := n.(type) {
   202  	case *plan.Project:
   203  		for _, p := range n.Projections {
   204  			if ide, ok := p.(sql.IdExpression); ok {
   205  				ids = append(ids, ide.Id())
   206  			} else {
   207  				ids = append(ids, 0)
   208  			}
   209  		}
   210  		return ids
   211  	case plan.TableIdNode:
   212  		cols := n.Columns()
   213  		if tn, ok := n.(sql.TableNode); ok {
   214  			if pkt, ok := tn.UnderlyingTable().(sql.PrimaryKeyTable); ok && len(pkt.PrimaryKeySchema().Schema) != len(n.Schema()) {
   215  				firstcol, _ := cols.Next(1)
   216  				for _, c := range n.Schema() {
   217  					ord := pkt.PrimaryKeySchema().IndexOfColName(c.Name)
   218  					colId := firstcol + sql.ColumnId(ord)
   219  					ids = append(ids, colId)
   220  				}
   221  				return ids
   222  			}
   223  		}
   224  		cols.ForEach(func(col sql.ColumnId) {
   225  			ids = append(ids, col)
   226  		})
   227  		return ids
   228  	default:
   229  		return colIdsForRel(n.Children()[0])
   230  	}
   231  	return nil
   232  }