github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/div_test.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"fmt"
    19  	"testing"
    20  
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	"github.com/shopspring/decimal"
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  	"gopkg.in/src-d/go-errors.v1"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  func TestDiv(t *testing.T) {
    32  	var testCases = []struct {
    33  		name  string
    34  		left  sql.Expression
    35  		right sql.Expression
    36  		exp   interface{}
    37  		err   *errors.Kind
    38  		skip  bool
    39  	}{
    40  		{
    41  			left:  NewLiteral(1, types.Int64),
    42  			right: NewLiteral(0, types.Int64),
    43  			exp:   nil,
    44  		},
    45  
    46  		// Unsigned Integers
    47  		{
    48  			left:  NewLiteral(1, types.Uint32),
    49  			right: NewLiteral(1, types.Uint32),
    50  			exp:   "1.0000",
    51  		},
    52  		{
    53  			left:  NewLiteral(1, types.Uint32),
    54  			right: NewLiteral(2, types.Uint32),
    55  			exp:   "0.5000",
    56  		},
    57  		{
    58  			left:  NewLiteral(1, types.Uint64),
    59  			right: NewLiteral(1, types.Uint64),
    60  			exp:   "1.0000",
    61  		},
    62  		{
    63  			left:  NewLiteral(1, types.Uint64),
    64  			right: NewLiteral(2, types.Uint64),
    65  			exp:   "0.5000",
    66  		},
    67  
    68  		// Signed Integers
    69  		{
    70  			left:  NewLiteral(1, types.Int32),
    71  			right: NewLiteral(1, types.Int32),
    72  			exp:   "1.0000",
    73  		},
    74  		{
    75  			left:  NewLiteral(1, types.Int32),
    76  			right: NewLiteral(2, types.Int32),
    77  			exp:   "0.5000",
    78  		},
    79  		{
    80  			left:  NewLiteral(-1, types.Int32),
    81  			right: NewLiteral(2, types.Int32),
    82  			exp:   "-0.5000",
    83  		},
    84  		{
    85  			left:  NewLiteral(1, types.Int32),
    86  			right: NewLiteral(-2, types.Int32),
    87  			exp:   "-0.5000",
    88  		},
    89  		{
    90  			left:  NewLiteral(1, types.Int64),
    91  			right: NewLiteral(1, types.Int64),
    92  			exp:   "1.0000",
    93  		},
    94  		{
    95  			left:  NewLiteral(1, types.Int64),
    96  			right: NewLiteral(2, types.Int64),
    97  			exp:   "0.5000",
    98  		},
    99  		{
   100  			left:  NewLiteral(-1, types.Int64),
   101  			right: NewLiteral(2, types.Int64),
   102  			exp:   "-0.5000",
   103  		},
   104  		{
   105  			left:  NewLiteral(1, types.Int64),
   106  			right: NewLiteral(-2, types.Int64),
   107  			exp:   "-0.5000",
   108  		},
   109  
   110  		// Unsigned and Signed Integers
   111  		{
   112  			left:  NewLiteral(1, types.Uint32),
   113  			right: NewLiteral(-2, types.Int32),
   114  			exp:   "-0.5000",
   115  		},
   116  		{
   117  			left:  NewLiteral(-1, types.Int64),
   118  			right: NewLiteral(2, types.Uint32),
   119  			exp:   "-0.5000",
   120  		},
   121  		{
   122  			left:  NewLiteral(1, types.Int64),
   123  			right: NewLiteral(123456789, types.Int64),
   124  			exp:   "0.0000",
   125  		},
   126  
   127  		// Repeating Decimals
   128  		{
   129  			left:  NewLiteral(1, types.Int64),
   130  			right: NewLiteral(3, types.Int64),
   131  			exp:   "0.3333",
   132  		},
   133  		{
   134  			left:  NewLiteral(1, types.Int64),
   135  			right: NewLiteral(9, types.Int64),
   136  			exp:   "0.1111",
   137  		},
   138  		{
   139  			left:  NewLiteral(1, types.Int64),
   140  			right: NewLiteral(6, types.Int64),
   141  			exp:   "0.1667",
   142  		},
   143  
   144  		// Floats
   145  		{
   146  			left:  NewLiteral(1.0, types.Float32),
   147  			right: NewLiteral(3.0, types.Float32),
   148  			exp:   0.3333333333333333,
   149  		},
   150  		{
   151  			left:  NewLiteral(1.0, types.Float32),
   152  			right: NewLiteral(9.0, types.Float32),
   153  			exp:   0.1111111111111111,
   154  		},
   155  		{
   156  			left:  NewLiteral(1.0, types.Float64),
   157  			right: NewLiteral(3.0, types.Float64),
   158  			exp:   0.3333333333333333,
   159  		},
   160  		{
   161  			left:  NewLiteral(1.0, types.Float64),
   162  			right: NewLiteral(9.0, types.Float64),
   163  			exp:   0.1111111111111111,
   164  		},
   165  		{
   166  			// MySQL treats float32 a little differently
   167  			skip:  true,
   168  			left:  NewLiteral(3.14159, types.Float32),
   169  			right: NewLiteral(3.0, types.Float32),
   170  			exp:   1.0471967061360676,
   171  		},
   172  		{
   173  			left:  NewLiteral(3.14159, types.Float64),
   174  			right: NewLiteral(3.0, types.Float64),
   175  			exp:   1.0471966666666666,
   176  		},
   177  
   178  		// Decimals
   179  		{
   180  			left:  NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)),
   181  			right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)),
   182  			exp:   "0.3333",
   183  		},
   184  		{
   185  			left:  NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)),
   186  			right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)),
   187  			exp:   "0.3333333",
   188  		},
   189  		{
   190  			left:  NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)),
   191  			right: NewLiteral(decimal.New(3000, -3), types.MustCreateDecimalType(10, 3)),
   192  			exp:   "0.3333",
   193  		},
   194  		{
   195  			left:  NewLiteral(decimal.New(314159, -5), types.MustCreateDecimalType(10, 5)),
   196  			right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)),
   197  			exp:   "1.047196666",
   198  		},
   199  		{
   200  			left:  NewLiteral(decimal.NewFromFloat(3.14159), types.MustCreateDecimalType(10, 5)),
   201  			right: NewLiteral(3, types.Int64),
   202  			exp:   "1.047196666",
   203  		},
   204  
   205  		// Bit
   206  		{
   207  			left:  NewLiteral(0, types.MustCreateBitType(1)),
   208  			right: NewLiteral(1, types.MustCreateBitType(1)),
   209  			exp:   "0.0000",
   210  		},
   211  		{
   212  			left:  NewLiteral(1, types.MustCreateBitType(1)),
   213  			right: NewLiteral(1, types.MustCreateBitType(1)),
   214  			exp:   "1.0000",
   215  		},
   216  
   217  		// Year
   218  		{
   219  			left:  NewLiteral(2001, types.YearType_{}),
   220  			right: NewLiteral(2002, types.YearType_{}),
   221  			exp:   "0.9995",
   222  		},
   223  
   224  		// Time
   225  		{
   226  			left:  NewLiteral("2001-01-01", types.Date),
   227  			right: NewLiteral("2001-01-01", types.Date),
   228  			exp:   "1.0000",
   229  		},
   230  		{
   231  			left:  NewLiteral("2001-01-01 12:00:00", types.Date),
   232  			right: NewLiteral("2001-01-01 12:00:00", types.Date),
   233  			exp:   "1.0000",
   234  		},
   235  		{
   236  			skip:  true, // need to trim just the date portion
   237  			left:  NewLiteral("2001-01-01 12:00:00.123456", types.Date),
   238  			right: NewLiteral("2001-01-01 12:00:00.123456", types.Date),
   239  			exp:   "1.0000",
   240  		},
   241  		{
   242  			left:  NewLiteral("2001-01-01 12:00:00", types.Datetime),
   243  			right: NewLiteral("2001-01-01 12:00:00", types.Datetime),
   244  			exp:   "1.0000",
   245  		},
   246  		{
   247  			skip:  true, // need to trim just the datetime portion according to precision and use as exponent
   248  			left:  NewLiteral("2001-01-01 12:00:00.123456", types.Datetime),
   249  			right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime),
   250  			exp:   "1.0000",
   251  		},
   252  		{
   253  			skip:  true, // need to trim just the datetime portion according to precision and use as exponent
   254  			left:  NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)),
   255  			right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)),
   256  			exp:   "1.0000000",
   257  		},
   258  		{
   259  			left:  NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision),
   260  			right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision),
   261  			exp:   "1.0000000000",
   262  		},
   263  
   264  		// Text
   265  		{
   266  			left:  NewLiteral("1", types.Text),
   267  			right: NewLiteral("3", types.Text),
   268  			exp:   0.3333333333333333,
   269  		},
   270  		{
   271  			left:  NewLiteral("1.000", types.Text),
   272  			right: NewLiteral("3", types.Text),
   273  			exp:   0.3333333333333333,
   274  		},
   275  		{
   276  			left:  NewLiteral("1", types.Text),
   277  			right: NewLiteral("3.000", types.Text),
   278  			exp:   0.3333333333333333,
   279  		},
   280  		{
   281  			left:  NewLiteral("3.14159", types.Text),
   282  			right: NewLiteral("3", types.Text),
   283  			exp:   1.0471966666666666,
   284  		},
   285  		{
   286  			left:  NewLiteral("1", types.Text),
   287  			right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)),
   288  			exp:   0.3333333333333333,
   289  		},
   290  		{
   291  			left:  NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)),
   292  			right: NewLiteral("3", types.Text),
   293  			exp:   0.3333333333333333,
   294  		},
   295  	}
   296  
   297  	for _, tt := range testCases {
   298  		name := fmt.Sprintf("%s(%v)/%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right)
   299  		t.Run(name, func(t *testing.T) {
   300  			require := require.New(t)
   301  			if tt.skip {
   302  				t.Skip()
   303  			}
   304  			f := NewDiv(tt.left, tt.right)
   305  			result, err := f.Eval(sql.NewEmptyContext(), nil)
   306  			if tt.err != nil {
   307  				require.Error(err)
   308  				require.True(tt.err.Is(err), err.Error())
   309  				return
   310  			}
   311  			require.NoError(err)
   312  			if dec, ok := result.(decimal.Decimal); ok {
   313  				result = dec.StringFixed(dec.Exponent() * -1)
   314  			}
   315  			assert.Equal(t, tt.exp, result)
   316  		})
   317  	}
   318  }
   319  
   320  // TestDivUsesFloatsInternally tests that division expression trees internally use floating point types when operating
   321  // on integers, but when returning the final result from the expression tree, it is returned as a Decimal.
   322  func TestDivUsesFloatsInternally(t *testing.T) {
   323  	t.Skip("TODO: see if we can actually enable this")
   324  	bottomDiv := NewDiv(NewGetField(0, types.Int32, "", false), NewGetField(1, types.Int64, "", false))
   325  	middleDiv := NewDiv(bottomDiv, NewGetField(2, types.Int64, "", false))
   326  	topDiv := NewDiv(middleDiv, NewGetField(3, types.Int64, "", false))
   327  
   328  	result, err := topDiv.Eval(sql.NewEmptyContext(), sql.NewRow(250, 2, 5, 2))
   329  	require.NoError(t, err)
   330  	dec, isDecimal := result.(decimal.Decimal)
   331  	require.True(t, isDecimal)
   332  	require.Equal(t, "12.5", dec.String())
   333  
   334  	// Internal nodes should use floats for division with integers (for performance reasons), but the top node
   335  	// should return a Decimal (to match MySQL's behavior).
   336  	require.Equal(t, types.Float64, bottomDiv.Type())
   337  	require.Equal(t, types.Float64, middleDiv.Type())
   338  	require.True(t, types.IsDecimal(topDiv.Type()))
   339  }
   340  
   341  func TestIntDiv(t *testing.T) {
   342  	var testCases = []struct {
   343  		name                string
   344  		left, right         interface{}
   345  		leftType, rightType sql.Type
   346  		expected            int64
   347  		null                bool
   348  	}{
   349  		{"1 div 1", 1, 1, types.Int64, types.Int64, 1, false},
   350  		{"8 div 3", 8, 3, types.Int64, types.Int64, 2, false},
   351  		{"1 div 3", 1, 3, types.Int64, types.Int64, 0, false},
   352  		{"0 div -1024", 0, -1024, types.Int64, types.Int64, 0, false},
   353  		{"1 div 0", 1, 0, types.Int64, types.Int64, 0, true},
   354  		{"0 div 0", 1, 0, types.Int64, types.Int64, 0, true},
   355  		{"10.24 div 0.6", 10.24, 0.6, types.Float64, types.Float64, 17, false},
   356  		{"-10.24 div 0.6", -10.24, 0.6, types.Float64, types.Float64, -17, false},
   357  		{"-10.24 div -0.6", -10.24, -0.6, types.Float64, types.Float64, 17, false},
   358  	}
   359  
   360  	for _, tt := range testCases {
   361  		t.Run(tt.name, func(t *testing.T) {
   362  			require := require.New(t)
   363  			result, err := NewIntDiv(
   364  				NewLiteral(tt.left, tt.leftType),
   365  				NewLiteral(tt.right, tt.rightType),
   366  			).Eval(sql.NewEmptyContext(), sql.NewRow())
   367  			require.NoError(err)
   368  			if tt.null {
   369  				assert.Equal(t, nil, result)
   370  			} else {
   371  				assert.Equal(t, tt.expected, result)
   372  			}
   373  		})
   374  	}
   375  }
   376  
   377  // Results:
   378  // BenchmarkDivInt-16        365416              3117 ns/op
   379  func BenchmarkDivInt(b *testing.B) {
   380  	require := require.New(b)
   381  	ctx := sql.NewEmptyContext()
   382  	div := NewDiv(
   383  		NewLiteral(1, types.Int64),
   384  		NewLiteral(3, types.Int64),
   385  	)
   386  	var res interface{}
   387  	var err error
   388  	for i := 0; i < b.N; i++ {
   389  		res, err = div.Eval(ctx, nil)
   390  		require.NoError(err)
   391  	}
   392  	if dec, ok := res.(decimal.Decimal); ok {
   393  		res = dec.StringFixed(dec.Exponent() * -1)
   394  	}
   395  	exp := "0.3333"
   396  	if res != exp {
   397  		b.Logf("Expected %v, got %v", exp, res)
   398  	}
   399  }
   400  
   401  // Results:
   402  // BenchmarkDivFloat-16             1521937               787.7 ns/op
   403  func BenchmarkDivFloat(b *testing.B) {
   404  	require := require.New(b)
   405  	ctx := sql.NewEmptyContext()
   406  	div := NewDiv(
   407  		NewLiteral(1.0, types.Float64),
   408  		NewLiteral(3.0, types.Float64),
   409  	)
   410  	var res interface{}
   411  	var err error
   412  	for i := 0; i < b.N; i++ {
   413  		res, err = div.Eval(ctx, nil)
   414  		require.NoError(err)
   415  	}
   416  	exp := 1.0 / 3.0
   417  	if res != exp {
   418  		b.Logf("Expected %v, got %v", exp, res)
   419  	}
   420  }
   421  
   422  // Results:
   423  // BenchmarkDivHighScaleDecimals-16          294921              3901 ns/op
   424  func BenchmarkDivHighScaleDecimals(b *testing.B) {
   425  	require := require.New(b)
   426  	ctx := sql.NewEmptyContext()
   427  	div := NewDiv(
   428  		NewLiteral(decimal.NewFromFloat(0.123456789), types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)),
   429  		NewLiteral(decimal.NewFromFloat(0.987654321), types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)),
   430  	)
   431  	var res interface{}
   432  	var err error
   433  	for i := 0; i < b.N; i++ {
   434  		res, err = div.Eval(ctx, nil)
   435  		require.NoError(err)
   436  	}
   437  	if dec, ok := res.(decimal.Decimal); ok {
   438  		res = dec.StringFixed(dec.Exponent() * -1)
   439  	}
   440  	exp := "0.124999998860937500014238281250"
   441  	if res != exp {
   442  		b.Logf("Expected %v, got %v", exp, res)
   443  	}
   444  }
   445  
   446  // Results:
   447  // BenchmarkDivManyInts-16            40711             29372 ns/op
   448  func BenchmarkDivManyInts(b *testing.B) {
   449  	require := require.New(b)
   450  	var div sql.Expression = NewLiteral(1, types.Int64)
   451  	for i := 2; i < 10; i++ {
   452  		div = NewDiv(div, NewLiteral(int64(i), types.Int64))
   453  	}
   454  	ctx := sql.NewEmptyContext()
   455  	var res interface{}
   456  	var err error
   457  	for i := 0; i < b.N; i++ {
   458  		res, err = div.Eval(ctx, nil)
   459  		require.NoError(err)
   460  	}
   461  	if dec, ok := res.(decimal.Decimal); ok {
   462  		res = dec.StringFixed(dec.Exponent() * -1)
   463  	}
   464  	exp := "0.000002755731922398589054232804"
   465  	if res != exp {
   466  		b.Logf("Expected %v, got %v", exp, res)
   467  	}
   468  }
   469  
   470  // Results:
   471  // BenchmarkManyFloats-16            174555              6666 ns/op
   472  func BenchmarkManyFloats(b *testing.B) {
   473  	require := require.New(b)
   474  	ctx := sql.NewEmptyContext()
   475  	var div sql.Expression = NewLiteral(1.0, types.Float64)
   476  	for i := 2; i < 10; i++ {
   477  		div = NewDiv(div, NewLiteral(float64(i), types.Float64))
   478  	}
   479  	var res interface{}
   480  	var err error
   481  	for i := 0; i < b.N; i++ {
   482  		res, err = div.Eval(ctx, nil)
   483  		require.NoError(err)
   484  	}
   485  	exp := 1.0 / 2.0 / 3.0 / 4.0 / 5.0 / 6.0 / 7.0 / 8.0 / 9.0
   486  	if res != exp {
   487  		b.Logf("Expected %v, got %v", exp, res)
   488  	}
   489  }
   490  
   491  // Results:
   492  // BenchmarkDivManyDecimals-16        52053             23134 ns/op
   493  func BenchmarkDivManyDecimals(b *testing.B) {
   494  	require := require.New(b)
   495  	var div sql.Expression = NewLiteral(decimal.NewFromInt(int64(1)), types.DecimalType_{})
   496  	for i := 2; i < 10; i++ {
   497  		div = NewDiv(div, NewLiteral(decimal.NewFromInt(int64(i)), types.DecimalType_{}))
   498  	}
   499  	ctx := sql.NewEmptyContext()
   500  	var res interface{}
   501  	var err error
   502  	for i := 0; i < b.N; i++ {
   503  		res, err = div.Eval(ctx, nil)
   504  		require.NoError(err)
   505  	}
   506  	if dec, ok := res.(decimal.Decimal); ok {
   507  		res = dec.StringFixed(dec.Exponent() * -1)
   508  	}
   509  	exp := "0.000002755731922398589054232804"
   510  	if res != exp {
   511  		b.Logf("Expected %v, got %v", exp, res)
   512  	}
   513  }