github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/absval_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/shopspring/decimal"
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  	_ "github.com/dolthub/go-mysql-server/sql/variables"
    27  )
    28  
    29  func TestAbsValue(t *testing.T) {
    30  	type toTypeFunc func(float64) interface{}
    31  
    32  	decimal1616 := types.MustCreateDecimalType(16, 16)
    33  
    34  	toInt64 := func(x float64) interface{} { return int64(x) }
    35  	toInt32 := func(x float64) interface{} { return int32(x) }
    36  	toInt := func(x float64) interface{} { return int(x) }
    37  	toInt16 := func(x float64) interface{} { return int16(x) }
    38  	toInt8 := func(x float64) interface{} { return int8(x) }
    39  	toUint64 := func(x float64) interface{} { return uint64(x) }
    40  	toUint32 := func(x float64) interface{} { return uint32(x) }
    41  	toUint := func(x float64) interface{} { return uint(x) }
    42  	toUint16 := func(x float64) interface{} { return uint16(x) }
    43  	toUint8 := func(x float64) interface{} { return uint8(x) }
    44  	toFloat64 := func(x float64) interface{} { return x }
    45  	toFloat32 := func(x float64) interface{} { return float32(x) }
    46  	toDecimal1616 := func(x float64) interface{} { return decimal.NewFromFloat(x) }
    47  
    48  	signedTypes := map[sql.Type]toTypeFunc{
    49  		types.Int64: toInt64,
    50  		types.Int32: toInt32,
    51  		types.Int24: toInt,
    52  		types.Int16: toInt16,
    53  		types.Int8:  toInt8}
    54  	unsignedTypes := map[sql.Type]toTypeFunc{
    55  		types.Uint64: toUint64,
    56  		types.Uint32: toUint32,
    57  		types.Uint24: toUint,
    58  		types.Uint16: toUint16,
    59  		types.Uint8:  toUint8}
    60  	floatTypes := map[sql.Type]toTypeFunc{
    61  		types.Float64: toFloat64,
    62  		types.Float32: toFloat32,
    63  		decimal1616:   toDecimal1616,
    64  	}
    65  
    66  	testCases := []struct {
    67  		name       string
    68  		typeToConv map[sql.Type]toTypeFunc
    69  		val        float64
    70  		expected   float64
    71  		err        error
    72  	}{
    73  		{
    74  			"signed types positive int",
    75  			signedTypes,
    76  			5.0,
    77  			5.0,
    78  			nil,
    79  		}, {
    80  			"signed types negative int",
    81  			signedTypes,
    82  			-5.0,
    83  			5.0,
    84  			nil,
    85  		},
    86  		{
    87  			"unsigned types positive int",
    88  			unsignedTypes,
    89  			5.0,
    90  			5.0,
    91  			nil,
    92  		},
    93  		{
    94  			"float positive int",
    95  			floatTypes,
    96  			5.0,
    97  			5.0,
    98  			nil,
    99  		}, {
   100  			"float negative int",
   101  			floatTypes,
   102  			-5.0,
   103  			5.0,
   104  			nil,
   105  		},
   106  	}
   107  
   108  	for _, test := range testCases {
   109  		t.Run(test.name, func(t *testing.T) {
   110  			for sqlType, conv := range test.typeToConv {
   111  				f := NewAbsVal(expression.NewGetField(0, sqlType, "blob", true))
   112  
   113  				row := sql.NewRow(conv(test.val))
   114  				res, err := f.Eval(sql.NewEmptyContext(), row)
   115  
   116  				if test.err == nil {
   117  					require.NoError(t, err)
   118  					require.Equal(t, conv(test.expected), res)
   119  				} else {
   120  					require.Error(t, err)
   121  				}
   122  			}
   123  		})
   124  	}
   125  }