github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/math_test.go (about) 1 // Copyright 2020-2024 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "math" 19 "testing" 20 "time" 21 22 "github.com/shopspring/decimal" 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/require" 25 26 "github.com/dolthub/go-mysql-server/sql" 27 "github.com/dolthub/go-mysql-server/sql/expression" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 func TestRand(t *testing.T) { 32 r, _ := NewRand() 33 34 assert.Equal(t, types.Float64, r.Type()) 35 assert.Equal(t, "rand()", r.String()) 36 37 f, err := r.Eval(nil, nil) 38 require.NoError(t, err) 39 f64, ok := f.(float64) 40 require.True(t, ok, "not a float64") 41 42 assert.GreaterOrEqual(t, f64, float64(0)) 43 assert.Less(t, f64, float64(1)) 44 45 f, err = r.Eval(nil, nil) 46 require.NoError(t, err) 47 f642, ok := f.(float64) 48 require.True(t, ok, "not a float64") 49 50 assert.NotEqual(t, f64, f642) // i guess this could fail, but come on 51 } 52 53 func TestRandWithSeed(t *testing.T) { 54 r, _ := NewRand(expression.NewLiteral(10, types.Int8)) 55 56 assert.Equal(t, types.Float64, r.Type()) 57 assert.Equal(t, "rand(10)", r.String()) 58 59 f, err := r.Eval(nil, nil) 60 require.NoError(t, err) 61 f64 := f.(float64) 62 63 assert.GreaterOrEqual(t, f64, float64(0)) 64 assert.Less(t, f64, float64(1)) 65 66 f, err = r.Eval(nil, nil) 67 require.NoError(t, err) 68 f642 := f.(float64) 69 70 assert.Equal(t, f64, f642) 71 72 r, _ = NewRand(expression.NewLiteral("not a number", types.LongText)) 73 assert.Equal(t, `rand('not a number')`, r.String()) 74 75 f, err = r.Eval(nil, nil) 76 require.NoError(t, err) 77 f64 = f.(float64) 78 79 assert.GreaterOrEqual(t, f64, float64(0)) 80 assert.Less(t, f64, float64(1)) 81 82 f, err = r.Eval(nil, nil) 83 require.NoError(t, err) 84 f642 = f.(float64) 85 86 assert.Equal(t, f64, f642) 87 } 88 89 func TestRadians(t *testing.T) { 90 f := sql.Function1{Name: "radians", Fn: NewRadians} 91 tf := NewTestFactory(f.Fn) 92 tf.AddSucceeding(0.0, "0") 93 tf.AddSucceeding(-math.Pi, "-180") 94 tf.AddSucceeding(math.Pi, int16(180)) 95 tf.AddSucceeding(math.Pi/2.0, (90)) 96 tf.AddSucceeding(2*math.Pi, 360.0) 97 tf.Test(t, nil, nil) 98 } 99 100 func TestDegrees(t *testing.T) { 101 tests := []struct { 102 name string 103 input interface{} 104 expected float64 105 }{ 106 {"string pi", "3.1415926536", 180.0}, 107 {"decimal 2pi", decimal.NewFromFloat(2 * math.Pi), 360.0}, 108 {"float64 pi/2", math.Pi / 2.0, 90.0}, 109 {"float32 3*pi/2", float32(3.0 * math.Pi / 2.0), 270.0}, 110 } 111 112 f := sql.Function1{Name: "degrees", Fn: NewDegrees} 113 114 for _, test := range tests { 115 t.Run(test.name, func(t *testing.T) { 116 degrees := f.Fn(expression.NewLiteral(test.input, nil)) 117 res, err := degrees.Eval(nil, nil) 118 require.NoError(t, err) 119 assert.True(t, withinRoundingErr(test.expected, res.(float64))) 120 }) 121 } 122 } 123 124 func TestCRC32(t *testing.T) { 125 tests := []struct { 126 name string 127 input interface{} 128 expected uint32 129 }{ 130 {"CRC32('MySQL)", "MySQL", 3259397556}, 131 {"CRC32('mysql')", "mysql", 2501908538}, 132 133 {"CRC32('6')", "6", 498629140}, 134 {"CRC32(int 6)", 6, 498629140}, 135 {"CRC32(int8 6)", int8(6), 498629140}, 136 {"CRC32(int16 6)", int16(6), 498629140}, 137 {"CRC32(int32 6)", int32(6), 498629140}, 138 {"CRC32(int64 6)", int64(6), 498629140}, 139 {"CRC32(uint 6)", uint(6), 498629140}, 140 {"CRC32(uint8 6)", uint8(6), 498629140}, 141 {"CRC32(uint16 6)", uint16(6), 498629140}, 142 {"CRC32(uint32 6)", uint32(6), 498629140}, 143 {"CRC32(uint64 6)", uint64(6), 498629140}, 144 145 {"CRC32('6.0')", "6.0", 4068047280}, 146 {"CRC32(float32 6.0)", float32(6.0), 4068047280}, 147 {"CRC32(float64 6.0)", float64(6.0), 4068047280}, 148 } 149 150 f := sql.Function1{Name: "crc32", Fn: NewCrc32} 151 152 for _, test := range tests { 153 t.Run(test.name, func(t *testing.T) { 154 crc32 := f.Fn(expression.NewLiteral(test.input, nil)) 155 res, err := crc32.Eval(nil, nil) 156 assert.NoError(t, err) 157 assert.Equal(t, test.expected, res) 158 }) 159 } 160 161 crc32 := f.Fn(nil) 162 res, err := crc32.Eval(nil, nil) 163 assert.NoError(t, err) 164 assert.Equal(t, nil, res) 165 166 nullLiteral := expression.NewLiteral(nil, types.Null) 167 crc32 = f.Fn(nullLiteral) 168 res, err = crc32.Eval(nil, nil) 169 assert.NoError(t, err) 170 assert.Equal(t, nil, res) 171 } 172 173 func TestTrigFunctions(t *testing.T) { 174 asin := sql.Function1{Name: "asin", Fn: NewAsin} 175 acos := sql.Function1{Name: "acos", Fn: NewAcos} 176 atan := sql.FunctionN{Name: "atan", Fn: NewAtan} 177 atan2 := sql.FunctionN{Name: "atan2", Fn: NewAtan} 178 sin := sql.Function1{Name: "sin", Fn: NewSin} 179 cos := sql.Function1{Name: "cos", Fn: NewCos} 180 tan := sql.Function1{Name: "tan", Fn: NewTan} 181 182 const numChecks = 24 183 delta := (2 * math.Pi) / float64(numChecks) 184 for i := 0; i <= numChecks; i++ { 185 theta := delta * float64(i) 186 thetaLiteral := expression.NewLiteral(theta, nil) 187 sinVal, err := sin.Fn(thetaLiteral).Eval(nil, nil) 188 assert.NoError(t, err) 189 cosVal, err := cos.Fn(thetaLiteral).Eval(nil, nil) 190 assert.NoError(t, err) 191 tanVal, err := tan.Fn(thetaLiteral).Eval(nil, nil) 192 assert.NoError(t, err) 193 194 sinF, _ := sinVal.(float64) 195 cosF, _ := cosVal.(float64) 196 tanF, _ := tanVal.(float64) 197 198 assert.True(t, withinRoundingErr(math.Sin(theta), sinF)) 199 assert.True(t, withinRoundingErr(math.Cos(theta), cosF)) 200 assert.True(t, withinRoundingErr(math.Tan(theta), tanF)) 201 202 asinVal, err := asin.Fn(expression.NewLiteral(sinF, nil)).Eval(nil, nil) 203 assert.NoError(t, err) 204 acosVal, err := acos.Fn(expression.NewLiteral(cosF, nil)).Eval(nil, nil) 205 assert.NoError(t, err) 206 atanFn, err := atan.Fn(expression.NewLiteral(tanF, nil)) 207 assert.NoError(t, err) 208 atanVal, err := atanFn.Eval(nil, nil) 209 assert.NoError(t, err) 210 atan2Fn, err := atan2.Fn(expression.NewLiteral(tanF, nil), expression.NewLiteral(tanF-1, nil)) 211 assert.NoError(t, err) 212 atan2Val, err := atan2Fn.Eval(nil, nil) 213 assert.NoError(t, err) 214 215 assert.True(t, withinRoundingErr(math.Asin(sinF), asinVal.(float64))) 216 assert.True(t, withinRoundingErr(math.Acos(cosF), acosVal.(float64))) 217 assert.True(t, withinRoundingErr(math.Atan(tanF), atanVal.(float64))) 218 assert.True(t, withinRoundingErr(math.Atan2(tanF, tanF-1), atan2Val.(float64))) 219 } 220 } 221 222 func withinRoundingErr(v1, v2 float64) bool { 223 const roundingErr = 0.00001 224 diff := v1 - v2 225 226 if diff < 0 { 227 diff = -diff 228 } 229 230 return diff < roundingErr 231 } 232 233 func TestSignFunc(t *testing.T) { 234 f := sql.Function1{Name: "sign", Fn: NewSign} 235 tf := NewTestFactory(f.Fn) 236 tf.AddSucceeding(nil, nil) 237 tf.AddSignedVariations(int8(-1), -10) 238 tf.AddFloatVariations(int8(-1), -10.0) 239 tf.AddSignedVariations(int8(1), 100) 240 tf.AddUnsignedVariations(int8(1), 100) 241 tf.AddFloatVariations(int8(1), 100.0) 242 tf.AddSignedVariations(int8(0), 0) 243 tf.AddUnsignedVariations(int8(0), 0) 244 tf.AddFloatVariations(int8(0), 0) 245 tf.AddSucceeding(int8(1), time.Now()) 246 tf.AddSucceeding(int8(0), false) 247 tf.AddSucceeding(int8(1), true) 248 249 // string logic matches mysql. It's really odd. Uses the numeric portion of the string at the beginning. If 250 // it starts with a nonnumeric character then 251 tf.AddSucceeding(int8(0), "0-1z1Xaoebu") 252 tf.AddSucceeding(int8(-1), "-1z1Xaoebu") 253 tf.AddSucceeding(int8(1), "1z1Xaoebu") 254 tf.AddSucceeding(int8(0), "z1Xaoebu") 255 tf.AddSucceeding(int8(-1), "-.1a,1,1") 256 tf.AddSucceeding(int8(-1), "-0.1a,1,1") 257 tf.AddSucceeding(int8(1), "0.1a,1,1") 258 tf.AddSucceeding(int8(0), "-0,1,1") 259 tf.AddSucceeding(int8(0), "-.z1,1,1") 260 261 tf.Test(t, nil, nil) 262 } 263 264 func TestMod(t *testing.T) { 265 tests := []struct { 266 name string 267 left interface{} 268 right interface{} 269 expected interface{} 270 }{ 271 {"MOD(5,2)", 5, 2, "1"}, 272 {"MOD(2,5)", 2, 5, "2"}, 273 {"MOD(1,0.240)", 1, "0.240", "0.040"}, 274 {"MOD(NULL,2)", nil, 2, nil}, 275 {"MOD(5,NULL)", 5, nil, nil}, 276 {"MOD(NULL,NULL)", nil, nil, nil}, 277 } 278 279 f := sql.FunctionN{Name: "mod", Fn: NewMod} 280 281 for _, test := range tests { 282 t.Run(test.name, func(t *testing.T) { 283 mod, err := f.Fn(expression.NewLiteral(test.left, types.Int32), expression.NewLiteral(test.right, types.Int32)) 284 res, err := mod.Eval(nil, nil) 285 assert.NoError(t, err) 286 if r, ok := res.(decimal.Decimal); ok { 287 assert.Equal(t, test.expected, r.StringFixed(r.Exponent()*-1)) 288 } else { 289 assert.Equal(t, test.expected, res) 290 } 291 }) 292 } 293 } 294 295 func TestPi(t *testing.T) { 296 tests := []struct { 297 name string 298 exp interface{} 299 }{ 300 { 301 name: "call pi", 302 exp: math.Pi, 303 }, 304 } 305 306 for _, test := range tests { 307 t.Run(test.name, func(t *testing.T) { 308 ctx := sql.NewEmptyContext() 309 pi := NewPi() 310 res, err := pi.Eval(ctx, nil) 311 require.NoError(t, err) 312 assert.Equal(t, test.exp, res) 313 }) 314 } 315 316 var res interface{} 317 var err error 318 sin := NewSin(NewPi()) 319 res, err = sin.Eval(nil, nil) 320 require.NoError(t, err) 321 assert.Equal(t, 1.2246467991473515e-16, res) 322 323 cos := NewCos(NewPi()) 324 res, err = cos.Eval(nil, nil) 325 require.NoError(t, err) 326 assert.Equal(t, -1.0, res) 327 } 328 329 func TestExp(t *testing.T) { 330 tests := []struct { 331 name string 332 arg sql.Expression 333 exp interface{} 334 err bool 335 skip bool 336 }{ 337 { 338 name: "null argument", 339 arg: nil, 340 exp: nil, 341 }, 342 { 343 name: "zero", 344 arg: expression.NewLiteral(int64(0), types.Int64), 345 exp: math.Exp(0), 346 }, 347 { 348 name: "one", 349 arg: expression.NewLiteral(int64(1), types.Int64), 350 exp: math.Exp(1), 351 }, 352 { 353 name: "ten", 354 arg: expression.NewLiteral(int64(10), types.Int64), 355 exp: math.Exp(10), 356 }, 357 { 358 name: "negative", 359 arg: expression.NewLiteral(int64(-1), types.Int64), 360 exp: math.Exp(-1), 361 }, 362 { 363 name: "float64 1.1", 364 arg: expression.NewLiteral(1.1, types.Float64), 365 exp: math.Exp(1.1), 366 }, 367 { 368 name: "decimal 1.1", 369 arg: expression.NewLiteral(decimal.NewFromFloat(1.1), types.DecimalType_{}), 370 exp: math.Exp(1.1), 371 }, 372 { 373 name: "float64 -12.34", 374 arg: expression.NewLiteral(-12.34, types.Float64), 375 exp: math.Exp(-12.34), 376 }, 377 { 378 name: "decimal is -12.34", 379 arg: expression.NewLiteral(decimal.NewFromFloat(-12.34), types.DecimalType_{}), 380 exp: math.Exp(-12.34), 381 }, 382 { 383 name: "invalid string is 0", 384 arg: expression.NewLiteral("notanumber", types.Text), 385 exp: math.Exp(0), 386 }, 387 { 388 name: "empty string", 389 arg: expression.NewLiteral("", types.Text), 390 exp: math.Exp(0), 391 }, 392 { 393 name: "numerical string", 394 arg: expression.NewLiteral("10", types.Text), 395 exp: math.Exp(10), 396 }, 397 { 398 // we don't do truncation yet 399 // https://github.com/dolthub/dolt/issues/7302 400 name: "scientific string is truncated", 401 arg: expression.NewLiteral("1e1", types.Text), 402 exp: "", 403 err: false, 404 skip: true, 405 }, 406 } 407 408 for _, tt := range tests { 409 t.Run(tt.name, func(t *testing.T) { 410 if tt.skip { 411 t.Skip() 412 } 413 414 ctx := sql.NewEmptyContext() 415 f := NewExp(tt.arg) 416 417 res, err := f.Eval(ctx, nil) 418 if tt.err { 419 require.Error(t, err) 420 return 421 } 422 423 require.NoError(t, err) 424 require.Equal(t, tt.exp, res) 425 }) 426 } 427 }