github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/math_test.go (about)

     1  // Copyright 2020-2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"math"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/shopspring/decimal"
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	"github.com/dolthub/go-mysql-server/sql/expression"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  func TestRand(t *testing.T) {
    32  	r, _ := NewRand()
    33  
    34  	assert.Equal(t, types.Float64, r.Type())
    35  	assert.Equal(t, "rand()", r.String())
    36  
    37  	f, err := r.Eval(nil, nil)
    38  	require.NoError(t, err)
    39  	f64, ok := f.(float64)
    40  	require.True(t, ok, "not a float64")
    41  
    42  	assert.GreaterOrEqual(t, f64, float64(0))
    43  	assert.Less(t, f64, float64(1))
    44  
    45  	f, err = r.Eval(nil, nil)
    46  	require.NoError(t, err)
    47  	f642, ok := f.(float64)
    48  	require.True(t, ok, "not a float64")
    49  
    50  	assert.NotEqual(t, f64, f642) // i guess this could fail, but come on
    51  }
    52  
    53  func TestRandWithSeed(t *testing.T) {
    54  	r, _ := NewRand(expression.NewLiteral(10, types.Int8))
    55  
    56  	assert.Equal(t, types.Float64, r.Type())
    57  	assert.Equal(t, "rand(10)", r.String())
    58  
    59  	f, err := r.Eval(nil, nil)
    60  	require.NoError(t, err)
    61  	f64 := f.(float64)
    62  
    63  	assert.GreaterOrEqual(t, f64, float64(0))
    64  	assert.Less(t, f64, float64(1))
    65  
    66  	f, err = r.Eval(nil, nil)
    67  	require.NoError(t, err)
    68  	f642 := f.(float64)
    69  
    70  	assert.Equal(t, f64, f642)
    71  
    72  	r, _ = NewRand(expression.NewLiteral("not a number", types.LongText))
    73  	assert.Equal(t, `rand('not a number')`, r.String())
    74  
    75  	f, err = r.Eval(nil, nil)
    76  	require.NoError(t, err)
    77  	f64 = f.(float64)
    78  
    79  	assert.GreaterOrEqual(t, f64, float64(0))
    80  	assert.Less(t, f64, float64(1))
    81  
    82  	f, err = r.Eval(nil, nil)
    83  	require.NoError(t, err)
    84  	f642 = f.(float64)
    85  
    86  	assert.Equal(t, f64, f642)
    87  }
    88  
    89  func TestRadians(t *testing.T) {
    90  	f := sql.Function1{Name: "radians", Fn: NewRadians}
    91  	tf := NewTestFactory(f.Fn)
    92  	tf.AddSucceeding(0.0, "0")
    93  	tf.AddSucceeding(-math.Pi, "-180")
    94  	tf.AddSucceeding(math.Pi, int16(180))
    95  	tf.AddSucceeding(math.Pi/2.0, (90))
    96  	tf.AddSucceeding(2*math.Pi, 360.0)
    97  	tf.Test(t, nil, nil)
    98  }
    99  
   100  func TestDegrees(t *testing.T) {
   101  	tests := []struct {
   102  		name     string
   103  		input    interface{}
   104  		expected float64
   105  	}{
   106  		{"string pi", "3.1415926536", 180.0},
   107  		{"decimal 2pi", decimal.NewFromFloat(2 * math.Pi), 360.0},
   108  		{"float64 pi/2", math.Pi / 2.0, 90.0},
   109  		{"float32 3*pi/2", float32(3.0 * math.Pi / 2.0), 270.0},
   110  	}
   111  
   112  	f := sql.Function1{Name: "degrees", Fn: NewDegrees}
   113  
   114  	for _, test := range tests {
   115  		t.Run(test.name, func(t *testing.T) {
   116  			degrees := f.Fn(expression.NewLiteral(test.input, nil))
   117  			res, err := degrees.Eval(nil, nil)
   118  			require.NoError(t, err)
   119  			assert.True(t, withinRoundingErr(test.expected, res.(float64)))
   120  		})
   121  	}
   122  }
   123  
   124  func TestCRC32(t *testing.T) {
   125  	tests := []struct {
   126  		name     string
   127  		input    interface{}
   128  		expected uint32
   129  	}{
   130  		{"CRC32('MySQL)", "MySQL", 3259397556},
   131  		{"CRC32('mysql')", "mysql", 2501908538},
   132  
   133  		{"CRC32('6')", "6", 498629140},
   134  		{"CRC32(int 6)", 6, 498629140},
   135  		{"CRC32(int8 6)", int8(6), 498629140},
   136  		{"CRC32(int16 6)", int16(6), 498629140},
   137  		{"CRC32(int32 6)", int32(6), 498629140},
   138  		{"CRC32(int64 6)", int64(6), 498629140},
   139  		{"CRC32(uint 6)", uint(6), 498629140},
   140  		{"CRC32(uint8 6)", uint8(6), 498629140},
   141  		{"CRC32(uint16 6)", uint16(6), 498629140},
   142  		{"CRC32(uint32 6)", uint32(6), 498629140},
   143  		{"CRC32(uint64 6)", uint64(6), 498629140},
   144  
   145  		{"CRC32('6.0')", "6.0", 4068047280},
   146  		{"CRC32(float32 6.0)", float32(6.0), 4068047280},
   147  		{"CRC32(float64 6.0)", float64(6.0), 4068047280},
   148  	}
   149  
   150  	f := sql.Function1{Name: "crc32", Fn: NewCrc32}
   151  
   152  	for _, test := range tests {
   153  		t.Run(test.name, func(t *testing.T) {
   154  			crc32 := f.Fn(expression.NewLiteral(test.input, nil))
   155  			res, err := crc32.Eval(nil, nil)
   156  			assert.NoError(t, err)
   157  			assert.Equal(t, test.expected, res)
   158  		})
   159  	}
   160  
   161  	crc32 := f.Fn(nil)
   162  	res, err := crc32.Eval(nil, nil)
   163  	assert.NoError(t, err)
   164  	assert.Equal(t, nil, res)
   165  
   166  	nullLiteral := expression.NewLiteral(nil, types.Null)
   167  	crc32 = f.Fn(nullLiteral)
   168  	res, err = crc32.Eval(nil, nil)
   169  	assert.NoError(t, err)
   170  	assert.Equal(t, nil, res)
   171  }
   172  
   173  func TestTrigFunctions(t *testing.T) {
   174  	asin := sql.Function1{Name: "asin", Fn: NewAsin}
   175  	acos := sql.Function1{Name: "acos", Fn: NewAcos}
   176  	atan := sql.FunctionN{Name: "atan", Fn: NewAtan}
   177  	atan2 := sql.FunctionN{Name: "atan2", Fn: NewAtan}
   178  	sin := sql.Function1{Name: "sin", Fn: NewSin}
   179  	cos := sql.Function1{Name: "cos", Fn: NewCos}
   180  	tan := sql.Function1{Name: "tan", Fn: NewTan}
   181  
   182  	const numChecks = 24
   183  	delta := (2 * math.Pi) / float64(numChecks)
   184  	for i := 0; i <= numChecks; i++ {
   185  		theta := delta * float64(i)
   186  		thetaLiteral := expression.NewLiteral(theta, nil)
   187  		sinVal, err := sin.Fn(thetaLiteral).Eval(nil, nil)
   188  		assert.NoError(t, err)
   189  		cosVal, err := cos.Fn(thetaLiteral).Eval(nil, nil)
   190  		assert.NoError(t, err)
   191  		tanVal, err := tan.Fn(thetaLiteral).Eval(nil, nil)
   192  		assert.NoError(t, err)
   193  
   194  		sinF, _ := sinVal.(float64)
   195  		cosF, _ := cosVal.(float64)
   196  		tanF, _ := tanVal.(float64)
   197  
   198  		assert.True(t, withinRoundingErr(math.Sin(theta), sinF))
   199  		assert.True(t, withinRoundingErr(math.Cos(theta), cosF))
   200  		assert.True(t, withinRoundingErr(math.Tan(theta), tanF))
   201  
   202  		asinVal, err := asin.Fn(expression.NewLiteral(sinF, nil)).Eval(nil, nil)
   203  		assert.NoError(t, err)
   204  		acosVal, err := acos.Fn(expression.NewLiteral(cosF, nil)).Eval(nil, nil)
   205  		assert.NoError(t, err)
   206  		atanFn, err := atan.Fn(expression.NewLiteral(tanF, nil))
   207  		assert.NoError(t, err)
   208  		atanVal, err := atanFn.Eval(nil, nil)
   209  		assert.NoError(t, err)
   210  		atan2Fn, err := atan2.Fn(expression.NewLiteral(tanF, nil), expression.NewLiteral(tanF-1, nil))
   211  		assert.NoError(t, err)
   212  		atan2Val, err := atan2Fn.Eval(nil, nil)
   213  		assert.NoError(t, err)
   214  
   215  		assert.True(t, withinRoundingErr(math.Asin(sinF), asinVal.(float64)))
   216  		assert.True(t, withinRoundingErr(math.Acos(cosF), acosVal.(float64)))
   217  		assert.True(t, withinRoundingErr(math.Atan(tanF), atanVal.(float64)))
   218  		assert.True(t, withinRoundingErr(math.Atan2(tanF, tanF-1), atan2Val.(float64)))
   219  	}
   220  }
   221  
   222  func withinRoundingErr(v1, v2 float64) bool {
   223  	const roundingErr = 0.00001
   224  	diff := v1 - v2
   225  
   226  	if diff < 0 {
   227  		diff = -diff
   228  	}
   229  
   230  	return diff < roundingErr
   231  }
   232  
   233  func TestSignFunc(t *testing.T) {
   234  	f := sql.Function1{Name: "sign", Fn: NewSign}
   235  	tf := NewTestFactory(f.Fn)
   236  	tf.AddSucceeding(nil, nil)
   237  	tf.AddSignedVariations(int8(-1), -10)
   238  	tf.AddFloatVariations(int8(-1), -10.0)
   239  	tf.AddSignedVariations(int8(1), 100)
   240  	tf.AddUnsignedVariations(int8(1), 100)
   241  	tf.AddFloatVariations(int8(1), 100.0)
   242  	tf.AddSignedVariations(int8(0), 0)
   243  	tf.AddUnsignedVariations(int8(0), 0)
   244  	tf.AddFloatVariations(int8(0), 0)
   245  	tf.AddSucceeding(int8(1), time.Now())
   246  	tf.AddSucceeding(int8(0), false)
   247  	tf.AddSucceeding(int8(1), true)
   248  
   249  	// string logic matches mysql.  It's really odd.  Uses the numeric portion of the string at the beginning.  If
   250  	// it starts with a nonnumeric character then
   251  	tf.AddSucceeding(int8(0), "0-1z1Xaoebu")
   252  	tf.AddSucceeding(int8(-1), "-1z1Xaoebu")
   253  	tf.AddSucceeding(int8(1), "1z1Xaoebu")
   254  	tf.AddSucceeding(int8(0), "z1Xaoebu")
   255  	tf.AddSucceeding(int8(-1), "-.1a,1,1")
   256  	tf.AddSucceeding(int8(-1), "-0.1a,1,1")
   257  	tf.AddSucceeding(int8(1), "0.1a,1,1")
   258  	tf.AddSucceeding(int8(0), "-0,1,1")
   259  	tf.AddSucceeding(int8(0), "-.z1,1,1")
   260  
   261  	tf.Test(t, nil, nil)
   262  }
   263  
   264  func TestMod(t *testing.T) {
   265  	tests := []struct {
   266  		name     string
   267  		left     interface{}
   268  		right    interface{}
   269  		expected interface{}
   270  	}{
   271  		{"MOD(5,2)", 5, 2, "1"},
   272  		{"MOD(2,5)", 2, 5, "2"},
   273  		{"MOD(1,0.240)", 1, "0.240", "0.040"},
   274  		{"MOD(NULL,2)", nil, 2, nil},
   275  		{"MOD(5,NULL)", 5, nil, nil},
   276  		{"MOD(NULL,NULL)", nil, nil, nil},
   277  	}
   278  
   279  	f := sql.FunctionN{Name: "mod", Fn: NewMod}
   280  
   281  	for _, test := range tests {
   282  		t.Run(test.name, func(t *testing.T) {
   283  			mod, err := f.Fn(expression.NewLiteral(test.left, types.Int32), expression.NewLiteral(test.right, types.Int32))
   284  			res, err := mod.Eval(nil, nil)
   285  			assert.NoError(t, err)
   286  			if r, ok := res.(decimal.Decimal); ok {
   287  				assert.Equal(t, test.expected, r.StringFixed(r.Exponent()*-1))
   288  			} else {
   289  				assert.Equal(t, test.expected, res)
   290  			}
   291  		})
   292  	}
   293  }
   294  
   295  func TestPi(t *testing.T) {
   296  	tests := []struct {
   297  		name string
   298  		exp  interface{}
   299  	}{
   300  		{
   301  			name: "call pi",
   302  			exp:  math.Pi,
   303  		},
   304  	}
   305  
   306  	for _, test := range tests {
   307  		t.Run(test.name, func(t *testing.T) {
   308  			ctx := sql.NewEmptyContext()
   309  			pi := NewPi()
   310  			res, err := pi.Eval(ctx, nil)
   311  			require.NoError(t, err)
   312  			assert.Equal(t, test.exp, res)
   313  		})
   314  	}
   315  
   316  	var res interface{}
   317  	var err error
   318  	sin := NewSin(NewPi())
   319  	res, err = sin.Eval(nil, nil)
   320  	require.NoError(t, err)
   321  	assert.Equal(t, 1.2246467991473515e-16, res)
   322  
   323  	cos := NewCos(NewPi())
   324  	res, err = cos.Eval(nil, nil)
   325  	require.NoError(t, err)
   326  	assert.Equal(t, -1.0, res)
   327  }
   328  
   329  func TestExp(t *testing.T) {
   330  	tests := []struct {
   331  		name string
   332  		arg  sql.Expression
   333  		exp  interface{}
   334  		err  bool
   335  		skip bool
   336  	}{
   337  		{
   338  			name: "null argument",
   339  			arg:  nil,
   340  			exp:  nil,
   341  		},
   342  		{
   343  			name: "zero",
   344  			arg:  expression.NewLiteral(int64(0), types.Int64),
   345  			exp:  math.Exp(0),
   346  		},
   347  		{
   348  			name: "one",
   349  			arg:  expression.NewLiteral(int64(1), types.Int64),
   350  			exp:  math.Exp(1),
   351  		},
   352  		{
   353  			name: "ten",
   354  			arg:  expression.NewLiteral(int64(10), types.Int64),
   355  			exp:  math.Exp(10),
   356  		},
   357  		{
   358  			name: "negative",
   359  			arg:  expression.NewLiteral(int64(-1), types.Int64),
   360  			exp:  math.Exp(-1),
   361  		},
   362  		{
   363  			name: "float64 1.1",
   364  			arg:  expression.NewLiteral(1.1, types.Float64),
   365  			exp:  math.Exp(1.1),
   366  		},
   367  		{
   368  			name: "decimal 1.1",
   369  			arg:  expression.NewLiteral(decimal.NewFromFloat(1.1), types.DecimalType_{}),
   370  			exp:  math.Exp(1.1),
   371  		},
   372  		{
   373  			name: "float64 -12.34",
   374  			arg:  expression.NewLiteral(-12.34, types.Float64),
   375  			exp:  math.Exp(-12.34),
   376  		},
   377  		{
   378  			name: "decimal is -12.34",
   379  			arg:  expression.NewLiteral(decimal.NewFromFloat(-12.34), types.DecimalType_{}),
   380  			exp:  math.Exp(-12.34),
   381  		},
   382  		{
   383  			name: "invalid string is 0",
   384  			arg:  expression.NewLiteral("notanumber", types.Text),
   385  			exp:  math.Exp(0),
   386  		},
   387  		{
   388  			name: "empty string",
   389  			arg:  expression.NewLiteral("", types.Text),
   390  			exp:  math.Exp(0),
   391  		},
   392  		{
   393  			name: "numerical string",
   394  			arg:  expression.NewLiteral("10", types.Text),
   395  			exp:  math.Exp(10),
   396  		},
   397  		{
   398  			// we don't do truncation yet
   399  			// https://github.com/dolthub/dolt/issues/7302
   400  			name: "scientific string is truncated",
   401  			arg:  expression.NewLiteral("1e1", types.Text),
   402  			exp:  "",
   403  			err:  false,
   404  			skip: true,
   405  		},
   406  	}
   407  
   408  	for _, tt := range tests {
   409  		t.Run(tt.name, func(t *testing.T) {
   410  			if tt.skip {
   411  				t.Skip()
   412  			}
   413  
   414  			ctx := sql.NewEmptyContext()
   415  			f := NewExp(tt.arg)
   416  
   417  			res, err := f.Eval(ctx, nil)
   418  			if tt.err {
   419  				require.Error(t, err)
   420  				return
   421  			}
   422  
   423  			require.NoError(t, err)
   424  			require.Equal(t, tt.exp, res)
   425  		})
   426  	}
   427  }