github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/mod.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  
    21  	"github.com/dolthub/vitess/go/vt/sqlparser"
    22  	"github.com/shopspring/decimal"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  )
    27  
    28  var _ ArithmeticOp = (*Mod)(nil)
    29  var _ sql.CollationCoercible = (*Mod)(nil)
    30  
    31  // Mod expression represents "%" arithmetic operation
    32  type Mod struct {
    33  	BinaryExpressionStub
    34  	ops int32
    35  }
    36  
    37  var _ sql.FunctionExpression = (*Mod)(nil)
    38  var _ sql.CollationCoercible = (*Mod)(nil)
    39  
    40  // NewMod creates a new Mod sql.Expression.
    41  func NewMod(left, right sql.Expression) *Mod {
    42  	a := &Mod{BinaryExpressionStub{LeftChild: left, RightChild: right}, 0}
    43  	ops := countArithmeticOps(a)
    44  	setArithmeticOps(a, ops)
    45  	return a
    46  }
    47  
    48  func (m *Mod) FunctionName() string {
    49  	return "mod"
    50  }
    51  
    52  func (m *Mod) Description() string {
    53  	return "returns the remainder of the first argument divided by the second argument"
    54  }
    55  
    56  func (m *Mod) Operator() string {
    57  	return sqlparser.ModStr
    58  }
    59  
    60  func (m *Mod) SetOpCount(i int32) {
    61  	m.ops = i
    62  }
    63  
    64  func (m *Mod) String() string {
    65  	return fmt.Sprintf("(%s %% %s)", m.LeftChild, m.RightChild)
    66  }
    67  
    68  func (m *Mod) DebugString() string {
    69  	return fmt.Sprintf("(%s %% %s)", sql.DebugString(m.LeftChild), sql.DebugString(m.RightChild))
    70  }
    71  
    72  // IsNullable implements the sql.Expression interface.
    73  func (m *Mod) IsNullable() bool {
    74  	return m.BinaryExpressionStub.IsNullable()
    75  }
    76  
    77  // Type returns the greatest type for given operation.
    78  func (m *Mod) Type() sql.Type {
    79  	//TODO: what if both BindVars? should be constant folded
    80  	rTyp := m.RightChild.Type()
    81  	if types.IsDeferredType(rTyp) {
    82  		return rTyp
    83  	}
    84  	lTyp := m.LeftChild.Type()
    85  	if types.IsDeferredType(lTyp) {
    86  		return lTyp
    87  	}
    88  
    89  	if types.IsText(lTyp) || types.IsText(rTyp) {
    90  		return types.Float64
    91  	}
    92  
    93  	// for division operation, it's either float or decimal.Decimal type
    94  	// except invalid value will result it either 0 or nil
    95  	return getFloatOrMaxDecimalType(m, false)
    96  }
    97  
    98  // CollationCoercibility implements the interface sql.CollationCoercible.
    99  func (*Mod) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   100  	return sql.Collation_binary, 5
   101  }
   102  
   103  // WithChildren implements the Expression interface.
   104  func (m *Mod) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   105  	if len(children) != 2 {
   106  		return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 2)
   107  	}
   108  	return NewMod(children[0], children[1]), nil
   109  }
   110  
   111  // Eval implements the Expression interface.
   112  func (m *Mod) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   113  	lval, rval, err := m.evalLeftRight(ctx, row)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	if lval == nil || rval == nil {
   119  		return nil, nil
   120  	}
   121  
   122  	lval, rval = m.convertLeftRight(ctx, lval, rval)
   123  
   124  	return mod(ctx, lval, rval)
   125  }
   126  
   127  func (m *Mod) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) {
   128  	var lval, rval interface{}
   129  	var err error
   130  
   131  	// mod used with Interval error is caught at parsing the query
   132  	lval, err = m.LeftChild.Eval(ctx, row)
   133  	if err != nil {
   134  		return nil, nil, err
   135  	}
   136  
   137  	rval, err = m.RightChild.Eval(ctx, row)
   138  	if err != nil {
   139  		return nil, nil, err
   140  	}
   141  
   142  	return lval, rval, nil
   143  }
   144  
   145  func (m *Mod) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) {
   146  	typ := m.Type()
   147  	lIsTimeType := types.IsTime(m.LeftChild.Type())
   148  	rIsTimeType := types.IsTime(m.RightChild.Type())
   149  
   150  	if types.IsFloat(typ) {
   151  		left = convertValueToType(ctx, typ, left, lIsTimeType)
   152  		right = convertValueToType(ctx, typ, right, rIsTimeType)
   153  	} else {
   154  		left = convertToDecimalValue(left, lIsTimeType)
   155  		right = convertToDecimalValue(right, rIsTimeType)
   156  	}
   157  
   158  	return left, right
   159  }
   160  
   161  func mod(ctx *sql.Context, lval, rval interface{}) (interface{}, error) {
   162  	switch l := lval.(type) {
   163  	case float32:
   164  		switch r := rval.(type) {
   165  		case float32:
   166  			if r == 0 {
   167  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   168  				return nil, nil
   169  			}
   170  			return math.Mod(float64(l), float64(r)), nil
   171  		}
   172  
   173  	case float64:
   174  		switch r := rval.(type) {
   175  		case float64:
   176  			if r == 0 {
   177  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   178  				return nil, nil
   179  			}
   180  			return math.Mod(l, r), nil
   181  		}
   182  	case decimal.Decimal:
   183  		switch r := rval.(type) {
   184  		case decimal.Decimal:
   185  			if r.Equal(decimal.NewFromInt(0)) {
   186  				arithmeticWarning(ctx, ERDivisionByZero, "Division by 0")
   187  				return nil, nil
   188  			}
   189  
   190  			// Mod function from the decimal package takes care of precision and scale for the result value
   191  			return l.Mod(r), nil
   192  		}
   193  	}
   194  
   195  	return nil, errUnableToCast.New(lval, rval)
   196  }