vitess.io/vitess@v0.16.2/go/vt/vtgate/evalengine/arithmetic_expr.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package evalengine
    18  
    19  import (
    20  	"vitess.io/vitess/go/sqltypes"
    21  )
    22  
    23  type (
    24  	ArithmeticExpr struct {
    25  		BinaryExpr
    26  		Op ArithmeticOp
    27  	}
    28  
    29  	ArithmeticOp interface {
    30  		eval(left, right, out *EvalResult) error
    31  		String() string
    32  	}
    33  
    34  	OpAddition       struct{}
    35  	OpSubstraction   struct{}
    36  	OpMultiplication struct{}
    37  	OpDivision       struct{}
    38  )
    39  
    40  var _ ArithmeticOp = (*OpAddition)(nil)
    41  var _ ArithmeticOp = (*OpSubstraction)(nil)
    42  var _ ArithmeticOp = (*OpMultiplication)(nil)
    43  var _ ArithmeticOp = (*OpDivision)(nil)
    44  
    45  func (b *ArithmeticExpr) eval(env *ExpressionEnv, out *EvalResult) {
    46  	var left, right EvalResult
    47  	left.init(env, b.Left)
    48  	right.init(env, b.Right)
    49  	if left.isNull() || right.isNull() {
    50  		out.setNull()
    51  		return
    52  	}
    53  	if left.typeof() == sqltypes.Tuple || right.typeof() == sqltypes.Tuple {
    54  		panic("failed to typecheck tuples")
    55  	}
    56  	if err := b.Op.eval(&left, &right, out); err != nil {
    57  		throwEvalError(err)
    58  	}
    59  }
    60  
    61  func makeNumericalType(t sqltypes.Type, f flag) sqltypes.Type {
    62  	if sqltypes.IsNumber(t) {
    63  		return t
    64  	}
    65  	if t == sqltypes.VarBinary && (f&flagHex) != 0 {
    66  		return sqltypes.Uint64
    67  	}
    68  	return sqltypes.Float64
    69  }
    70  
    71  // typeof implements the Expr interface
    72  func (b *ArithmeticExpr) typeof(env *ExpressionEnv) (sqltypes.Type, flag) {
    73  	t1, f1 := b.Left.typeof(env)
    74  	t2, f2 := b.Right.typeof(env)
    75  	flags := f1 | f2
    76  
    77  	t1 = makeNumericalType(t1, f1)
    78  	t2 = makeNumericalType(t2, f2)
    79  
    80  	switch b.Op.(type) {
    81  	case *OpDivision:
    82  		if t1 == sqltypes.Float64 || t2 == sqltypes.Float64 {
    83  			return sqltypes.Float64, flags
    84  		}
    85  		return sqltypes.Decimal, flags
    86  	}
    87  
    88  	switch t1 {
    89  	case sqltypes.Int64:
    90  		switch t2 {
    91  		case sqltypes.Uint64, sqltypes.Float64, sqltypes.Decimal:
    92  			return t2, flags
    93  		}
    94  	case sqltypes.Uint64:
    95  		switch t2 {
    96  		case sqltypes.Float64, sqltypes.Decimal:
    97  			return t2, flags
    98  		}
    99  	case sqltypes.Decimal:
   100  		if t2 == sqltypes.Float64 {
   101  			return t2, flags
   102  		}
   103  	}
   104  	return t1, flags
   105  }
   106  
   107  func (a *OpAddition) eval(left, right, out *EvalResult) error {
   108  	return addNumericWithError(left, right, out)
   109  }
   110  func (a *OpAddition) String() string { return "+" }
   111  
   112  func (s *OpSubstraction) eval(left, right, out *EvalResult) error {
   113  	return subtractNumericWithError(left, right, out)
   114  }
   115  func (s *OpSubstraction) String() string { return "-" }
   116  
   117  func (m *OpMultiplication) eval(left, right, out *EvalResult) error {
   118  	return multiplyNumericWithError(left, right, out)
   119  }
   120  func (m *OpMultiplication) String() string { return "*" }
   121  
   122  func (d *OpDivision) eval(left, right, out *EvalResult) error {
   123  	return divideNumericWithError(left, right, true, out)
   124  }
   125  func (d *OpDivision) String() string { return "/" }