github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/arithmetic_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "fmt" 19 "testing" 20 "time" 21 22 "github.com/dolthub/vitess/go/sqltypes" 23 "github.com/shopspring/decimal" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 "gopkg.in/src-d/go-errors.v1" 27 28 "github.com/dolthub/go-mysql-server/sql" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 _ "github.com/dolthub/go-mysql-server/sql/variables" 31 ) 32 33 func TestPlus(t *testing.T) { 34 var testCases = []struct { 35 name string 36 left sql.Expression 37 right sql.Expression 38 exp interface{} 39 skip bool 40 }{ 41 { 42 left: NewLiteral(1, types.Uint32), 43 right: NewLiteral(1, types.Uint32), 44 exp: uint64(2), 45 }, 46 { 47 left: NewLiteral(1, types.Uint64), 48 right: NewLiteral(1, types.Uint64), 49 exp: uint64(2), 50 }, 51 { 52 left: NewLiteral(1, types.Int32), 53 right: NewLiteral(1, types.Int32), 54 exp: int64(2), 55 }, 56 { 57 left: NewLiteral(1, types.Int64), 58 right: NewLiteral(1, types.Int64), 59 exp: int64(2), 60 }, 61 { 62 left: NewLiteral(0, types.Int64), 63 right: NewLiteral(0, types.Int64), 64 exp: int64(0), 65 }, 66 { 67 left: NewLiteral(-1, types.Int64), 68 right: NewLiteral(1, types.Int64), 69 exp: int64(0), 70 }, 71 { 72 left: NewLiteral(1, types.Float32), 73 right: NewLiteral(1, types.Float32), 74 exp: float64(2), 75 }, 76 { 77 left: NewLiteral(1, types.Float64), 78 right: NewLiteral(1, types.Float64), 79 exp: float64(2), 80 }, 81 { 82 left: NewLiteral(0.1459, types.Float64), 83 right: NewLiteral(3.0, types.Float64), 84 exp: 3.1459, 85 }, 86 { 87 left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 88 right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 89 exp: "2", 90 }, 91 { 92 left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 93 right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 94 exp: "2.000", 95 }, 96 { 97 left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 98 right: NewLiteral(decimal.New(100000, -5), types.MustCreateDecimalType(10, 5)), // 1.00000 99 exp: "2.00000", 100 }, 101 { 102 left: NewLiteral(decimal.New(1459, -4), types.MustCreateDecimalType(10, 4)), // 0.1459 103 right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), // 3 104 exp: "3.1459", 105 }, 106 { 107 left: NewLiteral(2001, types.Year), 108 right: NewLiteral(2002, types.Year), 109 exp: uint64(4003), 110 }, 111 { 112 left: NewLiteral("2001-01-01", types.Date), 113 right: NewLiteral("2001-01-01", types.Date), 114 exp: int64(40020202), 115 }, 116 { 117 skip: true, // need to trim just the date portion 118 left: NewLiteral("2001-01-01 12:00:00", types.Date), 119 right: NewLiteral("2001-01-01 12:00:00", types.Date), 120 exp: int64(40020202), 121 }, 122 { 123 skip: true, // need to trim just the date portion 124 left: NewLiteral("2001-01-01 12:00:00.123456", types.Date), 125 right: NewLiteral("2001-01-01 12:00:00.123456", types.Date), 126 exp: int64(40020202), 127 }, 128 { 129 left: NewLiteral("2001-01-01 12:00:00", types.Datetime), 130 right: NewLiteral("2001-01-01 12:00:00", types.Datetime), 131 exp: int64(40020202240000), 132 }, 133 { 134 skip: true, // need to trim just the datetime portion according to precision 135 left: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), 136 right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), 137 exp: int64(40020202240000), 138 }, 139 { 140 skip: true, // need to trim just the datetime portion according to precision and use as exponent 141 left: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), 142 right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), 143 exp: "40020202240000.246", 144 }, 145 { 146 skip: true, // need to use precision as exponent 147 left: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), 148 right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), 149 exp: "40020202240000.246912", 150 }, 151 { 152 left: NewLiteral("1", types.Text), 153 right: NewLiteral("1", types.Text), 154 exp: float64(2), 155 }, 156 { 157 left: NewLiteral("1", types.Text), 158 right: NewLiteral(1.0, types.Float64), 159 exp: float64(2), 160 }, 161 { 162 left: NewLiteral(1, types.MustCreateBitType(1)), 163 right: NewLiteral(0, types.MustCreateBitType(1)), 164 exp: int64(1), 165 }, 166 { 167 left: NewLiteral("2018-05-01", types.LongText), 168 right: NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), 169 exp: time.Date(2018, time.May, 2, 0, 0, 0, 0, time.UTC), 170 }, 171 { 172 left: NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), 173 right: NewLiteral("2018-05-01", types.LongText), 174 exp: time.Date(2018, time.May, 2, 0, 0, 0, 0, time.UTC), 175 }, 176 } 177 178 for _, tt := range testCases { 179 name := fmt.Sprintf("%s(%v)+%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right) 180 t.Run(name, func(t *testing.T) { 181 require := require.New(t) 182 if tt.skip { 183 t.Skip() 184 } 185 f := NewPlus(tt.left, tt.right) 186 result, err := f.Eval(sql.NewEmptyContext(), nil) 187 require.NoError(err) 188 if dec, ok := result.(decimal.Decimal); ok { 189 result = dec.StringFixed(dec.Exponent() * -1) 190 } 191 assert.Equal(t, tt.exp, result) 192 }) 193 } 194 } 195 196 func TestMinus(t *testing.T) { 197 var testCases = []struct { 198 name string 199 left sql.Expression 200 right sql.Expression 201 exp interface{} 202 skip bool 203 }{ 204 { 205 left: NewLiteral(1, types.Uint32), 206 right: NewLiteral(1, types.Uint32), 207 exp: uint64(0), 208 }, 209 { 210 left: NewLiteral(1, types.Uint64), 211 right: NewLiteral(1, types.Uint64), 212 exp: uint64(0), 213 }, 214 { 215 left: NewLiteral(1, types.Int32), 216 right: NewLiteral(1, types.Int32), 217 exp: int64(0), 218 }, 219 { 220 left: NewLiteral(1, types.Int64), 221 right: NewLiteral(1, types.Int64), 222 exp: int64(0), 223 }, 224 { 225 left: NewLiteral(0, types.Int64), 226 right: NewLiteral(0, types.Int64), 227 exp: int64(0), 228 }, 229 { 230 left: NewLiteral(-1, types.Int64), 231 right: NewLiteral(1, types.Int64), 232 exp: int64(-2), 233 }, 234 { 235 left: NewLiteral(1, types.Float32), 236 right: NewLiteral(1, types.Float32), 237 exp: float64(0), 238 }, 239 { 240 left: NewLiteral(1, types.Float64), 241 right: NewLiteral(1, types.Float64), 242 exp: float64(0), 243 }, 244 { 245 left: NewLiteral(0.1459, types.Float64), 246 right: NewLiteral(3.0, types.Float64), 247 exp: -2.8541, 248 }, 249 { 250 left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 251 right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 252 exp: "0", 253 }, 254 { 255 left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 256 right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 257 exp: "0.000", 258 }, 259 { 260 left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 261 right: NewLiteral(decimal.New(100000, -5), types.MustCreateDecimalType(10, 5)), // 1.00000 262 exp: "0.00000", 263 }, 264 { 265 left: NewLiteral(decimal.New(1459, -4), types.MustCreateDecimalType(10, 4)), // 0.1459 266 right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), // 3 267 exp: "-2.8541", 268 }, 269 { 270 left: NewLiteral(2002, types.Year), 271 right: NewLiteral(2001, types.Year), 272 exp: uint64(1), 273 }, 274 { 275 left: NewLiteral("2001-01-01", types.Date), 276 right: NewLiteral("2001-01-01", types.Date), 277 exp: int64(0), 278 }, 279 { 280 skip: true, // need to trim just the date portion 281 left: NewLiteral("2001-01-01 12:00:00", types.Date), 282 right: NewLiteral("2001-01-01 12:00:00", types.Date), 283 exp: int64(0), 284 }, 285 { 286 skip: true, // need to trim just the date portion 287 left: NewLiteral("2001-01-01 12:00:00.123456", types.Date), 288 right: NewLiteral("2001-01-01 12:00:00.123456", types.Date), 289 exp: int64(0), 290 }, 291 { 292 left: NewLiteral("2001-01-01 12:00:00", types.Datetime), 293 right: NewLiteral("2001-01-01 12:00:00", types.Datetime), 294 exp: int64(0), 295 }, 296 { 297 skip: true, // need to trim just the datetime portion according to precision 298 left: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), 299 right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), 300 exp: int64(0), 301 }, 302 { 303 skip: true, // need to trim just the datetime portion according to precision and use as exponent 304 left: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), 305 right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), 306 exp: "0.000", 307 }, 308 { 309 skip: true, // need to use precision as exponent 310 left: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), 311 right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), 312 exp: "0.000000", 313 }, 314 { 315 left: NewLiteral("1", types.Text), 316 right: NewLiteral("1", types.Text), 317 exp: float64(0), 318 }, 319 { 320 left: NewLiteral("1", types.Text), 321 right: NewLiteral(1.0, types.Float64), 322 exp: float64(0), 323 }, 324 { 325 left: NewLiteral(1, types.MustCreateBitType(1)), 326 right: NewLiteral(0, types.MustCreateBitType(1)), 327 exp: int64(1), 328 }, 329 { 330 left: NewLiteral("2018-05-01", types.LongText), 331 right: NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), 332 exp: time.Date(2018, time.April, 30, 0, 0, 0, 0, time.UTC), 333 }, 334 } 335 336 for _, tt := range testCases { 337 name := fmt.Sprintf("%s(%v)-%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right) 338 t.Run(name, func(t *testing.T) { 339 require := require.New(t) 340 if tt.skip { 341 t.Skip() 342 } 343 f := NewMinus(tt.left, tt.right) 344 result, err := f.Eval(sql.NewEmptyContext(), nil) 345 require.NoError(err) 346 if dec, ok := result.(decimal.Decimal); ok { 347 result = dec.StringFixed(dec.Exponent() * -1) 348 } 349 assert.Equal(t, tt.exp, result) 350 }) 351 } 352 } 353 354 func TestMult(t *testing.T) { 355 var testCases = []struct { 356 name string 357 left sql.Expression 358 right sql.Expression 359 exp interface{} 360 err *errors.Kind 361 skip bool 362 }{ 363 { 364 left: NewLiteral(1, types.Uint32), 365 right: NewLiteral(1, types.Uint32), 366 exp: uint64(1), 367 }, 368 { 369 left: NewLiteral(1, types.Uint64), 370 right: NewLiteral(1, types.Uint64), 371 exp: uint64(1), 372 }, 373 { 374 left: NewLiteral(1, types.Int32), 375 right: NewLiteral(1, types.Int32), 376 exp: int64(1), 377 }, 378 { 379 left: NewLiteral(1, types.Int64), 380 right: NewLiteral(1, types.Int64), 381 exp: int64(1), 382 }, 383 { 384 left: NewLiteral(0, types.Int64), 385 right: NewLiteral(0, types.Int64), 386 exp: int64(0), 387 }, 388 { 389 left: NewLiteral(-1, types.Int64), 390 right: NewLiteral(1, types.Int64), 391 exp: int64(-1), 392 }, 393 { 394 left: NewLiteral(1, types.Float32), 395 right: NewLiteral(1, types.Float32), 396 exp: float64(1), 397 }, 398 { 399 left: NewLiteral(1, types.Float64), 400 right: NewLiteral(1, types.Float64), 401 exp: float64(1), 402 }, 403 { 404 left: NewLiteral(0.1459, types.Float64), 405 right: NewLiteral(3.0, types.Float64), 406 exp: 0.4377, 407 }, 408 { 409 left: NewLiteral(3.1459, types.Float64), 410 right: NewLiteral(3.0, types.Float64), 411 exp: 9.4377, 412 }, 413 { 414 left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 415 right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 416 exp: "1", 417 }, 418 { 419 left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 420 right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 421 exp: "1.000", 422 }, 423 { 424 left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 425 right: NewLiteral(decimal.New(100000, -5), types.MustCreateDecimalType(10, 5)), // 1.00000 426 exp: "1.00000000", 427 }, 428 { 429 left: NewLiteral(decimal.New(1459, -4), types.MustCreateDecimalType(10, 4)), // 0.1459 430 right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), // 3 431 exp: "0.4377", 432 }, 433 { 434 left: NewLiteral(decimal.New(31459, -4), types.MustCreateDecimalType(10, 4)), // 3.1459 435 right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), // 3 436 exp: "9.4377", 437 }, 438 { 439 left: NewLiteral(2002, types.Year), 440 right: NewLiteral(2001, types.Year), 441 exp: uint64(4006002), 442 }, 443 { 444 left: NewLiteral("2001-01-01", types.Date), 445 right: NewLiteral("2001-01-01", types.Date), 446 exp: int64(400404142030201), 447 }, 448 { 449 skip: true, // need to trim just the date portion 450 left: NewLiteral("2001-01-01 12:00:00", types.Date), 451 right: NewLiteral("2001-01-01 12:00:00", types.Date), 452 exp: int64(400404142030201), 453 }, 454 { 455 skip: true, // need to trim just the date portion 456 left: NewLiteral("2001-01-01 12:00:00.123456", types.Date), 457 right: NewLiteral("2001-01-01 12:00:00.123456", types.Date), 458 exp: int64(400404142030201), 459 }, 460 { 461 // MySQL throws out of range 462 skip: true, 463 left: NewLiteral("2001-01-01 12:00:00", types.Datetime), 464 right: NewLiteral("2001-01-01 12:00:00", types.Datetime), 465 err: sql.ErrValueOutOfRange, 466 }, 467 { 468 skip: true, // need to trim just the datetime portion according to precision 469 left: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), 470 right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), 471 err: sql.ErrValueOutOfRange, 472 }, 473 { 474 skip: true, // need to trim just the datetime portion according to precision and use as exponent 475 left: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), 476 right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), 477 exp: "400404146832630176884875520.015129", 478 }, 479 { 480 skip: true, // need to use precision as exponent 481 left: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), 482 right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), 483 exp: "400404146832630195134087741.455241383936", 484 }, 485 { 486 left: NewLiteral("10", types.Text), 487 right: NewLiteral("10", types.Text), 488 exp: float64(100), 489 }, 490 { 491 left: NewLiteral("10", types.Text), 492 right: NewLiteral(10.0, types.Float64), 493 exp: float64(100), 494 }, 495 { 496 left: NewLiteral(1, types.MustCreateBitType(1)), 497 right: NewLiteral(0, types.MustCreateBitType(1)), 498 exp: int64(0), 499 }, 500 } 501 502 for _, tt := range testCases { 503 name := fmt.Sprintf("%s(%v)*%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right) 504 t.Run(name, func(t *testing.T) { 505 require := require.New(t) 506 if tt.skip { 507 t.Skip() 508 } 509 f := NewMult(tt.left, tt.right) 510 result, err := f.Eval(sql.NewEmptyContext(), nil) 511 if tt.err != nil { 512 require.Error(err) 513 require.True(tt.err.Is(err), err.Error()) 514 return 515 } 516 require.NoError(err) 517 if dec, ok := result.(decimal.Decimal); ok { 518 result = dec.StringFixed(dec.Exponent() * -1) 519 } 520 assert.Equal(t, tt.exp, result) 521 }) 522 } 523 } 524 525 func TestMod(t *testing.T) { 526 // TODO: make this match the others 527 var testCases = []struct { 528 name string 529 left, right int64 530 expected string 531 null bool 532 }{ 533 {"1 % 1", 1, 1, "0", false}, 534 {"8 % 3", 8, 3, "2", false}, 535 {"1 % 3", 1, 3, "1", false}, 536 {"0 % -1024", 0, -1024, "0", false}, 537 {"-1 % 2", -1, 2, "-1", false}, 538 {"1 % -2", 1, -2, "1", false}, 539 {"-1 % -2", -1, -2, "-1", false}, 540 {"1 % 0", 1, 0, "0", true}, 541 {"0 % 0", 0, 0, "0", true}, 542 {"0.5 % 0.24", 0, 0, "0.02", true}, 543 } 544 545 for _, tt := range testCases { 546 t.Run(tt.name, func(t *testing.T) { 547 require := require.New(t) 548 result, err := NewMod( 549 NewLiteral(tt.left, types.Int64), 550 NewLiteral(tt.right, types.Int64), 551 ).Eval(sql.NewEmptyContext(), sql.NewRow()) 552 require.NoError(err) 553 if tt.null { 554 require.Nil(result) 555 } else { 556 r, ok := result.(decimal.Decimal) 557 require.True(ok) 558 require.Equal(tt.expected, r.StringFixed(r.Exponent()*-1)) 559 } 560 }) 561 } 562 } 563 564 func TestUnaryMinus(t *testing.T) { 565 testCases := []struct { 566 name string 567 input interface{} 568 typ sql.Type 569 expected interface{} 570 }{ 571 {"int32", int32(1), types.Int32, int32(-1)}, 572 {"uint32", uint32(1), types.Uint32, int32(-1)}, 573 {"int64", int64(1), types.Int64, int64(-1)}, 574 {"uint64", uint64(1), types.Uint64, int64(-1)}, 575 {"float32", float32(1), types.Float32, float32(-1)}, 576 {"float64", float64(1), types.Float64, float64(-1)}, 577 {"int text", "1", types.LongText, "-1"}, 578 {"float text", "1.2", types.LongText, "-1.2"}, 579 {"nil", nil, types.LongText, nil}, 580 } 581 582 for _, tt := range testCases { 583 t.Run(tt.name, func(t *testing.T) { 584 f := NewUnaryMinus(NewLiteral(tt.input, tt.typ)) 585 result, err := f.Eval(sql.NewEmptyContext(), nil) 586 require.NoError(t, err) 587 if dt, ok := result.(decimal.Decimal); ok { 588 require.Equal(t, tt.expected, dt.StringFixed(dt.Exponent()*-1)) 589 } else { 590 require.Equal(t, tt.expected, result) 591 } 592 }) 593 } 594 }