github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/greatest_least.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  	"strings"
    21  	"time"
    22  
    23  	"gopkg.in/src-d/go-errors.v1"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  var ErrUintOverflow = errors.NewKind(
    30  	"Unsigned integer too big to fit on signed integer")
    31  
    32  // compEval is used to implement Greatest/Least Eval() using a comparison function
    33  func compEval(
    34  	returnType sql.Type,
    35  	args []sql.Expression,
    36  	ctx *sql.Context,
    37  	row sql.Row,
    38  	cmp compareFn,
    39  ) (interface{}, error) {
    40  
    41  	if returnType == types.Null {
    42  		return nil, nil
    43  	}
    44  
    45  	var selectedNum float64
    46  	var selectedString string
    47  	var selectedTime time.Time
    48  
    49  	for i, arg := range args {
    50  		val, err := arg.Eval(ctx, row)
    51  		if err != nil {
    52  			return nil, err
    53  		}
    54  
    55  		switch t := val.(type) {
    56  		case int, int8, int16, int32, int64, uint,
    57  			uint8, uint16, uint32, uint64:
    58  			switch x := t.(type) {
    59  			case int:
    60  				t = int64(x)
    61  			case int8:
    62  				t = int64(x)
    63  			case int16:
    64  				t = int64(x)
    65  			case int32:
    66  				t = int64(x)
    67  			case uint:
    68  				i := int64(x)
    69  				if i < 0 {
    70  					return nil, ErrUintOverflow.New()
    71  				}
    72  				t = i
    73  			case uint64:
    74  				i := int64(x)
    75  				if i < 0 {
    76  					return nil, ErrUintOverflow.New()
    77  				}
    78  				t = i
    79  			case uint8:
    80  				t = int64(x)
    81  			case uint16:
    82  				t = int64(x)
    83  			case uint32:
    84  				t = int64(x)
    85  			}
    86  			ival := t.(int64)
    87  			if i == 0 || cmp(ival, int64(selectedNum)) {
    88  				selectedNum = float64(ival)
    89  			}
    90  		case float32, float64:
    91  			if x, ok := t.(float32); ok {
    92  				t = float64(x)
    93  			}
    94  
    95  			fval := t.(float64)
    96  			if i == 0 || cmp(fval, float64(selectedNum)) {
    97  				selectedNum = fval
    98  			}
    99  
   100  		case string:
   101  			if types.IsTextOnly(returnType) && (i == 0 || cmp(t, selectedString)) {
   102  				selectedString = t
   103  			}
   104  
   105  			fval, err := strconv.ParseFloat(t, 64)
   106  			if err != nil {
   107  				// MySQL just ignores non numerically convertible string arguments
   108  				// when mixed with numeric ones
   109  				continue
   110  			}
   111  
   112  			if i == 0 || cmp(fval, selectedNum) {
   113  				selectedNum = fval
   114  			}
   115  		case time.Time:
   116  			// Since we deviate from MySQL with int -> time handling, we only set the selectedTime variable
   117  			if i == 0 || cmp(t, selectedTime) {
   118  				selectedTime = t
   119  			}
   120  		case nil:
   121  			return nil, nil
   122  		default:
   123  			return nil, ErrUnsupportedType.New(t)
   124  		}
   125  
   126  	}
   127  
   128  	if types.IsDatetimeType(returnType) {
   129  		return selectedTime, nil
   130  	}
   131  
   132  	switch returnType {
   133  	case types.Int64:
   134  		return int64(selectedNum), nil
   135  	case types.LongText:
   136  		return selectedString, nil
   137  	}
   138  
   139  	// sql.Float64
   140  	return float64(selectedNum), nil
   141  }
   142  
   143  // compRetType is used to determine the type from args based on the rules described for
   144  // Greatest/Least
   145  func compRetType(args ...sql.Expression) (sql.Type, error) {
   146  	if len(args) == 0 {
   147  		return nil, sql.ErrInvalidArgumentNumber.New("LEAST", "1 or more", 0)
   148  	}
   149  
   150  	allString := true
   151  	allInt := true
   152  	allDatetime := true
   153  
   154  	for _, arg := range args {
   155  		if !arg.Resolved() {
   156  			return nil, nil
   157  		}
   158  		argType := arg.Type()
   159  
   160  		if svt, ok := argType.(sql.SystemVariableType); ok {
   161  			argType = svt.UnderlyingType()
   162  		}
   163  
   164  		if types.IsTuple(argType) {
   165  			return nil, sql.ErrInvalidType.New("tuple")
   166  		} else if types.IsNumber(argType) {
   167  			allString = false
   168  			allDatetime = false
   169  			if types.IsFloat(argType) {
   170  				allString = false
   171  				allInt = false
   172  			}
   173  		} else if types.IsText(argType) {
   174  			allInt = false
   175  			allDatetime = false
   176  		} else if types.IsTime(argType) {
   177  			allString = false
   178  			allInt = false
   179  		} else if types.IsDeferredType(argType) {
   180  			return argType, nil
   181  		} else if argType == types.Null {
   182  			// When a Null is present the return will always be Null
   183  			return types.Null, nil
   184  		} else {
   185  			return nil, ErrUnsupportedType.New(argType)
   186  		}
   187  	}
   188  
   189  	if allString {
   190  		return types.LongText, nil
   191  	} else if allInt {
   192  		return types.Int64, nil
   193  	} else if allDatetime {
   194  		return types.DatetimeMaxPrecision, nil
   195  	} else {
   196  		return types.Float64, nil
   197  	}
   198  }
   199  
   200  // Greatest returns the argument with the greatest numerical or string value. It allows for
   201  // numeric (ints and floats) and string arguments and will return the used type
   202  // when all arguments are of the same type or floats if there are numerically
   203  // convertible strings or integers mixed with floats. When ints or floats
   204  // are mixed with non numerically convertible strings, those are ignored.
   205  type Greatest struct {
   206  	Args       []sql.Expression
   207  	returnType sql.Type
   208  }
   209  
   210  var _ sql.FunctionExpression = (*Greatest)(nil)
   211  
   212  // ErrUnsupportedType is returned when an argument to Greatest or Latest is not numeric or string
   213  var ErrUnsupportedType = errors.NewKind("unsupported type for greatest/least argument: %T")
   214  
   215  // NewGreatest creates a new Greatest UDF
   216  func NewGreatest(args ...sql.Expression) (sql.Expression, error) {
   217  	retType, err := compRetType(args...)
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  	return &Greatest{Args: args, returnType: retType}, nil
   222  }
   223  
   224  // FunctionName implements sql.FunctionExpression
   225  func (f *Greatest) FunctionName() string {
   226  	return "greatest"
   227  }
   228  
   229  // Description implements sql.FunctionExpression
   230  func (f *Greatest) Description() string {
   231  	return "returns the greatest numeric or string value."
   232  }
   233  
   234  // Type implements the Expression interface.
   235  func (f *Greatest) Type() sql.Type {
   236  	if f.returnType != nil {
   237  		return f.returnType
   238  	}
   239  	return f.Args[0].Type()
   240  }
   241  
   242  // IsNullable implements the Expression interface.
   243  func (f *Greatest) IsNullable() bool {
   244  	for _, arg := range f.Args {
   245  		if arg.IsNullable() {
   246  			return true
   247  		}
   248  	}
   249  	return false
   250  }
   251  
   252  func (f *Greatest) String() string {
   253  	var args = make([]string, len(f.Args))
   254  	for i, arg := range f.Args {
   255  		args[i] = arg.String()
   256  	}
   257  	return fmt.Sprintf("%s(%s)", f.FunctionName(), strings.Join(args, ","))
   258  }
   259  
   260  // WithChildren implements the Expression interface.
   261  func (f *Greatest) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   262  	return NewGreatest(children...)
   263  }
   264  
   265  // Resolved implements the Expression interface.
   266  func (f *Greatest) Resolved() bool {
   267  	for _, arg := range f.Args {
   268  		if !arg.Resolved() {
   269  			return false
   270  		}
   271  	}
   272  	return f.returnType != nil
   273  }
   274  
   275  // Children implements the Expression interface.
   276  func (f *Greatest) Children() []sql.Expression { return f.Args }
   277  
   278  type compareFn func(interface{}, interface{}) bool
   279  
   280  func greaterThan(a, b interface{}) bool {
   281  	switch i := a.(type) {
   282  	case int64:
   283  		return i > b.(int64)
   284  	case float64:
   285  		return i > b.(float64)
   286  	case string:
   287  		return i > b.(string)
   288  	case time.Time:
   289  		return i.After(b.(time.Time))
   290  	}
   291  	panic("Implementation error on greaterThan")
   292  }
   293  
   294  func lessThan(a, b interface{}) bool {
   295  	switch i := a.(type) {
   296  	case int64:
   297  		return i < b.(int64)
   298  	case float64:
   299  		return i < b.(float64)
   300  	case string:
   301  		return i < b.(string)
   302  	case time.Time:
   303  		return i.Before(b.(time.Time))
   304  	}
   305  	panic("Implementation error on lessThan")
   306  }
   307  
   308  // Eval implements the Expression interface.
   309  func (f *Greatest) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   310  	return compEval(f.returnType, f.Args, ctx, row, greaterThan)
   311  }
   312  
   313  // Least returns the argument with the least numerical or string value. It allows for
   314  // numeric (ints anf floats) and string arguments and will return the used type
   315  // when all arguments are of the same type or floats if there are numerically
   316  // convertible strings or integers mixed with floats. When ints or floats
   317  // are mixed with non numerically convertible strings, those are ignored.
   318  type Least struct {
   319  	Args       []sql.Expression
   320  	returnType sql.Type
   321  }
   322  
   323  var _ sql.FunctionExpression = (*Least)(nil)
   324  
   325  // NewLeast creates a new Least UDF
   326  func NewLeast(args ...sql.Expression) (sql.Expression, error) {
   327  	retType, err := compRetType(args...)
   328  	if err != nil {
   329  		return nil, err
   330  	}
   331  	return &Least{Args: args, returnType: retType}, nil
   332  }
   333  
   334  // FunctionName implements sql.FunctionExpression
   335  func (f *Least) FunctionName() string {
   336  	return "least"
   337  }
   338  
   339  // Description implements sql.FunctionExpression
   340  func (f *Least) Description() string {
   341  	return "returns the smaller numeric or string value."
   342  }
   343  
   344  // Type implements the Expression interface.
   345  func (f *Least) Type() sql.Type {
   346  	if f.returnType != nil {
   347  		return f.returnType
   348  	}
   349  	return f.Args[0].Type()
   350  }
   351  
   352  // IsNullable implements the Expression interface.
   353  func (f *Least) IsNullable() bool {
   354  	for _, arg := range f.Args {
   355  		if arg.IsNullable() {
   356  			return true
   357  		}
   358  	}
   359  	return false
   360  }
   361  
   362  func (f *Least) String() string {
   363  	var args = make([]string, len(f.Args))
   364  	for i, arg := range f.Args {
   365  		args[i] = arg.String()
   366  	}
   367  	return fmt.Sprintf("%s(%s)", f.FunctionName(), strings.Join(args, ", "))
   368  }
   369  
   370  // WithChildren implements the Expression interface.
   371  func (f *Least) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   372  	return NewLeast(children...)
   373  }
   374  
   375  // Resolved implements the Expression interface.
   376  func (f *Least) Resolved() bool {
   377  	for _, arg := range f.Args {
   378  		if !arg.Resolved() {
   379  			return false
   380  		}
   381  	}
   382  	return f.returnType != nil
   383  }
   384  
   385  // Children implements the Expression interface.
   386  func (f *Least) Children() []sql.Expression { return f.Args }
   387  
   388  // Eval implements the Expression interface.
   389  func (f *Least) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   390  	return compEval(f.returnType, f.Args, ctx, row, lessThan)
   391  }