
     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package expression
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"strconv"
    21  	"strings"
    22  	"time"
    24  	""
    25  	""
    26  	""
    28  	""
    29  	""
    30  )
    32  var ErrIntDivDataOutOfRange = errors.NewKind("BIGINT value is out of range (%s DIV %s)")
    34  // '4 scales' are added to scale of the number on the left side of division operator at every division operation.
    35  // The default value is 4, and it can be set using sysvar
    36  const divPrecInc = 4
    38  // '9 scales' are added for every non-integer divider(right side).
    39  const divIntPrecInc = 9
    41  const ERDivisionByZero = 1365
    43  var _ ArithmeticOp = (*Div)(nil)
    44  var _ sql.CollationCoercible = (*Div)(nil)
    46  // Div expression represents "/" arithmetic operation
    47  type Div struct {
    48  	BinaryExpressionStub
    49  	ops    int32
    50  	divOps int32
    51  }
    53  // NewDiv creates a new Div / sql.Expression.
    54  func NewDiv(left, right sql.Expression) *Div {
    55  	d := &Div{BinaryExpressionStub: BinaryExpressionStub{LeftChild: left, RightChild: right}}
    56  	setDivOps(d, countDivOps(d))
    57  	setArithmeticOps(d, countArithmeticOps(d))
    58  	return d
    59  }
    61  func (d *Div) Operator() string {
    62  	return sqlparser.DivStr
    63  }
    65  func (d *Div) SetOpCount(i int32) {
    66  	d.ops = i
    67  }
    69  func (d *Div) String() string {
    70  	return fmt.Sprintf("(%s / %s)", d.LeftChild, d.RightChild)
    71  }
    73  func (d *Div) DebugString() string {
    74  	return fmt.Sprintf("(%s / %s)", sql.DebugString(d.LeftChild), sql.DebugString(d.RightChild))
    75  }
    77  // IsNullable implements the sql.Expression interface.
    78  func (d *Div) IsNullable() bool {
    79  	return d.BinaryExpressionStub.IsNullable()
    80  }
    82  // Type returns the result type for this division expression. For nested division expressions, we prefer sending
    83  // the result back as a float when possible, since division with floats is more efficient than division with Decimals.
    84  // However, if this is the outermost division expression in an expression tree, we must return the result as a
    85  // Decimal type in order to match MySQL's results exactly.
    86  func (d *Div) Type() sql.Type {
    87  	return d.determineResultType(isOutermostDiv(d, 0, d.divOps))
    88  }
    90  // internalType returns the internal result type for this division expression. For performance reasons, we prefer
    91  // to use floats internally in division operations wherever possible, since division operations on floats can be
    92  // orders of magnitude faster than division operations on Decimal types.
    93  func (d *Div) internalType() sql.Type {
    94  	return d.determineResultType(false)
    95  }
    97  // CollationCoercibility implements the interface sql.CollationCoercible.
    98  func (*Div) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    99  	return sql.Collation_binary, 5
   100  }
   102  // WithChildren implements the Expression interface.
   103  func (d *Div) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   104  	if len(children) != 2 {
   105  		return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 2)
   106  	}
   107  	return NewDiv(children[0], children[1]), nil
   108  }
   110  // Eval implements the Expression interface.
   111  func (d *Div) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   112  	lval, rval, err := d.evalLeftRight(ctx, row)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   117  	if lval == nil || rval == nil {
   118  		return nil, nil
   119  	}
   121  	lval, rval = d.convertLeftRight(ctx, lval, rval)
   123  	result, err := d.div(ctx, lval, rval)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   128  	// Decimals must be rounded
   129  	if res, ok := result.(decimal.Decimal); ok {
   130  		if isOutermostArithmeticOp(d, d.ops) {
   131  			finalScale, _ := getFinalScale(ctx, row, d, 0)
   132  			return res.Round(finalScale), nil
   133  		}
   134  	}
   136  	return result, nil
   137  }
   139  func (d *Div) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) {
   140  	var lval, rval interface{}
   141  	var err error
   143  	// division used with Interval error is caught at parsing the query
   144  	lval, err = d.LeftChild.Eval(ctx, row)
   145  	if err != nil {
   146  		return nil, nil, err
   147  	}
   149  	// this operation is only done on the left value as the scale/fraction part of the leftmost value
   150  	// is used to calculate the scale of the final result. If the value is GetField of decimal type column
   151  	// the decimal value evaluated does not always match the scale of column type definition
   152  	if dt, ok := d.LeftChild.Type().(sql.DecimalType); ok {
   153  		if dVal, ok := lval.(decimal.Decimal); ok {
   154  			ts := int32(dt.Scale())
   155  			if ts > dVal.Exponent()*-1 {
   156  				lval, err = decimal.NewFromString(dVal.StringFixed(ts))
   157  				if err != nil {
   158  					return nil, nil, err
   159  				}
   160  			}
   161  		}
   162  	}
   164  	rval, err = d.RightChild.Eval(ctx, row)
   165  	if err != nil {
   166  		return nil, nil, err
   167  	}
   169  	return lval, rval, nil
   170  }
   172  // convertLeftRight returns the most appropriate type for left and right evaluated values,
   173  // which may or may not be converted from its original type.
   174  // It checks for float type column reference, then the both values converted to the same float type.
   175  // Integer column references are treated as floats internally for performance reason, but the final result
   176  // from the expression tree is converted to a Decimal in order to match MySQL's behavior.
   177  // The decimal types of left and right value does NOT need to be the same. Both the types
   178  // should be preserved.
   179  func (d *Div) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) {
   180  	typ := d.internalType()
   181  	lIsTimeType := types.IsTime(d.LeftChild.Type())
   182  	rIsTimeType := types.IsTime(d.RightChild.Type())
   184  	if types.IsFloat(typ) {
   185  		left = convertValueToType(ctx, typ, left, lIsTimeType)
   186  	} else {
   187  		left = convertToDecimalValue(left, lIsTimeType)
   188  	}
   190  	if types.IsFloat(typ) {
   191  		right = convertValueToType(ctx, typ, right, rIsTimeType)
   192  	} else {
   193  		right = convertToDecimalValue(right, rIsTimeType)
   194  	}
   196  	return left, right
   197  }
   199  func (d *Div) div(ctx *sql.Context, lval, rval interface{}) (interface{}, error) {
   200  	switch l := lval.(type) {
   201  	case float32:
   202  		switch r := rval.(type) {
   203  		case float32:
   204  			if r == 0 {
   205  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   206  				return nil, nil
   207  			}
   208  			return l / r, nil
   209  		}
   210  	case float64:
   211  		switch r := rval.(type) {
   212  		case float64:
   213  			if r == 0 {
   214  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   215  				return nil, nil
   216  			}
   217  			return l / r, nil
   218  		}
   219  	case decimal.Decimal:
   220  		switch r := rval.(type) {
   221  		case decimal.Decimal:
   222  			if r.Equal(decimal.NewFromInt(0)) {
   223  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   224  				return nil, nil
   225  			}
   227  			lScale, rScale := -1*l.Exponent(), -1*r.Exponent()
   228  			inc := int32(math.Ceil(float64(lScale+rScale+divPrecInc) / divIntPrecInc))
   229  			if lScale != 0 && rScale != 0 {
   230  				lInc := int32(math.Ceil(float64(lScale) / divIntPrecInc))
   231  				rInc := int32(math.Ceil(float64(rScale) / divIntPrecInc))
   232  				inc2 := lInc + rInc
   233  				if inc2 > inc {
   234  					inc = inc2
   235  				}
   236  			}
   237  			scale := inc * divIntPrecInc
   238  			l = l.Truncate(scale)
   239  			r = r.Truncate(scale)
   241  			// give it buffer of 2 additional scale to avoid the result to be rounded
   242  			res := l.DivRound(r, scale+2)
   243  			res = res.Truncate(scale)
   244  			return res, nil
   245  		}
   246  	}
   248  	return nil, errUnableToCast.New(lval, rval)
   249  }
   251  // determineResultType looks at the expressions in the expression tree with this division operation and determines
   252  // the result type of this division expression. This involves looking at the types of the expressions in the tree,
   253  // and looking for float types or Decimal types. If |outermostResult| is false, then we prefer to treat ints as floats
   254  // (instead of Decimals) for performance reasons, but when |outermostResult| is true, we must treat ints as Decimals
   255  // in order to match MySQL's behavior.
   256  func (d *Div) determineResultType(outermostResult bool) sql.Type {
   257  	//TODO: what if both BindVars? should be constant folded
   258  	rTyp := d.RightChild.Type()
   259  	if types.IsDeferredType(rTyp) {
   260  		return rTyp
   261  	}
   262  	lTyp := d.LeftChild.Type()
   263  	if types.IsDeferredType(lTyp) {
   264  		return lTyp
   265  	}
   267  	if types.IsText(lTyp) || types.IsText(rTyp) {
   268  		return types.Float64
   269  	}
   271  	if types.IsJSON(lTyp) || types.IsJSON(rTyp) {
   272  		return types.Float64
   273  	}
   275  	if types.IsFloat(lTyp) || types.IsFloat(rTyp) {
   276  		return types.Float64
   277  	}
   279  	// Decimal only results from here on
   281  	if types.IsDatetimeType(lTyp) {
   282  		if dtType, ok := lTyp.(sql.DatetimeType); ok {
   283  			scale := uint8(dtType.Precision() + divPrecInc)
   284  			if scale > types.DecimalTypeMaxScale {
   285  				scale = types.DecimalTypeMaxScale
   286  			}
   287  			// TODO: determine actual precision
   288  			return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale)
   289  		}
   290  	}
   292  	if types.IsDecimal(lTyp) {
   293  		prec, scale := lTyp.(sql.DecimalType).Precision(), lTyp.(sql.DecimalType).Scale()
   294  		scale = scale + divPrecInc
   295  		if d.ops == -1 {
   296  			scale = (scale/divIntPrecInc + 1) * divIntPrecInc
   297  			prec = prec + scale
   298  		} else {
   299  			prec = prec + divPrecInc
   300  		}
   302  		if prec > types.DecimalTypeMaxPrecision {
   303  			prec = types.DecimalTypeMaxPrecision
   304  		}
   305  		if scale > types.DecimalTypeMaxScale {
   306  			scale = types.DecimalTypeMaxScale
   307  		}
   308  		return types.MustCreateDecimalType(prec, scale)
   309  	}
   311  	// All other types are treated as if they were integers
   312  	if d.ops == -1 {
   313  		return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divIntPrecInc)
   314  	}
   315  	return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecInc)
   316  }
   318  // getFloatOrMaxDecimalType returns either Float64 or Decimal type with max precision and scale
   319  // depending on column reference, expression types and evaluated value types. Otherwise, the return
   320  // type is always max decimal type. |treatIntsAsFloats| is used for division operation optimization.
   321  func getFloatOrMaxDecimalType(e sql.Expression, treatIntsAsFloats bool) sql.Type {
   322  	var resType sql.Type
   323  	var maxWhole, maxFrac uint8
   324  	sql.Inspect(e, func(expr sql.Expression) bool {
   325  		switch c := expr.(type) {
   326  		case *GetField:
   327  			ct := c.Type()
   328  			if treatIntsAsFloats && types.IsInteger(ct) {
   329  				resType = types.Float64
   330  				return false
   331  			}
   332  			// If there is float type column reference, the result type is always float.
   333  			if types.IsFloat(ct) {
   334  				resType = types.Float64
   335  				return false
   336  			}
   337  			if types.IsDecimal(ct) {
   338  				dt := ct.(sql.DecimalType)
   339  				p, s := dt.Precision(), dt.Scale()
   340  				if whole := p - s; whole > maxWhole {
   341  					maxWhole = whole
   342  				}
   343  				if s > maxFrac {
   344  					maxFrac = s
   345  				}
   346  			}
   347  		case *Convert:
   348  			if c.cachedDecimalType != nil {
   349  				p, s := GetPrecisionAndScale(c.cachedDecimalType)
   350  				if whole := p - s; whole > maxWhole {
   351  					maxWhole = whole
   352  				}
   353  				if s > maxFrac {
   354  					maxFrac = s
   355  				}
   356  			}
   357  		case *Literal:
   358  			if types.IsNumber(c.Type()) {
   359  				l, err := c.Eval(nil, nil)
   360  				if err == nil {
   361  					p, s := GetPrecisionAndScale(l)
   362  					if whole := p - s; whole > maxWhole {
   363  						maxWhole = whole
   364  					}
   365  					if s > maxFrac {
   366  						maxFrac = s
   367  					}
   368  				}
   369  			}
   370  		case sql.FunctionExpression:
   371  			// Mod.Type() calls this, so ignore it for infinite loop
   372  			if c.FunctionName() != "mod" {
   373  				resType = c.Type()
   374  			}
   375  		}
   376  		return true
   377  	})
   378  	if resType == types.Float64 {
   379  		return resType
   380  	}
   382  	// defType is defined by evaluating all number literals available and defined column type.
   383  	defType, derr := types.CreateDecimalType(maxWhole+maxFrac, maxFrac)
   384  	if derr != nil {
   385  		return types.MustCreateDecimalType(65, 10)
   386  	}
   388  	return defType
   389  }
   391  // convertToDecimalValue returns value converted to decimaltype.
   392  // If the value is invalid, it returns decimal 0. This function
   393  // is used for 'div' or 'mod' arithmetic operation, which requires
   394  // the result value to have precise precision and scale.
   395  func convertToDecimalValue(val interface{}, isTimeType bool) interface{} {
   396  	if isTimeType {
   397  		val = convertTimeTypeToString(val)
   398  	}
   399  	switch v := val.(type) {
   400  	case bool:
   401  		val = 0
   402  		if v {
   403  			val = 1
   404  		}
   405  	default:
   406  	}
   408  	if _, ok := val.(decimal.Decimal); !ok {
   409  		p, s := GetPrecisionAndScale(val)
   410  		if p > types.DecimalTypeMaxPrecision {
   411  			p = types.DecimalTypeMaxPrecision
   412  		}
   413  		if s > types.DecimalTypeMaxScale {
   414  			s = types.DecimalTypeMaxScale
   415  		}
   416  		dtyp, err := types.CreateDecimalType(p, s)
   417  		if err != nil {
   418  			val = decimal.Zero
   419  		}
   420  		val, _, err = dtyp.Convert(val)
   421  		if err != nil {
   422  			val = decimal.Zero
   423  		}
   424  	}
   426  	return val
   427  }
   429  // countDivs returns the number of division operators in order on the left child node of the current node.
   430  // This lets us count how many division operator used one after the other. E.g. 24/3/2/1 will have this structure:
   431  //
   432  //		     'div'
   433  //		     /   \
   434  //		   'div'  1
   435  //		   /   \
   436  //		 'div'  2
   437  //		 /   \
   438  //	    24    3
   439  func countDivOps(e sql.Expression) int32 {
   440  	if e == nil {
   441  		return 0
   442  	}
   443  	if a, ok := e.(*Div); ok {
   444  		return countDivOps(a.LeftChild) + 1
   445  	}
   446  	if a, ok := e.(ArithmeticOp); ok {
   447  		return countDivOps(a.Left())
   448  	}
   449  	return 0
   450  }
   452  // setDivs will set each node's DivScale to the number counted by countDivs. This allows us to
   453  // keep track of whether the current Div expression is the last Div operation, so the result is
   454  // rounded appropriately.
   455  func setDivOps(e sql.Expression, divOps int32) {
   456  	if e == nil {
   457  		return
   458  	}
   459  	if a, isArithmeticOp := e.(ArithmeticOp); isArithmeticOp {
   460  		if d, ok := a.(*Div); ok {
   461  			d.divOps = divOps
   462  		}
   463  		setDivOps(a.Left(), divOps)
   464  		setDivOps(a.Right(), divOps)
   465  	}
   466  	if tup, ok := e.(Tuple); ok {
   467  		for _, expr := range tup {
   468  			setDivOps(expr, divOps)
   469  		}
   470  	}
   471  	return
   472  }
   474  // isOutermostDiv returns whether the expression we're currently evaluating is
   475  // the last division operation of all continuous divisions.
   476  // E.g. the top 'div' (divided by 1) is the outermost/last division that is calculated:
   477  //
   478  //		     'div'
   479  //		     /   \
   480  //		   'div'  1
   481  //		   /   \
   482  //		 'div'  2
   483  //		 /   \
   484  //	    24    3
   485  func isOutermostDiv(e sql.Expression, d, dScale int32) bool {
   486  	if e == nil {
   487  		return false
   488  	}
   490  	if a, ok := e.(*Div); ok {
   491  		d = d + 1
   492  		if d == dScale {
   493  			return true
   494  		}
   495  		return isOutermostDiv(a.LeftChild, d, dScale)
   496  	}
   498  	if a, ok := e.(ArithmeticOp); ok {
   499  		return isOutermostDiv(a.Left(), d, dScale)
   500  	}
   502  	return false
   503  }
   505  // getFinalScale returns the final scale of the result value.
   506  // it traverses both the left and right nodes looking for Div, Arithmetic, and Literal nodes
   507  func getFinalScale(ctx *sql.Context, row sql.Row, expr sql.Expression, divOpCnt int32) (int32, bool) {
   508  	if expr == nil {
   509  		return 0, false
   510  	}
   512  	if div, isDiv := expr.(*Div); isDiv {
   513  		// TODO: there's gotta be a better way of determining if this is the leftmost div...
   514  		fScale := int32(divPrecInc)
   515  		divOpCnt = divOpCnt + 1
   516  		if divOpCnt == div.divOps {
   517  			// TODO: redundant call to Eval for LeftChild
   518  			lval, err := div.LeftChild.Eval(ctx, row)
   519  			if err != nil {
   520  				return 0, false
   521  			}
   522  			_, s := GetPrecisionAndScale(lval)
   523  			typ := div.LeftChild.Type()
   524  			if dt, dok := typ.(sql.DecimalType); dok {
   525  				ts := dt.Scale()
   526  				if ts > s {
   527  					s = ts
   528  				}
   529  			}
   530  			fScale += int32(s)
   531  		} else {
   532  			// We only care about left scale for divs
   533  			lScale, _ := getFinalScale(ctx, row, div.LeftChild, divOpCnt)
   534  			fScale += lScale
   535  		}
   537  		if fScale > types.DecimalTypeMaxScale {
   538  			fScale = types.DecimalTypeMaxScale
   539  		}
   540  		return fScale, true
   541  	}
   543  	if a, isArith := expr.(*Arithmetic); isArith {
   544  		lScale, lHasDiv := getFinalScale(ctx, row, a.Left(), divOpCnt)
   545  		rScale, rHasDiv := getFinalScale(ctx, row, a.Right(), divOpCnt)
   546  		var fScale int32
   547  		switch a.Operator() {
   548  		case sqlparser.PlusStr, sqlparser.MinusStr:
   549  			if lScale > rScale {
   550  				fScale = lScale
   551  			} else {
   552  				fScale = rScale
   553  			}
   554  		case sqlparser.MultStr:
   555  			fScale = lScale + rScale
   556  		}
   557  		if fScale > types.DecimalTypeMaxScale {
   558  			fScale = types.DecimalTypeMaxScale
   559  		}
   560  		return fScale, lHasDiv || rHasDiv
   561  	}
   563  	// TODO: this is just a guess of what mod should do with scale; test this
   564  	if m, isMod := expr.(*Mod); isMod {
   565  		fScale, leftHasDiv := getFinalScale(ctx, row, m.LeftChild, divOpCnt)
   566  		rScale, rightHasDiv := getFinalScale(ctx, row, m.RightChild, divOpCnt)
   567  		if rScale > fScale {
   568  			fScale = rScale
   569  		}
   570  		if fScale > types.DecimalTypeMaxScale {
   571  			fScale = types.DecimalTypeMaxScale
   572  		}
   573  		return fScale, leftHasDiv || rightHasDiv
   574  	}
   576  	// TODO: likely need a case for IntDiv
   578  	var fScale uint8
   579  	if lit, isLit := expr.(*Literal); isLit {
   580  		_, fScale = GetPrecisionAndScale(lit.value)
   581  	}
   582  	typ := expr.Type()
   583  	if dt, dok := typ.(sql.DecimalType); dok {
   584  		ts := dt.Scale()
   585  		if ts > fScale {
   586  			fScale = ts
   587  		}
   588  	}
   590  	return int32(fScale), false
   591  }
   593  // GetDecimalPrecisionAndScale returns precision and scale for given string formatted float/double number.
   594  func GetDecimalPrecisionAndScale(val string) (uint8, uint8) {
   595  	scale := 0
   596  	precScale := strings.Split(strings.TrimPrefix(val, "-"), ".")
   597  	if len(precScale) != 1 {
   598  		scale = len(precScale[1])
   599  	}
   600  	precision := len((precScale)[0]) + scale
   601  	return uint8(precision), uint8(scale)
   602  }
   604  // GetPrecisionAndScale converts the value to string format and parses it to get the precision and scale.
   605  func GetPrecisionAndScale(val interface{}) (uint8, uint8) {
   606  	var str string
   607  	switch v := val.(type) {
   608  	case time.Time:
   609  		str = fmt.Sprintf("%v", v.In(time.UTC).Format("2006-01-02 15:04:05"))
   610  	case decimal.Decimal:
   611  		str = v.StringFixed(v.Exponent() * -1)
   612  	case float32:
   613  		d := decimal.NewFromFloat32(v)
   614  		str = d.StringFixed(d.Exponent() * -1)
   615  	case float64:
   616  		d := decimal.NewFromFloat(v)
   617  		str = d.StringFixed(d.Exponent() * -1)
   618  	default:
   619  		str = fmt.Sprintf("%v", v)
   620  	}
   621  	return GetDecimalPrecisionAndScale(str)
   622  }
   624  var _ ArithmeticOp = (*IntDiv)(nil)
   625  var _ sql.CollationCoercible = (*IntDiv)(nil)
   627  // IntDiv expression represents integer "div" arithmetic operation
   628  type IntDiv struct {
   629  	BinaryExpressionStub
   630  	ops int32
   631  }
   633  // NewIntDiv creates a new IntDiv 'div' sql.Expression.
   634  func NewIntDiv(left, right sql.Expression) *IntDiv {
   635  	a := &IntDiv{BinaryExpressionStub{LeftChild: left, RightChild: right}, 0}
   636  	ops := countArithmeticOps(a)
   637  	setArithmeticOps(a, ops)
   638  	return a
   639  }
   641  func (i *IntDiv) Operator() string {
   642  	return sqlparser.IntDivStr
   643  }
   645  func (i *IntDiv) SetOpCount(i2 int32) {
   646  	i.ops = i2
   647  }
   649  func (i *IntDiv) String() string {
   650  	return fmt.Sprintf("(%s div %s)", i.LeftChild, i.RightChild)
   651  }
   653  func (i *IntDiv) DebugString() string {
   654  	return fmt.Sprintf("(%s div %s)", sql.DebugString(i.LeftChild), sql.DebugString(i.RightChild))
   655  }
   657  // IsNullable implements the sql.Expression interface.
   658  func (i *IntDiv) IsNullable() bool {
   659  	return i.BinaryExpressionStub.IsNullable()
   660  }
   662  // Type returns the greatest type for given operation.
   663  func (i *IntDiv) Type() sql.Type {
   664  	lTyp := i.LeftChild.Type()
   665  	rTyp := i.RightChild.Type()
   667  	if types.IsUnsigned(lTyp) || types.IsUnsigned(rTyp) {
   668  		return types.Uint64
   669  	}
   671  	return types.Int64
   672  }
   674  // CollationCoercibility implements the interface sql.CollationCoercible.
   675  func (*IntDiv) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   676  	return sql.Collation_binary, 5
   677  }
   679  // WithChildren implements the Expression interface.
   680  func (i *IntDiv) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   681  	if len(children) != 2 {
   682  		return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 2)
   683  	}
   684  	return NewIntDiv(children[0], children[1]), nil
   685  }
   687  // Eval implements the Expression interface.
   688  func (i *IntDiv) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   689  	lval, rval, err := i.evalLeftRight(ctx, row)
   690  	if err != nil {
   691  		return nil, err
   692  	}
   694  	if lval == nil || rval == nil {
   695  		return nil, nil
   696  	}
   698  	lval, rval = i.convertLeftRight(ctx, lval, rval)
   700  	return intDiv(ctx, lval, rval)
   701  }
   703  func (i *IntDiv) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) {
   704  	var lval, rval interface{}
   705  	var err error
   707  	// int division used with Interval error is caught at parsing the query
   708  	lval, err = i.LeftChild.Eval(ctx, row)
   709  	if err != nil {
   710  		return nil, nil, err
   711  	}
   713  	rval, err = i.RightChild.Eval(ctx, row)
   714  	if err != nil {
   715  		return nil, nil, err
   716  	}
   718  	return lval, rval, nil
   719  }
   721  // convertLeftRight return most appropriate value for left and right from evaluated value,
   722  // which can might or might not be converted from its original value.
   723  // It checks for float type column reference, then the both values converted to the same float types.
   724  // If there is no float type column reference, both values should be handled as decimal type
   725  // The decimal types of left and right value does NOT need to be the same. Both the types
   726  // should be preserved.
   727  func (i *IntDiv) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) {
   728  	var typ sql.Type
   729  	lTyp, rTyp := i.LeftChild.Type(), i.RightChild.Type()
   730  	lIsTimeType := types.IsTime(lTyp)
   731  	rIsTimeType := types.IsTime(rTyp)
   733  	if types.IsText(lTyp) || types.IsText(rTyp) {
   734  		typ = types.Float64
   735  	} else if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) {
   736  		typ = types.Uint64
   737  	} else if (lIsTimeType && rIsTimeType) || (types.IsSigned(lTyp) && types.IsSigned(rTyp)) {
   738  		typ = types.Int64
   739  	} else {
   740  		typ = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 0)
   741  	}
   743  	if types.IsInteger(typ) || types.IsFloat(typ) {
   744  		left = convertValueToType(ctx, typ, left, lIsTimeType)
   745  		right = convertValueToType(ctx, typ, right, rIsTimeType)
   746  	} else {
   747  		left = convertToDecimalValue(left, lIsTimeType)
   748  		right = convertToDecimalValue(right, rIsTimeType)
   749  	}
   751  	return left, right
   752  }
   754  func intDiv(ctx *sql.Context, lval, rval interface{}) (interface{}, error) {
   755  	switch l := lval.(type) {
   756  	case uint64:
   757  		switch r := rval.(type) {
   758  		case uint64:
   759  			if r == 0 {
   760  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   761  				return nil, nil
   762  			}
   763  			return l / r, nil
   764  		}
   765  	case int64:
   766  		switch r := rval.(type) {
   767  		case int64:
   768  			if r == 0 {
   769  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   770  				return nil, nil
   771  			}
   772  			return l / r, nil
   773  		}
   774  	case float64:
   775  		switch r := rval.(type) {
   776  		case float64:
   777  			if r == 0 {
   778  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   779  				return nil, nil
   780  			}
   781  			res := l / r
   782  			return int64(math.Floor(res)), nil
   783  		}
   784  	case decimal.Decimal:
   785  		switch r := rval.(type) {
   786  		case decimal.Decimal:
   787  			if r.Equal(decimal.NewFromInt(0)) {
   788  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   789  				return nil, nil
   790  			}
   792  			// intDiv operation gets the integer part of the divided value without rounding the result with 0 precision
   793  			// We get division result with non-zero precision and then truncate it to get integer part without it being rounded
   794  			divRes := l.DivRound(r, 2).Truncate(0)
   796  			// cannot use IntPart() function of decimal.Decimal package as it returns 0 as undefined value for out of range value
   797  			// it causes valid result value of 0 to be the same as invalid out of range value of 0. The fraction part
   798  			// should not be rounded, so truncate the result wih 0 precision.
   799  			intPart, err := strconv.ParseInt(divRes.String(), 10, 64)
   800  			if err != nil {
   801  				return nil, ErrIntDivDataOutOfRange.New(l.StringFixed(l.Exponent()), r.StringFixed(r.Exponent()))
   802  			}
   804  			return intPart, nil
   805  		}
   806  	}
   808  	return nil, errUnableToCast.New(lval, rval)
   809  }