github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/udt_test.go (about)

     1  // Copyright 2021 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/mock"
    24  
    25  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    26  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    27  )
    28  
    29  var (
    30  	udtTypeSimple, _ = datatype.NewUserDefined(
    31  		"ks1",
    32  		"type1",
    33  		[]string{"f1", "f2", "f3"},
    34  		[]datatype.DataType{datatype.Int, datatype.Boolean, datatype.Varchar},
    35  	)
    36  	udtTypeComplex, _ = datatype.NewUserDefined(
    37  		"ks1",
    38  		"type2",
    39  		[]string{"f1", "f2"},
    40  		[]datatype.DataType{udtTypeSimple, udtTypeSimple},
    41  	)
    42  	udtTypeEmpty, _ = datatype.NewUserDefined("ks1", "type3", []string{}, []datatype.DataType{})
    43  	udtTypeWrong, _ = datatype.NewUserDefined("ks1", "type4", []string{"f1"}, []datatype.DataType{wrongDataType{}})
    44  )
    45  
    46  var (
    47  	udtCodecSimple, _  = NewUserDefined(udtTypeSimple)
    48  	udtCodecComplex, _ = NewUserDefined(udtTypeComplex)
    49  	udtCodecEmpty, _   = NewUserDefined(udtTypeEmpty)
    50  )
    51  
    52  type (
    53  	SimpleUdt struct {
    54  		F1 int
    55  		F2 bool
    56  		F3 *string
    57  	}
    58  	partialUdt struct {
    59  		F1 int
    60  		F2 bool
    61  	}
    62  	excessUdt struct {
    63  		F1 int
    64  		F2 bool
    65  		F3 *string
    66  		F4 float64
    67  	}
    68  	complexUdt struct {
    69  		F1 SimpleUdt
    70  		F2 *excessUdt
    71  	}
    72  )
    73  
    74  var (
    75  	nullElementsUdtBytes = []byte{
    76  		255, 255, 255, 255, // nil int
    77  		255, 255, 255, 255, // nil boolean
    78  		255, 255, 255, 255, // nil string
    79  	}
    80  	oneTwoThreeAbcUdtBytes = []byte{
    81  		0, 0, 0, 4, // length of int
    82  		0, 0, 0, 123, // int
    83  		0, 0, 0, 1, // length of boolean
    84  		1,          // boolean
    85  		0, 0, 0, 3, // length of string
    86  		a, b, c, // string
    87  	}
    88  	udtWithNullFieldsBytes = []byte{
    89  		0, 0, 0, 4, // length of int
    90  		0, 0, 0, 123, // int
    91  		0, 0, 0, 1, // length of boolean
    92  		0,                  // boolean
    93  		255, 255, 255, 255, // nil string
    94  	}
    95  	udtWithNullFieldsBytes2 = []byte{
    96  		0, 0, 0, 4, // length of int
    97  		0, 0, 0, 123, // int
    98  		255, 255, 255, 255, // nil boolean
    99  		255, 255, 255, 255, // nil string
   100  	}
   101  	udtOneTwoThreeFalseAbcBytes = []byte{
   102  		0, 0, 0, 4, // length of int
   103  		0, 0, 0, 123, // int
   104  		0, 0, 0, 1, // length of boolean
   105  		0,          // boolean
   106  		0, 0, 0, 3, // length of string
   107  		a, b, c, // string
   108  	}
   109  	udtComplexBytes = []byte{
   110  		0, 0, 0, 20, // length of element 1
   111  		// element 1
   112  		0, 0, 0, 4, // length of int
   113  		0, 0, 0, 12, // int
   114  		0, 0, 0, 1, // length of boolean
   115  		0,          // boolean
   116  		0, 0, 0, 3, // length of string
   117  		a, b, c, // string
   118  		0, 0, 0, 20, // length of element 2
   119  		// element 2
   120  		0, 0, 0, 4, // length of int
   121  		0, 0, 0, 34, // int
   122  		0, 0, 0, 1, // length of boolean
   123  		1,          // boolean
   124  		0, 0, 0, 3, // length of string
   125  		d, e, f, // string
   126  	}
   127  	udtZeroBytes = []byte{
   128  		0, 0, 0, 4, // length of int
   129  		0, 0, 0, 0, // int
   130  		0, 0, 0, 1, // length of boolean
   131  		0,                  // boolean
   132  		255, 255, 255, 255, // nil string
   133  	}
   134  	udtComplexWithNullsBytes = []byte{
   135  		0, 0, 0, 20, // length of element 1
   136  		// element 1
   137  		0, 0, 0, 4, // length of int
   138  		0, 0, 0, 12, // int
   139  		0, 0, 0, 1, // length of boolean
   140  		0,          // boolean
   141  		0, 0, 0, 3, // length of string
   142  		a, b, c, // string
   143  		0, 0, 0, 17, // length of element 2
   144  		// element 2
   145  		0, 0, 0, 4, // length of int
   146  		0, 0, 0, 34, // int
   147  		0, 0, 0, 1, // length of boolean
   148  		1,                  // boolean
   149  		255, 255, 255, 255, // nil string
   150  	}
   151  	udtComplexWithNulls2Bytes = []byte{
   152  		0, 0, 0, 20, // length of element 1
   153  		// element 1
   154  		0, 0, 0, 4, // length of int
   155  		0, 0, 0, 12, // int
   156  		0, 0, 0, 1, // length of boolean
   157  		0,          // boolean
   158  		0, 0, 0, 3, // length of string
   159  		a, b, c, // string
   160  		255, 255, 255, 255, // nil element 2
   161  	}
   162  	udtMissingBytes = []byte{
   163  		0, 0, 0, 4, // length of int
   164  		0, 0, 0, 123, // int
   165  		0, 0, 0, 1, // length of boolean
   166  		1,          // boolean
   167  		0, 0, 0, 3, // length of string
   168  		// missing string
   169  	}
   170  )
   171  
   172  func TestNewUserDefinedCodec(t *testing.T) {
   173  	tests := []struct {
   174  		name     string
   175  		dataType *datatype.UserDefined
   176  		expected Codec
   177  		err      string
   178  	}{
   179  		{
   180  			"simple",
   181  			udtTypeSimple,
   182  			&udtCodec{dataType: udtTypeSimple, fieldCodecs: []Codec{Int, Boolean, Varchar}},
   183  			"",
   184  		},
   185  		{
   186  			"complex",
   187  			udtTypeComplex,
   188  			&udtCodec{
   189  				dataType: udtTypeComplex,
   190  				fieldCodecs: []Codec{
   191  					&udtCodec{dataType: udtTypeSimple, fieldCodecs: []Codec{Int, Boolean, Varchar}},
   192  					&udtCodec{dataType: udtTypeSimple, fieldCodecs: []Codec{Int, Boolean, Varchar}},
   193  				},
   194  			},
   195  			"",
   196  		},
   197  		{
   198  			"empty",
   199  			udtTypeEmpty,
   200  			&udtCodec{dataType: udtTypeEmpty, fieldCodecs: []Codec{}},
   201  			"",
   202  		},
   203  		{
   204  			"wrong child",
   205  			udtTypeWrong,
   206  			nil,
   207  			"cannot create codec for user-defined type field 0 (f1): cannot create data codec for CQL type 666",
   208  		},
   209  		{
   210  			"nil",
   211  			nil,
   212  			nil,
   213  			"data type is nil",
   214  		},
   215  	}
   216  	for _, tt := range tests {
   217  		t.Run(tt.name, func(t *testing.T) {
   218  			actual, err := NewUserDefined(tt.dataType)
   219  			assert.Equal(t, tt.expected, actual)
   220  			assertErrorMessage(t, tt.err, err)
   221  		})
   222  	}
   223  }
   224  
   225  func Test_udtCodec_Encode(t *testing.T) {
   226  	for _, version := range primitive.SupportedProtocolVersionsGreaterThanOrEqualTo(primitive.ProtocolVersion3) {
   227  		t.Run(version.String(), func(t *testing.T) {
   228  			t.Run("[]interface{}", func(t *testing.T) {
   229  				tests := []struct {
   230  					name     string
   231  					codec    Codec
   232  					input    *[]interface{}
   233  					expected []byte
   234  					err      string
   235  				}{
   236  					{"nil", udtCodecEmpty, nil, nil, ""},
   237  					{"empty", udtCodecSimple, &[]interface{}{nil, nil, nil}, nullElementsUdtBytes, ""},
   238  					{"simple", udtCodecSimple, &[]interface{}{123, true, "abc"}, oneTwoThreeAbcUdtBytes, ""},
   239  					{"simple with pointers", udtCodecSimple, &[]interface{}{intPtr(123), boolPtr(true), stringPtr("abc")}, oneTwoThreeAbcUdtBytes, ""},
   240  					{"nil element", udtCodecSimple, &[]interface{}{123, false, nil}, udtWithNullFieldsBytes, ""},
   241  					{"not enough elements", udtCodecSimple, &[]interface{}{123}, nil, "slice index out of range: 1"},
   242  					{"too many elements", udtCodecSimple, &[]interface{}{123, false, "abc", "extra"}, udtOneTwoThreeFalseAbcBytes, ""},
   243  					{"complex", udtCodecComplex, &[]interface{}{[]interface{}{12, false, "abc"}, []interface{}{34, true, "def"}}, udtComplexBytes, ""},
   244  				}
   245  				for _, tt := range tests {
   246  					t.Run(tt.name, func(t *testing.T) {
   247  						if tt.input != nil {
   248  							t.Run("value", func(t *testing.T) {
   249  								dest, err := tt.codec.Encode(*tt.input, version)
   250  								assert.Equal(t, tt.expected, dest)
   251  								assertErrorMessage(t, tt.err, err)
   252  							})
   253  						}
   254  						t.Run("pointer", func(t *testing.T) {
   255  							dest, err := tt.codec.Encode(tt.input, version)
   256  							assert.Equal(t, tt.expected, dest)
   257  							assertErrorMessage(t, tt.err, err)
   258  						})
   259  					})
   260  				}
   261  			})
   262  			t.Run("map[string]interface{}", func(t *testing.T) {
   263  				tests := []struct {
   264  					name     string
   265  					codec    Codec
   266  					input    *map[string]interface{}
   267  					expected []byte
   268  					err      string
   269  				}{
   270  					{"nil", udtCodecEmpty, nil, nil, ""},
   271  					{"empty", udtCodecSimple, &map[string]interface{}{"f1": nil, "f2": nil, "f3": nil}, nullElementsUdtBytes, ""},
   272  					{"simple", udtCodecSimple, &map[string]interface{}{"f1": 123, "f2": true, "f3": "abc"}, oneTwoThreeAbcUdtBytes, ""},
   273  					{"simple with pointers", udtCodecSimple, &map[string]interface{}{"f1": intPtr(123), "f2": boolPtr(true), "f3": stringPtr("abc")}, oneTwoThreeAbcUdtBytes, ""},
   274  					{"nil element", udtCodecSimple, &map[string]interface{}{"f1": 123, "f2": false, "f3": nil}, udtWithNullFieldsBytes, ""},
   275  					{"not enough elements", udtCodecSimple, &map[string]interface{}{"f1": 123}, udtWithNullFieldsBytes2, ""},
   276  					{"too many elements", udtCodecSimple, &map[string]interface{}{"f1": 123, "f2": false, "f3": "abc", "f4": "extra"}, udtOneTwoThreeFalseAbcBytes, ""},
   277  					{"complex", udtCodecComplex, &map[string]interface{}{"f1": map[string]interface{}{"f1": 12, "f2": false, "f3": "abc"}, "f2": map[string]interface{}{"f1": 34, "f2": true, "f3": "def"}}, udtComplexBytes, ""},
   278  				}
   279  				for _, tt := range tests {
   280  					t.Run(tt.name, func(t *testing.T) {
   281  						if tt.input != nil {
   282  							t.Run("value", func(t *testing.T) {
   283  								dest, err := tt.codec.Encode(*tt.input, version)
   284  								assert.Equal(t, tt.expected, dest)
   285  								assertErrorMessage(t, tt.err, err)
   286  							})
   287  						}
   288  						t.Run("pointer", func(t *testing.T) {
   289  							dest, err := tt.codec.Encode(tt.input, version)
   290  							assert.Equal(t, tt.expected, dest)
   291  							assertErrorMessage(t, tt.err, err)
   292  						})
   293  					})
   294  				}
   295  			})
   296  			t.Run("struct simple", func(t *testing.T) {
   297  				tests := []struct {
   298  					name     string
   299  					codec    Codec
   300  					input    *SimpleUdt
   301  					expected []byte
   302  				}{
   303  					{"nil", udtCodecEmpty, nil, nil},
   304  					{"empty", udtCodecSimple, &SimpleUdt{}, udtZeroBytes},
   305  					{"simple", udtCodecSimple, &SimpleUdt{123, false, stringPtr("abc")}, udtOneTwoThreeFalseAbcBytes},
   306  					{"nil element", udtCodecSimple, &SimpleUdt{123, false, nil}, udtWithNullFieldsBytes},
   307  				}
   308  				for _, tt := range tests {
   309  					t.Run(tt.name, func(t *testing.T) {
   310  						if tt.input != nil {
   311  							t.Run("value", func(t *testing.T) {
   312  								dest, err := tt.codec.Encode(*tt.input, version)
   313  								assert.Equal(t, tt.expected, dest)
   314  								assert.NoError(t, err)
   315  
   316  							})
   317  						}
   318  						t.Run("pointer", func(t *testing.T) {
   319  							dest, err := tt.codec.Encode(tt.input, version)
   320  							assert.Equal(t, tt.expected, dest)
   321  							assert.NoError(t, err)
   322  						})
   323  					})
   324  				}
   325  			})
   326  			t.Run("struct partial", func(t *testing.T) {
   327  				tests := []struct {
   328  					name     string
   329  					codec    Codec
   330  					input    *partialUdt
   331  					expected []byte
   332  					err      string
   333  				}{
   334  					{"simple", udtCodecSimple, &partialUdt{123, false}, nil, "no accessible field with name 'f3' found"},
   335  				}
   336  				for _, tt := range tests {
   337  					t.Run(tt.name, func(t *testing.T) {
   338  						if tt.input != nil {
   339  							t.Run("value", func(t *testing.T) {
   340  								dest, err := tt.codec.Encode(*tt.input, version)
   341  								assert.Equal(t, tt.expected, dest)
   342  								assertErrorMessage(t, tt.err, err)
   343  							})
   344  						}
   345  						t.Run("pointer", func(t *testing.T) {
   346  							dest, err := tt.codec.Encode(tt.input, version)
   347  							assert.Equal(t, tt.expected, dest)
   348  							assertErrorMessage(t, tt.err, err)
   349  						})
   350  					})
   351  				}
   352  			})
   353  			t.Run("struct excess", func(t *testing.T) {
   354  				tests := []struct {
   355  					name     string
   356  					codec    Codec
   357  					input    *excessUdt
   358  					expected []byte
   359  				}{
   360  					{"nil", udtCodecEmpty, nil, nil},
   361  					{"empty", udtCodecSimple, &excessUdt{}, udtZeroBytes},
   362  					{"simple", udtCodecSimple, &excessUdt{123, false, stringPtr("abc"), 42.0}, udtOneTwoThreeFalseAbcBytes},
   363  				}
   364  				for _, tt := range tests {
   365  					t.Run(tt.name, func(t *testing.T) {
   366  						if tt.input != nil {
   367  							t.Run("value", func(t *testing.T) {
   368  								dest, err := tt.codec.Encode(*tt.input, version)
   369  								assert.Equal(t, tt.expected, dest)
   370  								assert.NoError(t, err)
   371  							})
   372  						}
   373  						t.Run("pointer", func(t *testing.T) {
   374  							dest, err := tt.codec.Encode(tt.input, version)
   375  							assert.Equal(t, tt.expected, dest)
   376  							assert.NoError(t, err)
   377  						})
   378  					})
   379  				}
   380  			})
   381  			t.Run("struct complex", func(t *testing.T) {
   382  				tests := []struct {
   383  					name     string
   384  					codec    Codec
   385  					input    *complexUdt
   386  					expected []byte
   387  				}{
   388  					{"nil", udtCodecEmpty, nil, nil},
   389  					{"empty", udtCodecEmpty, &complexUdt{}, nil},
   390  					{"complex", udtCodecComplex, &complexUdt{
   391  						SimpleUdt{12, false, stringPtr("abc")},
   392  						&excessUdt{34, true, nil, 0.0},
   393  					}, udtComplexWithNullsBytes},
   394  					{"nil element", udtCodecComplex, &complexUdt{
   395  						SimpleUdt{12, false, stringPtr("abc")},
   396  						nil,
   397  					}, udtComplexWithNulls2Bytes},
   398  				}
   399  				for _, tt := range tests {
   400  					t.Run(tt.name, func(t *testing.T) {
   401  						if tt.input != nil {
   402  							t.Run("value", func(t *testing.T) {
   403  								dest, err := tt.codec.Encode(*tt.input, version)
   404  								assert.Equal(t, tt.expected, dest)
   405  								assert.NoError(t, err)
   406  							})
   407  						}
   408  						t.Run("pointer", func(t *testing.T) {
   409  							dest, err := tt.codec.Encode(tt.input, version)
   410  							assert.Equal(t, tt.expected, dest)
   411  							assert.NoError(t, err)
   412  						})
   413  					})
   414  				}
   415  			})
   416  		})
   417  	}
   418  	for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) {
   419  		t.Run(version.String(), func(t *testing.T) {
   420  			dest, err := udtCodecSimple.Encode(nil, version)
   421  			assert.Nil(t, dest)
   422  			expectedMessage := fmt.Sprintf("data type %s not supported in %v", udtTypeSimple, version)
   423  			assertErrorMessage(t, expectedMessage, err)
   424  		})
   425  	}
   426  	t.Run("invalid types", func(t *testing.T) {
   427  		dest, err := udtCodecSimple.Encode(123, primitive.ProtocolVersion5)
   428  		assert.Nil(t, dest)
   429  		assert.EqualError(t, err, "cannot encode int as CQL ks1.type1<f1:int,f2:boolean,f3:varchar> with ProtocolVersion OSS 5: source type not supported")
   430  		dest, err = udtCodecSimple.Encode(map[int]string{123: "abc"}, primitive.ProtocolVersion5)
   431  		assert.Nil(t, dest)
   432  		assert.EqualError(t, err, "cannot encode map[int]string as CQL ks1.type1<f1:int,f2:boolean,f3:varchar> with ProtocolVersion OSS 5: wrong map key, expected string, got: int")
   433  		// this can only be detected once the decoding started
   434  		dest, err = udtCodecSimple.Encode(map[string]int{"f3": 123}, primitive.ProtocolVersion5)
   435  		assert.Nil(t, dest)
   436  		assert.EqualError(t, err, "cannot encode map[string]int as CQL ks1.type1<f1:int,f2:boolean,f3:varchar> with ProtocolVersion OSS 5: cannot encode field 2 (f3): cannot encode int as CQL varchar with ProtocolVersion OSS 5: cannot convert from int to []uint8: conversion not supported")
   437  	})
   438  }
   439  
   440  func Test_udtCodec_Decode(t *testing.T) {
   441  	for _, version := range primitive.SupportedProtocolVersionsGreaterThanOrEqualTo(primitive.ProtocolVersion3) {
   442  		t.Run(version.String(), func(t *testing.T) {
   443  			t.Run("interface{}", func(t *testing.T) {
   444  				tests := []struct {
   445  					name     string
   446  					codec    Codec
   447  					input    []byte
   448  					dest     *interface{}
   449  					expected *interface{}
   450  					err      string
   451  					wasNull  bool
   452  				}{
   453  					{"nil input", udtCodecSimple, nil, new(interface{}), new(interface{}), "", true},
   454  					{"nil elements map to zero values", udtCodecSimple, nullElementsUdtBytes, new(interface{}), interfacePtr(map[string]interface{}{"f1": nil, "f2": nil, "f3": nil}), "", false},
   455  					{"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, new(interface{}), interfacePtr(map[string]interface{}{"f1": int32(123), "f2": true, "f3": "abc"}), "", false},
   456  					{"complex", udtCodecComplex, udtComplexBytes, new(interface{}), interfacePtr(map[string]interface{}{
   457  						"f1": map[string]interface{}{"f1": int32(12), "f2": false, "f3": "abc"},
   458  						"f2": map[string]interface{}{"f1": int32(34), "f2": true, "f3": "def"},
   459  					}), "", false},
   460  					{"nil dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, nil, nil, "destination is nil", false},
   461  					{"not enough bytes", udtCodecSimple, udtMissingBytes, new(interface{}), interfacePtr(map[string]interface{}{"f1": int32(123), "f2": true}), "cannot read field 2 (f3)", false},
   462  					{"slice dest -> map dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, interfacePtr([]interface{}{}), interfacePtr(map[string]interface{}{"f1": int32(123), "f2": true, "f3": "abc"}), "", false},
   463  				}
   464  				for _, tt := range tests {
   465  					t.Run(tt.name, func(t *testing.T) {
   466  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   467  						assert.Equal(t, tt.expected, tt.dest)
   468  						assert.Equal(t, tt.wasNull, wasNull)
   469  						assertErrorMessage(t, tt.err, err)
   470  					})
   471  				}
   472  			})
   473  			t.Run("*[]interface{}", func(t *testing.T) {
   474  				tests := []struct {
   475  					name     string
   476  					codec    Codec
   477  					input    []byte
   478  					dest     *[]interface{}
   479  					expected *[]interface{}
   480  					err      string
   481  					wasNull  bool
   482  				}{
   483  					{"nil input", udtCodecSimple, nil, new([]interface{}), new([]interface{}), "", true},
   484  					{"nil elements map to zero values", udtCodecSimple, nullElementsUdtBytes, new([]interface{}), &[]interface{}{nil, nil, nil}, "", false},
   485  					{"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, new([]interface{}), &[]interface{}{int32(123), true, "abc"}, "", false},
   486  					{"complex", udtCodecComplex, udtComplexBytes, new([]interface{}), &[]interface{}{
   487  						map[string]interface{}{"f1": int32(12), "f2": false, "f3": "abc"},
   488  						map[string]interface{}{"f1": int32(34), "f2": true, "f3": "def"},
   489  					}, "", false},
   490  					{"nil dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, nil, nil, "destination is nil", false},
   491  					{"not enough bytes", udtCodecSimple, udtMissingBytes, new([]interface{}), &[]interface{}{int32(123), true, nil}, "cannot read field 2 (f3)", false},
   492  					{"slice length too large", udtCodecSimple, oneTwoThreeAbcUdtBytes, &[]interface{}{nil, nil, nil, 42.0}, &[]interface{}{int32(123), true, "abc"}, "", false},
   493  				}
   494  				for _, tt := range tests {
   495  					t.Run(tt.name, func(t *testing.T) {
   496  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   497  						assert.Equal(t, tt.expected, tt.dest)
   498  						assert.Equal(t, tt.wasNull, wasNull)
   499  						assertErrorMessage(t, tt.err, err)
   500  					})
   501  				}
   502  			})
   503  			t.Run("*[3]interface{}", func(t *testing.T) {
   504  				tests := []struct {
   505  					name     string
   506  					codec    Codec
   507  					input    []byte
   508  					dest     *[3]interface{}
   509  					expected *[3]interface{}
   510  					err      string
   511  					wasNull  bool
   512  				}{
   513  					{"nil input", udtCodecSimple, nil, new([3]interface{}), new([3]interface{}), "", true},
   514  					{"nil elements map to zero values", udtCodecSimple, nullElementsUdtBytes, new([3]interface{}), &[3]interface{}{nil, nil, nil}, "", false},
   515  					{"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, new([3]interface{}), &[3]interface{}{int32(123), true, "abc"}, "", false},
   516  					{"nil dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, nil, nil, "destination is nil", false},
   517  					{"not enough bytes", udtCodecSimple, udtMissingBytes, new([3]interface{}), &[3]interface{}{int32(123), true, nil}, "cannot read field 2 (f3)", false},
   518  				}
   519  				for _, tt := range tests {
   520  					t.Run(tt.name, func(t *testing.T) {
   521  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   522  						assert.Equal(t, tt.expected, tt.dest)
   523  						assert.Equal(t, tt.wasNull, wasNull)
   524  						assertErrorMessage(t, tt.err, err)
   525  					})
   526  				}
   527  			})
   528  			t.Run("*[][]interface{}", func(t *testing.T) {
   529  				tests := []struct {
   530  					name     string
   531  					codec    Codec
   532  					input    []byte
   533  					dest     *[][]interface{}
   534  					expected *[][]interface{}
   535  					err      string
   536  					wasNull  bool
   537  				}{
   538  					{"complex", udtCodecComplex, udtComplexBytes, new([][]interface{}), &[][]interface{}{
   539  						{int32(12), false, "abc"},
   540  						{int32(34), true, "def"},
   541  					}, "", false},
   542  				}
   543  				for _, tt := range tests {
   544  					t.Run(tt.name, func(t *testing.T) {
   545  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   546  						assert.Equal(t, tt.expected, tt.dest)
   547  						assert.Equal(t, tt.wasNull, wasNull)
   548  						assertErrorMessage(t, tt.err, err)
   549  					})
   550  				}
   551  			})
   552  			t.Run("*map[string]interface{}", func(t *testing.T) {
   553  				tests := []struct {
   554  					name     string
   555  					codec    Codec
   556  					input    []byte
   557  					dest     *map[string]interface{}
   558  					expected *map[string]interface{}
   559  					err      string
   560  					wasNull  bool
   561  				}{
   562  					{"nil input", udtCodecSimple, nil, new(map[string]interface{}), new(map[string]interface{}), "", true},
   563  					{"nil elements map to zero values", udtCodecSimple, nullElementsUdtBytes, new(map[string]interface{}), &map[string]interface{}{"f1": nil, "f2": nil, "f3": nil}, "", false},
   564  					{"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, new(map[string]interface{}), &map[string]interface{}{"f1": int32(123), "f2": true, "f3": "abc"}, "", false},
   565  					{"complex", udtCodecComplex, udtComplexBytes, new(map[string]interface{}), &map[string]interface{}{
   566  						"f1": map[string]interface{}{"f1": int32(12), "f2": false, "f3": "abc"},
   567  						"f2": map[string]interface{}{"f1": int32(34), "f2": true, "f3": "def"},
   568  					}, "", false},
   569  					{"nil dest", udtCodecSimple, oneTwoThreeAbcUdtBytes, nil, nil, "destination is nil", false},
   570  					{"not enough bytes", udtCodecSimple, udtMissingBytes, new(map[string]interface{}), &map[string]interface{}{"f1": int32(123), "f2": true}, "cannot read field 2 (f3)", false},
   571  				}
   572  				for _, tt := range tests {
   573  					t.Run(tt.name, func(t *testing.T) {
   574  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   575  						assert.Equal(t, tt.expected, tt.dest)
   576  						assert.Equal(t, tt.wasNull, wasNull)
   577  						assertErrorMessage(t, tt.err, err)
   578  					})
   579  				}
   580  			})
   581  			t.Run("struct simple", func(t *testing.T) {
   582  				tests := []struct {
   583  					name     string
   584  					codec    Codec
   585  					input    []byte
   586  					dest     *SimpleUdt
   587  					expected *SimpleUdt
   588  					err      string
   589  					wasNull  bool
   590  				}{
   591  					{"nil input", udtCodecSimple, nil, &SimpleUdt{}, &SimpleUdt{}, "", true},
   592  					{"empty input", udtCodecSimple, []byte{}, &SimpleUdt{}, &SimpleUdt{}, "", true},
   593  					{"nil elements", udtCodecSimple, nullElementsUdtBytes, &SimpleUdt{}, &SimpleUdt{F1: 0, F2: false, F3: nil}, "", false},
   594  					{"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, &SimpleUdt{}, &SimpleUdt{F1: 123, F2: true, F3: stringPtr("abc")}, "", false},
   595  					{"nil dest", udtCodecSimple, udtMissingBytes, nil, nil, "destination is nil", false},
   596  				}
   597  				for _, tt := range tests {
   598  					t.Run(tt.name, func(t *testing.T) {
   599  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   600  						if tt.expected != nil && tt.dest != nil {
   601  							assert.Equal(t, *tt.expected, *tt.dest)
   602  						}
   603  						assert.Equal(t, tt.wasNull, wasNull)
   604  						assertErrorMessage(t, tt.err, err)
   605  					})
   606  				}
   607  			})
   608  			t.Run("struct partial", func(t *testing.T) {
   609  				tests := []struct {
   610  					name     string
   611  					codec    Codec
   612  					input    []byte
   613  					dest     *partialUdt
   614  					expected *partialUdt
   615  					err      string
   616  					wasNull  bool
   617  				}{
   618  					{"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, &partialUdt{}, &partialUdt{F1: 123, F2: true}, "no accessible field with name 'f3' found", false},
   619  				}
   620  				for _, tt := range tests {
   621  					t.Run(tt.name, func(t *testing.T) {
   622  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   623  						if tt.expected != nil && tt.dest != nil {
   624  							assert.Equal(t, *tt.expected, *tt.dest)
   625  						}
   626  						assert.Equal(t, tt.wasNull, wasNull)
   627  						assertErrorMessage(t, tt.err, err)
   628  					})
   629  				}
   630  			})
   631  			t.Run("struct excess", func(t *testing.T) {
   632  				tests := []struct {
   633  					name     string
   634  					codec    Codec
   635  					input    []byte
   636  					dest     *excessUdt
   637  					expected *excessUdt
   638  					err      string
   639  					wasNull  bool
   640  				}{
   641  					{"nil input", udtCodecSimple, nil, &excessUdt{}, &excessUdt{}, "", true},
   642  					{"empty input", udtCodecSimple, []byte{}, &excessUdt{}, &excessUdt{}, "", true},
   643  					{"nil elements", udtCodecSimple, nullElementsUdtBytes, &excessUdt{}, &excessUdt{}, "", false},
   644  					{"simple", udtCodecSimple, oneTwoThreeAbcUdtBytes, &excessUdt{}, &excessUdt{F1: 123, F2: true, F3: stringPtr("abc")}, "", false},
   645  					{"nil dest", udtCodecSimple, udtMissingBytes, nil, nil, "destination is nil", false},
   646  				}
   647  				for _, tt := range tests {
   648  					t.Run(tt.name, func(t *testing.T) {
   649  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   650  						if tt.expected != nil && tt.dest != nil {
   651  							assert.Equal(t, *tt.expected, *tt.dest)
   652  						}
   653  						assert.Equal(t, tt.wasNull, wasNull)
   654  						assertErrorMessage(t, tt.err, err)
   655  					})
   656  				}
   657  			})
   658  			t.Run("struct complex", func(t *testing.T) {
   659  				tests := []struct {
   660  					name     string
   661  					codec    Codec
   662  					input    []byte
   663  					dest     *complexUdt
   664  					expected *complexUdt
   665  					err      string
   666  					wasNull  bool
   667  				}{
   668  					{"nil", udtCodecComplex, nil, &complexUdt{}, &complexUdt{}, "", true},
   669  					{"empty", udtCodecComplex, []byte{}, &complexUdt{}, &complexUdt{}, "", true},
   670  					{"complex", udtCodecComplex, udtComplexWithNullsBytes, &complexUdt{}, &complexUdt{
   671  						SimpleUdt{12, false, stringPtr("abc")},
   672  						&excessUdt{34, true, nil, 0.0},
   673  					}, "", false},
   674  					{"nil element", udtCodecComplex, udtComplexWithNulls2Bytes, &complexUdt{}, &complexUdt{
   675  						SimpleUdt{12, false, stringPtr("abc")},
   676  						nil,
   677  					}, "", false},
   678  				}
   679  				for _, tt := range tests {
   680  					t.Run(tt.name, func(t *testing.T) {
   681  						wasNull, err := tt.codec.Decode(tt.input, tt.dest, version)
   682  						assert.Equal(t, *tt.expected, *tt.dest)
   683  						assert.Equal(t, tt.wasNull, wasNull)
   684  						assertErrorMessage(t, tt.err, err)
   685  					})
   686  				}
   687  			})
   688  		})
   689  	}
   690  	for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) {
   691  		t.Run(version.String(), func(t *testing.T) {
   692  			_, err := udtCodecSimple.Decode(nil, nil, version)
   693  			expectedMessage := fmt.Sprintf("data type %s not supported in %v", udtTypeSimple, version)
   694  			assertErrorMessage(t, expectedMessage, err)
   695  		})
   696  	}
   697  	t.Run("invalid types", func(t *testing.T) {
   698  		wasNull, err := udtCodecSimple.Decode([]byte{1, 2, 3}, new(int), primitive.ProtocolVersion5)
   699  		assert.False(t, wasNull)
   700  		assert.EqualError(t, err, "cannot decode CQL ks1.type1<f1:int,f2:boolean,f3:varchar> as *int with ProtocolVersion OSS 5: destination type not supported")
   701  		wasNull, err = udtCodecSimple.Decode([]byte{1, 2, 3}, new(map[int]string), primitive.ProtocolVersion5)
   702  		assert.False(t, wasNull)
   703  		assert.EqualError(t, err, "cannot decode CQL ks1.type1<f1:int,f2:boolean,f3:varchar> as *map[int]string with ProtocolVersion OSS 5: wrong map key, expected string, got: int")
   704  	})
   705  }
   706  
   707  func Test_writeUdt(t *testing.T) {
   708  	type args struct {
   709  		ext         extractor
   710  		fieldNames  []string
   711  		fieldCodecs []Codec
   712  		version     primitive.ProtocolVersion
   713  	}
   714  	tests := []struct {
   715  		name    string
   716  		args    args
   717  		want    []byte
   718  		wantErr string
   719  	}{
   720  		{
   721  			"cannot extract elem",
   722  			args{
   723  				func() extractor {
   724  					ext := &mockExtractor{}
   725  					ext.On("getElem", 0, "f1").Return(nil, errors.New("wrong type"))
   726  					return ext
   727  				}(),
   728  				[]string{"f1"},
   729  				[]Codec{nil},
   730  				primitive.ProtocolVersion5,
   731  			},
   732  			nil,
   733  			"cannot extract field 0 (f1): wrong type",
   734  		},
   735  		{
   736  			"cannot encode",
   737  			args{
   738  				func() extractor {
   739  					ext := &mockExtractor{}
   740  					ext.On("getElem", 0, "f1").Return(123, nil)
   741  					return ext
   742  				}(),
   743  				[]string{"f1"},
   744  				func() []Codec {
   745  					codec := &mockCodec{}
   746  					codec.On("Encode", 123, primitive.ProtocolVersion5).Return(nil, errors.New("write failed"))
   747  					return []Codec{codec}
   748  				}(),
   749  				primitive.ProtocolVersion5,
   750  			},
   751  			nil,
   752  			"cannot encode field 0 (f1): write failed",
   753  		},
   754  		{"success", args{
   755  			func() extractor {
   756  				ext := &mockExtractor{}
   757  				ext.On("getElem", 0, "f1").Return(123, nil)
   758  				ext.On("getElem", 1, "f2").Return("abc", nil)
   759  				ext.On("getElem", 2, "f3").Return(true, nil)
   760  				return ext
   761  			}(),
   762  			[]string{"f1", "f2", "f3"},
   763  			func() []Codec {
   764  				codec1 := &mockCodec{}
   765  				codec1.On("Encode", 123, primitive.ProtocolVersion5).Return([]byte{1}, nil)
   766  				codec2 := &mockCodec{}
   767  				codec2.On("Encode", "abc", primitive.ProtocolVersion5).Return([]byte{2}, nil)
   768  				codec3 := &mockCodec{}
   769  				codec3.On("Encode", true, primitive.ProtocolVersion5).Return(nil, nil)
   770  				return []Codec{codec1, codec2, codec3}
   771  			}(),
   772  			primitive.ProtocolVersion5,
   773  		}, []byte{
   774  			0, 0, 0, 1, // field 1
   775  			1,
   776  			0, 0, 0, 1, // field 2
   777  			2,
   778  			255, 255, 255, 255, // field 3 (nil)
   779  		}, ""},
   780  	}
   781  	for _, tt := range tests {
   782  		t.Run(tt.name, func(t *testing.T) {
   783  			got, gotErr := writeUdt(tt.args.ext, tt.args.fieldNames, tt.args.fieldCodecs, tt.args.version)
   784  			assert.Equal(t, tt.want, got)
   785  			assertErrorMessage(t, tt.wantErr, gotErr)
   786  		})
   787  	}
   788  }
   789  
   790  func Test_readUdt(t *testing.T) {
   791  	type args struct {
   792  		source      []byte
   793  		inj         injector
   794  		fieldNames  []string
   795  		fieldCodecs []Codec
   796  		version     primitive.ProtocolVersion
   797  	}
   798  	tests := []struct {
   799  		name    string
   800  		args    args
   801  		wantErr string
   802  	}{
   803  		{
   804  			"cannot read element",
   805  			args{
   806  				[]byte{
   807  					0, // wrong [bytes]
   808  				},
   809  				nil,
   810  				[]string{"f1"},
   811  				[]Codec{nil},
   812  				primitive.ProtocolVersion5,
   813  			},
   814  			"cannot read field 0 (f1): cannot read [bytes] length: cannot read [int]: unexpected EOF",
   815  		},
   816  		{
   817  			"cannot create element",
   818  			args{
   819  				[]byte{
   820  					0, 0, 0, 1, 123, // [bytes]
   821  				},
   822  				func() injector {
   823  					inj := &mockInjector{}
   824  					inj.On("zeroElem", 0, "f1").Return(nil, errors.New("wrong data type"))
   825  					return inj
   826  				}(),
   827  				[]string{"f1"},
   828  				func() []Codec {
   829  					codec := &mockCodec{}
   830  					codec.On("DataType").Return(datatype.Int)
   831  					return []Codec{codec}
   832  				}(),
   833  				primitive.ProtocolVersion5,
   834  			},
   835  			"cannot create zero field 0 (f1): wrong data type",
   836  		},
   837  		{
   838  			"cannot decode element",
   839  			args{
   840  				[]byte{
   841  					0, 0, 0, 1, 123, // [bytes]
   842  				},
   843  				func() injector {
   844  					inj := &mockInjector{}
   845  					inj.On("zeroElem", 0, "f1").Return(new(int), nil)
   846  					return inj
   847  				}(),
   848  				[]string{"f1"},
   849  				func() []Codec {
   850  					codec := &mockCodec{}
   851  					codec.On("DataType").Return(datatype.Int)
   852  					codec.On("Decode", []byte{123}, new(int), primitive.ProtocolVersion5).Return(false, errors.New("decode failed"))
   853  					return []Codec{codec}
   854  				}(),
   855  				primitive.ProtocolVersion5,
   856  			},
   857  			"cannot decode field 0 (f1): decode failed",
   858  		},
   859  		{
   860  			"cannot set element",
   861  			args{
   862  				[]byte{
   863  					0, 0, 0, 1, 123, // [bytes]
   864  				},
   865  				func() injector {
   866  					inj := &mockInjector{}
   867  					inj.On("zeroElem", 0, "f1").Return(new(int), nil)
   868  					inj.On("setElem", 0, "f1", intPtr(123), false, false).Return(errors.New("cannot set elem"))
   869  					return inj
   870  				}(),
   871  				[]string{"f1"},
   872  				func() []Codec {
   873  					codec := &mockCodec{}
   874  					codec.On("DataType").Return(datatype.Int)
   875  					codec.On("Decode", []byte{123}, new(int), primitive.ProtocolVersion5).Run(func(args mock.Arguments) {
   876  						decodedElement := args.Get(1).(*int)
   877  						*decodedElement = 123
   878  					}).Return(false, nil)
   879  					return []Codec{codec}
   880  				}(),
   881  				primitive.ProtocolVersion5,
   882  			},
   883  			"cannot inject field 0 (f1): cannot set elem",
   884  		},
   885  		{
   886  			"bytes remaining",
   887  			args{
   888  				[]byte{
   889  					0, 0, 0, 1, 123, // [bytes]
   890  					1, // trailing bytes
   891  				},
   892  				func() injector {
   893  					inj := &mockInjector{}
   894  					inj.On("zeroElem", 0, "f1").Return(new(int), nil)
   895  					inj.On("setElem", 0, "f1", intPtr(123), false, false).Return(nil)
   896  					return inj
   897  				}(),
   898  				[]string{"f1"},
   899  				func() []Codec {
   900  					codec := &mockCodec{}
   901  					codec.On("DataType").Return(datatype.Int)
   902  					codec.On("Decode", []byte{123}, new(int), primitive.ProtocolVersion5).Run(func(args mock.Arguments) {
   903  						decodedElement := args.Get(1).(*int)
   904  						*decodedElement = 123
   905  					}).Return(false, nil)
   906  					return []Codec{codec}
   907  				}(),
   908  				primitive.ProtocolVersion5,
   909  			},
   910  			"source was not fully read: bytes total: 6, read: 5, remaining: 1",
   911  		},
   912  		{
   913  			"success",
   914  			args{
   915  				[]byte{
   916  					0, 0, 0, 1, 123, // 1st elem
   917  					0, 0, 0, 3, a, b, c, // 2nd elem
   918  					255, 255, 255, 255, // 3rd elem (nil)
   919  				},
   920  				func() injector {
   921  					inj := &mockInjector{}
   922  					inj.On("zeroElem", 0, "f1").Return(new(int), nil)
   923  					inj.On("zeroElem", 1, "f2").Return(new(string), nil)
   924  					inj.On("zeroElem", 2, "f3").Return(new(bool), nil)
   925  					inj.On("setElem", 0, "f1", intPtr(123), false, false).Return(nil)
   926  					inj.On("setElem", 1, "f2", stringPtr("abc"), false, false).Return(nil)
   927  					inj.On("setElem", 2, "f3", new(bool), false, true).Return(nil)
   928  					return inj
   929  				}(),
   930  				[]string{"f1", "f2", "f3"},
   931  				func() []Codec {
   932  					codec1 := &mockCodec{}
   933  					codec1.On("DataType").Return(datatype.Int)
   934  					codec1.On("Decode", []byte{123}, new(int), primitive.ProtocolVersion5).Run(func(args mock.Arguments) {
   935  						decodedElement := args.Get(1).(*int)
   936  						*decodedElement = 123
   937  					}).Return(false, nil)
   938  					codec2 := &mockCodec{}
   939  					codec2.On("DataType").Return(datatype.Varchar)
   940  					codec2.On("Decode", []byte{a, b, c}, new(string), primitive.ProtocolVersion5).Run(func(args mock.Arguments) {
   941  						decodedElement := args.Get(1).(*string)
   942  						*decodedElement = "abc"
   943  					}).Return(false, nil)
   944  					codec3 := &mockCodec{}
   945  					codec3.On("DataType").Return(datatype.Boolean)
   946  					codec3.On("Decode", []byte(nil), new(bool), primitive.ProtocolVersion5).Return(true, nil)
   947  					return []Codec{codec1, codec2, codec3}
   948  				}(),
   949  				primitive.ProtocolVersion5,
   950  			},
   951  			"",
   952  		},
   953  	}
   954  	for _, tt := range tests {
   955  		t.Run(tt.name, func(t *testing.T) {
   956  			gotErr := readUdt(tt.args.source, tt.args.inj, tt.args.fieldNames, tt.args.fieldCodecs, tt.args.version)
   957  			assertErrorMessage(t, tt.wantErr, gotErr)
   958  		})
   959  	}
   960  }