github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/function_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "reflect" 20 "testing" 21 "time" 22 23 "github.com/shopspring/decimal" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 "github.com/dolthub/go-mysql-server/sql/expression" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 ) 31 32 type FuncTest struct { 33 name string 34 expr sql.Expression 35 expected interface{} 36 expectErr bool 37 } 38 39 func (ft FuncTest) Run(t *testing.T, ctx *sql.Context, r sql.Row) { 40 t.Run(ft.name, func(t *testing.T) { 41 expr, err := ft.expr.WithChildren(ft.expr.Children()...) 42 require.NoError(t, err) 43 44 res, err := expr.Eval(ctx, r) 45 46 if ft.expectErr { 47 assert.Error(t, err) 48 } else { 49 assert.NoError(t, err) 50 assert.Equal(t, ft.expected, res) 51 assertResultType(t, ft.expr.Type(), res) 52 } 53 }) 54 } 55 56 func assertResultType(t *testing.T, expectedType sql.Type, result interface{}) { 57 if result == nil { 58 return 59 } 60 61 switch result.(type) { 62 case uint8: 63 assert.Equal(t, expectedType, types.Uint8) 64 case int8: 65 assert.Equal(t, expectedType, types.Int8) 66 case uint16: 67 assert.Equal(t, expectedType, types.Uint16) 68 case int16: 69 assert.Equal(t, expectedType, types.Int16) 70 case uint32, uint: 71 assert.Equal(t, expectedType, types.Uint32) 72 case int32, int: 73 assert.Equal(t, expectedType, types.Int32) 74 case uint64: 75 assert.Equal(t, expectedType, types.Uint64) 76 case int64: 77 assert.Equal(t, expectedType, types.Int64) 78 case float64: 79 assert.Equal(t, expectedType, types.Float64) 80 case float32: 81 assert.Equal(t, expectedType, types.Float32) 82 case string: 83 assert.True(t, types.IsText(expectedType)) 84 case time.Time: 85 assert.Equal(t, expectedType, types.DatetimeMaxPrecision) 86 case bool: 87 assert.Equal(t, expectedType, types.Boolean) 88 case []byte: 89 assert.Equal(t, expectedType, types.LongBlob) 90 default: 91 assert.Fail(t, "unhandled case") 92 } 93 } 94 95 type TestFactory struct { 96 createFunc interface{} 97 Tests []FuncTest 98 } 99 100 func NewTestFactory(createFunc interface{}) *TestFactory { 101 switch createFunc.(type) { 102 case sql.CreateFunc0Args, sql.CreateFunc1Args, sql.CreateFunc2Args, sql.CreateFunc3Args, sql.CreateFunc4Args, sql.CreateFunc5Args, sql.CreateFunc6Args, sql.CreateFunc7Args, sql.CreateFuncNArgs: 103 default: 104 panic("Unsupported create method") 105 } 106 107 return &TestFactory{createFunc: createFunc} 108 } 109 110 func (tf *TestFactory) Test(t *testing.T, ctx *sql.Context, r sql.Row) { 111 for _, test := range tf.Tests { 112 test.Run(t, ctx, r) 113 } 114 } 115 116 func (tf *TestFactory) AddTest(name string, expected interface{}, expectErr bool, inputs ...interface{}) { 117 var test FuncTest 118 119 inputExprs := toLiteralExpressions(inputs) 120 switch fn := tf.createFunc.(type) { 121 case sql.CreateFunc0Args: 122 if len(inputExprs) != 0 { 123 panic(fmt.Sprintf("error in test: %s. params provided: %d, params required: %d", name, len(inputExprs), 0)) 124 } 125 sqlFuncExpr := fn() 126 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 127 128 case sql.CreateFunc1Args: 129 if len(inputExprs) != 1 { 130 panic(fmt.Sprintf("error in test: %s. params provided: %d, params required: %d", name, len(inputExprs), 1)) 131 } 132 sqlFuncExpr := fn(inputExprs[0]) 133 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 134 135 case sql.CreateFunc2Args: 136 if len(inputExprs) != 2 { 137 panic(fmt.Sprintf("error in test: %s. params provided: %d, params required: %d", name, len(inputExprs), 2)) 138 } 139 sqlFuncExpr := fn(inputExprs[0], inputExprs[1]) 140 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 141 142 case sql.CreateFunc3Args: 143 if len(inputExprs) != 3 { 144 panic(fmt.Sprintf("error in test: %s. params provided: %d, params required: %d", name, len(inputExprs), 3)) 145 } 146 sqlFuncExpr := fn(inputExprs[0], inputExprs[1], inputExprs[2]) 147 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 148 149 case sql.CreateFunc4Args: 150 if len(inputExprs) != 4 { 151 panic(fmt.Sprintf("error in test: %s. params provided: %d, params required: %d", name, len(inputExprs), 4)) 152 } 153 sqlFuncExpr := fn(inputExprs[0], inputExprs[1], inputExprs[2], inputExprs[3]) 154 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 155 156 case sql.CreateFunc5Args: 157 if len(inputExprs) != 5 { 158 panic(fmt.Sprintf("error in test: %s. params provided: %d, params required: %d", name, len(inputExprs), 5)) 159 } 160 sqlFuncExpr := fn(inputExprs[0], inputExprs[1], inputExprs[2], inputExprs[3], inputExprs[4]) 161 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 162 163 case sql.CreateFunc6Args: 164 if len(inputExprs) != 6 { 165 panic(fmt.Sprintf("error in test: %s. params provided: %d, params required: %d", name, len(inputExprs), 6)) 166 } 167 sqlFuncExpr := fn(inputExprs[0], inputExprs[1], inputExprs[2], inputExprs[3], inputExprs[4], inputExprs[5]) 168 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 169 170 case sql.CreateFunc7Args: 171 if len(inputExprs) != 7 { 172 panic(fmt.Sprintf("error in test: %s. params provided: %d, params required: %d", name, len(inputExprs), 7)) 173 } 174 sqlFuncExpr := fn(inputExprs[0], inputExprs[1], inputExprs[2], inputExprs[3], inputExprs[4], inputExprs[5], inputExprs[6]) 175 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 176 177 case sql.CreateFuncNArgs: 178 sqlFuncExpr, err := fn(inputExprs...) 179 180 if err != nil { 181 panic(fmt.Sprintf("error in test: %s. %s", name, err.Error())) 182 } 183 184 test = FuncTest{name, sqlFuncExpr, expected, expectErr} 185 186 default: 187 panic("should never get here. Should have failed in NewTestFactory already.") 188 } 189 190 tf.Tests = append(tf.Tests, test) 191 } 192 193 func (tf *TestFactory) AddSucceeding(expected interface{}, args ...interface{}) { 194 name := generateNameFromArgs(args...) 195 tf.AddTest(name, expected, false, args...) 196 } 197 198 func (tf *TestFactory) AddFailing(args ...interface{}) { 199 name := generateNameFromArgs(args...) 200 tf.AddTest(name, nil, true, args...) 201 } 202 203 func isValidIntOfSize(bits int, n int64) bool { 204 var max int64 = 1<<(bits-1) - 1 205 var min int64 = -(1 << (bits - 1)) 206 return n < max && n > min 207 } 208 209 func (tf *TestFactory) AddSignedVariations(expected interface{}, arg int64) { 210 if isValidIntOfSize(8, arg) { 211 tf.AddSucceeding(expected, int8(arg)) 212 } 213 214 if isValidIntOfSize(16, arg) { 215 tf.AddSucceeding(expected, int16(arg)) 216 } 217 218 if isValidIntOfSize(32, arg) { 219 tf.AddSucceeding(expected, int32(arg)) 220 tf.AddSucceeding(expected, int(arg)) 221 } 222 223 tf.AddSucceeding(expected, arg) 224 } 225 226 func isValidUintOfSize(bits int, n uint64) bool { 227 var max uint64 = 1<<bits - 1 228 return n < max 229 } 230 231 func (tf *TestFactory) AddUnsignedVariations(expected interface{}, arg uint64) { 232 if isValidUintOfSize(8, arg) { 233 tf.AddSucceeding(expected, uint8(arg)) 234 } 235 236 if isValidUintOfSize(16, arg) { 237 tf.AddSucceeding(expected, uint16(arg)) 238 } 239 240 if isValidUintOfSize(32, arg) { 241 tf.AddSucceeding(expected, uint32(arg)) 242 tf.AddSucceeding(expected, uint(arg)) 243 } 244 245 tf.AddSucceeding(expected, arg) 246 } 247 248 func (tf *TestFactory) AddFloatVariations(expected interface{}, f float64) { 249 tf.AddSucceeding(expected, f) 250 tf.AddSucceeding(expected, float32(f)) 251 tf.AddSucceeding(expected, decimal.NewFromFloat(f)) 252 } 253 254 func generateNameFromArgs(args ...interface{}) string { 255 name := "(" 256 257 for i, arg := range args { 258 if i > 0 { 259 name += "," 260 } 261 262 if arg == nil { 263 name += "nil" 264 } else { 265 name += fmt.Sprintf("%s{%v}", reflect.TypeOf(arg).String(), arg) 266 } 267 } 268 269 name += ")" 270 271 return name 272 } 273 274 func toLiteralExpressions(inputs []interface{}) []sql.Expression { 275 literals := make([]sql.Expression, len(inputs)) 276 for i, in := range inputs { 277 literals[i] = toLiteralExpression(in) 278 } 279 280 return literals 281 } 282 283 func toLiteralExpression(input interface{}) *expression.Literal { 284 if input == nil { 285 return expression.NewLiteral(nil, types.Null) 286 } 287 288 switch val := input.(type) { 289 case bool: 290 return expression.NewLiteral(val, types.Boolean) 291 case int8: 292 return expression.NewLiteral(val, types.Int8) 293 case uint8: 294 return expression.NewLiteral(val, types.Int8) 295 case int16: 296 return expression.NewLiteral(val, types.Int16) 297 case uint16: 298 return expression.NewLiteral(val, types.Uint16) 299 case int32: 300 return expression.NewLiteral(val, types.Int32) 301 case uint32: 302 return expression.NewLiteral(val, types.Uint32) 303 case int64: 304 return expression.NewLiteral(val, types.Int64) 305 case uint64: 306 return expression.NewLiteral(val, types.Uint64) 307 case int: 308 return expression.NewLiteral(val, types.Int32) 309 case uint: 310 return expression.NewLiteral(val, types.Uint32) 311 case float32: 312 return expression.NewLiteral(val, types.Float32) 313 case float64: 314 return expression.NewLiteral(val, types.Float64) 315 case decimal.Decimal: 316 return expression.NewLiteral(val, types.Float64) 317 case string: 318 return expression.NewLiteral(val, types.Text) 319 case time.Time: 320 return expression.NewLiteral(val, types.DatetimeMaxPrecision) 321 case []byte: 322 return expression.NewLiteral(string(val), types.Blob) 323 default: 324 panic("unsupported type") 325 } 326 }