github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/udt_test.go (about) 1 // Copyright 2021 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datacodec 16 17 import ( 18 "errors" 19 "fmt" 20 "testing" 21 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/mock" 24 25 "github.com/datastax/go-cassandra-native-protocol/datatype" 26 "github.com/datastax/go-cassandra-native-protocol/primitive" 27 ) 28 29 var ( 30 udtTypeSimple, _ = datatype.NewUserDefined( 31 "ks1", 32 "type1", 33 []string{"f1", "f2", "f3"}, 34 []datatype.DataType{datatype.Int, datatype.Boolean, datatype.Varchar}, 35 ) 36 udtTypeComplex, _ = datatype.NewUserDefined( 37 "ks1", 38 "type2", 39 []string{"f1", "f2"}, 40 []datatype.DataType{udtTypeSimple, udtTypeSimple}, 41 ) 42 udtTypeEmpty, _ = datatype.NewUserDefined("ks1", "type3", []string{}, []datatype.DataType{}) 43 udtTypeWrong, _ = datatype.NewUserDefined("ks1", "type4", []string{"f1"}, []datatype.DataType{wrongDataType{}}) 44 ) 45 46 var ( 47 udtCodecSimple, _ = NewUserDefined(udtTypeSimple) 48 udtCodecComplex, _ = NewUserDefined(udtTypeComplex) 49 udtCodecEmpty, _ = NewUserDefined(udtTypeEmpty) 50 ) 51 52 type ( 53 SimpleUdt struct { 54 F1 int 55 F2 bool 56 F3 *string 57 } 58 partialUdt struct { 59 F1 int 60 F2 bool 61 } 62 excessUdt struct { 63 F1 int 64 F2 bool 65 F3 *string 66 F4 float64 67 } 68 complexUdt struct { 69 F1 SimpleUdt 70 F2 *excessUdt 71 } 72 ) 73 74 var ( 75 nullElementsUdtBytes = []byte{ 76 255, 255, 255, 255, // nil int 77 255, 255, 255, 255, // nil boolean 78 255, 255, 255, 255, // nil string 79 } 80 oneTwoThreeAbcUdtBytes = []byte{ 81 0, 0, 0, 4, // length of int 82 0, 0, 0, 123, // int 83 0, 0, 0, 1, // length of boolean 84 1, // boolean 85 0, 0, 0, 3, // length of string 86 a, b, c, // string 87 } 88 udtWithNullFieldsBytes = []byte{ 89 0, 0, 0, 4, // length of int 90 0, 0, 0, 123, // int 91 0, 0, 0, 1, // length of boolean 92 0, // boolean 93 255, 255, 255, 255, // nil string 94 } 95 udtWithNullFieldsBytes2 = []byte{ 96 0, 0, 0, 4, // length of int 97 0, 0, 0, 123, // int 98 255, 255, 255, 255, // nil boolean 99 255, 255, 255, 255, // nil string 100 } 101 udtOneTwoThreeFalseAbcBytes = []byte{ 102 0, 0, 0, 4, // length of int 103 0, 0, 0, 123, // int 104 0, 0, 0, 1, // length of boolean 105 0, // boolean 106 0, 0, 0, 3, // length of string 107 a, b, c, // string 108 } 109 udtComplexBytes = []byte{ 110 0, 0, 0, 20, // length of element 1 111 // element 1 112 0, 0, 0, 4, // length of int 113 0, 0, 0, 12, // int 114 0, 0, 0, 1, // length of boolean 115 0, // boolean 116 0, 0, 0, 3, // length of string 117 a, b, c, // string 118 0, 0, 0, 20, // length of element 2 119 // element 2 120 0, 0, 0, 4, // length of int 121 0, 0, 0, 34, // int 122 0, 0, 0, 1, // length of boolean 123 1, // boolean 124 0, 0, 0, 3, // length of string 125 d, e, f, // string 126 } 127 udtZeroBytes = []byte{ 128 0, 0, 0, 4, // length of int 129 0, 0, 0, 0, // int 130 0, 0, 0, 1, // length of boolean 131 0, // boolean 132 255, 255, 255, 255, // nil string 133 } 134 udtComplexWithNullsBytes = []byte{ 135 0, 0, 0, 20, // length of element 1 136 // element 1 137 0, 0, 0, 4, // length of int 138 0, 0, 0, 12, // int 139 0, 0, 0, 1, // length of boolean 140 0, // boolean 141 0, 0, 0, 3, // length of string 142 a, b, c, // string 143 0, 0, 0, 17, // length of element 2 144 // element 2 145 0, 0, 0, 4, // length of int 146 0, 0, 0, 34, // int 147 0, 0, 0, 1, // length of boolean 148 1, // boolean 149 255, 255, 255, 255, // nil string 150 } 151 udtComplexWithNulls2Bytes = []byte{ 152 0, 0, 0, 20, // length of element 1 153 // element 1 154 0, 0, 0, 4, // length of int 155 0, 0, 0, 12, // int 156 0, 0, 0, 1, // length of boolean 157 0, // boolean 158 0, 0, 0, 3, // length of string 159 a, b, c, // string 160 255, 255, 255, 255, // nil element 2 161 } 162 udtMissingBytes = []byte{ 163 0, 0, 0, 4, // length of int 164 0, 0, 0, 123, // int 165 0, 0, 0, 1, // length of boolean 166 1, // boolean 167 0, 0, 0, 3, // length of string 168 // missing string 169 } 170 ) 171 172 func TestNewUserDefinedCodec(t *testing.T) { 173 tests := []struct { 174 name string 175 dataType *datatype.UserDefined 176 expected Codec 177 err string 178 }{ 179 { 180 "simple", 181 udtTypeSimple, 182 &udtCodec{dataType: udtTypeSimple, fieldCodecs: []Codec{Int, Boolean, Varchar}}, 183 "", 184 }, 185 { 186 "complex", 187 udtTypeComplex, 188 &udtCodec{ 189 dataType: udtTypeComplex, 190 fieldCodecs: []Codec{ 191 &udtCodec{dataType: udtTypeSimple, fieldCodecs: []Codec{Int, Boolean, Varchar}}, 192 &udtCodec{dataType: udtTypeSimple, fieldCodecs: []Codec{Int, Boolean, Varchar}}, 193 }, 194 }, 195 "", 196 }, 197 { 198 "empty", 199 udtTypeEmpty, 200 &udtCodec{dataType: udtTypeEmpty, fieldCodecs: []Codec{}}, 201 "", 202 }, 203 { 204 "wrong child", 205 udtTypeWrong, 206 nil, 207 "cannot create codec for user-defined type field 0 (f1): cannot create data codec for CQL type 666", 208 }, 209 { 210 "nil", 211 nil, 212 nil, 213 "data type is nil", 214 }, 215 } 216 for _, tt := range tests { 217 t.Run(tt.name, func(t *testing.T) { 218 actual, err := NewUserDefined(tt.dataType) 219 assert.Equal(t, tt.expected, actual) 220 assertErrorMessage(t, tt.err, err) 221 }) 222 } 223 } 224 225 func Test_udtCodec_Encode(t *testing.T) { 226 for _, version := range primitive.SupportedProtocolVersionsGreaterThanOrEqualTo(primitive.ProtocolVersion3) { 227 t.Run(version.String(), func(t *testing.T) { 228 t.Run("[]interface{}", func(t *testing.T) { 229 tests := []struct { 230 name string 231 codec Codec 232 input *[]interface{} 233 expected []byte 234 err string 235 }{ 236 {"nil", udtCodecEmpty, nil, nil, ""}, 237 {"empty", udtCodecSimple, &[]interface{}{nil, nil, nil}, nullElementsUdtBytes, ""}, 238 {"simple", udtCodecSimple, &[]interface{}{123, true, "abc"}, oneTwoThreeAbcUdtBytes, ""}, 239 {"simple with pointers", udtCodecSimple, &[]interface{}{intPtr(123), boolPtr(true), stringPtr("abc")}, oneTwoThreeAbcUdtBytes, ""}, 240 {"nil element", udtCodecSimple, &[]interface{}{123, false, nil}, udtWithNullFieldsBytes, ""}, 241 {"not enough elements", udtCodecSimple, &[]interface{}{123}, nil, "slice index out of range: 1"}, 242 {"too many elements", udtCodecSimple, &[]interface{}{123, false, "abc", "extra"}, udtOneTwoThreeFalseAbcBytes, ""}, 243 {"complex", udtCodecComplex, &[]interface{}{[]interface{}{12, false, "abc"}, []interface{}{34, true, "def"}}, udtComplexBytes, ""}, 244 } 245 for _, tt := range tests { 246 t.Run(tt.name, func(t *testing.T) { 247 if tt.input != nil { 248 t.Run("value", func(t *testing.T) { 249 dest, err := tt.codec.Encode(*tt.input, version) 250 assert.Equal(t, tt.expected, dest) 251 assertErrorMessage(t, tt.err, err) 252 }) 253 } 254 t.Run("pointer", func(t *testing.T) { 255 dest, err := tt.codec.Encode(tt.input, version) 256 assert.Equal(t, tt.expected, dest) 257 assertErrorMessage(t, tt.err, err) 258 }) 259 }) 260 } 261 }) 262 t.Run("map[string]interface{}", func(t *testing.T) { 263 tests := []struct { 264 name string 265 codec Codec 266 input *map[string]interface{} 267 expected []byte 268 err string 269 }{ 270 {"nil", udtCodecEmpty, nil, nil, ""}, 271 {"empty", udtCodecSimple, &map[string]interface{}{"f1": nil, "f2": nil, "f3": nil}, nullElementsUdtBytes, ""}, 272 {"simple", udtCodecSimple, &map[string]interface{}{"f1": 123, "f2": true, "f3": "abc"}, oneTwoThreeAbcUdtBytes, ""}, 273 {"simple with pointers", udtCodecSimple, &map[string]interface{}{"f1": intPtr(123), "f2": boolPtr(true), "f3": stringPtr("abc")}, oneTwoThreeAbcUdtBytes, ""}, 274 {"nil element", udtCodecSimple, &map[string]interface{}{"f1": 123, "f2": false, "f3": nil}, udtWithNullFieldsBytes, ""}, 275 {"not enough elements", udtCodecSimple, &map[string]interface{}{"f1": 123}, udtWithNullFieldsBytes2, ""}, 276 {"too many elements", udtCodecSimple, &map[string]interface{}{"f1": 123, "f2": false, "f3": "abc", "f4": "extra"}, udtOneTwoThreeFalseAbcBytes, ""}, 277 {"complex", udtCodecComplex, &map[string]interface{}{"f1": map[string]interface{}{"f1": 12, "f2": false, "f3": "abc"}, "f2": map[string]interface{}{"f1": 34, "f2": true, "f3": "def"}}, udtComplexBytes, ""}, 278 } 279 for _, tt := range tests { 280 t.Run(tt.name, func(t *testing.T) { 281 if tt.input != nil { 282 t.Run("value", func(t *testing.T) { 283 dest, err := tt.codec.Encode(*tt.input, version) 284 assert.Equal(t, tt.expected, dest) 285 assertErrorMessage(t, tt.err, err) 286 }) 287 } 288 t.Run("pointer", func(t *testing.T) { 289 dest, err := tt.codec.Encode(tt.input, version) 290 assert.Equal(t, tt.expected, dest) 291 assertErrorMessage(t, tt.err, err) 292 }) 293 }) 294 } 295 }) 296 t.Run("struct simple", func(t *testing.T) { 297 tests := []struct { 298 name string 299 codec Codec 300 input *SimpleUdt 301 expected []byte 302 }{ 303 {"nil", udtCodecEmpty, nil, nil}, 304 {"empty", udtCodecSimple, &SimpleUdt{}, udtZeroBytes}, 305 {"simple", udtCodecSimple, &SimpleUdt{123, false, stringPtr("abc")}, udtOneTwoThreeFalseAbcBytes}, 306 {"nil element", udtCodecSimple, &SimpleUdt{123, false, nil}, udtWithNullFieldsBytes}, 307 } 308 for _, tt := range tests { 309 t.Run(tt.name, func(t *testing.T) { 310 if tt.input != nil { 311 t.Run("value", func(t *testing.T) { 312 dest, err := tt.codec.Encode(*tt.input, version) 313 assert.Equal(t, tt.expected, dest) 314 assert.NoError(t, err) 315 316 }) 317 } 318 t.Run("pointer", func(t *testing.T) { 319 dest, err := tt.codec.Encode(tt.input, version) 320 assert.Equal(t, tt.expected, dest) 321 assert.NoError(t, err) 322 }) 323 }) 324 } 325 }) 326 t.Run("struct partial", func(t *testing.T) { 327 tests := []struct { 328 name string 329 codec Codec 330 input *partialUdt 331 expected []byte 332 err string 333 }{ 334 {"simple", udtCodecSimple, &partialUdt{123, false}, nil, "no accessible field with name 'f3' found"}, 335 } 336 for _, tt := range tests { 337 t.Run(tt.name, func(t *testing.T) { 338 if tt.input != nil { 339 t.Run("value", func(t *testing.T) { 340 dest, err := tt.codec.Encode(*tt.input, version) 341 assert.Equal(t, tt.expected, dest) 342 assertErrorMessage(t, tt.err, err) 343 }) 344 } 345 t.Run("pointer", func(t *testing.T) { 346 dest, err := tt.codec.Encode(tt.input, version) 347 assert.Equal(t, tt.expected, dest) 348 assertErrorMessage(t, tt.err, err) 349 }) 350 }) 351 } 352 }) 353 t.Run("struct excess", func(t *testing.T) { 354 tests := []struct { 355 name string 356 codec Codec 357 input *excessUdt 358 expected []byte 359 }{ 360 {"nil", udtCodecEmpty, nil, nil}, 361 {"empty", udtCodecSimple, &excessUdt{}, udtZeroBytes}, 362 {"simple", udtCodecSimple, &excessUdt{123, false, stringPtr("abc"), 42.0}, udtOneTwoThreeFalseAbcBytes}, 363 } 364 for _, tt := range tests { 365 t.Run(tt.name, func(t *testing.T) { 366 if tt.input != nil { 367 t.Run("value", func(t *testing.T) { 368 dest, err := tt.codec.Encode(*tt.input, version) 369 assert.Equal(t, tt.expected, dest) 370 assert.NoError(t, err) 371 }) 372 } 373 t.Run("pointer", func(t *testing.T) { 374 dest, err := tt.codec.Encode(tt.input, version) 375 assert.Equal(t, tt.expected, dest) 376 assert.NoError(t, err) 377 }) 378 }) 379 } 380 }) 381 t.Run("struct complex", func(t *testing.T) { 382 tests := []struct { 383 name string 384 codec Codec 385 input *complexUdt 386 expected []byte 387 }{ 388 {"nil", udtCodecEmpty, nil, nil}, 389 {"empty", udtCodecEmpty, &complexUdt{}, nil}, 390 {"complex", udtCodecComplex, &complexUdt{ 391 SimpleUdt{12, false, stringPtr("abc")}, 392 &excessUdt{34, true, nil, 0.0}, 393 }, udtComplexWithNullsBytes}, 394 {"nil element", udtCodecComplex, &complexUdt{ 395 SimpleUdt{12, false, stringPtr("abc")}, 396 nil, 397 }, udtComplexWithNulls2Bytes}, 398 } 399 for _, tt := range tests { 400 t.Run(tt.name, func(t *testing.T) { 401 if tt.input != nil { 402 t.Run("value", func(t *testing.T) { 403 dest, err := tt.codec.Encode(*tt.input, version) 404 assert.Equal(t, tt.expected, dest) 405 assert.NoError(t, err) 406 }) 407 } 408 t.Run("pointer", func(t *testing.T) { 409 dest, err := tt.codec.Encode(tt.input, version) 410 assert.Equal(t, tt.expected, dest) 411 assert.NoError(t, err) 412 }) 413 }) 414 } 415 }) 416 }) 417 } 418 for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { 419 t.Run(version.String(), func(t *testing.T) { 420 dest, err := udtCodecSimple.Encode(nil, version) 421 assert.Nil(t, dest) 422 expectedMessage := fmt.Sprintf("data type %s not supported in %v", udtTypeSimple, version) 423 assertErrorMessage(t, expectedMessage, err) 424 }) 425 } 426 t.Run("invalid types", func(t *testing.T) { 427 dest, err := udtCodecSimple.Encode(123, primitive.ProtocolVersion5) 428 assert.Nil(t, dest) 429 assert.EqualError(t, err, "cannot encode int as CQL ks1.type1<f1:int,f2:boolean,f3:varchar> with ProtocolVersion OSS 5: source type not supported") 430 dest, err = udtCodecSimple.Encode(map[int]string{123: "abc"}, primitive.ProtocolVersion5) 431 assert.Nil(t, dest) 432 assert.EqualError(t, err, "cannot encode map[int]string as CQL ks1.type1<f1:int,f2:boolean,f3:varchar> with ProtocolVersion OSS 5: wrong map key, expected string, got: int") 433 // this can only be detected once the decoding started 434 dest, err = udtCodecSimple.Encode(map[string]int{"f3": 123}, primitive.ProtocolVersion5) 435 assert.Nil(t, dest) 436 assert.EqualError(t, err, "cannot encode map[string]int as CQL ks1.type1<f1:int,f2:boolean,f3:varchar> with ProtocolVersion OSS 5: cannot encode field 2 (f3): cannot encode int as CQL varchar with ProtocolVersion OSS 5: cannot convert from int to []uint8: conversion not supported") 437 }) 438 } 439 440 func Test_udtCodec_Decode(t *testing.T) { 441 for _, version := range primitive.SupportedProtocolVersionsGreaterThanOrEqualTo(primitive.ProtocolVersion3) { 442 t.Run(version.String(), func(t *testing.T) { 443 t.Run("interface{}", func(t *testing.T) { 444 tests := []struct { 445 name string 446 codec Codec 447 input []byte 448 dest *interface{} 449 expected *interface{} 450 err string 451 wasNull bool 452 }{ 453 {"nil input", udtCodecSimple, nil, new(interface{}), new(interface{}), "", true}, 454 {"nil elements map to zero values", udtCodecSimple, nullElementsUdtBytes, new(interface{}), interfacePtr(map[string]interface{}{"f1": nil, "f2": nil, "f3": nil}), "", false}, 455 {"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, new(interface{}), interfacePtr(map[string]interface{}{"f1": int32(123), "f2": true, "f3": "abc"}), "", false}, 456 {"complex", udtCodecComplex, udtComplexBytes, new(interface{}), interfacePtr(map[string]interface{}{ 457 "f1": map[string]interface{}{"f1": int32(12), "f2": false, "f3": "abc"}, 458 "f2": map[string]interface{}{"f1": int32(34), "f2": true, "f3": "def"}, 459 }), "", false}, 460 {"nil dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, nil, nil, "destination is nil", false}, 461 {"not enough bytes", udtCodecSimple, udtMissingBytes, new(interface{}), interfacePtr(map[string]interface{}{"f1": int32(123), "f2": true}), "cannot read field 2 (f3)", false}, 462 {"slice dest -> map dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, interfacePtr([]interface{}{}), interfacePtr(map[string]interface{}{"f1": int32(123), "f2": true, "f3": "abc"}), "", false}, 463 } 464 for _, tt := range tests { 465 t.Run(tt.name, func(t *testing.T) { 466 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 467 assert.Equal(t, tt.expected, tt.dest) 468 assert.Equal(t, tt.wasNull, wasNull) 469 assertErrorMessage(t, tt.err, err) 470 }) 471 } 472 }) 473 t.Run("*[]interface{}", func(t *testing.T) { 474 tests := []struct { 475 name string 476 codec Codec 477 input []byte 478 dest *[]interface{} 479 expected *[]interface{} 480 err string 481 wasNull bool 482 }{ 483 {"nil input", udtCodecSimple, nil, new([]interface{}), new([]interface{}), "", true}, 484 {"nil elements map to zero values", udtCodecSimple, nullElementsUdtBytes, new([]interface{}), &[]interface{}{nil, nil, nil}, "", false}, 485 {"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, new([]interface{}), &[]interface{}{int32(123), true, "abc"}, "", false}, 486 {"complex", udtCodecComplex, udtComplexBytes, new([]interface{}), &[]interface{}{ 487 map[string]interface{}{"f1": int32(12), "f2": false, "f3": "abc"}, 488 map[string]interface{}{"f1": int32(34), "f2": true, "f3": "def"}, 489 }, "", false}, 490 {"nil dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, nil, nil, "destination is nil", false}, 491 {"not enough bytes", udtCodecSimple, udtMissingBytes, new([]interface{}), &[]interface{}{int32(123), true, nil}, "cannot read field 2 (f3)", false}, 492 {"slice length too large", udtCodecSimple, oneTwoThreeAbcUdtBytes, &[]interface{}{nil, nil, nil, 42.0}, &[]interface{}{int32(123), true, "abc"}, "", false}, 493 } 494 for _, tt := range tests { 495 t.Run(tt.name, func(t *testing.T) { 496 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 497 assert.Equal(t, tt.expected, tt.dest) 498 assert.Equal(t, tt.wasNull, wasNull) 499 assertErrorMessage(t, tt.err, err) 500 }) 501 } 502 }) 503 t.Run("*[3]interface{}", func(t *testing.T) { 504 tests := []struct { 505 name string 506 codec Codec 507 input []byte 508 dest *[3]interface{} 509 expected *[3]interface{} 510 err string 511 wasNull bool 512 }{ 513 {"nil input", udtCodecSimple, nil, new([3]interface{}), new([3]interface{}), "", true}, 514 {"nil elements map to zero values", udtCodecSimple, nullElementsUdtBytes, new([3]interface{}), &[3]interface{}{nil, nil, nil}, "", false}, 515 {"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, new([3]interface{}), &[3]interface{}{int32(123), true, "abc"}, "", false}, 516 {"nil dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, nil, nil, "destination is nil", false}, 517 {"not enough bytes", udtCodecSimple, udtMissingBytes, new([3]interface{}), &[3]interface{}{int32(123), true, nil}, "cannot read field 2 (f3)", false}, 518 } 519 for _, tt := range tests { 520 t.Run(tt.name, func(t *testing.T) { 521 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 522 assert.Equal(t, tt.expected, tt.dest) 523 assert.Equal(t, tt.wasNull, wasNull) 524 assertErrorMessage(t, tt.err, err) 525 }) 526 } 527 }) 528 t.Run("*[][]interface{}", func(t *testing.T) { 529 tests := []struct { 530 name string 531 codec Codec 532 input []byte 533 dest *[][]interface{} 534 expected *[][]interface{} 535 err string 536 wasNull bool 537 }{ 538 {"complex", udtCodecComplex, udtComplexBytes, new([][]interface{}), &[][]interface{}{ 539 {int32(12), false, "abc"}, 540 {int32(34), true, "def"}, 541 }, "", false}, 542 } 543 for _, tt := range tests { 544 t.Run(tt.name, func(t *testing.T) { 545 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 546 assert.Equal(t, tt.expected, tt.dest) 547 assert.Equal(t, tt.wasNull, wasNull) 548 assertErrorMessage(t, tt.err, err) 549 }) 550 } 551 }) 552 t.Run("*map[string]interface{}", func(t *testing.T) { 553 tests := []struct { 554 name string 555 codec Codec 556 input []byte 557 dest *map[string]interface{} 558 expected *map[string]interface{} 559 err string 560 wasNull bool 561 }{ 562 {"nil input", udtCodecSimple, nil, new(map[string]interface{}), new(map[string]interface{}), "", true}, 563 {"nil elements map to zero values", udtCodecSimple, nullElementsUdtBytes, new(map[string]interface{}), &map[string]interface{}{"f1": nil, "f2": nil, "f3": nil}, "", false}, 564 {"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, new(map[string]interface{}), &map[string]interface{}{"f1": int32(123), "f2": true, "f3": "abc"}, "", false}, 565 {"complex", udtCodecComplex, udtComplexBytes, new(map[string]interface{}), &map[string]interface{}{ 566 "f1": map[string]interface{}{"f1": int32(12), "f2": false, "f3": "abc"}, 567 "f2": map[string]interface{}{"f1": int32(34), "f2": true, "f3": "def"}, 568 }, "", false}, 569 {"nil dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, nil, nil, "destination is nil", false}, 570 {"not enough bytes", udtCodecSimple, udtMissingBytes, new(map[string]interface{}), &map[string]interface{}{"f1": int32(123), "f2": true}, "cannot read field 2 (f3)", false}, 571 } 572 for _, tt := range tests { 573 t.Run(tt.name, func(t *testing.T) { 574 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 575 assert.Equal(t, tt.expected, tt.dest) 576 assert.Equal(t, tt.wasNull, wasNull) 577 assertErrorMessage(t, tt.err, err) 578 }) 579 } 580 }) 581 t.Run("struct simple", func(t *testing.T) { 582 tests := []struct { 583 name string 584 codec Codec 585 input []byte 586 dest *SimpleUdt 587 expected *SimpleUdt 588 err string 589 wasNull bool 590 }{ 591 {"nil input", udtCodecSimple, nil, &SimpleUdt{}, &SimpleUdt{}, "", true}, 592 {"empty input", udtCodecSimple, []byte{}, &SimpleUdt{}, &SimpleUdt{}, "", true}, 593 {"nil elements", udtCodecSimple, nullElementsUdtBytes, &SimpleUdt{}, &SimpleUdt{F1: 0, F2: false, F3: nil}, "", false}, 594 {"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, &SimpleUdt{}, &SimpleUdt{F1: 123, F2: true, F3: stringPtr("abc")}, "", false}, 595 {"nil dest", udtCodecSimple, udtMissingBytes, nil, nil, "destination is nil", false}, 596 } 597 for _, tt := range tests { 598 t.Run(tt.name, func(t *testing.T) { 599 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 600 if tt.expected != nil && tt.dest != nil { 601 assert.Equal(t, *tt.expected, *tt.dest) 602 } 603 assert.Equal(t, tt.wasNull, wasNull) 604 assertErrorMessage(t, tt.err, err) 605 }) 606 } 607 }) 608 t.Run("struct partial", func(t *testing.T) { 609 tests := []struct { 610 name string 611 codec Codec 612 input []byte 613 dest *partialUdt 614 expected *partialUdt 615 err string 616 wasNull bool 617 }{ 618 {"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, &partialUdt{}, &partialUdt{F1: 123, F2: true}, "no accessible field with name 'f3' found", false}, 619 } 620 for _, tt := range tests { 621 t.Run(tt.name, func(t *testing.T) { 622 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 623 if tt.expected != nil && tt.dest != nil { 624 assert.Equal(t, *tt.expected, *tt.dest) 625 } 626 assert.Equal(t, tt.wasNull, wasNull) 627 assertErrorMessage(t, tt.err, err) 628 }) 629 } 630 }) 631 t.Run("struct excess", func(t *testing.T) { 632 tests := []struct { 633 name string 634 codec Codec 635 input []byte 636 dest *excessUdt 637 expected *excessUdt 638 err string 639 wasNull bool 640 }{ 641 {"nil input", udtCodecSimple, nil, &excessUdt{}, &excessUdt{}, "", true}, 642 {"empty input", udtCodecSimple, []byte{}, &excessUdt{}, &excessUdt{}, "", true}, 643 {"nil elements", udtCodecSimple, nullElementsUdtBytes, &excessUdt{}, &excessUdt{}, "", false}, 644 {"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, &excessUdt{}, &excessUdt{F1: 123, F2: true, F3: stringPtr("abc")}, "", false}, 645 {"nil dest", udtCodecSimple, udtMissingBytes, nil, nil, "destination is nil", false}, 646 } 647 for _, tt := range tests { 648 t.Run(tt.name, func(t *testing.T) { 649 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 650 if tt.expected != nil && tt.dest != nil { 651 assert.Equal(t, *tt.expected, *tt.dest) 652 } 653 assert.Equal(t, tt.wasNull, wasNull) 654 assertErrorMessage(t, tt.err, err) 655 }) 656 } 657 }) 658 t.Run("struct complex", func(t *testing.T) { 659 tests := []struct { 660 name string 661 codec Codec 662 input []byte 663 dest *complexUdt 664 expected *complexUdt 665 err string 666 wasNull bool 667 }{ 668 {"nil", udtCodecComplex, nil, &complexUdt{}, &complexUdt{}, "", true}, 669 {"empty", udtCodecComplex, []byte{}, &complexUdt{}, &complexUdt{}, "", true}, 670 {"complex", udtCodecComplex, udtComplexWithNullsBytes, &complexUdt{}, &complexUdt{ 671 SimpleUdt{12, false, stringPtr("abc")}, 672 &excessUdt{34, true, nil, 0.0}, 673 }, "", false}, 674 {"nil element", udtCodecComplex, udtComplexWithNulls2Bytes, &complexUdt{}, &complexUdt{ 675 SimpleUdt{12, false, stringPtr("abc")}, 676 nil, 677 }, "", false}, 678 } 679 for _, tt := range tests { 680 t.Run(tt.name, func(t *testing.T) { 681 wasNull, err := tt.codec.Decode(tt.input, tt.dest, version) 682 assert.Equal(t, *tt.expected, *tt.dest) 683 assert.Equal(t, tt.wasNull, wasNull) 684 assertErrorMessage(t, tt.err, err) 685 }) 686 } 687 }) 688 }) 689 } 690 for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { 691 t.Run(version.String(), func(t *testing.T) { 692 _, err := udtCodecSimple.Decode(nil, nil, version) 693 expectedMessage := fmt.Sprintf("data type %s not supported in %v", udtTypeSimple, version) 694 assertErrorMessage(t, expectedMessage, err) 695 }) 696 } 697 t.Run("invalid types", func(t *testing.T) { 698 wasNull, err := udtCodecSimple.Decode([]byte{1, 2, 3}, new(int), primitive.ProtocolVersion5) 699 assert.False(t, wasNull) 700 assert.EqualError(t, err, "cannot decode CQL ks1.type1<f1:int,f2:boolean,f3:varchar> as *int with ProtocolVersion OSS 5: destination type not supported") 701 wasNull, err = udtCodecSimple.Decode([]byte{1, 2, 3}, new(map[int]string), primitive.ProtocolVersion5) 702 assert.False(t, wasNull) 703 assert.EqualError(t, err, "cannot decode CQL ks1.type1<f1:int,f2:boolean,f3:varchar> as *map[int]string with ProtocolVersion OSS 5: wrong map key, expected string, got: int") 704 }) 705 } 706 707 func Test_writeUdt(t *testing.T) { 708 type args struct { 709 ext extractor 710 fieldNames []string 711 fieldCodecs []Codec 712 version primitive.ProtocolVersion 713 } 714 tests := []struct { 715 name string 716 args args 717 want []byte 718 wantErr string 719 }{ 720 { 721 "cannot extract elem", 722 args{ 723 func() extractor { 724 ext := &mockExtractor{} 725 ext.On("getElem", 0, "f1").Return(nil, errors.New("wrong type")) 726 return ext 727 }(), 728 []string{"f1"}, 729 []Codec{nil}, 730 primitive.ProtocolVersion5, 731 }, 732 nil, 733 "cannot extract field 0 (f1): wrong type", 734 }, 735 { 736 "cannot encode", 737 args{ 738 func() extractor { 739 ext := &mockExtractor{} 740 ext.On("getElem", 0, "f1").Return(123, nil) 741 return ext 742 }(), 743 []string{"f1"}, 744 func() []Codec { 745 codec := &mockCodec{} 746 codec.On("Encode", 123, primitive.ProtocolVersion5).Return(nil, errors.New("write failed")) 747 return []Codec{codec} 748 }(), 749 primitive.ProtocolVersion5, 750 }, 751 nil, 752 "cannot encode field 0 (f1): write failed", 753 }, 754 {"success", args{ 755 func() extractor { 756 ext := &mockExtractor{} 757 ext.On("getElem", 0, "f1").Return(123, nil) 758 ext.On("getElem", 1, "f2").Return("abc", nil) 759 ext.On("getElem", 2, "f3").Return(true, nil) 760 return ext 761 }(), 762 []string{"f1", "f2", "f3"}, 763 func() []Codec { 764 codec1 := &mockCodec{} 765 codec1.On("Encode", 123, primitive.ProtocolVersion5).Return([]byte{1}, nil) 766 codec2 := &mockCodec{} 767 codec2.On("Encode", "abc", primitive.ProtocolVersion5).Return([]byte{2}, nil) 768 codec3 := &mockCodec{} 769 codec3.On("Encode", true, primitive.ProtocolVersion5).Return(nil, nil) 770 return []Codec{codec1, codec2, codec3} 771 }(), 772 primitive.ProtocolVersion5, 773 }, []byte{ 774 0, 0, 0, 1, // field 1 775 1, 776 0, 0, 0, 1, // field 2 777 2, 778 255, 255, 255, 255, // field 3 (nil) 779 }, ""}, 780 } 781 for _, tt := range tests { 782 t.Run(tt.name, func(t *testing.T) { 783 got, gotErr := writeUdt(tt.args.ext, tt.args.fieldNames, tt.args.fieldCodecs, tt.args.version) 784 assert.Equal(t, tt.want, got) 785 assertErrorMessage(t, tt.wantErr, gotErr) 786 }) 787 } 788 } 789 790 func Test_readUdt(t *testing.T) { 791 type args struct { 792 source []byte 793 inj injector 794 fieldNames []string 795 fieldCodecs []Codec 796 version primitive.ProtocolVersion 797 } 798 tests := []struct { 799 name string 800 args args 801 wantErr string 802 }{ 803 { 804 "cannot read element", 805 args{ 806 []byte{ 807 0, // wrong [bytes] 808 }, 809 nil, 810 []string{"f1"}, 811 []Codec{nil}, 812 primitive.ProtocolVersion5, 813 }, 814 "cannot read field 0 (f1): cannot read [bytes] length: cannot read [int]: unexpected EOF", 815 }, 816 { 817 "cannot create element", 818 args{ 819 []byte{ 820 0, 0, 0, 1, 123, // [bytes] 821 }, 822 func() injector { 823 inj := &mockInjector{} 824 inj.On("zeroElem", 0, "f1").Return(nil, errors.New("wrong data type")) 825 return inj 826 }(), 827 []string{"f1"}, 828 func() []Codec { 829 codec := &mockCodec{} 830 codec.On("DataType").Return(datatype.Int) 831 return []Codec{codec} 832 }(), 833 primitive.ProtocolVersion5, 834 }, 835 "cannot create zero field 0 (f1): wrong data type", 836 }, 837 { 838 "cannot decode element", 839 args{ 840 []byte{ 841 0, 0, 0, 1, 123, // [bytes] 842 }, 843 func() injector { 844 inj := &mockInjector{} 845 inj.On("zeroElem", 0, "f1").Return(new(int), nil) 846 return inj 847 }(), 848 []string{"f1"}, 849 func() []Codec { 850 codec := &mockCodec{} 851 codec.On("DataType").Return(datatype.Int) 852 codec.On("Decode", []byte{123}, new(int), primitive.ProtocolVersion5).Return(false, errors.New("decode failed")) 853 return []Codec{codec} 854 }(), 855 primitive.ProtocolVersion5, 856 }, 857 "cannot decode field 0 (f1): decode failed", 858 }, 859 { 860 "cannot set element", 861 args{ 862 []byte{ 863 0, 0, 0, 1, 123, // [bytes] 864 }, 865 func() injector { 866 inj := &mockInjector{} 867 inj.On("zeroElem", 0, "f1").Return(new(int), nil) 868 inj.On("setElem", 0, "f1", intPtr(123), false, false).Return(errors.New("cannot set elem")) 869 return inj 870 }(), 871 []string{"f1"}, 872 func() []Codec { 873 codec := &mockCodec{} 874 codec.On("DataType").Return(datatype.Int) 875 codec.On("Decode", []byte{123}, new(int), primitive.ProtocolVersion5).Run(func(args mock.Arguments) { 876 decodedElement := args.Get(1).(*int) 877 *decodedElement = 123 878 }).Return(false, nil) 879 return []Codec{codec} 880 }(), 881 primitive.ProtocolVersion5, 882 }, 883 "cannot inject field 0 (f1): cannot set elem", 884 }, 885 { 886 "bytes remaining", 887 args{ 888 []byte{ 889 0, 0, 0, 1, 123, // [bytes] 890 1, // trailing bytes 891 }, 892 func() injector { 893 inj := &mockInjector{} 894 inj.On("zeroElem", 0, "f1").Return(new(int), nil) 895 inj.On("setElem", 0, "f1", intPtr(123), false, false).Return(nil) 896 return inj 897 }(), 898 []string{"f1"}, 899 func() []Codec { 900 codec := &mockCodec{} 901 codec.On("DataType").Return(datatype.Int) 902 codec.On("Decode", []byte{123}, new(int), primitive.ProtocolVersion5).Run(func(args mock.Arguments) { 903 decodedElement := args.Get(1).(*int) 904 *decodedElement = 123 905 }).Return(false, nil) 906 return []Codec{codec} 907 }(), 908 primitive.ProtocolVersion5, 909 }, 910 "source was not fully read: bytes total: 6, read: 5, remaining: 1", 911 }, 912 { 913 "success", 914 args{ 915 []byte{ 916 0, 0, 0, 1, 123, // 1st elem 917 0, 0, 0, 3, a, b, c, // 2nd elem 918 255, 255, 255, 255, // 3rd elem (nil) 919 }, 920 func() injector { 921 inj := &mockInjector{} 922 inj.On("zeroElem", 0, "f1").Return(new(int), nil) 923 inj.On("zeroElem", 1, "f2").Return(new(string), nil) 924 inj.On("zeroElem", 2, "f3").Return(new(bool), nil) 925 inj.On("setElem", 0, "f1", intPtr(123), false, false).Return(nil) 926 inj.On("setElem", 1, "f2", stringPtr("abc"), false, false).Return(nil) 927 inj.On("setElem", 2, "f3", new(bool), false, true).Return(nil) 928 return inj 929 }(), 930 []string{"f1", "f2", "f3"}, 931 func() []Codec { 932 codec1 := &mockCodec{} 933 codec1.On("DataType").Return(datatype.Int) 934 codec1.On("Decode", []byte{123}, new(int), primitive.ProtocolVersion5).Run(func(args mock.Arguments) { 935 decodedElement := args.Get(1).(*int) 936 *decodedElement = 123 937 }).Return(false, nil) 938 codec2 := &mockCodec{} 939 codec2.On("DataType").Return(datatype.Varchar) 940 codec2.On("Decode", []byte{a, b, c}, new(string), primitive.ProtocolVersion5).Run(func(args mock.Arguments) { 941 decodedElement := args.Get(1).(*string) 942 *decodedElement = "abc" 943 }).Return(false, nil) 944 codec3 := &mockCodec{} 945 codec3.On("DataType").Return(datatype.Boolean) 946 codec3.On("Decode", []byte(nil), new(bool), primitive.ProtocolVersion5).Return(true, nil) 947 return []Codec{codec1, codec2, codec3} 948 }(), 949 primitive.ProtocolVersion5, 950 }, 951 "", 952 }, 953 } 954 for _, tt := range tests { 955 t.Run(tt.name, func(t *testing.T) { 956 gotErr := readUdt(tt.args.source, tt.args.inj, tt.args.fieldNames, tt.args.fieldCodecs, tt.args.version) 957 assertErrorMessage(t, tt.wantErr, gotErr) 958 }) 959 } 960 }