github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/set_op.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "fmt" 19 "reflect" 20 21 ast "github.com/dolthub/vitess/go/vt/sqlparser" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/plan" 26 "github.com/dolthub/go-mysql-server/sql/transform" 27 ) 28 29 func hasRecursiveCte(node sql.Node) bool { 30 hasRCTE := false 31 transform.Inspect(node, func(n sql.Node) bool { 32 if _, ok := n.(*plan.RecursiveTable); ok { 33 hasRCTE = true 34 return false 35 } 36 return true 37 }) 38 return hasRCTE 39 } 40 41 func (b *Builder) buildSetOp(inScope *scope, u *ast.SetOp) (outScope *scope) { 42 leftScope := b.buildSelectStmt(inScope, u.Left) 43 rightScope := b.buildSelectStmt(inScope, u.Right) 44 45 var setOpType int 46 switch u.Type { 47 case ast.UnionStr, ast.UnionAllStr, ast.UnionDistinctStr: 48 setOpType = plan.UnionType 49 case ast.IntersectStr, ast.IntersectAllStr, ast.IntersectDistinctStr: 50 setOpType = plan.IntersectType 51 case ast.ExceptStr, ast.ExceptAllStr, ast.ExceptDistinctStr: 52 setOpType = plan.ExceptType 53 default: 54 b.handleErr(fmt.Errorf("unknown union type %s", u.Type)) 55 } 56 57 if setOpType != plan.UnionType { 58 if hasRecursiveCte(leftScope.node) { 59 b.handleErr(sql.ErrRecursiveCTENotUnion.New()) 60 } 61 if hasRecursiveCte(rightScope.node) { 62 b.handleErr(sql.ErrRecursiveCTENotUnion.New()) 63 } 64 } 65 66 // all is not distinct 67 distinct := true 68 switch u.Type { 69 case ast.UnionAllStr, ast.IntersectAllStr, ast.ExceptAllStr: 70 distinct = false 71 } 72 73 limit := b.buildLimit(inScope, u.Limit) 74 offset := b.buildOffset(inScope, u.Limit) 75 76 for _, o := range u.OrderBy { 77 if expr, ok := o.Expr.(*ast.ColName); ok && len(expr.Qualifier.Name.String()) != 0 { 78 b.handleErr(ErrQualifiedOrderBy.New(expr.Qualifier.Name.String())) 79 } 80 } 81 82 // mysql errors for order by right projection 83 orderByScope := b.analyzeOrderBy(leftScope, leftScope, u.OrderBy) 84 85 var sortFields sql.SortFields 86 for _, c := range orderByScope.cols { 87 so := sql.Ascending 88 if c.descending { 89 so = sql.Descending 90 } 91 scalar := c.scalar 92 if scalar == nil { 93 scalar = c.scalarGf() 94 } 95 // Unions pass order bys to the top scope, where the original 96 // order by get field may no longer be accessible. Here it is 97 // safe to assume the alias has already been computed. 98 scalar, _, _ = transform.Expr(scalar, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 99 switch e := e.(type) { 100 case *expression.Alias: 101 return expression.NewGetField(int(c.id), e.Type(), e.Name(), e.IsNullable()), transform.NewTree, nil 102 default: 103 return e, transform.SameTree, nil 104 } 105 }) 106 sf := sql.SortField{ 107 Column: scalar, 108 Order: so, 109 } 110 sortFields = append(sortFields, sf) 111 } 112 113 n, ok := leftScope.node.(*plan.SetOp) 114 if ok { 115 if len(n.SortFields) > 0 { 116 if len(sortFields) > 0 { 117 err := sql.ErrConflictingExternalQuery.New() 118 b.handleErr(err) 119 } 120 sortFields = n.SortFields 121 } 122 if n.Limit != nil { 123 if limit != nil { 124 err := fmt.Errorf("conflicing external LIMIT") 125 b.handleErr(err) 126 } 127 limit = n.Limit 128 } 129 if n.Offset != nil { 130 if offset != nil { 131 err := fmt.Errorf("conflicing external OFFSET") 132 b.handleErr(err) 133 } 134 offset = n.Offset 135 } 136 leftScope.node = plan.NewSetOp(n.SetOpType, n.Left(), n.Right(), n.Distinct, nil, nil, nil).WithColumns(n.Columns()).WithId(n.Id()) 137 } 138 139 var cols sql.ColSet 140 for _, c := range leftScope.cols { 141 cols.Add(sql.ColumnId(c.id)) 142 } 143 b.tabId++ 144 tabId := b.tabId 145 ret := plan.NewSetOp(setOpType, leftScope.node, rightScope.node, distinct, limit, offset, sortFields).WithId(tabId).WithColumns(cols) 146 outScope = leftScope 147 outScope.node = b.mergeSetOpSchemas(ret.(*plan.SetOp)) 148 return 149 } 150 151 func (b *Builder) mergeSetOpSchemas(u *plan.SetOp) sql.Node { 152 ls, rs := u.Left().Schema(), u.Right().Schema() 153 if len(ls) != len(rs) { 154 err := ErrUnionSchemasDifferentLength.New(len(ls), len(rs)) 155 b.handleErr(err) 156 } 157 158 leftIds := colIdsForRel(u.Left()) 159 rightIds := colIdsForRel(u.Right()) 160 161 les, res := make([]sql.Expression, len(ls)), make([]sql.Expression, len(rs)) 162 hasdiff := false 163 var err error 164 for i := range ls { 165 // todo: proj col ids should align with input column ids 166 les[i] = expression.NewGetFieldWithTable(int(leftIds[i]), 0, ls[i].Type, ls[i].DatabaseSource, ls[i].Source, ls[i].Name, ls[i].Nullable) 167 res[i] = expression.NewGetFieldWithTable(int(rightIds[i]), 0, rs[i].Type, rs[i].DatabaseSource, rs[i].Source, rs[i].Name, rs[i].Nullable) 168 if reflect.DeepEqual(ls[i].Type, rs[i].Type) { 169 continue 170 } 171 hasdiff = true 172 173 // try to get optimal type to convert both into 174 convertTo := expression.GetConvertToType(ls[i].Type, rs[i].Type) 175 176 // TODO: Principled type coercion... 177 les[i], err = b.f.buildConvert(les[i], convertTo, 0, 0) 178 res[i], err = b.f.buildConvert(res[i], convertTo, 0, 0) 179 180 // Preserve schema names across the conversion. 181 les[i] = expression.NewAlias(ls[i].Name, les[i]) 182 res[i] = expression.NewAlias(rs[i].Name, res[i]) 183 } 184 var ret sql.Node = u 185 if hasdiff { 186 ret, err = u.WithChildren( 187 plan.NewProject(les, u.Left()), 188 plan.NewProject(res, u.Right()), 189 ) 190 if err != nil { 191 b.handleErr(err) 192 } 193 } 194 return ret 195 } 196 197 // colIdsForRel returns the padded column set returned by a node, 198 // with 0's filled in for non-aliasable columns 199 func colIdsForRel(n sql.Node) []sql.ColumnId { 200 var ids []sql.ColumnId 201 switch n := n.(type) { 202 case *plan.Project: 203 for _, p := range n.Projections { 204 if ide, ok := p.(sql.IdExpression); ok { 205 ids = append(ids, ide.Id()) 206 } else { 207 ids = append(ids, 0) 208 } 209 } 210 return ids 211 case plan.TableIdNode: 212 cols := n.Columns() 213 if tn, ok := n.(sql.TableNode); ok { 214 if pkt, ok := tn.UnderlyingTable().(sql.PrimaryKeyTable); ok && len(pkt.PrimaryKeySchema().Schema) != len(n.Schema()) { 215 firstcol, _ := cols.Next(1) 216 for _, c := range n.Schema() { 217 ord := pkt.PrimaryKeySchema().IndexOfColName(c.Name) 218 colId := firstcol + sql.ColumnId(ord) 219 ids = append(ids, colId) 220 } 221 return ids 222 } 223 } 224 cols.ForEach(func(col sql.ColumnId) { 225 ids = append(ids, col) 226 }) 227 return ids 228 default: 229 return colIdsForRel(n.Children()[0]) 230 } 231 return nil 232 }