github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/orderby.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/vitess/go/sqltypes" 22 ast "github.com/dolthub/vitess/go/vt/sqlparser" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/plan" 27 "github.com/dolthub/go-mysql-server/sql/transform" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 func (b *Builder) analyzeOrderBy(fromScope, projScope *scope, order ast.OrderBy) (outScope *scope) { 32 // Order by resolves to 33 // 1) alias in projScope 34 // 2) column name in fromScope 35 // 3) index into projection scope 36 37 // if regular col, make sure in aggOut or add to extra cols 38 39 outScope = fromScope.replace() 40 for _, o := range order { 41 var descending bool 42 switch strings.ToLower(o.Direction) { 43 default: 44 err := errInvalidSortOrder.New(o.Direction) 45 b.handleErr(err) 46 case ast.AscScr: 47 descending = false 48 case ast.DescScr: 49 descending = true 50 } 51 52 switch e := o.Expr.(type) { 53 case *ast.ColName: 54 // check for projection alias first 55 dbName := strings.ToLower(e.Qualifier.Qualifier.String()) 56 tblName := strings.ToLower(e.Qualifier.Name.String()) 57 colName := strings.ToLower(e.Name.String()) 58 c, ok := projScope.resolveColumn(dbName, tblName, colName, false, false) 59 if ok { 60 c.descending = descending 61 outScope.addColumn(c) 62 continue 63 } 64 65 // fromScope col 66 c, ok = fromScope.resolveColumn(dbName, tblName, colName, true, false) 67 if !ok { 68 err := sql.ErrColumnNotFound.New(e.Name) 69 b.handleErr(err) 70 } 71 c.descending = descending 72 c.scalar = c.scalarGf() 73 outScope.addColumn(c) 74 fromScope.addExtraColumn(c) 75 case *ast.SQLVal: 76 // integer literal into projScope 77 // else throw away 78 expr := b.normalizeValArg(e) 79 if val, ok := expr.(*ast.SQLVal); ok && val.Type == ast.IntVal { 80 lit := b.convertInt(string(val.Val), 10) 81 idx, _, err := types.Int64.Convert(lit.Value()) 82 if err != nil { 83 b.handleErr(err) 84 } 85 intIdx, ok := idx.(int64) 86 if !ok { 87 b.handleErr(fmt.Errorf("expected integer order by literal")) 88 } 89 // negative intIdx is allowed in MySQL, and is treated as a no-op 90 if intIdx < 0 { 91 continue 92 } 93 if projScope == nil || len(projScope.cols) == 0 { 94 err := fmt.Errorf("invalid order by ordinal context") 95 b.handleErr(err) 96 } 97 // MySQL throws a column not found for intIdx = 0 and intIdx > len(cols) 98 if intIdx > int64(len(projScope.cols)) || intIdx == 0 { 99 err := sql.ErrColumnNotFound.New(fmt.Sprintf("%d", intIdx)) 100 b.handleErr(err) 101 } 102 target := projScope.cols[intIdx-1] 103 scalar := target.scalar 104 if scalar == nil { 105 scalar = target.scalarGf() 106 } 107 if a, ok := target.scalar.(*expression.Alias); ok && a.Unreferencable() && fromScope.groupBy != nil { 108 for _, c := range fromScope.groupBy.outScope.cols { 109 if target.id == c.id { 110 target = c 111 } 112 } 113 } 114 outScope.addColumn(scopeColumn{ 115 tableId: target.tableId, 116 col: target.col, 117 scalar: scalar, 118 typ: target.typ, 119 nullable: target.nullable, 120 descending: descending, 121 id: target.id, 122 }) 123 } 124 default: 125 // track order by col 126 // replace aggregations with refs 127 // pick up auxiliary cols 128 expr := b.buildScalar(fromScope, e) 129 _, ok := outScope.getExpr(expr.String(), true) 130 if ok { 131 continue 132 } 133 // aggregate ref -> expr.String() in 134 // or compound expression 135 expr, _, _ = transform.Expr(expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 136 // get fields outside of aggs need to be in extra cols 137 switch e := e.(type) { 138 case *expression.GetField: 139 c, ok := fromScope.resolveColumn("", strings.ToLower(e.Table()), strings.ToLower(e.Name()), true, false) 140 if !ok { 141 err := sql.ErrColumnNotFound.New(e.Name) 142 b.handleErr(err) 143 } 144 fromScope.addExtraColumn(c) 145 case sql.WindowAdaptableExpression: 146 // has to have been ref'd already 147 id, ok := fromScope.getExpr(e.String(), true) 148 if !ok { 149 err := fmt.Errorf("faild to ref aggregate expression: %s", e.String()) 150 b.handleErr(err) 151 } 152 return expression.NewGetField(int(id), e.Type(), e.String(), e.IsNullable()), transform.NewTree, nil 153 default: 154 } 155 return e, transform.SameTree, nil 156 }) 157 col := scopeColumn{ 158 col: expr.String(), 159 scalar: expr, 160 typ: expr.Type(), 161 nullable: expr.IsNullable(), 162 descending: descending, 163 } 164 outScope.newColumn(col) 165 } 166 } 167 return 168 } 169 170 func (b *Builder) normalizeValArg(e *ast.SQLVal) ast.Expr { 171 if e.Type != ast.ValArg || b.bindCtx == nil { 172 return e 173 } 174 name := strings.TrimPrefix(string(e.Val), ":") 175 if b.bindCtx.Bindings == nil { 176 err := fmt.Errorf("bind variable not provided: '%s'", name) 177 b.handleErr(err) 178 } 179 bv, ok := b.bindCtx.GetSubstitute(name) 180 if !ok { 181 err := fmt.Errorf("bind variable not provided: '%s'", name) 182 b.handleErr(err) 183 } 184 185 val, err := sqltypes.BindVariableToValue(bv) 186 if err != nil { 187 b.handleErr(err) 188 } 189 expr, err := ast.ExprFromValue(val) 190 switch e := expr.(type) { 191 case *ast.SQLVal: 192 return e 193 case *ast.NullVal: 194 return e 195 default: 196 err := fmt.Errorf("unknown ast.Expr: %T", e) 197 b.handleErr(err) 198 } 199 return nil 200 } 201 202 func (b *Builder) buildOrderBy(inScope, orderByScope *scope) { 203 if len(orderByScope.cols) == 0 { 204 return 205 } 206 var sortFields sql.SortFields 207 for _, c := range orderByScope.cols { 208 so := sql.Ascending 209 if c.descending { 210 so = sql.Descending 211 } 212 scalar := c.scalar 213 if scalar == nil { 214 scalar = c.scalarGf() 215 } 216 sf := sql.SortField{ 217 Column: scalar, 218 Order: so, 219 } 220 sortFields = append(sortFields, sf) 221 } 222 sort := plan.NewSort(sortFields, inScope.node) 223 inScope.node = sort 224 return 225 }