github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/div_test.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "fmt" 19 "testing" 20 21 "github.com/dolthub/vitess/go/sqltypes" 22 "github.com/shopspring/decimal" 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/require" 25 "gopkg.in/src-d/go-errors.v1" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 func TestDiv(t *testing.T) { 32 var testCases = []struct { 33 name string 34 left sql.Expression 35 right sql.Expression 36 exp interface{} 37 err *errors.Kind 38 skip bool 39 }{ 40 { 41 left: NewLiteral(1, types.Int64), 42 right: NewLiteral(0, types.Int64), 43 exp: nil, 44 }, 45 46 // Unsigned Integers 47 { 48 left: NewLiteral(1, types.Uint32), 49 right: NewLiteral(1, types.Uint32), 50 exp: "1.0000", 51 }, 52 { 53 left: NewLiteral(1, types.Uint32), 54 right: NewLiteral(2, types.Uint32), 55 exp: "0.5000", 56 }, 57 { 58 left: NewLiteral(1, types.Uint64), 59 right: NewLiteral(1, types.Uint64), 60 exp: "1.0000", 61 }, 62 { 63 left: NewLiteral(1, types.Uint64), 64 right: NewLiteral(2, types.Uint64), 65 exp: "0.5000", 66 }, 67 68 // Signed Integers 69 { 70 left: NewLiteral(1, types.Int32), 71 right: NewLiteral(1, types.Int32), 72 exp: "1.0000", 73 }, 74 { 75 left: NewLiteral(1, types.Int32), 76 right: NewLiteral(2, types.Int32), 77 exp: "0.5000", 78 }, 79 { 80 left: NewLiteral(-1, types.Int32), 81 right: NewLiteral(2, types.Int32), 82 exp: "-0.5000", 83 }, 84 { 85 left: NewLiteral(1, types.Int32), 86 right: NewLiteral(-2, types.Int32), 87 exp: "-0.5000", 88 }, 89 { 90 left: NewLiteral(1, types.Int64), 91 right: NewLiteral(1, types.Int64), 92 exp: "1.0000", 93 }, 94 { 95 left: NewLiteral(1, types.Int64), 96 right: NewLiteral(2, types.Int64), 97 exp: "0.5000", 98 }, 99 { 100 left: NewLiteral(-1, types.Int64), 101 right: NewLiteral(2, types.Int64), 102 exp: "-0.5000", 103 }, 104 { 105 left: NewLiteral(1, types.Int64), 106 right: NewLiteral(-2, types.Int64), 107 exp: "-0.5000", 108 }, 109 110 // Unsigned and Signed Integers 111 { 112 left: NewLiteral(1, types.Uint32), 113 right: NewLiteral(-2, types.Int32), 114 exp: "-0.5000", 115 }, 116 { 117 left: NewLiteral(-1, types.Int64), 118 right: NewLiteral(2, types.Uint32), 119 exp: "-0.5000", 120 }, 121 { 122 left: NewLiteral(1, types.Int64), 123 right: NewLiteral(123456789, types.Int64), 124 exp: "0.0000", 125 }, 126 127 // Repeating Decimals 128 { 129 left: NewLiteral(1, types.Int64), 130 right: NewLiteral(3, types.Int64), 131 exp: "0.3333", 132 }, 133 { 134 left: NewLiteral(1, types.Int64), 135 right: NewLiteral(9, types.Int64), 136 exp: "0.1111", 137 }, 138 { 139 left: NewLiteral(1, types.Int64), 140 right: NewLiteral(6, types.Int64), 141 exp: "0.1667", 142 }, 143 144 // Floats 145 { 146 left: NewLiteral(1.0, types.Float32), 147 right: NewLiteral(3.0, types.Float32), 148 exp: 0.3333333333333333, 149 }, 150 { 151 left: NewLiteral(1.0, types.Float32), 152 right: NewLiteral(9.0, types.Float32), 153 exp: 0.1111111111111111, 154 }, 155 { 156 left: NewLiteral(1.0, types.Float64), 157 right: NewLiteral(3.0, types.Float64), 158 exp: 0.3333333333333333, 159 }, 160 { 161 left: NewLiteral(1.0, types.Float64), 162 right: NewLiteral(9.0, types.Float64), 163 exp: 0.1111111111111111, 164 }, 165 { 166 // MySQL treats float32 a little differently 167 skip: true, 168 left: NewLiteral(3.14159, types.Float32), 169 right: NewLiteral(3.0, types.Float32), 170 exp: 1.0471967061360676, 171 }, 172 { 173 left: NewLiteral(3.14159, types.Float64), 174 right: NewLiteral(3.0, types.Float64), 175 exp: 1.0471966666666666, 176 }, 177 178 // Decimals 179 { 180 left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 181 right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), 182 exp: "0.3333", 183 }, 184 { 185 left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), 186 right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), 187 exp: "0.3333333", 188 }, 189 { 190 left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 191 right: NewLiteral(decimal.New(3000, -3), types.MustCreateDecimalType(10, 3)), 192 exp: "0.3333", 193 }, 194 { 195 left: NewLiteral(decimal.New(314159, -5), types.MustCreateDecimalType(10, 5)), 196 right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), 197 exp: "1.047196666", 198 }, 199 { 200 left: NewLiteral(decimal.NewFromFloat(3.14159), types.MustCreateDecimalType(10, 5)), 201 right: NewLiteral(3, types.Int64), 202 exp: "1.047196666", 203 }, 204 205 // Bit 206 { 207 left: NewLiteral(0, types.MustCreateBitType(1)), 208 right: NewLiteral(1, types.MustCreateBitType(1)), 209 exp: "0.0000", 210 }, 211 { 212 left: NewLiteral(1, types.MustCreateBitType(1)), 213 right: NewLiteral(1, types.MustCreateBitType(1)), 214 exp: "1.0000", 215 }, 216 217 // Year 218 { 219 left: NewLiteral(2001, types.YearType_{}), 220 right: NewLiteral(2002, types.YearType_{}), 221 exp: "0.9995", 222 }, 223 224 // Time 225 { 226 left: NewLiteral("2001-01-01", types.Date), 227 right: NewLiteral("2001-01-01", types.Date), 228 exp: "1.0000", 229 }, 230 { 231 left: NewLiteral("2001-01-01 12:00:00", types.Date), 232 right: NewLiteral("2001-01-01 12:00:00", types.Date), 233 exp: "1.0000", 234 }, 235 { 236 skip: true, // need to trim just the date portion 237 left: NewLiteral("2001-01-01 12:00:00.123456", types.Date), 238 right: NewLiteral("2001-01-01 12:00:00.123456", types.Date), 239 exp: "1.0000", 240 }, 241 { 242 left: NewLiteral("2001-01-01 12:00:00", types.Datetime), 243 right: NewLiteral("2001-01-01 12:00:00", types.Datetime), 244 exp: "1.0000", 245 }, 246 { 247 skip: true, // need to trim just the datetime portion according to precision and use as exponent 248 left: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), 249 right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), 250 exp: "1.0000", 251 }, 252 { 253 skip: true, // need to trim just the datetime portion according to precision and use as exponent 254 left: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), 255 right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), 256 exp: "1.0000000", 257 }, 258 { 259 left: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), 260 right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), 261 exp: "1.0000000000", 262 }, 263 264 // Text 265 { 266 left: NewLiteral("1", types.Text), 267 right: NewLiteral("3", types.Text), 268 exp: 0.3333333333333333, 269 }, 270 { 271 left: NewLiteral("1.000", types.Text), 272 right: NewLiteral("3", types.Text), 273 exp: 0.3333333333333333, 274 }, 275 { 276 left: NewLiteral("1", types.Text), 277 right: NewLiteral("3.000", types.Text), 278 exp: 0.3333333333333333, 279 }, 280 { 281 left: NewLiteral("3.14159", types.Text), 282 right: NewLiteral("3", types.Text), 283 exp: 1.0471966666666666, 284 }, 285 { 286 left: NewLiteral("1", types.Text), 287 right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), 288 exp: 0.3333333333333333, 289 }, 290 { 291 left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), 292 right: NewLiteral("3", types.Text), 293 exp: 0.3333333333333333, 294 }, 295 } 296 297 for _, tt := range testCases { 298 name := fmt.Sprintf("%s(%v)/%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right) 299 t.Run(name, func(t *testing.T) { 300 require := require.New(t) 301 if tt.skip { 302 t.Skip() 303 } 304 f := NewDiv(tt.left, tt.right) 305 result, err := f.Eval(sql.NewEmptyContext(), nil) 306 if tt.err != nil { 307 require.Error(err) 308 require.True(tt.err.Is(err), err.Error()) 309 return 310 } 311 require.NoError(err) 312 if dec, ok := result.(decimal.Decimal); ok { 313 result = dec.StringFixed(dec.Exponent() * -1) 314 } 315 assert.Equal(t, tt.exp, result) 316 }) 317 } 318 } 319 320 // TestDivUsesFloatsInternally tests that division expression trees internally use floating point types when operating 321 // on integers, but when returning the final result from the expression tree, it is returned as a Decimal. 322 func TestDivUsesFloatsInternally(t *testing.T) { 323 t.Skip("TODO: see if we can actually enable this") 324 bottomDiv := NewDiv(NewGetField(0, types.Int32, "", false), NewGetField(1, types.Int64, "", false)) 325 middleDiv := NewDiv(bottomDiv, NewGetField(2, types.Int64, "", false)) 326 topDiv := NewDiv(middleDiv, NewGetField(3, types.Int64, "", false)) 327 328 result, err := topDiv.Eval(sql.NewEmptyContext(), sql.NewRow(250, 2, 5, 2)) 329 require.NoError(t, err) 330 dec, isDecimal := result.(decimal.Decimal) 331 require.True(t, isDecimal) 332 require.Equal(t, "12.5", dec.String()) 333 334 // Internal nodes should use floats for division with integers (for performance reasons), but the top node 335 // should return a Decimal (to match MySQL's behavior). 336 require.Equal(t, types.Float64, bottomDiv.Type()) 337 require.Equal(t, types.Float64, middleDiv.Type()) 338 require.True(t, types.IsDecimal(topDiv.Type())) 339 } 340 341 func TestIntDiv(t *testing.T) { 342 var testCases = []struct { 343 name string 344 left, right interface{} 345 leftType, rightType sql.Type 346 expected int64 347 null bool 348 }{ 349 {"1 div 1", 1, 1, types.Int64, types.Int64, 1, false}, 350 {"8 div 3", 8, 3, types.Int64, types.Int64, 2, false}, 351 {"1 div 3", 1, 3, types.Int64, types.Int64, 0, false}, 352 {"0 div -1024", 0, -1024, types.Int64, types.Int64, 0, false}, 353 {"1 div 0", 1, 0, types.Int64, types.Int64, 0, true}, 354 {"0 div 0", 1, 0, types.Int64, types.Int64, 0, true}, 355 {"10.24 div 0.6", 10.24, 0.6, types.Float64, types.Float64, 17, false}, 356 {"-10.24 div 0.6", -10.24, 0.6, types.Float64, types.Float64, -17, false}, 357 {"-10.24 div -0.6", -10.24, -0.6, types.Float64, types.Float64, 17, false}, 358 } 359 360 for _, tt := range testCases { 361 t.Run(tt.name, func(t *testing.T) { 362 require := require.New(t) 363 result, err := NewIntDiv( 364 NewLiteral(tt.left, tt.leftType), 365 NewLiteral(tt.right, tt.rightType), 366 ).Eval(sql.NewEmptyContext(), sql.NewRow()) 367 require.NoError(err) 368 if tt.null { 369 assert.Equal(t, nil, result) 370 } else { 371 assert.Equal(t, tt.expected, result) 372 } 373 }) 374 } 375 } 376 377 // Results: 378 // BenchmarkDivInt-16 365416 3117 ns/op 379 func BenchmarkDivInt(b *testing.B) { 380 require := require.New(b) 381 ctx := sql.NewEmptyContext() 382 div := NewDiv( 383 NewLiteral(1, types.Int64), 384 NewLiteral(3, types.Int64), 385 ) 386 var res interface{} 387 var err error 388 for i := 0; i < b.N; i++ { 389 res, err = div.Eval(ctx, nil) 390 require.NoError(err) 391 } 392 if dec, ok := res.(decimal.Decimal); ok { 393 res = dec.StringFixed(dec.Exponent() * -1) 394 } 395 exp := "0.3333" 396 if res != exp { 397 b.Logf("Expected %v, got %v", exp, res) 398 } 399 } 400 401 // Results: 402 // BenchmarkDivFloat-16 1521937 787.7 ns/op 403 func BenchmarkDivFloat(b *testing.B) { 404 require := require.New(b) 405 ctx := sql.NewEmptyContext() 406 div := NewDiv( 407 NewLiteral(1.0, types.Float64), 408 NewLiteral(3.0, types.Float64), 409 ) 410 var res interface{} 411 var err error 412 for i := 0; i < b.N; i++ { 413 res, err = div.Eval(ctx, nil) 414 require.NoError(err) 415 } 416 exp := 1.0 / 3.0 417 if res != exp { 418 b.Logf("Expected %v, got %v", exp, res) 419 } 420 } 421 422 // Results: 423 // BenchmarkDivHighScaleDecimals-16 294921 3901 ns/op 424 func BenchmarkDivHighScaleDecimals(b *testing.B) { 425 require := require.New(b) 426 ctx := sql.NewEmptyContext() 427 div := NewDiv( 428 NewLiteral(decimal.NewFromFloat(0.123456789), types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)), 429 NewLiteral(decimal.NewFromFloat(0.987654321), types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)), 430 ) 431 var res interface{} 432 var err error 433 for i := 0; i < b.N; i++ { 434 res, err = div.Eval(ctx, nil) 435 require.NoError(err) 436 } 437 if dec, ok := res.(decimal.Decimal); ok { 438 res = dec.StringFixed(dec.Exponent() * -1) 439 } 440 exp := "0.124999998860937500014238281250" 441 if res != exp { 442 b.Logf("Expected %v, got %v", exp, res) 443 } 444 } 445 446 // Results: 447 // BenchmarkDivManyInts-16 40711 29372 ns/op 448 func BenchmarkDivManyInts(b *testing.B) { 449 require := require.New(b) 450 var div sql.Expression = NewLiteral(1, types.Int64) 451 for i := 2; i < 10; i++ { 452 div = NewDiv(div, NewLiteral(int64(i), types.Int64)) 453 } 454 ctx := sql.NewEmptyContext() 455 var res interface{} 456 var err error 457 for i := 0; i < b.N; i++ { 458 res, err = div.Eval(ctx, nil) 459 require.NoError(err) 460 } 461 if dec, ok := res.(decimal.Decimal); ok { 462 res = dec.StringFixed(dec.Exponent() * -1) 463 } 464 exp := "0.000002755731922398589054232804" 465 if res != exp { 466 b.Logf("Expected %v, got %v", exp, res) 467 } 468 } 469 470 // Results: 471 // BenchmarkManyFloats-16 174555 6666 ns/op 472 func BenchmarkManyFloats(b *testing.B) { 473 require := require.New(b) 474 ctx := sql.NewEmptyContext() 475 var div sql.Expression = NewLiteral(1.0, types.Float64) 476 for i := 2; i < 10; i++ { 477 div = NewDiv(div, NewLiteral(float64(i), types.Float64)) 478 } 479 var res interface{} 480 var err error 481 for i := 0; i < b.N; i++ { 482 res, err = div.Eval(ctx, nil) 483 require.NoError(err) 484 } 485 exp := 1.0 / 2.0 / 3.0 / 4.0 / 5.0 / 6.0 / 7.0 / 8.0 / 9.0 486 if res != exp { 487 b.Logf("Expected %v, got %v", exp, res) 488 } 489 } 490 491 // Results: 492 // BenchmarkDivManyDecimals-16 52053 23134 ns/op 493 func BenchmarkDivManyDecimals(b *testing.B) { 494 require := require.New(b) 495 var div sql.Expression = NewLiteral(decimal.NewFromInt(int64(1)), types.DecimalType_{}) 496 for i := 2; i < 10; i++ { 497 div = NewDiv(div, NewLiteral(decimal.NewFromInt(int64(i)), types.DecimalType_{})) 498 } 499 ctx := sql.NewEmptyContext() 500 var res interface{} 501 var err error 502 for i := 0; i < b.N; i++ { 503 res, err = div.Eval(ctx, nil) 504 require.NoError(err) 505 } 506 if dec, ok := res.(decimal.Decimal); ok { 507 res = dec.StringFixed(dec.Exponent() * -1) 508 } 509 exp := "0.000002755731922398589054232804" 510 if res != exp { 511 b.Logf("Expected %v, got %v", exp, res) 512 } 513 }