github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/mod.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "fmt" 19 "math" 20 21 "github.com/dolthub/vitess/go/vt/sqlparser" 22 "github.com/shopspring/decimal" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/types" 26 ) 27 28 var _ ArithmeticOp = (*Mod)(nil) 29 var _ sql.CollationCoercible = (*Mod)(nil) 30 31 // Mod expression represents "%" arithmetic operation 32 type Mod struct { 33 BinaryExpressionStub 34 ops int32 35 } 36 37 var _ sql.FunctionExpression = (*Mod)(nil) 38 var _ sql.CollationCoercible = (*Mod)(nil) 39 40 // NewMod creates a new Mod sql.Expression. 41 func NewMod(left, right sql.Expression) *Mod { 42 a := &Mod{BinaryExpressionStub{LeftChild: left, RightChild: right}, 0} 43 ops := countArithmeticOps(a) 44 setArithmeticOps(a, ops) 45 return a 46 } 47 48 func (m *Mod) FunctionName() string { 49 return "mod" 50 } 51 52 func (m *Mod) Description() string { 53 return "returns the remainder of the first argument divided by the second argument" 54 } 55 56 func (m *Mod) Operator() string { 57 return sqlparser.ModStr 58 } 59 60 func (m *Mod) SetOpCount(i int32) { 61 m.ops = i 62 } 63 64 func (m *Mod) String() string { 65 return fmt.Sprintf("(%s %% %s)", m.LeftChild, m.RightChild) 66 } 67 68 func (m *Mod) DebugString() string { 69 return fmt.Sprintf("(%s %% %s)", sql.DebugString(m.LeftChild), sql.DebugString(m.RightChild)) 70 } 71 72 // IsNullable implements the sql.Expression interface. 73 func (m *Mod) IsNullable() bool { 74 return m.BinaryExpressionStub.IsNullable() 75 } 76 77 // Type returns the greatest type for given operation. 78 func (m *Mod) Type() sql.Type { 79 //TODO: what if both BindVars? should be constant folded 80 rTyp := m.RightChild.Type() 81 if types.IsDeferredType(rTyp) { 82 return rTyp 83 } 84 lTyp := m.LeftChild.Type() 85 if types.IsDeferredType(lTyp) { 86 return lTyp 87 } 88 89 if types.IsText(lTyp) || types.IsText(rTyp) { 90 return types.Float64 91 } 92 93 // for division operation, it's either float or decimal.Decimal type 94 // except invalid value will result it either 0 or nil 95 return getFloatOrMaxDecimalType(m, false) 96 } 97 98 // CollationCoercibility implements the interface sql.CollationCoercible. 99 func (*Mod) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 100 return sql.Collation_binary, 5 101 } 102 103 // WithChildren implements the Expression interface. 104 func (m *Mod) WithChildren(children ...sql.Expression) (sql.Expression, error) { 105 if len(children) != 2 { 106 return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 2) 107 } 108 return NewMod(children[0], children[1]), nil 109 } 110 111 // Eval implements the Expression interface. 112 func (m *Mod) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 113 lval, rval, err := m.evalLeftRight(ctx, row) 114 if err != nil { 115 return nil, err 116 } 117 118 if lval == nil || rval == nil { 119 return nil, nil 120 } 121 122 lval, rval = m.convertLeftRight(ctx, lval, rval) 123 124 return mod(ctx, lval, rval) 125 } 126 127 func (m *Mod) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { 128 var lval, rval interface{} 129 var err error 130 131 // mod used with Interval error is caught at parsing the query 132 lval, err = m.LeftChild.Eval(ctx, row) 133 if err != nil { 134 return nil, nil, err 135 } 136 137 rval, err = m.RightChild.Eval(ctx, row) 138 if err != nil { 139 return nil, nil, err 140 } 141 142 return lval, rval, nil 143 } 144 145 func (m *Mod) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) { 146 typ := m.Type() 147 lIsTimeType := types.IsTime(m.LeftChild.Type()) 148 rIsTimeType := types.IsTime(m.RightChild.Type()) 149 150 if types.IsFloat(typ) { 151 left = convertValueToType(ctx, typ, left, lIsTimeType) 152 right = convertValueToType(ctx, typ, right, rIsTimeType) 153 } else { 154 left = convertToDecimalValue(left, lIsTimeType) 155 right = convertToDecimalValue(right, rIsTimeType) 156 } 157 158 return left, right 159 } 160 161 func mod(ctx *sql.Context, lval, rval interface{}) (interface{}, error) { 162 switch l := lval.(type) { 163 case float32: 164 switch r := rval.(type) { 165 case float32: 166 if r == 0 { 167 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 168 return nil, nil 169 } 170 return math.Mod(float64(l), float64(r)), nil 171 } 172 173 case float64: 174 switch r := rval.(type) { 175 case float64: 176 if r == 0 { 177 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 178 return nil, nil 179 } 180 return math.Mod(l, r), nil 181 } 182 case decimal.Decimal: 183 switch r := rval.(type) { 184 case decimal.Decimal: 185 if r.Equal(decimal.NewFromInt(0)) { 186 arithmeticWarning(ctx, ERDivisionByZero, "Division by 0") 187 return nil, nil 188 } 189 190 // Mod function from the decimal package takes care of precision and scale for the result value 191 return l.Mod(r), nil 192 } 193 } 194 195 return nil, errUnableToCast.New(lval, rval) 196 }