github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/case_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/shopspring/decimal"
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  func TestCase(t *testing.T) {
    28  	f1 := NewCase(
    29  		NewGetField(0, types.Int64, "foo", false),
    30  		[]CaseBranch{
    31  			{Cond: NewLiteral(int64(1), types.Int64), Value: NewLiteral(int64(2), types.Int64)},
    32  			{Cond: NewLiteral(int64(3), types.Int64), Value: NewLiteral(int64(4), types.Int64)},
    33  			{Cond: NewLiteral(int64(5), types.Int64), Value: NewLiteral(int64(6), types.Int64)},
    34  		},
    35  		NewLiteral(int64(7), types.Int64),
    36  	)
    37  
    38  	f2 := NewCase(
    39  		nil,
    40  		[]CaseBranch{
    41  			{
    42  				Cond: NewEquals(
    43  					NewGetField(0, types.Int64, "foo", false),
    44  					NewLiteral(int64(1), types.Int64),
    45  				),
    46  				Value: NewLiteral(int64(2), types.Int64),
    47  			},
    48  			{
    49  				Cond: NewEquals(
    50  					NewGetField(0, types.Int64, "foo", false),
    51  					NewLiteral(int64(3), types.Int64),
    52  				),
    53  				Value: NewLiteral(int64(4), types.Int64),
    54  			},
    55  			{
    56  				Cond: NewEquals(
    57  					NewGetField(0, types.Int64, "foo", false),
    58  					NewLiteral(int64(5), types.Int64),
    59  				),
    60  				Value: NewLiteral(int64(6), types.Int64),
    61  			},
    62  		},
    63  		NewLiteral(int64(7), types.Int64),
    64  	)
    65  
    66  	f3 := NewCase(
    67  		NewGetField(0, types.Int64, "foo", false),
    68  		[]CaseBranch{
    69  			{Cond: NewLiteral(int64(1), types.Int64), Value: NewLiteral(int64(2), types.Int64)},
    70  			{Cond: NewLiteral(int64(3), types.Int64), Value: NewLiteral(int64(4), types.Int64)},
    71  			{Cond: NewLiteral(int64(5), types.Int64), Value: NewLiteral(int64(6), types.Int64)},
    72  		},
    73  		nil,
    74  	)
    75  
    76  	testCases := []struct {
    77  		name     string
    78  		f        *Case
    79  		row      sql.Row
    80  		expected interface{}
    81  	}{
    82  		{
    83  			"with expr and else branch 1",
    84  			f1,
    85  			sql.Row{int64(1)},
    86  			int64(2),
    87  		},
    88  		{
    89  			"with expr and else branch 2",
    90  			f1,
    91  			sql.Row{int64(3)},
    92  			int64(4),
    93  		},
    94  		{
    95  			"with expr and else branch 3",
    96  			f1,
    97  			sql.Row{int64(5)},
    98  			int64(6),
    99  		},
   100  		{
   101  			"with expr and else, else branch",
   102  			f1,
   103  			sql.Row{int64(9)},
   104  			int64(7),
   105  		},
   106  		{
   107  			"without expr and else branch 1",
   108  			f2,
   109  			sql.Row{int64(1)},
   110  			int64(2),
   111  		},
   112  		{
   113  			"without expr and else branch 2",
   114  			f2,
   115  			sql.Row{int64(3)},
   116  			int64(4),
   117  		},
   118  		{
   119  			"without expr and else branch 3",
   120  			f2,
   121  			sql.Row{int64(5)},
   122  			int64(6),
   123  		},
   124  		{
   125  			"without expr and else, else branch",
   126  			f2,
   127  			sql.Row{int64(9)},
   128  			int64(7),
   129  		},
   130  		{
   131  			"without else, else branch",
   132  			f3,
   133  			sql.Row{int64(9)},
   134  			nil,
   135  		},
   136  	}
   137  
   138  	for _, tt := range testCases {
   139  		t.Run(tt.name, func(t *testing.T) {
   140  			require := require.New(t)
   141  			result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row)
   142  			require.NoError(err)
   143  			require.Equal(tt.expected, result)
   144  		})
   145  	}
   146  }
   147  
   148  func TestCaseType(t *testing.T) {
   149  	caseExpr := func(values ...sql.Expression) *Case {
   150  		var branches []CaseBranch
   151  		for i := 0; i < len(values)-1; i++ {
   152  			branches = append(branches, CaseBranch{
   153  				Cond:  NewLiteral(int64(i), types.Int64),
   154  				Value: values[i],
   155  			})
   156  		}
   157  		return &Case{
   158  			nil,
   159  			branches,
   160  			values[len(values)-1],
   161  		}
   162  	}
   163  
   164  	decimalType := types.MustCreateDecimalType(65, 10)
   165  
   166  	testCases := []struct {
   167  		name string
   168  		c    *Case
   169  		t    sql.Type
   170  	}{
   171  		{
   172  			"standalone else clause",
   173  			caseExpr(NewLiteral(int64(0), types.Int64)),
   174  			types.Int64,
   175  		},
   176  		{
   177  			"unsigned promoted and unsigned",
   178  			caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint32)),
   179  			types.Uint64,
   180  		},
   181  		{
   182  			"signed promoted and signed",
   183  			caseExpr(NewLiteral(int8(0), types.Int8), NewLiteral(int32(1), types.Int32)),
   184  			types.Int64,
   185  		},
   186  		{
   187  			"int and float to float",
   188  			caseExpr(NewLiteral(int64(0), types.Int64), NewLiteral(float64(1.0), types.Float64)),
   189  			types.Float64,
   190  		},
   191  		{
   192  			"float and int to float",
   193  			caseExpr(NewLiteral(float64(1.0), types.Float64), NewLiteral(int64(0), types.Int64)),
   194  			types.Float64,
   195  		},
   196  		{
   197  			"int and text to text",
   198  			caseExpr(NewLiteral(int64(0), types.Int64), NewLiteral("Hello, world!", types.Text)),
   199  			types.LongText,
   200  		},
   201  		{
   202  			"text and blob to blob",
   203  			caseExpr(NewLiteral("Hello, world!", types.Text), NewLiteral([]byte("0x480x650x6c0x6c0x6f"), types.Blob)),
   204  			types.LongBlob,
   205  		},
   206  		{
   207  			"int and null to int",
   208  			caseExpr(NewLiteral(int64(10), types.Int64), NewLiteral(nil, types.Null)),
   209  			types.Int64,
   210  		},
   211  		{
   212  			"null and int to int",
   213  			caseExpr(NewLiteral(nil, types.Null), NewLiteral(int64(10), types.Int64)),
   214  			types.Int64,
   215  		},
   216  		{
   217  			"uint64 and int8 to decimal",
   218  			caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral(int8(0), types.Int8)),
   219  			decimalType,
   220  		},
   221  		{
   222  			"int and text to text",
   223  			caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral("Hello, world!", types.LongText)),
   224  			types.LongText,
   225  		},
   226  		{
   227  			"uint and decimal to decimal",
   228  			caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral("Hello, world!", types.LongText)),
   229  			types.LongText,
   230  		},
   231  		{
   232  			"int and decimal to decimal",
   233  			caseExpr(NewLiteral(int32(10), types.Int32), NewLiteral(decimal.NewFromInt(1), decimalType)),
   234  			decimalType,
   235  		},
   236  		{
   237  			"date and date stays date",
   238  			caseExpr(NewLiteral("2020-04-07", types.Date), NewLiteral("2020-04-07", types.Date)),
   239  			types.Date,
   240  		},
   241  		{
   242  			"date and timestamp becomes datetime",
   243  			caseExpr(NewLiteral("2020-04-07", types.Date), NewLiteral("2020-04-07T00:00:00Z", types.Timestamp)),
   244  			types.DatetimeMaxPrecision,
   245  		},
   246  	}
   247  
   248  	for _, tt := range testCases {
   249  		t.Run(tt.name, func(t *testing.T) {
   250  			require.Equal(t, tt.t, tt.c.Type())
   251  		})
   252  	}
   253  }
   254  
   255  func TestCaseNullBranch(t *testing.T) {
   256  	require := require.New(t)
   257  	f := NewCase(
   258  		NewGetField(0, types.Int64, "x", false),
   259  		[]CaseBranch{
   260  			{
   261  				Cond:  NewLiteral(int64(1), types.Int64),
   262  				Value: NewLiteral(nil, types.Null),
   263  			},
   264  		},
   265  		nil,
   266  	)
   267  	result, err := f.Eval(sql.NewEmptyContext(), sql.Row{int64(1)})
   268  	require.NoError(err)
   269  	require.Nil(result)
   270  }