github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/builtin_arithmetic.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package memex 15 16 import ( 17 "fmt" 18 "math" 19 20 "github.com/cznic/mathutil" 21 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 22 "github.com/whtcorpsinc/BerolinaSQL/terror" 23 "github.com/whtcorpsinc/milevadb/stochastikctx" 24 "github.com/whtcorpsinc/milevadb/types" 25 "github.com/whtcorpsinc/milevadb/soliton/chunk" 26 "github.com/whtcorpsinc/fidelpb/go-fidelpb" 27 ) 28 29 var ( 30 _ functionClass = &arithmeticPlusFunctionClass{} 31 _ functionClass = &arithmeticMinusFunctionClass{} 32 _ functionClass = &arithmeticDivideFunctionClass{} 33 _ functionClass = &arithmeticMultiplyFunctionClass{} 34 _ functionClass = &arithmeticIntDivideFunctionClass{} 35 _ functionClass = &arithmeticModFunctionClass{} 36 ) 37 38 var ( 39 _ builtinFunc = &builtinArithmeticPlusRealSig{} 40 _ builtinFunc = &builtinArithmeticPlusDecimalSig{} 41 _ builtinFunc = &builtinArithmeticPlusIntSig{} 42 _ builtinFunc = &builtinArithmeticMinusRealSig{} 43 _ builtinFunc = &builtinArithmeticMinusDecimalSig{} 44 _ builtinFunc = &builtinArithmeticMinusIntSig{} 45 _ builtinFunc = &builtinArithmeticDivideRealSig{} 46 _ builtinFunc = &builtinArithmeticDivideDecimalSig{} 47 _ builtinFunc = &builtinArithmeticMultiplyRealSig{} 48 _ builtinFunc = &builtinArithmeticMultiplyDecimalSig{} 49 _ builtinFunc = &builtinArithmeticMultiplyIntUnsignedSig{} 50 _ builtinFunc = &builtinArithmeticMultiplyIntSig{} 51 _ builtinFunc = &builtinArithmeticIntDivideIntSig{} 52 _ builtinFunc = &builtinArithmeticIntDivideDecimalSig{} 53 _ builtinFunc = &builtinArithmeticModIntSig{} 54 _ builtinFunc = &builtinArithmeticModRealSig{} 55 _ builtinFunc = &builtinArithmeticModDecimalSig{} 56 ) 57 58 // precIncrement indicates the number of digits by which to increase the scale of the result of division operations 59 // performed with the / operator. 60 const precIncrement = 4 61 62 // numericContextResultType returns types.EvalType for numeric function's parameters. 63 // the returned types.EvalType should be one of: types.ETInt, types.ETDecimal, types.ETReal 64 func numericContextResultType(ft *types.FieldType) types.EvalType { 65 if types.IsTypeTemporal(ft.Tp) { 66 if ft.Decimal > 0 { 67 return types.ETDecimal 68 } 69 return types.ETInt 70 } 71 if types.IsBinaryStr(ft) { 72 return types.ETInt 73 } 74 evalTp4Ft := types.ETReal 75 if !ft.Hybrid() { 76 evalTp4Ft = ft.EvalType() 77 if evalTp4Ft != types.ETDecimal && evalTp4Ft != types.ETInt { 78 evalTp4Ft = types.ETReal 79 } 80 } 81 return evalTp4Ft 82 } 83 84 // setFlenDecimal4Int is called to set proper `Flen` and `Decimal` of return 85 // type according to the two input parameter's types. 86 func setFlenDecimal4Int(retTp, a, b *types.FieldType) { 87 retTp.Decimal = 0 88 retTp.Flen = allegrosql.MaxIntWidth 89 } 90 91 // setFlenDecimal4RealOrDecimal is called to set proper `Flen` and `Decimal` of return 92 // type according to the two input parameter's types. 93 func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool, isMultiply bool) { 94 if a.Decimal != types.UnspecifiedLength && b.Decimal != types.UnspecifiedLength { 95 retTp.Decimal = a.Decimal + b.Decimal 96 if !isMultiply { 97 retTp.Decimal = mathutil.Max(a.Decimal, b.Decimal) 98 } 99 if !isReal && retTp.Decimal > allegrosql.MaxDecimalScale { 100 retTp.Decimal = allegrosql.MaxDecimalScale 101 } 102 if a.Flen == types.UnspecifiedLength || b.Flen == types.UnspecifiedLength { 103 retTp.Flen = types.UnspecifiedLength 104 return 105 } 106 digitsInt := mathutil.Max(a.Flen-a.Decimal, b.Flen-b.Decimal) 107 if isMultiply { 108 digitsInt = a.Flen - a.Decimal + b.Flen - b.Decimal 109 } 110 retTp.Flen = digitsInt + retTp.Decimal + 3 111 if isReal { 112 retTp.Flen = mathutil.Min(retTp.Flen, allegrosql.MaxRealWidth) 113 return 114 } 115 retTp.Flen = mathutil.Min(retTp.Flen, allegrosql.MaxDecimalWidth) 116 return 117 } 118 if isReal { 119 retTp.Flen, retTp.Decimal = types.UnspecifiedLength, types.UnspecifiedLength 120 } else { 121 retTp.Flen, retTp.Decimal = allegrosql.MaxDecimalWidth, allegrosql.MaxDecimalScale 122 } 123 } 124 125 func (c *arithmeticDivideFunctionClass) setType4DivDecimal(retTp, a, b *types.FieldType) { 126 var deca, decb = a.Decimal, b.Decimal 127 if deca == int(types.UnspecifiedFsp) { 128 deca = 0 129 } 130 if decb == int(types.UnspecifiedFsp) { 131 decb = 0 132 } 133 retTp.Decimal = deca + precIncrement 134 if retTp.Decimal > allegrosql.MaxDecimalScale { 135 retTp.Decimal = allegrosql.MaxDecimalScale 136 } 137 if a.Flen == types.UnspecifiedLength { 138 retTp.Flen = allegrosql.MaxDecimalWidth 139 return 140 } 141 retTp.Flen = a.Flen + decb + precIncrement 142 if retTp.Flen > allegrosql.MaxDecimalWidth { 143 retTp.Flen = allegrosql.MaxDecimalWidth 144 } 145 } 146 147 func (c *arithmeticDivideFunctionClass) setType4DivReal(retTp *types.FieldType) { 148 retTp.Decimal = types.UnspecifiedLength 149 retTp.Flen = allegrosql.MaxRealWidth 150 } 151 152 type arithmeticPlusFunctionClass struct { 153 baseFunctionClass 154 } 155 156 func (c *arithmeticPlusFunctionClass) getFunction(ctx stochastikctx.Context, args []Expression) (builtinFunc, error) { 157 if err := c.verifyArgs(args); err != nil { 158 return nil, err 159 } 160 lhsTp, rhsTp := args[0].GetType(), args[1].GetType() 161 lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) 162 if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { 163 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) 164 if err != nil { 165 return nil, err 166 } 167 setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true, false) 168 sig := &builtinArithmeticPlusRealSig{bf} 169 sig.setPbCode(fidelpb.ScalarFuncSig_PlusReal) 170 return sig, nil 171 } else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal { 172 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) 173 if err != nil { 174 return nil, err 175 } 176 setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false, false) 177 sig := &builtinArithmeticPlusDecimalSig{bf} 178 sig.setPbCode(fidelpb.ScalarFuncSig_PlusDecimal) 179 return sig, nil 180 } else { 181 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) 182 if err != nil { 183 return nil, err 184 } 185 if allegrosql.HasUnsignedFlag(args[0].GetType().Flag) || allegrosql.HasUnsignedFlag(args[1].GetType().Flag) { 186 bf.tp.Flag |= allegrosql.UnsignedFlag 187 } 188 setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) 189 sig := &builtinArithmeticPlusIntSig{bf} 190 sig.setPbCode(fidelpb.ScalarFuncSig_PlusInt) 191 return sig, nil 192 } 193 } 194 195 type builtinArithmeticPlusIntSig struct { 196 baseBuiltinFunc 197 } 198 199 func (s *builtinArithmeticPlusIntSig) Clone() builtinFunc { 200 newSig := &builtinArithmeticPlusIntSig{} 201 newSig.cloneFrom(&s.baseBuiltinFunc) 202 return newSig 203 } 204 205 func (s *builtinArithmeticPlusIntSig) evalInt(event chunk.Event) (val int64, isNull bool, err error) { 206 a, isNull, err := s.args[0].EvalInt(s.ctx, event) 207 if isNull || err != nil { 208 return 0, isNull, err 209 } 210 211 b, isNull, err := s.args[1].EvalInt(s.ctx, event) 212 if isNull || err != nil { 213 return 0, isNull, err 214 } 215 216 isLHSUnsigned := allegrosql.HasUnsignedFlag(s.args[0].GetType().Flag) 217 isRHSUnsigned := allegrosql.HasUnsignedFlag(s.args[1].GetType().Flag) 218 219 switch { 220 case isLHSUnsigned && isRHSUnsigned: 221 if uint64(a) > math.MaxUint64-uint64(b) { 222 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String())) 223 } 224 case isLHSUnsigned && !isRHSUnsigned: 225 if b < 0 && uint64(-b) > uint64(a) { 226 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String())) 227 } 228 if b > 0 && uint64(a) > math.MaxUint64-uint64(b) { 229 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String())) 230 } 231 case !isLHSUnsigned && isRHSUnsigned: 232 if a < 0 && uint64(-a) > uint64(b) { 233 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String())) 234 } 235 if a > 0 && uint64(b) > math.MaxUint64-uint64(a) { 236 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String())) 237 } 238 case !isLHSUnsigned && !isRHSUnsigned: 239 if (a > 0 && b > math.MaxInt64-a) || (a < 0 && b < math.MinInt64-a) { 240 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String())) 241 } 242 } 243 244 return a + b, false, nil 245 } 246 247 type builtinArithmeticPlusDecimalSig struct { 248 baseBuiltinFunc 249 } 250 251 func (s *builtinArithmeticPlusDecimalSig) Clone() builtinFunc { 252 newSig := &builtinArithmeticPlusDecimalSig{} 253 newSig.cloneFrom(&s.baseBuiltinFunc) 254 return newSig 255 } 256 257 func (s *builtinArithmeticPlusDecimalSig) evalDecimal(event chunk.Event) (*types.MyDecimal, bool, error) { 258 a, isNull, err := s.args[0].EvalDecimal(s.ctx, event) 259 if isNull || err != nil { 260 return nil, isNull, err 261 } 262 b, isNull, err := s.args[1].EvalDecimal(s.ctx, event) 263 if isNull || err != nil { 264 return nil, isNull, err 265 } 266 c := &types.MyDecimal{} 267 err = types.DecimalAdd(a, b, c) 268 if err != nil { 269 return nil, true, err 270 } 271 return c, false, nil 272 } 273 274 type builtinArithmeticPlusRealSig struct { 275 baseBuiltinFunc 276 } 277 278 func (s *builtinArithmeticPlusRealSig) Clone() builtinFunc { 279 newSig := &builtinArithmeticPlusRealSig{} 280 newSig.cloneFrom(&s.baseBuiltinFunc) 281 return newSig 282 } 283 284 func (s *builtinArithmeticPlusRealSig) evalReal(event chunk.Event) (float64, bool, error) { 285 a, isLHSNull, err := s.args[0].EvalReal(s.ctx, event) 286 if err != nil { 287 return 0, isLHSNull, err 288 } 289 b, isRHSNull, err := s.args[1].EvalReal(s.ctx, event) 290 if err != nil { 291 return 0, isRHSNull, err 292 } 293 if isLHSNull || isRHSNull { 294 return 0, true, nil 295 } 296 if (a > 0 && b > math.MaxFloat64-a) || (a < 0 && b < -math.MaxFloat64-a) { 297 return 0, true, types.ErrOverflow.GenWithStackByArgs("DOUBLE", fmt.Sprintf("(%s + %s)", s.args[0].String(), s.args[1].String())) 298 } 299 return a + b, false, nil 300 } 301 302 type arithmeticMinusFunctionClass struct { 303 baseFunctionClass 304 } 305 306 func (c *arithmeticMinusFunctionClass) getFunction(ctx stochastikctx.Context, args []Expression) (builtinFunc, error) { 307 if err := c.verifyArgs(args); err != nil { 308 return nil, err 309 } 310 lhsTp, rhsTp := args[0].GetType(), args[1].GetType() 311 lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) 312 if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { 313 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) 314 if err != nil { 315 return nil, err 316 } 317 setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true, false) 318 sig := &builtinArithmeticMinusRealSig{bf} 319 sig.setPbCode(fidelpb.ScalarFuncSig_MinusReal) 320 return sig, nil 321 } else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal { 322 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) 323 if err != nil { 324 return nil, err 325 } 326 setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false, false) 327 sig := &builtinArithmeticMinusDecimalSig{bf} 328 sig.setPbCode(fidelpb.ScalarFuncSig_MinusDecimal) 329 return sig, nil 330 } else { 331 332 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) 333 if err != nil { 334 return nil, err 335 } 336 setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) 337 if (allegrosql.HasUnsignedFlag(args[0].GetType().Flag) || allegrosql.HasUnsignedFlag(args[1].GetType().Flag)) && !ctx.GetStochastikVars().ALLEGROSQLMode.HasNoUnsignedSubtractionMode() { 338 bf.tp.Flag |= allegrosql.UnsignedFlag 339 } 340 sig := &builtinArithmeticMinusIntSig{baseBuiltinFunc: bf} 341 sig.setPbCode(fidelpb.ScalarFuncSig_MinusInt) 342 return sig, nil 343 } 344 } 345 346 type builtinArithmeticMinusRealSig struct { 347 baseBuiltinFunc 348 } 349 350 func (s *builtinArithmeticMinusRealSig) Clone() builtinFunc { 351 newSig := &builtinArithmeticMinusRealSig{} 352 newSig.cloneFrom(&s.baseBuiltinFunc) 353 return newSig 354 } 355 356 func (s *builtinArithmeticMinusRealSig) evalReal(event chunk.Event) (float64, bool, error) { 357 a, isNull, err := s.args[0].EvalReal(s.ctx, event) 358 if isNull || err != nil { 359 return 0, isNull, err 360 } 361 b, isNull, err := s.args[1].EvalReal(s.ctx, event) 362 if isNull || err != nil { 363 return 0, isNull, err 364 } 365 if (a > 0 && -b > math.MaxFloat64-a) || (a < 0 && -b < -math.MaxFloat64-a) { 366 return 0, true, types.ErrOverflow.GenWithStackByArgs("DOUBLE", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String())) 367 } 368 return a - b, false, nil 369 } 370 371 type builtinArithmeticMinusDecimalSig struct { 372 baseBuiltinFunc 373 } 374 375 func (s *builtinArithmeticMinusDecimalSig) Clone() builtinFunc { 376 newSig := &builtinArithmeticMinusDecimalSig{} 377 newSig.cloneFrom(&s.baseBuiltinFunc) 378 return newSig 379 } 380 381 func (s *builtinArithmeticMinusDecimalSig) evalDecimal(event chunk.Event) (*types.MyDecimal, bool, error) { 382 a, isNull, err := s.args[0].EvalDecimal(s.ctx, event) 383 if isNull || err != nil { 384 return nil, isNull, err 385 } 386 b, isNull, err := s.args[1].EvalDecimal(s.ctx, event) 387 if isNull || err != nil { 388 return nil, isNull, err 389 } 390 c := &types.MyDecimal{} 391 err = types.DecimalSub(a, b, c) 392 if err != nil { 393 return nil, true, err 394 } 395 return c, false, nil 396 } 397 398 type builtinArithmeticMinusIntSig struct { 399 baseBuiltinFunc 400 } 401 402 func (s *builtinArithmeticMinusIntSig) Clone() builtinFunc { 403 newSig := &builtinArithmeticMinusIntSig{} 404 newSig.cloneFrom(&s.baseBuiltinFunc) 405 return newSig 406 } 407 408 func (s *builtinArithmeticMinusIntSig) evalInt(event chunk.Event) (val int64, isNull bool, err error) { 409 a, isNull, err := s.args[0].EvalInt(s.ctx, event) 410 if isNull || err != nil { 411 return 0, isNull, err 412 } 413 414 b, isNull, err := s.args[1].EvalInt(s.ctx, event) 415 if isNull || err != nil { 416 return 0, isNull, err 417 } 418 forceToSigned := s.ctx.GetStochastikVars().ALLEGROSQLMode.HasNoUnsignedSubtractionMode() 419 isLHSUnsigned := !forceToSigned && allegrosql.HasUnsignedFlag(s.args[0].GetType().Flag) 420 isRHSUnsigned := !forceToSigned && allegrosql.HasUnsignedFlag(s.args[1].GetType().Flag) 421 422 if forceToSigned && allegrosql.HasUnsignedFlag(s.args[0].GetType().Flag) { 423 if a < 0 { 424 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String())) 425 } 426 } 427 if forceToSigned && allegrosql.HasUnsignedFlag(s.args[1].GetType().Flag) { 428 if b < 0 { 429 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String())) 430 } 431 } 432 433 switch { 434 case isLHSUnsigned && isRHSUnsigned: 435 if uint64(a) < uint64(b) { 436 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String())) 437 } 438 case isLHSUnsigned && !isRHSUnsigned: 439 if b >= 0 && uint64(a) < uint64(b) { 440 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String())) 441 } 442 if b < 0 && uint64(a) > math.MaxUint64-uint64(-b) { 443 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String())) 444 } 445 case !isLHSUnsigned && isRHSUnsigned: 446 if a < 0 || uint64(a) < uint64(b) { 447 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String())) 448 } 449 case !isLHSUnsigned && !isRHSUnsigned: 450 // We need `(a >= 0 && b == math.MinInt64)` due to `-(math.MinInt64) == math.MinInt64`. 451 // If `a<0 && b<=0`: `a-b` will not overflow even though b==math.MinInt64. 452 // If `a<0 && b>0`: `a-b` will not overflow only if `math.MinInt64<=a-b` satisfied 453 if (a >= 0 && b == math.MinInt64) || (a > 0 && -b > math.MaxInt64-a) || (a < 0 && -b < math.MinInt64-a) { 454 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s - %s)", s.args[0].String(), s.args[1].String())) 455 } 456 } 457 return a - b, false, nil 458 } 459 460 type arithmeticMultiplyFunctionClass struct { 461 baseFunctionClass 462 } 463 464 func (c *arithmeticMultiplyFunctionClass) getFunction(ctx stochastikctx.Context, args []Expression) (builtinFunc, error) { 465 if err := c.verifyArgs(args); err != nil { 466 return nil, err 467 } 468 lhsTp, rhsTp := args[0].GetType(), args[1].GetType() 469 lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) 470 if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { 471 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) 472 if err != nil { 473 return nil, err 474 } 475 setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true, true) 476 sig := &builtinArithmeticMultiplyRealSig{bf} 477 sig.setPbCode(fidelpb.ScalarFuncSig_MultiplyReal) 478 return sig, nil 479 } else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal { 480 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) 481 if err != nil { 482 return nil, err 483 } 484 setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false, true) 485 sig := &builtinArithmeticMultiplyDecimalSig{bf} 486 sig.setPbCode(fidelpb.ScalarFuncSig_MultiplyDecimal) 487 return sig, nil 488 } else { 489 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) 490 if err != nil { 491 return nil, err 492 } 493 if allegrosql.HasUnsignedFlag(lhsTp.Flag) || allegrosql.HasUnsignedFlag(rhsTp.Flag) { 494 bf.tp.Flag |= allegrosql.UnsignedFlag 495 setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) 496 sig := &builtinArithmeticMultiplyIntUnsignedSig{bf} 497 sig.setPbCode(fidelpb.ScalarFuncSig_MultiplyIntUnsigned) 498 return sig, nil 499 } 500 setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) 501 sig := &builtinArithmeticMultiplyIntSig{bf} 502 sig.setPbCode(fidelpb.ScalarFuncSig_MultiplyInt) 503 return sig, nil 504 } 505 } 506 507 type builtinArithmeticMultiplyRealSig struct{ baseBuiltinFunc } 508 509 func (s *builtinArithmeticMultiplyRealSig) Clone() builtinFunc { 510 newSig := &builtinArithmeticMultiplyRealSig{} 511 newSig.cloneFrom(&s.baseBuiltinFunc) 512 return newSig 513 } 514 515 type builtinArithmeticMultiplyDecimalSig struct{ baseBuiltinFunc } 516 517 func (s *builtinArithmeticMultiplyDecimalSig) Clone() builtinFunc { 518 newSig := &builtinArithmeticMultiplyDecimalSig{} 519 newSig.cloneFrom(&s.baseBuiltinFunc) 520 return newSig 521 } 522 523 type builtinArithmeticMultiplyIntUnsignedSig struct{ baseBuiltinFunc } 524 525 func (s *builtinArithmeticMultiplyIntUnsignedSig) Clone() builtinFunc { 526 newSig := &builtinArithmeticMultiplyIntUnsignedSig{} 527 newSig.cloneFrom(&s.baseBuiltinFunc) 528 return newSig 529 } 530 531 type builtinArithmeticMultiplyIntSig struct{ baseBuiltinFunc } 532 533 func (s *builtinArithmeticMultiplyIntSig) Clone() builtinFunc { 534 newSig := &builtinArithmeticMultiplyIntSig{} 535 newSig.cloneFrom(&s.baseBuiltinFunc) 536 return newSig 537 } 538 539 func (s *builtinArithmeticMultiplyRealSig) evalReal(event chunk.Event) (float64, bool, error) { 540 a, isNull, err := s.args[0].EvalReal(s.ctx, event) 541 if isNull || err != nil { 542 return 0, isNull, err 543 } 544 b, isNull, err := s.args[1].EvalReal(s.ctx, event) 545 if isNull || err != nil { 546 return 0, isNull, err 547 } 548 result := a * b 549 if math.IsInf(result, 0) { 550 return 0, true, types.ErrOverflow.GenWithStackByArgs("DOUBLE", fmt.Sprintf("(%s * %s)", s.args[0].String(), s.args[1].String())) 551 } 552 return result, false, nil 553 } 554 555 func (s *builtinArithmeticMultiplyDecimalSig) evalDecimal(event chunk.Event) (*types.MyDecimal, bool, error) { 556 a, isNull, err := s.args[0].EvalDecimal(s.ctx, event) 557 if isNull || err != nil { 558 return nil, isNull, err 559 } 560 b, isNull, err := s.args[1].EvalDecimal(s.ctx, event) 561 if isNull || err != nil { 562 return nil, isNull, err 563 } 564 c := &types.MyDecimal{} 565 err = types.DecimalMul(a, b, c) 566 if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) { 567 return nil, true, err 568 } 569 return c, false, nil 570 } 571 572 func (s *builtinArithmeticMultiplyIntUnsignedSig) evalInt(event chunk.Event) (val int64, isNull bool, err error) { 573 a, isNull, err := s.args[0].EvalInt(s.ctx, event) 574 if isNull || err != nil { 575 return 0, isNull, err 576 } 577 unsignedA := uint64(a) 578 b, isNull, err := s.args[1].EvalInt(s.ctx, event) 579 if isNull || err != nil { 580 return 0, isNull, err 581 } 582 unsignedB := uint64(b) 583 result := unsignedA * unsignedB 584 if unsignedA != 0 && result/unsignedA != unsignedB { 585 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s * %s)", s.args[0].String(), s.args[1].String())) 586 } 587 return int64(result), false, nil 588 } 589 590 func (s *builtinArithmeticMultiplyIntSig) evalInt(event chunk.Event) (val int64, isNull bool, err error) { 591 a, isNull, err := s.args[0].EvalInt(s.ctx, event) 592 if isNull || err != nil { 593 return 0, isNull, err 594 } 595 b, isNull, err := s.args[1].EvalInt(s.ctx, event) 596 if isNull || err != nil { 597 return 0, isNull, err 598 } 599 result := a * b 600 if a != 0 && result/a != b { 601 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s * %s)", s.args[0].String(), s.args[1].String())) 602 } 603 return result, false, nil 604 } 605 606 type arithmeticDivideFunctionClass struct { 607 baseFunctionClass 608 } 609 610 func (c *arithmeticDivideFunctionClass) getFunction(ctx stochastikctx.Context, args []Expression) (builtinFunc, error) { 611 if err := c.verifyArgs(args); err != nil { 612 return nil, err 613 } 614 lhsTp, rhsTp := args[0].GetType(), args[1].GetType() 615 lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) 616 if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { 617 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) 618 if err != nil { 619 return nil, err 620 } 621 c.setType4DivReal(bf.tp) 622 sig := &builtinArithmeticDivideRealSig{bf} 623 sig.setPbCode(fidelpb.ScalarFuncSig_DivideReal) 624 return sig, nil 625 } 626 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) 627 if err != nil { 628 return nil, err 629 } 630 c.setType4DivDecimal(bf.tp, lhsTp, rhsTp) 631 sig := &builtinArithmeticDivideDecimalSig{bf} 632 sig.setPbCode(fidelpb.ScalarFuncSig_DivideDecimal) 633 return sig, nil 634 } 635 636 type builtinArithmeticDivideRealSig struct{ baseBuiltinFunc } 637 638 func (s *builtinArithmeticDivideRealSig) Clone() builtinFunc { 639 newSig := &builtinArithmeticDivideRealSig{} 640 newSig.cloneFrom(&s.baseBuiltinFunc) 641 return newSig 642 } 643 644 type builtinArithmeticDivideDecimalSig struct{ baseBuiltinFunc } 645 646 func (s *builtinArithmeticDivideDecimalSig) Clone() builtinFunc { 647 newSig := &builtinArithmeticDivideDecimalSig{} 648 newSig.cloneFrom(&s.baseBuiltinFunc) 649 return newSig 650 } 651 652 func (s *builtinArithmeticDivideRealSig) evalReal(event chunk.Event) (float64, bool, error) { 653 a, isNull, err := s.args[0].EvalReal(s.ctx, event) 654 if isNull || err != nil { 655 return 0, isNull, err 656 } 657 b, isNull, err := s.args[1].EvalReal(s.ctx, event) 658 if isNull || err != nil { 659 return 0, isNull, err 660 } 661 if b == 0 { 662 return 0, true, handleDivisionByZeroError(s.ctx) 663 } 664 result := a / b 665 if math.IsInf(result, 0) { 666 return 0, true, types.ErrOverflow.GenWithStackByArgs("DOUBLE", fmt.Sprintf("(%s / %s)", s.args[0].String(), s.args[1].String())) 667 } 668 return result, false, nil 669 } 670 671 func (s *builtinArithmeticDivideDecimalSig) evalDecimal(event chunk.Event) (*types.MyDecimal, bool, error) { 672 a, isNull, err := s.args[0].EvalDecimal(s.ctx, event) 673 if isNull || err != nil { 674 return nil, isNull, err 675 } 676 677 b, isNull, err := s.args[1].EvalDecimal(s.ctx, event) 678 if isNull || err != nil { 679 return nil, isNull, err 680 } 681 682 c := &types.MyDecimal{} 683 err = types.DecimalDiv(a, b, c, types.DivFracIncr) 684 if err == types.ErrDivByZero { 685 return c, true, handleDivisionByZeroError(s.ctx) 686 } else if err == types.ErrTruncated { 687 sc := s.ctx.GetStochastikVars().StmtCtx 688 err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) 689 } else if err == nil { 690 _, frac := c.PrecisionAndFrac() 691 if frac < s.baseBuiltinFunc.tp.Decimal { 692 err = c.Round(c, s.baseBuiltinFunc.tp.Decimal, types.ModeHalfEven) 693 } 694 } 695 return c, false, err 696 } 697 698 type arithmeticIntDivideFunctionClass struct { 699 baseFunctionClass 700 } 701 702 func (c *arithmeticIntDivideFunctionClass) getFunction(ctx stochastikctx.Context, args []Expression) (builtinFunc, error) { 703 if err := c.verifyArgs(args); err != nil { 704 return nil, err 705 } 706 707 lhsTp, rhsTp := args[0].GetType(), args[1].GetType() 708 lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) 709 if lhsEvalTp == types.ETInt && rhsEvalTp == types.ETInt { 710 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) 711 if err != nil { 712 return nil, err 713 } 714 if allegrosql.HasUnsignedFlag(lhsTp.Flag) || allegrosql.HasUnsignedFlag(rhsTp.Flag) { 715 bf.tp.Flag |= allegrosql.UnsignedFlag 716 } 717 sig := &builtinArithmeticIntDivideIntSig{bf} 718 sig.setPbCode(fidelpb.ScalarFuncSig_IntDivideInt) 719 return sig, nil 720 } 721 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDecimal, types.ETDecimal) 722 if err != nil { 723 return nil, err 724 } 725 if allegrosql.HasUnsignedFlag(lhsTp.Flag) || allegrosql.HasUnsignedFlag(rhsTp.Flag) { 726 bf.tp.Flag |= allegrosql.UnsignedFlag 727 } 728 sig := &builtinArithmeticIntDivideDecimalSig{bf} 729 sig.setPbCode(fidelpb.ScalarFuncSig_IntDivideDecimal) 730 return sig, nil 731 } 732 733 type builtinArithmeticIntDivideIntSig struct{ baseBuiltinFunc } 734 735 func (s *builtinArithmeticIntDivideIntSig) Clone() builtinFunc { 736 newSig := &builtinArithmeticIntDivideIntSig{} 737 newSig.cloneFrom(&s.baseBuiltinFunc) 738 return newSig 739 } 740 741 type builtinArithmeticIntDivideDecimalSig struct{ baseBuiltinFunc } 742 743 func (s *builtinArithmeticIntDivideDecimalSig) Clone() builtinFunc { 744 newSig := &builtinArithmeticIntDivideDecimalSig{} 745 newSig.cloneFrom(&s.baseBuiltinFunc) 746 return newSig 747 } 748 749 func (s *builtinArithmeticIntDivideIntSig) evalInt(event chunk.Event) (int64, bool, error) { 750 return s.evalIntWithCtx(s.ctx, event) 751 } 752 753 func (s *builtinArithmeticIntDivideIntSig) evalIntWithCtx(sctx stochastikctx.Context, event chunk.Event) (int64, bool, error) { 754 b, isNull, err := s.args[1].EvalInt(sctx, event) 755 if isNull || err != nil { 756 return 0, isNull, err 757 } 758 759 if b == 0 { 760 return 0, true, handleDivisionByZeroError(sctx) 761 } 762 763 a, isNull, err := s.args[0].EvalInt(sctx, event) 764 if isNull || err != nil { 765 return 0, isNull, err 766 } 767 768 var ( 769 ret int64 770 val uint64 771 ) 772 isLHSUnsigned := allegrosql.HasUnsignedFlag(s.args[0].GetType().Flag) 773 isRHSUnsigned := allegrosql.HasUnsignedFlag(s.args[1].GetType().Flag) 774 775 switch { 776 case isLHSUnsigned && isRHSUnsigned: 777 ret = int64(uint64(a) / uint64(b)) 778 case isLHSUnsigned && !isRHSUnsigned: 779 val, err = types.DivUintWithInt(uint64(a), b) 780 ret = int64(val) 781 case !isLHSUnsigned && isRHSUnsigned: 782 val, err = types.DivIntWithUint(a, uint64(b)) 783 ret = int64(val) 784 case !isLHSUnsigned && !isRHSUnsigned: 785 ret, err = types.DivInt64(a, b) 786 } 787 788 return ret, err != nil, err 789 } 790 791 func (s *builtinArithmeticIntDivideDecimalSig) evalInt(event chunk.Event) (ret int64, isNull bool, err error) { 792 sc := s.ctx.GetStochastikVars().StmtCtx 793 var num [2]*types.MyDecimal 794 for i, arg := range s.args { 795 num[i], isNull, err = arg.EvalDecimal(s.ctx, event) 796 if isNull || err != nil { 797 return 0, isNull, err 798 } 799 } 800 801 c := &types.MyDecimal{} 802 err = types.DecimalDiv(num[0], num[1], c, types.DivFracIncr) 803 if err == types.ErrDivByZero { 804 return 0, true, handleDivisionByZeroError(s.ctx) 805 } 806 if err == types.ErrTruncated { 807 err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) 808 } 809 if err == types.ErrOverflow { 810 newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c) 811 err = sc.HandleOverflow(newErr, newErr) 812 } 813 if err != nil { 814 return 0, true, err 815 } 816 817 isLHSUnsigned := allegrosql.HasUnsignedFlag(s.args[0].GetType().Flag) 818 isRHSUnsigned := allegrosql.HasUnsignedFlag(s.args[1].GetType().Flag) 819 820 if isLHSUnsigned || isRHSUnsigned { 821 val, err := c.ToUint() 822 // err returned by ToUint may be ErrTruncated or ErrOverflow, only handle ErrOverflow, ignore ErrTruncated. 823 if err == types.ErrOverflow { 824 v, err := c.ToInt() 825 // when the final result is at (-1, 0], it should be return 0 instead of the error 826 if v == 0 && err == types.ErrTruncated { 827 ret = int64(0) 828 return ret, false, nil 829 } 830 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s DIV %s)", s.args[0].String(), s.args[1].String())) 831 } 832 ret = int64(val) 833 } else { 834 ret, err = c.ToInt() 835 // err returned by ToInt may be ErrTruncated or ErrOverflow, only handle ErrOverflow, ignore ErrTruncated. 836 if err == types.ErrOverflow { 837 return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s DIV %s)", s.args[0].String(), s.args[1].String())) 838 } 839 } 840 841 return ret, false, nil 842 } 843 844 type arithmeticModFunctionClass struct { 845 baseFunctionClass 846 } 847 848 func (c *arithmeticModFunctionClass) setType4ModRealOrDecimal(retTp, a, b *types.FieldType, isDecimal bool) { 849 if a.Decimal == types.UnspecifiedLength || b.Decimal == types.UnspecifiedLength { 850 retTp.Decimal = types.UnspecifiedLength 851 } else { 852 retTp.Decimal = mathutil.Max(a.Decimal, b.Decimal) 853 if isDecimal && retTp.Decimal > allegrosql.MaxDecimalScale { 854 retTp.Decimal = allegrosql.MaxDecimalScale 855 } 856 } 857 858 if a.Flen == types.UnspecifiedLength || b.Flen == types.UnspecifiedLength { 859 retTp.Flen = types.UnspecifiedLength 860 } else { 861 retTp.Flen = mathutil.Max(a.Flen, b.Flen) 862 if isDecimal { 863 retTp.Flen = mathutil.Min(retTp.Flen, allegrosql.MaxDecimalWidth) 864 return 865 } 866 retTp.Flen = mathutil.Min(retTp.Flen, allegrosql.MaxRealWidth) 867 } 868 } 869 870 func (c *arithmeticModFunctionClass) getFunction(ctx stochastikctx.Context, args []Expression) (builtinFunc, error) { 871 if err := c.verifyArgs(args); err != nil { 872 return nil, err 873 } 874 lhsTp, rhsTp := args[0].GetType(), args[1].GetType() 875 lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) 876 if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { 877 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) 878 if err != nil { 879 return nil, err 880 } 881 c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, false) 882 if allegrosql.HasUnsignedFlag(lhsTp.Flag) { 883 bf.tp.Flag |= allegrosql.UnsignedFlag 884 } 885 sig := &builtinArithmeticModRealSig{bf} 886 sig.setPbCode(fidelpb.ScalarFuncSig_ModReal) 887 return sig, nil 888 } else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal { 889 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) 890 if err != nil { 891 return nil, err 892 } 893 c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, true) 894 if allegrosql.HasUnsignedFlag(lhsTp.Flag) { 895 bf.tp.Flag |= allegrosql.UnsignedFlag 896 } 897 sig := &builtinArithmeticModDecimalSig{bf} 898 sig.setPbCode(fidelpb.ScalarFuncSig_ModDecimal) 899 return sig, nil 900 } else { 901 bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) 902 if err != nil { 903 return nil, err 904 } 905 if allegrosql.HasUnsignedFlag(lhsTp.Flag) { 906 bf.tp.Flag |= allegrosql.UnsignedFlag 907 } 908 sig := &builtinArithmeticModIntSig{bf} 909 sig.setPbCode(fidelpb.ScalarFuncSig_ModInt) 910 return sig, nil 911 } 912 } 913 914 type builtinArithmeticModRealSig struct { 915 baseBuiltinFunc 916 } 917 918 func (s *builtinArithmeticModRealSig) Clone() builtinFunc { 919 newSig := &builtinArithmeticModRealSig{} 920 newSig.cloneFrom(&s.baseBuiltinFunc) 921 return newSig 922 } 923 924 func (s *builtinArithmeticModRealSig) evalReal(event chunk.Event) (float64, bool, error) { 925 b, isNull, err := s.args[1].EvalReal(s.ctx, event) 926 if isNull || err != nil { 927 return 0, isNull, err 928 } 929 930 if b == 0 { 931 return 0, true, handleDivisionByZeroError(s.ctx) 932 } 933 934 a, isNull, err := s.args[0].EvalReal(s.ctx, event) 935 if isNull || err != nil { 936 return 0, isNull, err 937 } 938 939 return math.Mod(a, b), false, nil 940 } 941 942 type builtinArithmeticModDecimalSig struct { 943 baseBuiltinFunc 944 } 945 946 func (s *builtinArithmeticModDecimalSig) Clone() builtinFunc { 947 newSig := &builtinArithmeticModDecimalSig{} 948 newSig.cloneFrom(&s.baseBuiltinFunc) 949 return newSig 950 } 951 952 func (s *builtinArithmeticModDecimalSig) evalDecimal(event chunk.Event) (*types.MyDecimal, bool, error) { 953 a, isNull, err := s.args[0].EvalDecimal(s.ctx, event) 954 if isNull || err != nil { 955 return nil, isNull, err 956 } 957 b, isNull, err := s.args[1].EvalDecimal(s.ctx, event) 958 if isNull || err != nil { 959 return nil, isNull, err 960 } 961 c := &types.MyDecimal{} 962 err = types.DecimalMod(a, b, c) 963 if err == types.ErrDivByZero { 964 return c, true, handleDivisionByZeroError(s.ctx) 965 } 966 return c, err != nil, err 967 } 968 969 type builtinArithmeticModIntSig struct { 970 baseBuiltinFunc 971 } 972 973 func (s *builtinArithmeticModIntSig) Clone() builtinFunc { 974 newSig := &builtinArithmeticModIntSig{} 975 newSig.cloneFrom(&s.baseBuiltinFunc) 976 return newSig 977 } 978 979 func (s *builtinArithmeticModIntSig) evalInt(event chunk.Event) (val int64, isNull bool, err error) { 980 b, isNull, err := s.args[1].EvalInt(s.ctx, event) 981 if isNull || err != nil { 982 return 0, isNull, err 983 } 984 985 if b == 0 { 986 return 0, true, handleDivisionByZeroError(s.ctx) 987 } 988 989 a, isNull, err := s.args[0].EvalInt(s.ctx, event) 990 if isNull || err != nil { 991 return 0, isNull, err 992 } 993 994 var ret int64 995 isLHSUnsigned := allegrosql.HasUnsignedFlag(s.args[0].GetType().Flag) 996 isRHSUnsigned := allegrosql.HasUnsignedFlag(s.args[1].GetType().Flag) 997 998 switch { 999 case isLHSUnsigned && isRHSUnsigned: 1000 ret = int64(uint64(a) % uint64(b)) 1001 case isLHSUnsigned && !isRHSUnsigned: 1002 if b < 0 { 1003 ret = int64(uint64(a) % uint64(-b)) 1004 } else { 1005 ret = int64(uint64(a) % uint64(b)) 1006 } 1007 case !isLHSUnsigned && isRHSUnsigned: 1008 if a < 0 { 1009 ret = -int64(uint64(-a) % uint64(b)) 1010 } else { 1011 ret = int64(uint64(a) % uint64(b)) 1012 } 1013 case !isLHSUnsigned && !isRHSUnsigned: 1014 ret = a % b 1015 } 1016 1017 return ret, false, nil 1018 }