github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/arithmetic.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  	"regexp"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/dolthub/vitess/go/mysql"
    26  	"github.com/dolthub/vitess/go/vt/sqlparser"
    27  	"github.com/shopspring/decimal"
    28  	errors "gopkg.in/src-d/go-errors.v1"
    29  
    30  	"github.com/dolthub/go-mysql-server/sql"
    31  	"github.com/dolthub/go-mysql-server/sql/types"
    32  )
    33  
    34  var (
    35  	// errUnableToCast means that we could not find common type for two arithemtic objects
    36  	errUnableToCast = errors.NewKind("Unable to cast between types: %T, %T")
    37  
    38  	// errUnableToEval means that we could not evaluate an expression
    39  	errUnableToEval = errors.NewKind("Unable to evaluate an expression: %v %s %v")
    40  
    41  	timeTypeRegex = regexp.MustCompile("[0-9]+")
    42  )
    43  
    44  func arithmeticWarning(ctx *sql.Context, errCode int, errMsg string) {
    45  	if ctx != nil && ctx.Session != nil {
    46  		ctx.Session.Warn(&sql.Warning{
    47  			Level:   "Warning",
    48  			Code:    errCode,
    49  			Message: errMsg,
    50  		})
    51  	}
    52  }
    53  
    54  // ArithmeticOp implements an arithmetic expression. Since we had separate expressions
    55  // for division and mod operation, we need to group all arithmetic together. Use this
    56  // expression to define any arithmetic operation that is separately implemented from
    57  // Arithmetic expression in the future.
    58  type ArithmeticOp interface {
    59  	sql.Expression
    60  	BinaryExpression
    61  	SetOpCount(int32)
    62  	Operator() string
    63  }
    64  
    65  var _ ArithmeticOp = (*Arithmetic)(nil)
    66  var _ sql.CollationCoercible = (*Arithmetic)(nil)
    67  
    68  // Arithmetic expressions include plus, minus and multiplication (+, -, *) operations.
    69  type Arithmetic struct {
    70  	BinaryExpressionStub
    71  	Op  string
    72  	ops int32
    73  }
    74  
    75  // NewArithmetic creates a new Arithmetic sql.Expression.
    76  func NewArithmetic(left, right sql.Expression, op string) *Arithmetic {
    77  	a := &Arithmetic{BinaryExpressionStub{LeftChild: left, RightChild: right}, op, 0}
    78  	ops := countArithmeticOps(a)
    79  	setArithmeticOps(a, ops)
    80  	return a
    81  }
    82  
    83  // NewPlus creates a new Arithmetic + sql.Expression.
    84  func NewPlus(left, right sql.Expression) *Arithmetic {
    85  	return NewArithmetic(left, right, sqlparser.PlusStr)
    86  }
    87  
    88  // NewMinus creates a new Arithmetic - sql.Expression.
    89  func NewMinus(left, right sql.Expression) *Arithmetic {
    90  	return NewArithmetic(left, right, sqlparser.MinusStr)
    91  }
    92  
    93  // NewMult creates a new Arithmetic * sql.Expression.
    94  func NewMult(left, right sql.Expression) *Arithmetic {
    95  	return NewArithmetic(left, right, sqlparser.MultStr)
    96  }
    97  
    98  func (a *Arithmetic) Operator() string {
    99  	return a.Op
   100  }
   101  
   102  func (a *Arithmetic) SetOpCount(i int32) {
   103  	a.ops = i
   104  }
   105  
   106  func (a *Arithmetic) String() string {
   107  	return fmt.Sprintf("(%s %s %s)", a.LeftChild.String(), a.Op, a.RightChild.String())
   108  }
   109  
   110  func (a *Arithmetic) DebugString() string {
   111  	return fmt.Sprintf("(%s %s %s)", sql.DebugString(a.LeftChild), a.Op, sql.DebugString(a.RightChild))
   112  }
   113  
   114  // IsNullable implements the sql.Expression interface.
   115  func (a *Arithmetic) IsNullable() bool {
   116  	if types.IsDatetimeType(a.Type()) || types.IsTimestampType(a.Type()) {
   117  		return true
   118  	}
   119  
   120  	return a.BinaryExpressionStub.IsNullable()
   121  }
   122  
   123  // Type returns the greatest type for given operation.
   124  func (a *Arithmetic) Type() sql.Type {
   125  	//TODO: what if both BindVars? should be constant folded
   126  	rTyp := a.RightChild.Type()
   127  	if types.IsDeferredType(rTyp) {
   128  		return rTyp
   129  	}
   130  	lTyp := a.LeftChild.Type()
   131  	if types.IsDeferredType(lTyp) {
   132  		return lTyp
   133  	}
   134  
   135  	// applies for + and - ops
   136  	if isInterval(a.LeftChild) || isInterval(a.RightChild) {
   137  		// TODO: need to use the precision stored in datetimeType; something like
   138  		//   return types.MustCreateDatetimeType(sqltypes.Datetime, 0)
   139  		return types.Datetime
   140  	}
   141  
   142  	if types.IsText(lTyp) || types.IsText(rTyp) {
   143  		return types.Float64
   144  	}
   145  
   146  	if types.IsJSON(lTyp) || types.IsJSON(rTyp) {
   147  		return types.Float64
   148  	}
   149  
   150  	if types.IsFloat(lTyp) || types.IsFloat(rTyp) {
   151  		return types.Float64
   152  	}
   153  
   154  	if types.IsYear(lTyp) && types.IsYear(rTyp) {
   155  		// MySQL just returns the largest int that fits
   156  		return types.Uint64
   157  	}
   158  
   159  	// Bit types are integers
   160  	if types.IsBit(lTyp) {
   161  		lTyp = types.Int64
   162  	}
   163  	if types.IsBit(rTyp) {
   164  		rTyp = types.Int64
   165  	}
   166  
   167  	// Dates are Integers
   168  	if types.IsDateType(lTyp) {
   169  		lTyp = types.Int64
   170  	}
   171  	if types.IsDateType(rTyp) {
   172  		rTyp = types.Int64
   173  	}
   174  
   175  	// Datetime(0) is treated as Int64, otherwise as Decimal
   176  	if types.IsDatetimeType(lTyp) {
   177  		if dtType, ok := lTyp.(sql.DatetimeType); ok {
   178  			scale := uint8(dtType.Precision())
   179  			if scale == 0 {
   180  				lTyp = types.Int64
   181  			} else {
   182  				lTyp = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale)
   183  			}
   184  		}
   185  	}
   186  	if types.IsDatetimeType(rTyp) {
   187  		if dtType, ok := rTyp.(sql.DatetimeType); ok {
   188  			scale := uint8(dtType.Precision())
   189  			if scale == 0 {
   190  				rTyp = types.Int64
   191  			} else {
   192  				rTyp = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale)
   193  			}
   194  		}
   195  	}
   196  
   197  	if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) {
   198  		return types.Uint64
   199  	}
   200  
   201  	if types.IsInteger(lTyp) && types.IsInteger(rTyp) {
   202  		return types.Int64
   203  	}
   204  
   205  	if types.IsDecimal(lTyp) && !types.IsDecimal(rTyp) {
   206  		return lTyp
   207  	}
   208  
   209  	if types.IsDecimal(rTyp) && !types.IsDecimal(lTyp) {
   210  		return rTyp
   211  	}
   212  
   213  	if types.IsDecimal(lTyp) && types.IsDecimal(rTyp) {
   214  		lPrec := lTyp.(sql.DecimalType).Precision()
   215  		lScale := lTyp.(sql.DecimalType).Scale()
   216  		rPrec := rTyp.(sql.DecimalType).Precision()
   217  		rScale := rTyp.(sql.DecimalType).Scale()
   218  
   219  		var prec, scale uint8
   220  		if lPrec > rPrec {
   221  			prec = lPrec
   222  		} else {
   223  			prec = rPrec
   224  		}
   225  
   226  		switch a.Op {
   227  		case sqlparser.PlusStr, sqlparser.MinusStr:
   228  			if lScale > rScale {
   229  				scale = lScale
   230  			} else {
   231  				scale = rScale
   232  			}
   233  			prec = prec + scale
   234  		case sqlparser.MultStr:
   235  			scale = lScale + rScale
   236  			prec = prec + scale
   237  		}
   238  
   239  		if prec > types.DecimalTypeMaxPrecision {
   240  			prec = types.DecimalTypeMaxPrecision
   241  		}
   242  		if scale > types.DecimalTypeMaxScale {
   243  			scale = types.DecimalTypeMaxScale
   244  		}
   245  
   246  		return types.MustCreateDecimalType(prec, scale)
   247  	}
   248  
   249  	// When in doubt return float64
   250  	return types.Float64
   251  }
   252  
   253  // CollationCoercibility implements the interface sql.CollationCoercible.
   254  func (*Arithmetic) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   255  	return sql.Collation_binary, 5
   256  }
   257  
   258  // WithChildren implements the Expression interface.
   259  func (a *Arithmetic) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   260  	if len(children) != 2 {
   261  		return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2)
   262  	}
   263  	// sanity check
   264  	switch strings.ToLower(a.Op) {
   265  	case sqlparser.DivStr:
   266  		return NewDiv(children[0], children[1]), nil
   267  	case sqlparser.ModStr:
   268  		return NewMod(children[0], children[1]), nil
   269  	}
   270  	return NewArithmetic(children[0], children[1], a.Op), nil
   271  }
   272  
   273  // Eval implements the Expression interface.
   274  func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   275  	lval, rval, err := a.evalLeftRight(ctx, row)
   276  	if err != nil {
   277  		return nil, err
   278  	}
   279  
   280  	if lval == nil || rval == nil {
   281  		return nil, nil
   282  	}
   283  
   284  	lval, rval, err = a.convertLeftRight(ctx, lval, rval)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  
   289  	var result interface{}
   290  	switch strings.ToLower(a.Op) {
   291  	case sqlparser.PlusStr:
   292  		result, err = plus(lval, rval)
   293  	case sqlparser.MinusStr:
   294  		result, err = minus(lval, rval)
   295  	case sqlparser.MultStr:
   296  		result, err = mult(lval, rval)
   297  	}
   298  
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  
   303  	// Decimals must be rounded
   304  	if res, ok := result.(decimal.Decimal); ok {
   305  		if isOutermostArithmeticOp(a, a.ops) {
   306  			finalScale, hasDiv := getFinalScale(ctx, row, a, 0)
   307  			if hasDiv {
   308  				// TODO: should always round regardless; we have bad Decimal defaults
   309  				return res.Round(finalScale), nil
   310  			}
   311  		}
   312  	}
   313  
   314  	return result, nil
   315  }
   316  
   317  func (a *Arithmetic) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) {
   318  	var lval, rval interface{}
   319  	var err error
   320  
   321  	if i, ok := a.LeftChild.(*Interval); ok {
   322  		lval, err = i.EvalDelta(ctx, row)
   323  		if err != nil {
   324  			return nil, nil, err
   325  		}
   326  	} else {
   327  		lval, err = a.LeftChild.Eval(ctx, row)
   328  		if err != nil {
   329  			return nil, nil, err
   330  		}
   331  	}
   332  
   333  	if i, ok := a.RightChild.(*Interval); ok {
   334  		rval, err = i.EvalDelta(ctx, row)
   335  		if err != nil {
   336  			return nil, nil, err
   337  		}
   338  	} else {
   339  		rval, err = a.RightChild.Eval(ctx, row)
   340  		if err != nil {
   341  			return nil, nil, err
   342  		}
   343  	}
   344  
   345  	return lval, rval, nil
   346  }
   347  
   348  func (a *Arithmetic) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}, error) {
   349  	typ := a.Type()
   350  
   351  	lIsTimeType := types.IsTime(a.LeftChild.Type())
   352  	rIsTimeType := types.IsTime(a.RightChild.Type())
   353  
   354  	if i, ok := left.(*TimeDelta); ok {
   355  		left = i
   356  	} else {
   357  		// these are the types we specifically want to capture from we get from Type()
   358  		if types.IsInteger(typ) || types.IsFloat(typ) || types.IsTime(typ) {
   359  			left = convertValueToType(ctx, typ, left, lIsTimeType)
   360  		} else {
   361  			left = convertToDecimalValue(left, lIsTimeType)
   362  		}
   363  	}
   364  
   365  	if i, ok := right.(*TimeDelta); ok {
   366  		right = i
   367  	} else {
   368  		// these are the types we specifically want to capture from we get from Type()
   369  		if types.IsInteger(typ) || types.IsFloat(typ) || types.IsTime(typ) {
   370  			right = convertValueToType(ctx, typ, right, rIsTimeType)
   371  		} else {
   372  			right = convertToDecimalValue(right, rIsTimeType)
   373  		}
   374  	}
   375  
   376  	return left, right, nil
   377  }
   378  
   379  func isInterval(expr sql.Expression) bool {
   380  	_, ok := expr.(*Interval)
   381  	return ok
   382  }
   383  
   384  // countArithmeticOps returns the number of arithmetic operators under the current node.
   385  // This lets us count how many arithmetic operators used one after the other
   386  func countArithmeticOps(e sql.Expression) int32 {
   387  	if e == nil {
   388  		return 0
   389  	}
   390  
   391  	if a, ok := e.(ArithmeticOp); ok {
   392  		return countArithmeticOps(a.Left()) + countArithmeticOps(a.Right()) + 1
   393  	}
   394  
   395  	return 0
   396  }
   397  
   398  // setArithmeticOps will set ops number with number counted by countArithmeticOps. This allows
   399  // us to keep track of whether the expression is the last arithmetic operation.
   400  func setArithmeticOps(e sql.Expression, ops int32) {
   401  	if e == nil {
   402  		return
   403  	}
   404  
   405  	if a, ok := e.(ArithmeticOp); ok {
   406  		a.SetOpCount(ops)
   407  		setArithmeticOps(a.Left(), ops)
   408  		setArithmeticOps(a.Right(), ops)
   409  	}
   410  
   411  	if tup, ok := e.(Tuple); ok {
   412  		for _, expr := range tup {
   413  			setArithmeticOps(expr, ops)
   414  		}
   415  	}
   416  
   417  	return
   418  }
   419  
   420  // isOutermostArithmeticOp return whether the expression we're currently on is
   421  // the last arithmetic operation of all continuous arithmetic operations.
   422  func isOutermostArithmeticOp(e sql.Expression, opScale int32) bool {
   423  	return opScale == countArithmeticOps(e)
   424  }
   425  
   426  // convertValueToType returns |val| converted into type |typ|. If the value is
   427  // invalid and cannot be converted to the given type, it returns nil, and it should be
   428  // interpreted as value of 0. For time types, all the numbers are parsed up to seconds only.
   429  // E.g: `2022-11-10 12:14:36` is parsed into `20221110121436` and `2022-03-24` is parsed into `20220324`.
   430  func convertValueToType(ctx *sql.Context, typ sql.Type, val interface{}, isTimeType bool) interface{} {
   431  	var cval interface{}
   432  	if isTimeType {
   433  		val = convertTimeTypeToString(val)
   434  	}
   435  
   436  	cval, _, err := typ.Convert(val)
   437  	if err != nil {
   438  		arithmeticWarning(ctx, mysql.ERTruncatedWrongValue, fmt.Sprintf("Truncated incorrect %s value: '%v'", typ.String(), val))
   439  		// the value is interpreted as 0, but we need to match the type of the other valid value
   440  		// to avoid additional conversion, the nil value is handled in each operation
   441  	}
   442  	return cval
   443  }
   444  
   445  // convertTimeTypeToString returns string value parsed from either time.Time or string
   446  // representation. all the numbers are parsed up to seconds only. The location can be
   447  // different between two time.Time values, so we set it to default UTC location before
   448  // parsing. E.g:
   449  // `2022-11-10 12:14:36` is parsed into `20221110121436`
   450  // `2022-03-24` is parsed into `20220324`.
   451  func convertTimeTypeToString(val interface{}) interface{} {
   452  	if t, ok := val.(time.Time); ok {
   453  		val = t.In(time.UTC).Format("2006-01-02 15:04:05")
   454  	}
   455  	if t, ok := val.(string); ok {
   456  		nums := timeTypeRegex.FindAllString(t, -1)
   457  		val = strings.Join(nums, "")
   458  	}
   459  
   460  	return val
   461  }
   462  
   463  func plus(lval, rval interface{}) (interface{}, error) {
   464  	switch l := lval.(type) {
   465  	case uint8:
   466  		switch r := rval.(type) {
   467  		case uint8:
   468  			return l + r, nil
   469  		}
   470  	case int8:
   471  		switch r := rval.(type) {
   472  		case int8:
   473  			return l + r, nil
   474  		}
   475  	case uint16:
   476  		switch r := rval.(type) {
   477  		case uint16:
   478  			return l + r, nil
   479  		}
   480  	case int16:
   481  		switch r := rval.(type) {
   482  		case int16:
   483  			return l + r, nil
   484  		}
   485  	case uint32:
   486  		switch r := rval.(type) {
   487  		case uint32:
   488  			return l + r, nil
   489  		}
   490  	case int32:
   491  		switch r := rval.(type) {
   492  		case int32:
   493  			return l + r, nil
   494  		}
   495  	case uint64:
   496  		switch r := rval.(type) {
   497  		case uint64:
   498  			return l + r, nil
   499  		}
   500  	case int64:
   501  		switch r := rval.(type) {
   502  		case int64:
   503  			return l + r, nil
   504  		}
   505  	case float32:
   506  		switch r := rval.(type) {
   507  		case float32:
   508  			return l + r, nil
   509  		}
   510  	case float64:
   511  		switch r := rval.(type) {
   512  		case float64:
   513  			return l + r, nil
   514  		}
   515  	case decimal.Decimal:
   516  		switch r := rval.(type) {
   517  		case decimal.Decimal:
   518  			return l.Add(r), nil
   519  		}
   520  	case time.Time:
   521  		switch r := rval.(type) {
   522  		case *TimeDelta:
   523  			return types.ValidateTime(r.Add(l)), nil
   524  		case time.Time:
   525  			return l.Unix() + r.Unix(), nil
   526  		}
   527  	case *TimeDelta:
   528  		switch r := rval.(type) {
   529  		case time.Time:
   530  			return types.ValidateTime(l.Add(r)), nil
   531  		}
   532  	}
   533  
   534  	return nil, errUnableToCast.New(lval, rval)
   535  }
   536  
   537  func minus(lval, rval interface{}) (interface{}, error) {
   538  	switch l := lval.(type) {
   539  	case uint8:
   540  		switch r := rval.(type) {
   541  		case uint8:
   542  			return l - r, nil
   543  		}
   544  	case int8:
   545  		switch r := rval.(type) {
   546  		case int8:
   547  			return l - r, nil
   548  		}
   549  	case uint16:
   550  		switch r := rval.(type) {
   551  		case uint16:
   552  			return l - r, nil
   553  		}
   554  	case int16:
   555  		switch r := rval.(type) {
   556  		case int16:
   557  			return l - r, nil
   558  		}
   559  	case uint32:
   560  		switch r := rval.(type) {
   561  		case uint32:
   562  			return l - r, nil
   563  		}
   564  	case int32:
   565  		switch r := rval.(type) {
   566  		case int32:
   567  			return l - r, nil
   568  		}
   569  	case uint64:
   570  		switch r := rval.(type) {
   571  		case uint64:
   572  			return l - r, nil
   573  		}
   574  	case int64:
   575  		switch r := rval.(type) {
   576  		case int64:
   577  			return l - r, nil
   578  		}
   579  	case float32:
   580  		switch r := rval.(type) {
   581  		case float32:
   582  			return l - r, nil
   583  		}
   584  	case float64:
   585  		switch r := rval.(type) {
   586  		case float64:
   587  			return l - r, nil
   588  		}
   589  	case decimal.Decimal:
   590  		switch r := rval.(type) {
   591  		case decimal.Decimal:
   592  			return l.Sub(r), nil
   593  		}
   594  	case time.Time:
   595  		switch r := rval.(type) {
   596  		case *TimeDelta:
   597  			return types.ValidateTime(r.Sub(l)), nil
   598  		case time.Time:
   599  			return l.Unix() - r.Unix(), nil
   600  		}
   601  	}
   602  
   603  	return nil, errUnableToCast.New(lval, rval)
   604  }
   605  
   606  func mult(lval, rval interface{}) (interface{}, error) {
   607  	switch l := lval.(type) {
   608  	case uint8:
   609  		switch r := rval.(type) {
   610  		case uint8:
   611  			return l * r, nil
   612  		}
   613  	case int8:
   614  		switch r := rval.(type) {
   615  		case int8:
   616  			return l * r, nil
   617  		}
   618  	case uint16:
   619  		switch r := rval.(type) {
   620  		case uint16:
   621  			return l * r, nil
   622  		}
   623  	case int16:
   624  		switch r := rval.(type) {
   625  		case int16:
   626  			return l * r, nil
   627  		}
   628  	case uint32:
   629  		switch r := rval.(type) {
   630  		case uint32:
   631  			return l * r, nil
   632  		}
   633  	case int32:
   634  		switch r := rval.(type) {
   635  		case int32:
   636  			return l * r, nil
   637  		}
   638  	case uint64:
   639  		switch r := rval.(type) {
   640  		case uint64:
   641  			return l * r, nil
   642  		}
   643  	case int64:
   644  		switch r := rval.(type) {
   645  		case int64:
   646  			return l * r, nil
   647  		}
   648  	case float32:
   649  		switch r := rval.(type) {
   650  		case float32:
   651  			return l * r, nil
   652  		}
   653  	case float64:
   654  		switch r := rval.(type) {
   655  		case float64:
   656  			return l * r, nil
   657  		}
   658  	case decimal.Decimal:
   659  		switch r := rval.(type) {
   660  		case decimal.Decimal:
   661  			return l.Mul(r), nil
   662  		}
   663  	}
   664  
   665  	return nil, errUnableToCast.New(lval, rval)
   666  }
   667  
   668  // UnaryMinus is an unary minus operator.
   669  type UnaryMinus struct {
   670  	UnaryExpression
   671  }
   672  
   673  var _ sql.Expression = (*UnaryMinus)(nil)
   674  var _ sql.CollationCoercible = (*UnaryMinus)(nil)
   675  
   676  // NewUnaryMinus creates a new UnaryMinus expression node.
   677  func NewUnaryMinus(child sql.Expression) *UnaryMinus {
   678  	return &UnaryMinus{UnaryExpression{Child: child}}
   679  }
   680  
   681  // Eval implements the sql.Expression interface.
   682  func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   683  	child, err := e.Child.Eval(ctx, row)
   684  	if err != nil {
   685  		return nil, err
   686  	}
   687  
   688  	if child == nil {
   689  		return nil, nil
   690  	}
   691  
   692  	if !types.IsNumber(e.Child.Type()) {
   693  		child, err = decimal.NewFromString(fmt.Sprintf("%v", child))
   694  		if err != nil {
   695  			child = 0.0
   696  		}
   697  	}
   698  
   699  	switch n := child.(type) {
   700  	case float64:
   701  		return -n, nil
   702  	case float32:
   703  		return -n, nil
   704  	case int:
   705  		return -n, nil
   706  	case int8:
   707  		return -n, nil
   708  	case int16:
   709  		return -n, nil
   710  	case int32:
   711  		return -n, nil
   712  	case int64:
   713  		return -n, nil
   714  	case uint:
   715  		return -int(n), nil
   716  	case uint8:
   717  		return -int8(n), nil
   718  	case uint16:
   719  		return -int16(n), nil
   720  	case uint32:
   721  		return -int32(n), nil
   722  	case uint64:
   723  		return -int64(n), nil
   724  	case decimal.Decimal:
   725  		return n.Neg(), err
   726  	case string:
   727  		// try getting int out of string value
   728  		i, iErr := strconv.ParseInt(n, 10, 64)
   729  		if iErr != nil {
   730  			return nil, sql.ErrInvalidType.New(reflect.TypeOf(n))
   731  		}
   732  		return -i, nil
   733  	default:
   734  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(n))
   735  	}
   736  }
   737  
   738  // Type implements the sql.Expression interface.
   739  func (e *UnaryMinus) Type() sql.Type {
   740  	typ := e.Child.Type()
   741  	if !types.IsNumber(typ) {
   742  		return types.Float64
   743  	}
   744  
   745  	if typ == types.Uint32 {
   746  		return types.Int32
   747  	}
   748  
   749  	if typ == types.Uint64 {
   750  		return types.Int64
   751  	}
   752  
   753  	return e.Child.Type()
   754  }
   755  
   756  // CollationCoercibility implements the interface sql.CollationCoercible.
   757  func (*UnaryMinus) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   758  	return sql.Collation_binary, 5
   759  }
   760  
   761  func (e *UnaryMinus) String() string {
   762  	return fmt.Sprintf("-%s", e.Child)
   763  }
   764  
   765  // WithChildren implements the Expression interface.
   766  func (e *UnaryMinus) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   767  	if len(children) != 1 {
   768  		return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1)
   769  	}
   770  	return NewUnaryMinus(children[0]), nil
   771  }