github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/evaluator/evaluator_binop.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package evaluator
    15  
    16  import (
    17  	"math"
    18  
    19  	"github.com/insionng/yougam/libraries/juju/errors"
    20  	"github.com/insionng/yougam/libraries/pingcap/tidb/ast"
    21  	"github.com/insionng/yougam/libraries/pingcap/tidb/parser/opcode"
    22  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    23  )
    24  
    25  const (
    26  	zeroI64 int64 = 0
    27  	oneI64  int64 = 1
    28  )
    29  
    30  func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool {
    31  	switch o.Op {
    32  	case opcode.AndAnd, opcode.OrOr, opcode.LogicXor:
    33  		return e.handleLogicOperation(o)
    34  	case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
    35  		return e.handleComparisonOp(o)
    36  	case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor:
    37  		return e.handleBitOp(o)
    38  	case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv:
    39  		return e.handleArithmeticOp(o)
    40  	default:
    41  		e.err = ErrInvalidOperation
    42  		return false
    43  	}
    44  }
    45  
    46  func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool {
    47  	switch o.Op {
    48  	case opcode.AndAnd:
    49  		return e.handleAndAnd(o)
    50  	case opcode.OrOr:
    51  		return e.handleOrOr(o)
    52  	case opcode.LogicXor:
    53  		return e.handleXor(o)
    54  	default:
    55  		e.err = ErrInvalidOperation.Gen("unkown operator %s", o.Op)
    56  		return false
    57  	}
    58  }
    59  
    60  func (e *Evaluator) handleAndAnd(o *ast.BinaryOperationExpr) bool {
    61  	leftDatum := o.L.GetDatum()
    62  	rightDatum := o.R.GetDatum()
    63  	if leftDatum.Kind() != types.KindNull {
    64  		x, err := leftDatum.ToBool()
    65  		if err != nil {
    66  			e.err = errors.Trace(err)
    67  			return false
    68  		} else if x == 0 {
    69  			// false && any other types is false
    70  			o.SetInt64(x)
    71  			return true
    72  		}
    73  	}
    74  	if rightDatum.Kind() != types.KindNull {
    75  		y, err := rightDatum.ToBool()
    76  		if err != nil {
    77  			e.err = errors.Trace(err)
    78  			return false
    79  		} else if y == 0 {
    80  			o.SetInt64(y)
    81  			return true
    82  		}
    83  	}
    84  	if leftDatum.Kind() == types.KindNull || rightDatum.Kind() == types.KindNull {
    85  		o.SetNull()
    86  		return true
    87  	}
    88  	o.SetInt64(int64(1))
    89  	return true
    90  }
    91  
    92  func (e *Evaluator) handleOrOr(o *ast.BinaryOperationExpr) bool {
    93  	leftDatum := o.L.GetDatum()
    94  	if leftDatum.Kind() != types.KindNull {
    95  		x, err := leftDatum.ToBool()
    96  		if err != nil {
    97  			e.err = errors.Trace(err)
    98  			return false
    99  		} else if x == 1 {
   100  			// true || any other types is true.
   101  			o.SetInt64(x)
   102  			return true
   103  		}
   104  	}
   105  	righDatum := o.R.GetDatum()
   106  	if righDatum.Kind() != types.KindNull {
   107  		y, err := righDatum.ToBool()
   108  		if err != nil {
   109  			e.err = errors.Trace(err)
   110  			return false
   111  		} else if y == 1 {
   112  			o.SetInt64(y)
   113  			return true
   114  		}
   115  	}
   116  	if leftDatum.Kind() == types.KindNull || righDatum.Kind() == types.KindNull {
   117  		o.SetNull()
   118  		return true
   119  	}
   120  	o.SetInt64(int64(0))
   121  	return true
   122  }
   123  
   124  func (e *Evaluator) handleXor(o *ast.BinaryOperationExpr) bool {
   125  	leftDatum := o.L.GetDatum()
   126  	righDatum := o.R.GetDatum()
   127  	if leftDatum.Kind() == types.KindNull || righDatum.Kind() == types.KindNull {
   128  		o.SetNull()
   129  		return true
   130  	}
   131  	x, err := leftDatum.ToBool()
   132  	if err != nil {
   133  		e.err = errors.Trace(err)
   134  		return false
   135  	}
   136  
   137  	y, err := righDatum.ToBool()
   138  	if err != nil {
   139  		e.err = errors.Trace(err)
   140  		return false
   141  	}
   142  	if x == y {
   143  		o.SetInt64(int64(0))
   144  	} else {
   145  		o.SetInt64(int64(1))
   146  	}
   147  	return true
   148  }
   149  
   150  func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool {
   151  	a, b := types.CoerceDatum(*o.L.GetDatum(), *o.R.GetDatum())
   152  	if a.Kind() == types.KindNull || b.Kind() == types.KindNull {
   153  		// for <=>, if a and b are both nil, return true.
   154  		// if a or b is nil, return false.
   155  		if o.Op == opcode.NullEQ {
   156  			if a.Kind() == types.KindNull && b.Kind() == types.KindNull {
   157  				o.SetInt64(oneI64)
   158  			} else {
   159  				o.SetInt64(zeroI64)
   160  			}
   161  		} else {
   162  			o.SetNull()
   163  		}
   164  		return true
   165  	}
   166  
   167  	n, err := a.CompareDatum(b)
   168  
   169  	if err != nil {
   170  		e.err = errors.Trace(err)
   171  		return false
   172  	}
   173  
   174  	r, err := getCompResult(o.Op, n)
   175  	if err != nil {
   176  		e.err = errors.Trace(err)
   177  		return false
   178  	}
   179  	if r {
   180  		o.SetInt64(oneI64)
   181  	} else {
   182  		o.SetInt64(zeroI64)
   183  	}
   184  	return true
   185  }
   186  
   187  func getCompResult(op opcode.Op, value int) (bool, error) {
   188  	switch op {
   189  	case opcode.LT:
   190  		return value < 0, nil
   191  	case opcode.LE:
   192  		return value <= 0, nil
   193  	case opcode.GE:
   194  		return value >= 0, nil
   195  	case opcode.GT:
   196  		return value > 0, nil
   197  	case opcode.EQ:
   198  		return value == 0, nil
   199  	case opcode.NE:
   200  		return value != 0, nil
   201  	case opcode.NullEQ:
   202  		return value == 0, nil
   203  	default:
   204  		return false, ErrInvalidOperation.Gen("invalid op %v in comparision operation", op)
   205  	}
   206  }
   207  
   208  func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool {
   209  	a, b := types.CoerceDatum(*o.L.GetDatum(), *o.R.GetDatum())
   210  
   211  	if a.Kind() == types.KindNull || b.Kind() == types.KindNull {
   212  		o.SetNull()
   213  		return true
   214  	}
   215  
   216  	x, err := a.ToInt64()
   217  	if err != nil {
   218  		e.err = errors.Trace(err)
   219  		return false
   220  	}
   221  
   222  	y, err := b.ToInt64()
   223  	if err != nil {
   224  		e.err = errors.Trace(err)
   225  		return false
   226  	}
   227  
   228  	// use a int64 for bit operator, return uint64
   229  	switch o.Op {
   230  	case opcode.And:
   231  		o.SetUint64(uint64(x & y))
   232  	case opcode.Or:
   233  		o.SetUint64(uint64(x | y))
   234  	case opcode.Xor:
   235  		o.SetUint64(uint64(x ^ y))
   236  	case opcode.RightShift:
   237  		o.SetUint64(uint64(x) >> uint64(y))
   238  	case opcode.LeftShift:
   239  		o.SetUint64(uint64(x) << uint64(y))
   240  	default:
   241  		e.err = ErrInvalidOperation.Gen("invalid op %v in bit operation", o.Op)
   242  		return false
   243  	}
   244  	return true
   245  }
   246  
   247  func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool {
   248  	a, err := coerceArithmetic(*o.L.GetDatum())
   249  	if err != nil {
   250  		e.err = errors.Trace(err)
   251  		return false
   252  	}
   253  
   254  	b, err := coerceArithmetic(*o.R.GetDatum())
   255  	if err != nil {
   256  		e.err = errors.Trace(err)
   257  		return false
   258  	}
   259  
   260  	a, b = types.CoerceDatum(a, b)
   261  	if a.Kind() == types.KindNull || b.Kind() == types.KindNull {
   262  		o.SetNull()
   263  		return true
   264  	}
   265  
   266  	var result types.Datum
   267  	switch o.Op {
   268  	case opcode.Plus:
   269  		result, e.err = computePlus(a, b)
   270  	case opcode.Minus:
   271  		result, e.err = computeMinus(a, b)
   272  	case opcode.Mul:
   273  		result, e.err = computeMul(a, b)
   274  	case opcode.Div:
   275  		result, e.err = computeDiv(a, b)
   276  	case opcode.Mod:
   277  		result, e.err = computeMod(a, b)
   278  	case opcode.IntDiv:
   279  		result, e.err = computeIntDiv(a, b)
   280  	default:
   281  		e.err = ErrInvalidOperation.Gen("invalid op %v in arithmetic operation", o.Op)
   282  		return false
   283  	}
   284  	o.SetDatum(result)
   285  	return e.err == nil
   286  }
   287  
   288  func computePlus(a, b types.Datum) (d types.Datum, err error) {
   289  	switch a.Kind() {
   290  	case types.KindInt64:
   291  		switch b.Kind() {
   292  		case types.KindInt64:
   293  			r, err1 := types.AddInt64(a.GetInt64(), b.GetInt64())
   294  			d.SetInt64(r)
   295  			return d, errors.Trace(err1)
   296  		case types.KindUint64:
   297  			r, err1 := types.AddInteger(b.GetUint64(), a.GetInt64())
   298  			d.SetUint64(r)
   299  			return d, errors.Trace(err1)
   300  		}
   301  	case types.KindUint64:
   302  		switch b.Kind() {
   303  		case types.KindInt64:
   304  			r, err1 := types.AddInteger(a.GetUint64(), b.GetInt64())
   305  			d.SetUint64(r)
   306  			return d, errors.Trace(err1)
   307  		case types.KindUint64:
   308  			r, err1 := types.AddUint64(a.GetUint64(), b.GetUint64())
   309  			d.SetUint64(r)
   310  			return d, errors.Trace(err1)
   311  		}
   312  	case types.KindFloat64:
   313  		switch b.Kind() {
   314  		case types.KindFloat64:
   315  			r := a.GetFloat64() + b.GetFloat64()
   316  			d.SetFloat64(r)
   317  			return d, nil
   318  		}
   319  	case types.KindMysqlDecimal:
   320  		switch b.Kind() {
   321  		case types.KindMysqlDecimal:
   322  			r := a.GetMysqlDecimal().Add(b.GetMysqlDecimal())
   323  			d.SetMysqlDecimal(r)
   324  			return d, nil
   325  		}
   326  	}
   327  	_, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Plus)
   328  	return d, err
   329  }
   330  
   331  func computeMinus(a, b types.Datum) (d types.Datum, err error) {
   332  	switch a.Kind() {
   333  	case types.KindInt64:
   334  		switch b.Kind() {
   335  		case types.KindInt64:
   336  			r, err1 := types.SubInt64(a.GetInt64(), b.GetInt64())
   337  			d.SetInt64(r)
   338  			return d, errors.Trace(err1)
   339  		case types.KindUint64:
   340  			r, err1 := types.SubIntWithUint(a.GetInt64(), b.GetUint64())
   341  			d.SetUint64(r)
   342  			return d, errors.Trace(err1)
   343  		}
   344  	case types.KindUint64:
   345  		switch b.Kind() {
   346  		case types.KindInt64:
   347  			r, err1 := types.SubUintWithInt(a.GetUint64(), b.GetInt64())
   348  			d.SetUint64(r)
   349  			return d, errors.Trace(err1)
   350  		case types.KindUint64:
   351  			r, err1 := types.SubUint64(a.GetUint64(), b.GetUint64())
   352  			d.SetUint64(r)
   353  			return d, errors.Trace(err1)
   354  		}
   355  	case types.KindFloat64:
   356  		switch b.Kind() {
   357  		case types.KindFloat64:
   358  			r := a.GetFloat64() - b.GetFloat64()
   359  			d.SetFloat64(r)
   360  			return d, nil
   361  		}
   362  	case types.KindMysqlDecimal:
   363  		switch b.Kind() {
   364  		case types.KindMysqlDecimal:
   365  			r := a.GetMysqlDecimal().Sub(b.GetMysqlDecimal())
   366  			d.SetMysqlDecimal(r)
   367  			return d, nil
   368  		}
   369  	}
   370  	_, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Minus)
   371  	return d, errors.Trace(err)
   372  }
   373  
   374  func computeMul(a, b types.Datum) (d types.Datum, err error) {
   375  	switch a.Kind() {
   376  	case types.KindInt64:
   377  		switch b.Kind() {
   378  		case types.KindInt64:
   379  			r, err1 := types.MulInt64(a.GetInt64(), b.GetInt64())
   380  			d.SetInt64(r)
   381  			return d, errors.Trace(err1)
   382  		case types.KindUint64:
   383  			r, err1 := types.MulInteger(b.GetUint64(), a.GetInt64())
   384  			d.SetUint64(r)
   385  			return d, errors.Trace(err1)
   386  		}
   387  	case types.KindUint64:
   388  		switch b.Kind() {
   389  		case types.KindInt64:
   390  			r, err1 := types.MulInteger(a.GetUint64(), b.GetInt64())
   391  			d.SetUint64(r)
   392  			return d, errors.Trace(err1)
   393  		case types.KindUint64:
   394  			r, err1 := types.MulUint64(a.GetUint64(), b.GetUint64())
   395  			d.SetUint64(r)
   396  			return d, errors.Trace(err1)
   397  		}
   398  	case types.KindFloat64:
   399  		switch b.Kind() {
   400  		case types.KindFloat64:
   401  			r := a.GetFloat64() * b.GetFloat64()
   402  			d.SetFloat64(r)
   403  			return d, nil
   404  		}
   405  	case types.KindMysqlDecimal:
   406  		switch b.Kind() {
   407  		case types.KindMysqlDecimal:
   408  			r := a.GetMysqlDecimal().Mul(b.GetMysqlDecimal())
   409  			d.SetMysqlDecimal(r)
   410  			return d, nil
   411  		}
   412  	}
   413  
   414  	_, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Mul)
   415  	return d, errors.Trace(err)
   416  }
   417  
   418  func computeDiv(a, b types.Datum) (d types.Datum, err error) {
   419  	// MySQL support integer division Div and division operator /
   420  	// we use opcode.Div for division operator and will use another for integer division later.
   421  	// for division operator, we will use float64 for calculation.
   422  	switch a.Kind() {
   423  	case types.KindFloat64:
   424  		y, err1 := b.ToFloat64()
   425  		if err1 != nil {
   426  			return d, errors.Trace(err1)
   427  		}
   428  
   429  		if y == 0 {
   430  			return d, nil
   431  		}
   432  
   433  		x := a.GetFloat64()
   434  		d.SetFloat64(x / y)
   435  		return d, nil
   436  	default:
   437  		// the scale of the result is the scale of the first operand plus
   438  		// the value of the div_precision_increment system variable (which is 4 by default)
   439  		// we will use 4 here
   440  		xa, err1 := a.ToDecimal()
   441  		if err != nil {
   442  			return d, errors.Trace(err1)
   443  		}
   444  
   445  		xb, err1 := b.ToDecimal()
   446  		if err1 != nil {
   447  			return d, errors.Trace(err1)
   448  		}
   449  		if f, _ := xb.Float64(); f == 0 {
   450  			// division by zero return null
   451  			return d, nil
   452  		}
   453  
   454  		d.SetMysqlDecimal(xa.Div(xb))
   455  		return d, nil
   456  	}
   457  }
   458  
   459  func computeMod(a, b types.Datum) (d types.Datum, err error) {
   460  	switch a.Kind() {
   461  	case types.KindInt64:
   462  		x := a.GetInt64()
   463  		switch b.Kind() {
   464  		case types.KindInt64:
   465  			y := b.GetInt64()
   466  			if y == 0 {
   467  				return d, nil
   468  			}
   469  			d.SetInt64(x % y)
   470  			return d, nil
   471  		case types.KindUint64:
   472  			y := b.GetUint64()
   473  			if y == 0 {
   474  				return d, nil
   475  			} else if x < 0 {
   476  				d.SetInt64(-int64(uint64(-x) % y))
   477  				// first is int64, return int64.
   478  				return d, nil
   479  			}
   480  			d.SetInt64(int64(uint64(x) % y))
   481  			return d, nil
   482  		}
   483  	case types.KindUint64:
   484  		x := a.GetUint64()
   485  		switch b.Kind() {
   486  		case types.KindInt64:
   487  			y := b.GetInt64()
   488  			if y == 0 {
   489  				return d, nil
   490  			} else if y < 0 {
   491  				// first is uint64, return uint64.
   492  				d.SetUint64(uint64(x % uint64(-y)))
   493  				return d, nil
   494  			}
   495  			d.SetUint64(x % uint64(y))
   496  			return d, nil
   497  		case types.KindUint64:
   498  			y := b.GetUint64()
   499  			if y == 0 {
   500  				return d, nil
   501  			}
   502  			d.SetUint64(x % y)
   503  			return d, nil
   504  		}
   505  	case types.KindFloat64:
   506  		x := a.GetFloat64()
   507  		switch b.Kind() {
   508  		case types.KindFloat64:
   509  			y := b.GetFloat64()
   510  			if y == 0 {
   511  				return d, nil
   512  			}
   513  			d.SetFloat64(math.Mod(x, y))
   514  			return d, nil
   515  		}
   516  	case types.KindMysqlDecimal:
   517  		x := a.GetMysqlDecimal()
   518  		switch b.Kind() {
   519  		case types.KindMysqlDecimal:
   520  			y := b.GetMysqlDecimal()
   521  			xf, _ := x.Float64()
   522  			yf, _ := y.Float64()
   523  			if yf == 0 {
   524  				return d, nil
   525  			}
   526  			d.SetFloat64(math.Mod(xf, yf))
   527  			return d, nil
   528  		}
   529  	}
   530  	_, err = types.InvOp2(a.GetValue(), b.GetValue(), opcode.Mod)
   531  	return d, errors.Trace(err)
   532  }
   533  
   534  func computeIntDiv(a, b types.Datum) (d types.Datum, err error) {
   535  	switch a.Kind() {
   536  	case types.KindInt64:
   537  		x := a.GetInt64()
   538  		switch b.Kind() {
   539  		case types.KindInt64:
   540  			y := b.GetInt64()
   541  			if y == 0 {
   542  				return d, nil
   543  			}
   544  			r, err1 := types.DivInt64(x, y)
   545  			d.SetInt64(r)
   546  			return d, errors.Trace(err1)
   547  		case types.KindUint64:
   548  			y := b.GetUint64()
   549  			if y == 0 {
   550  				return d, nil
   551  			}
   552  			r, err1 := types.DivIntWithUint(x, y)
   553  			d.SetUint64(r)
   554  			return d, errors.Trace(err1)
   555  		}
   556  	case types.KindUint64:
   557  		x := a.GetUint64()
   558  		switch b.Kind() {
   559  		case types.KindInt64:
   560  			y := b.GetInt64()
   561  			if y == 0 {
   562  				return d, nil
   563  			}
   564  			r, err1 := types.DivUintWithInt(x, y)
   565  			d.SetUint64(r)
   566  			return d, errors.Trace(err1)
   567  		case types.KindUint64:
   568  			y := b.GetUint64()
   569  			if y == 0 {
   570  				return d, nil
   571  			}
   572  			d.SetUint64(x / y)
   573  			return d, nil
   574  		}
   575  	}
   576  
   577  	// if any is none integer, use decimal to calculate
   578  	x, err := a.ToDecimal()
   579  	if err != nil {
   580  		return d, errors.Trace(err)
   581  	}
   582  
   583  	y, err := b.ToDecimal()
   584  	if err != nil {
   585  		return d, errors.Trace(err)
   586  	}
   587  
   588  	if f, _ := y.Float64(); f == 0 {
   589  		return d, nil
   590  	}
   591  
   592  	d.SetInt64(x.Div(y).IntPart())
   593  	return d, nil
   594  }
   595  
   596  func coerceArithmetic(a types.Datum) (d types.Datum, err error) {
   597  	switch a.Kind() {
   598  	case types.KindString, types.KindBytes:
   599  		// MySQL will convert string to float for arithmetic operation
   600  		f, err := types.StrToFloat(a.GetString())
   601  		if err != nil {
   602  			return d, errors.Trace(err)
   603  		}
   604  		d.SetFloat64(f)
   605  		return d, errors.Trace(err)
   606  	case types.KindMysqlTime:
   607  		// if time has no precision, return int64
   608  		t := a.GetMysqlTime()
   609  		de := t.ToNumber()
   610  		if t.Fsp == 0 {
   611  			d.SetInt64(de.IntPart())
   612  			return d, nil
   613  		}
   614  		d.SetMysqlDecimal(de)
   615  		return d, nil
   616  	case types.KindMysqlDuration:
   617  		// if duration has no precision, return int64
   618  		du := a.GetMysqlDuration()
   619  		de := du.ToNumber()
   620  		if du.Fsp == 0 {
   621  			d.SetInt64(de.IntPart())
   622  			return d, nil
   623  		}
   624  		d.SetMysqlDecimal(de)
   625  		return d, nil
   626  	case types.KindMysqlHex:
   627  		d.SetFloat64(a.GetMysqlHex().ToNumber())
   628  		return d, nil
   629  	case types.KindMysqlBit:
   630  		d.SetFloat64(a.GetMysqlBit().ToNumber())
   631  		return d, nil
   632  	case types.KindMysqlEnum:
   633  		d.SetFloat64(a.GetMysqlEnum().ToNumber())
   634  		return d, nil
   635  	case types.KindMysqlSet:
   636  		d.SetFloat64(a.GetMysqlSet().ToNumber())
   637  		return d, nil
   638  	default:
   639  		return a, nil
   640  	}
   641  }