github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/case_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "testing" 19 20 "github.com/shopspring/decimal" 21 "github.com/stretchr/testify/require" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/types" 25 ) 26 27 func TestCase(t *testing.T) { 28 f1 := NewCase( 29 NewGetField(0, types.Int64, "foo", false), 30 []CaseBranch{ 31 {Cond: NewLiteral(int64(1), types.Int64), Value: NewLiteral(int64(2), types.Int64)}, 32 {Cond: NewLiteral(int64(3), types.Int64), Value: NewLiteral(int64(4), types.Int64)}, 33 {Cond: NewLiteral(int64(5), types.Int64), Value: NewLiteral(int64(6), types.Int64)}, 34 }, 35 NewLiteral(int64(7), types.Int64), 36 ) 37 38 f2 := NewCase( 39 nil, 40 []CaseBranch{ 41 { 42 Cond: NewEquals( 43 NewGetField(0, types.Int64, "foo", false), 44 NewLiteral(int64(1), types.Int64), 45 ), 46 Value: NewLiteral(int64(2), types.Int64), 47 }, 48 { 49 Cond: NewEquals( 50 NewGetField(0, types.Int64, "foo", false), 51 NewLiteral(int64(3), types.Int64), 52 ), 53 Value: NewLiteral(int64(4), types.Int64), 54 }, 55 { 56 Cond: NewEquals( 57 NewGetField(0, types.Int64, "foo", false), 58 NewLiteral(int64(5), types.Int64), 59 ), 60 Value: NewLiteral(int64(6), types.Int64), 61 }, 62 }, 63 NewLiteral(int64(7), types.Int64), 64 ) 65 66 f3 := NewCase( 67 NewGetField(0, types.Int64, "foo", false), 68 []CaseBranch{ 69 {Cond: NewLiteral(int64(1), types.Int64), Value: NewLiteral(int64(2), types.Int64)}, 70 {Cond: NewLiteral(int64(3), types.Int64), Value: NewLiteral(int64(4), types.Int64)}, 71 {Cond: NewLiteral(int64(5), types.Int64), Value: NewLiteral(int64(6), types.Int64)}, 72 }, 73 nil, 74 ) 75 76 testCases := []struct { 77 name string 78 f *Case 79 row sql.Row 80 expected interface{} 81 }{ 82 { 83 "with expr and else branch 1", 84 f1, 85 sql.Row{int64(1)}, 86 int64(2), 87 }, 88 { 89 "with expr and else branch 2", 90 f1, 91 sql.Row{int64(3)}, 92 int64(4), 93 }, 94 { 95 "with expr and else branch 3", 96 f1, 97 sql.Row{int64(5)}, 98 int64(6), 99 }, 100 { 101 "with expr and else, else branch", 102 f1, 103 sql.Row{int64(9)}, 104 int64(7), 105 }, 106 { 107 "without expr and else branch 1", 108 f2, 109 sql.Row{int64(1)}, 110 int64(2), 111 }, 112 { 113 "without expr and else branch 2", 114 f2, 115 sql.Row{int64(3)}, 116 int64(4), 117 }, 118 { 119 "without expr and else branch 3", 120 f2, 121 sql.Row{int64(5)}, 122 int64(6), 123 }, 124 { 125 "without expr and else, else branch", 126 f2, 127 sql.Row{int64(9)}, 128 int64(7), 129 }, 130 { 131 "without else, else branch", 132 f3, 133 sql.Row{int64(9)}, 134 nil, 135 }, 136 } 137 138 for _, tt := range testCases { 139 t.Run(tt.name, func(t *testing.T) { 140 require := require.New(t) 141 result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) 142 require.NoError(err) 143 require.Equal(tt.expected, result) 144 }) 145 } 146 } 147 148 func TestCaseType(t *testing.T) { 149 caseExpr := func(values ...sql.Expression) *Case { 150 var branches []CaseBranch 151 for i := 0; i < len(values)-1; i++ { 152 branches = append(branches, CaseBranch{ 153 Cond: NewLiteral(int64(i), types.Int64), 154 Value: values[i], 155 }) 156 } 157 return &Case{ 158 nil, 159 branches, 160 values[len(values)-1], 161 } 162 } 163 164 decimalType := types.MustCreateDecimalType(65, 10) 165 166 testCases := []struct { 167 name string 168 c *Case 169 t sql.Type 170 }{ 171 { 172 "standalone else clause", 173 caseExpr(NewLiteral(int64(0), types.Int64)), 174 types.Int64, 175 }, 176 { 177 "unsigned promoted and unsigned", 178 caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint32)), 179 types.Uint64, 180 }, 181 { 182 "signed promoted and signed", 183 caseExpr(NewLiteral(int8(0), types.Int8), NewLiteral(int32(1), types.Int32)), 184 types.Int64, 185 }, 186 { 187 "int and float to float", 188 caseExpr(NewLiteral(int64(0), types.Int64), NewLiteral(float64(1.0), types.Float64)), 189 types.Float64, 190 }, 191 { 192 "float and int to float", 193 caseExpr(NewLiteral(float64(1.0), types.Float64), NewLiteral(int64(0), types.Int64)), 194 types.Float64, 195 }, 196 { 197 "int and text to text", 198 caseExpr(NewLiteral(int64(0), types.Int64), NewLiteral("Hello, world!", types.Text)), 199 types.LongText, 200 }, 201 { 202 "text and blob to blob", 203 caseExpr(NewLiteral("Hello, world!", types.Text), NewLiteral([]byte("0x480x650x6c0x6c0x6f"), types.Blob)), 204 types.LongBlob, 205 }, 206 { 207 "int and null to int", 208 caseExpr(NewLiteral(int64(10), types.Int64), NewLiteral(nil, types.Null)), 209 types.Int64, 210 }, 211 { 212 "null and int to int", 213 caseExpr(NewLiteral(nil, types.Null), NewLiteral(int64(10), types.Int64)), 214 types.Int64, 215 }, 216 { 217 "uint64 and int8 to decimal", 218 caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral(int8(0), types.Int8)), 219 decimalType, 220 }, 221 { 222 "int and text to text", 223 caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral("Hello, world!", types.LongText)), 224 types.LongText, 225 }, 226 { 227 "uint and decimal to decimal", 228 caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral("Hello, world!", types.LongText)), 229 types.LongText, 230 }, 231 { 232 "int and decimal to decimal", 233 caseExpr(NewLiteral(int32(10), types.Int32), NewLiteral(decimal.NewFromInt(1), decimalType)), 234 decimalType, 235 }, 236 { 237 "date and date stays date", 238 caseExpr(NewLiteral("2020-04-07", types.Date), NewLiteral("2020-04-07", types.Date)), 239 types.Date, 240 }, 241 { 242 "date and timestamp becomes datetime", 243 caseExpr(NewLiteral("2020-04-07", types.Date), NewLiteral("2020-04-07T00:00:00Z", types.Timestamp)), 244 types.DatetimeMaxPrecision, 245 }, 246 } 247 248 for _, tt := range testCases { 249 t.Run(tt.name, func(t *testing.T) { 250 require.Equal(t, tt.t, tt.c.Type()) 251 }) 252 } 253 } 254 255 func TestCaseNullBranch(t *testing.T) { 256 require := require.New(t) 257 f := NewCase( 258 NewGetField(0, types.Int64, "x", false), 259 []CaseBranch{ 260 { 261 Cond: NewLiteral(int64(1), types.Int64), 262 Value: NewLiteral(nil, types.Null), 263 }, 264 }, 265 nil, 266 ) 267 result, err := f.Eval(sql.NewEmptyContext(), sql.Row{int64(1)}) 268 require.NoError(err) 269 require.Nil(result) 270 }