github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/greatest_least.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "strconv" 20 "strings" 21 "time" 22 23 "gopkg.in/src-d/go-errors.v1" 24 25 "github.com/dolthub/go-mysql-server/sql" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 var ErrUintOverflow = errors.NewKind( 30 "Unsigned integer too big to fit on signed integer") 31 32 // compEval is used to implement Greatest/Least Eval() using a comparison function 33 func compEval( 34 returnType sql.Type, 35 args []sql.Expression, 36 ctx *sql.Context, 37 row sql.Row, 38 cmp compareFn, 39 ) (interface{}, error) { 40 41 if returnType == types.Null { 42 return nil, nil 43 } 44 45 var selectedNum float64 46 var selectedString string 47 var selectedTime time.Time 48 49 for i, arg := range args { 50 val, err := arg.Eval(ctx, row) 51 if err != nil { 52 return nil, err 53 } 54 55 switch t := val.(type) { 56 case int, int8, int16, int32, int64, uint, 57 uint8, uint16, uint32, uint64: 58 switch x := t.(type) { 59 case int: 60 t = int64(x) 61 case int8: 62 t = int64(x) 63 case int16: 64 t = int64(x) 65 case int32: 66 t = int64(x) 67 case uint: 68 i := int64(x) 69 if i < 0 { 70 return nil, ErrUintOverflow.New() 71 } 72 t = i 73 case uint64: 74 i := int64(x) 75 if i < 0 { 76 return nil, ErrUintOverflow.New() 77 } 78 t = i 79 case uint8: 80 t = int64(x) 81 case uint16: 82 t = int64(x) 83 case uint32: 84 t = int64(x) 85 } 86 ival := t.(int64) 87 if i == 0 || cmp(ival, int64(selectedNum)) { 88 selectedNum = float64(ival) 89 } 90 case float32, float64: 91 if x, ok := t.(float32); ok { 92 t = float64(x) 93 } 94 95 fval := t.(float64) 96 if i == 0 || cmp(fval, float64(selectedNum)) { 97 selectedNum = fval 98 } 99 100 case string: 101 if types.IsTextOnly(returnType) && (i == 0 || cmp(t, selectedString)) { 102 selectedString = t 103 } 104 105 fval, err := strconv.ParseFloat(t, 64) 106 if err != nil { 107 // MySQL just ignores non numerically convertible string arguments 108 // when mixed with numeric ones 109 continue 110 } 111 112 if i == 0 || cmp(fval, selectedNum) { 113 selectedNum = fval 114 } 115 case time.Time: 116 // Since we deviate from MySQL with int -> time handling, we only set the selectedTime variable 117 if i == 0 || cmp(t, selectedTime) { 118 selectedTime = t 119 } 120 case nil: 121 return nil, nil 122 default: 123 return nil, ErrUnsupportedType.New(t) 124 } 125 126 } 127 128 if types.IsDatetimeType(returnType) { 129 return selectedTime, nil 130 } 131 132 switch returnType { 133 case types.Int64: 134 return int64(selectedNum), nil 135 case types.LongText: 136 return selectedString, nil 137 } 138 139 // sql.Float64 140 return float64(selectedNum), nil 141 } 142 143 // compRetType is used to determine the type from args based on the rules described for 144 // Greatest/Least 145 func compRetType(args ...sql.Expression) (sql.Type, error) { 146 if len(args) == 0 { 147 return nil, sql.ErrInvalidArgumentNumber.New("LEAST", "1 or more", 0) 148 } 149 150 allString := true 151 allInt := true 152 allDatetime := true 153 154 for _, arg := range args { 155 if !arg.Resolved() { 156 return nil, nil 157 } 158 argType := arg.Type() 159 160 if svt, ok := argType.(sql.SystemVariableType); ok { 161 argType = svt.UnderlyingType() 162 } 163 164 if types.IsTuple(argType) { 165 return nil, sql.ErrInvalidType.New("tuple") 166 } else if types.IsNumber(argType) { 167 allString = false 168 allDatetime = false 169 if types.IsFloat(argType) { 170 allString = false 171 allInt = false 172 } 173 } else if types.IsText(argType) { 174 allInt = false 175 allDatetime = false 176 } else if types.IsTime(argType) { 177 allString = false 178 allInt = false 179 } else if types.IsDeferredType(argType) { 180 return argType, nil 181 } else if argType == types.Null { 182 // When a Null is present the return will always be Null 183 return types.Null, nil 184 } else { 185 return nil, ErrUnsupportedType.New(argType) 186 } 187 } 188 189 if allString { 190 return types.LongText, nil 191 } else if allInt { 192 return types.Int64, nil 193 } else if allDatetime { 194 return types.DatetimeMaxPrecision, nil 195 } else { 196 return types.Float64, nil 197 } 198 } 199 200 // Greatest returns the argument with the greatest numerical or string value. It allows for 201 // numeric (ints and floats) and string arguments and will return the used type 202 // when all arguments are of the same type or floats if there are numerically 203 // convertible strings or integers mixed with floats. When ints or floats 204 // are mixed with non numerically convertible strings, those are ignored. 205 type Greatest struct { 206 Args []sql.Expression 207 returnType sql.Type 208 } 209 210 var _ sql.FunctionExpression = (*Greatest)(nil) 211 212 // ErrUnsupportedType is returned when an argument to Greatest or Latest is not numeric or string 213 var ErrUnsupportedType = errors.NewKind("unsupported type for greatest/least argument: %T") 214 215 // NewGreatest creates a new Greatest UDF 216 func NewGreatest(args ...sql.Expression) (sql.Expression, error) { 217 retType, err := compRetType(args...) 218 if err != nil { 219 return nil, err 220 } 221 return &Greatest{Args: args, returnType: retType}, nil 222 } 223 224 // FunctionName implements sql.FunctionExpression 225 func (f *Greatest) FunctionName() string { 226 return "greatest" 227 } 228 229 // Description implements sql.FunctionExpression 230 func (f *Greatest) Description() string { 231 return "returns the greatest numeric or string value." 232 } 233 234 // Type implements the Expression interface. 235 func (f *Greatest) Type() sql.Type { 236 if f.returnType != nil { 237 return f.returnType 238 } 239 return f.Args[0].Type() 240 } 241 242 // IsNullable implements the Expression interface. 243 func (f *Greatest) IsNullable() bool { 244 for _, arg := range f.Args { 245 if arg.IsNullable() { 246 return true 247 } 248 } 249 return false 250 } 251 252 func (f *Greatest) String() string { 253 var args = make([]string, len(f.Args)) 254 for i, arg := range f.Args { 255 args[i] = arg.String() 256 } 257 return fmt.Sprintf("%s(%s)", f.FunctionName(), strings.Join(args, ",")) 258 } 259 260 // WithChildren implements the Expression interface. 261 func (f *Greatest) WithChildren(children ...sql.Expression) (sql.Expression, error) { 262 return NewGreatest(children...) 263 } 264 265 // Resolved implements the Expression interface. 266 func (f *Greatest) Resolved() bool { 267 for _, arg := range f.Args { 268 if !arg.Resolved() { 269 return false 270 } 271 } 272 return f.returnType != nil 273 } 274 275 // Children implements the Expression interface. 276 func (f *Greatest) Children() []sql.Expression { return f.Args } 277 278 type compareFn func(interface{}, interface{}) bool 279 280 func greaterThan(a, b interface{}) bool { 281 switch i := a.(type) { 282 case int64: 283 return i > b.(int64) 284 case float64: 285 return i > b.(float64) 286 case string: 287 return i > b.(string) 288 case time.Time: 289 return i.After(b.(time.Time)) 290 } 291 panic("Implementation error on greaterThan") 292 } 293 294 func lessThan(a, b interface{}) bool { 295 switch i := a.(type) { 296 case int64: 297 return i < b.(int64) 298 case float64: 299 return i < b.(float64) 300 case string: 301 return i < b.(string) 302 case time.Time: 303 return i.Before(b.(time.Time)) 304 } 305 panic("Implementation error on lessThan") 306 } 307 308 // Eval implements the Expression interface. 309 func (f *Greatest) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 310 return compEval(f.returnType, f.Args, ctx, row, greaterThan) 311 } 312 313 // Least returns the argument with the least numerical or string value. It allows for 314 // numeric (ints anf floats) and string arguments and will return the used type 315 // when all arguments are of the same type or floats if there are numerically 316 // convertible strings or integers mixed with floats. When ints or floats 317 // are mixed with non numerically convertible strings, those are ignored. 318 type Least struct { 319 Args []sql.Expression 320 returnType sql.Type 321 } 322 323 var _ sql.FunctionExpression = (*Least)(nil) 324 325 // NewLeast creates a new Least UDF 326 func NewLeast(args ...sql.Expression) (sql.Expression, error) { 327 retType, err := compRetType(args...) 328 if err != nil { 329 return nil, err 330 } 331 return &Least{Args: args, returnType: retType}, nil 332 } 333 334 // FunctionName implements sql.FunctionExpression 335 func (f *Least) FunctionName() string { 336 return "least" 337 } 338 339 // Description implements sql.FunctionExpression 340 func (f *Least) Description() string { 341 return "returns the smaller numeric or string value." 342 } 343 344 // Type implements the Expression interface. 345 func (f *Least) Type() sql.Type { 346 if f.returnType != nil { 347 return f.returnType 348 } 349 return f.Args[0].Type() 350 } 351 352 // IsNullable implements the Expression interface. 353 func (f *Least) IsNullable() bool { 354 for _, arg := range f.Args { 355 if arg.IsNullable() { 356 return true 357 } 358 } 359 return false 360 } 361 362 func (f *Least) String() string { 363 var args = make([]string, len(f.Args)) 364 for i, arg := range f.Args { 365 args[i] = arg.String() 366 } 367 return fmt.Sprintf("%s(%s)", f.FunctionName(), strings.Join(args, ", ")) 368 } 369 370 // WithChildren implements the Expression interface. 371 func (f *Least) WithChildren(children ...sql.Expression) (sql.Expression, error) { 372 return NewLeast(children...) 373 } 374 375 // Resolved implements the Expression interface. 376 func (f *Least) Resolved() bool { 377 for _, arg := range f.Args { 378 if !arg.Resolved() { 379 return false 380 } 381 } 382 return f.returnType != nil 383 } 384 385 // Children implements the Expression interface. 386 func (f *Least) Children() []sql.Expression { return f.Args } 387 388 // Eval implements the Expression interface. 389 func (f *Least) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 390 return compEval(f.returnType, f.Args, ctx, row, lessThan) 391 }