github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datatype/udt_test.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datatype 16 17 import ( 18 "bytes" 19 "errors" 20 "fmt" 21 "testing" 22 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/require" 25 26 "github.com/datastax/go-cassandra-native-protocol/primitive" 27 ) 28 29 func TestUserDefinedType(t *testing.T) { 30 fieldNames := []string{"f1", "f2"} 31 fieldTypes := []DataType{Varchar, Int} 32 udtType, err := NewUserDefined("ks1", "udt1", fieldNames, fieldTypes) 33 assert.Nil(t, err) 34 assert.Equal(t, primitive.DataTypeCodeUdt, udtType.Code()) 35 assert.Equal(t, fieldTypes, udtType.FieldTypes) 36 udtType2, err2 := NewUserDefined("ks1", "udt1", fieldNames, []DataType{Varchar, Int, Boolean}) 37 assert.Nil(t, udtType2) 38 assert.Errorf(t, err2, "field names and field types length mismatch: 2 != 3") 39 } 40 41 func TestUserDefinedTypeDeepCopy(t *testing.T) { 42 fieldNames := []string{"f1", "f2"} 43 fieldTypes := []DataType{Varchar, Int} 44 udtType, err := NewUserDefined("ks1", "udt1", fieldNames, fieldTypes) 45 assert.Nil(t, err) 46 47 cloned := udtType.DeepCopy() 48 assert.Equal(t, udtType, cloned) 49 cloned.Name = "udt2" 50 cloned.Keyspace = "ks2" 51 cloned.FieldNames = []string{"f5", "field6", "f7"} 52 cloned.FieldTypes = []DataType{Uuid, Float, Varchar} 53 assert.NotEqual(t, udtType, cloned) 54 55 assert.Equal(t, primitive.DataTypeCodeUdt, udtType.Code()) 56 assert.Equal(t, []DataType{Varchar, Int}, udtType.FieldTypes) 57 assert.Equal(t, []string{"f1", "f2"}, udtType.FieldNames) 58 assert.Equal(t, "ks1", udtType.Keyspace) 59 assert.Equal(t, "udt1", udtType.Name) 60 61 assert.Equal(t, primitive.DataTypeCodeUdt, cloned.Code()) 62 assert.Equal(t, []DataType{Uuid, Float, Varchar}, cloned.FieldTypes) 63 assert.Equal(t, []string{"f5", "field6", "f7"}, cloned.FieldNames) 64 assert.Equal(t, "ks2", cloned.Keyspace) 65 assert.Equal(t, "udt2", cloned.Name) 66 } 67 68 func TestUserDefinedTypeDeepCopy_NilFieldTypesSlice(t *testing.T) { 69 fieldNames := []string{"f1", "f2", "f3"} 70 fieldTypes := []DataType{Int, Uuid, Float} 71 udtType, err := NewUserDefined("ks1", "udt1", fieldNames, fieldTypes) 72 assert.Nil(t, err) 73 udtType.FieldTypes = nil 74 75 cloned := udtType.DeepCopy() 76 assert.Equal(t, udtType, cloned) 77 cloned.FieldTypes = []DataType{Uuid, Float, Varchar} 78 assert.NotEqual(t, udtType, cloned) 79 80 assert.Nil(t, udtType.FieldTypes) 81 assert.Equal(t, []DataType{Uuid, Float, Varchar}, cloned.FieldTypes) 82 } 83 84 func TestUserDefinedTypeDeepCopy_NilFieldType(t *testing.T) { 85 fieldNames := []string{"f1", "f2", "f3"} 86 fieldTypes := []DataType{nil, Uuid, Float} 87 udtType, err := NewUserDefined("ks1", "udt1", fieldNames, fieldTypes) 88 assert.Nil(t, err) 89 90 cloned := udtType.DeepCopy() 91 assert.Equal(t, udtType, cloned) 92 cloned.FieldTypes = []DataType{Uuid, Float, Varchar} 93 assert.NotEqual(t, udtType, cloned) 94 95 assert.Equal(t, []DataType{nil, Uuid, Float}, udtType.FieldTypes) 96 assert.Equal(t, []DataType{Uuid, Float, Varchar}, cloned.FieldTypes) 97 } 98 99 func TestUserDefinedTypeDeepCopy_ComplexFieldTypes(t *testing.T) { 100 fieldNames := []string{"f1", "f2", "f3"} 101 fieldTypes := []DataType{NewList(NewTuple(Varchar)), Uuid, Float} 102 udtType, err := NewUserDefined("ks1", "udt1", fieldNames, fieldTypes) 103 assert.Nil(t, err) 104 105 cloned := udtType.DeepCopy() 106 assert.Equal(t, udtType, cloned) 107 cloned.FieldTypes[0].(*List).ElementType = NewTuple(Int) 108 assert.NotEqual(t, udtType, cloned) 109 110 assert.Equal(t, []DataType{NewList(NewTuple(Varchar)), Uuid, Float}, udtType.FieldTypes) 111 assert.Equal(t, []DataType{NewList(NewTuple(Int)), Uuid, Float}, cloned.FieldTypes) 112 } 113 114 var udt1, _ = NewUserDefined("ks1", "udt1", []string{"f1", "f2"}, []DataType{Varchar, Int}) 115 var udt2, _ = NewUserDefined("ks1", "udt2", []string{"f1"}, []DataType{udt1}) 116 117 func TestWriteUserDefinedType(t *testing.T) { 118 tests := []struct { 119 name string 120 input DataType 121 expected []byte 122 err error 123 }{ 124 { 125 "simple udt", 126 udt1, 127 []byte{ 128 0, byte(primitive.DataTypeCodeUdt & 0xFF), 129 0, 3, byte('k'), byte('s'), byte('1'), 130 0, 4, byte('u'), byte('d'), byte('t'), byte('1'), 131 0, 2, // field count 132 0, 2, byte('f'), byte('1'), 133 0, byte(primitive.DataTypeCodeVarchar & 0xFF), 134 0, 2, byte('f'), byte('2'), 135 0, byte(primitive.DataTypeCodeInt & 0xFF), 136 }, 137 nil, 138 }, 139 { 140 "complex udt", 141 udt2, 142 []byte{ 143 0, byte(primitive.DataTypeCodeUdt & 0xFF), 144 0, 3, byte('k'), byte('s'), byte('1'), 145 0, 4, byte('u'), byte('d'), byte('t'), byte('2'), 146 0, 1, // field count 147 0, 2, byte('f'), byte('1'), 148 0, byte(primitive.DataTypeCodeUdt & 0xFF), 149 0, 3, byte('k'), byte('s'), byte('1'), 150 0, 4, byte('u'), byte('d'), byte('t'), byte('1'), 151 0, 2, // field count 152 0, 2, byte('f'), byte('1'), 153 0, byte(primitive.DataTypeCodeVarchar & 0xFF), 154 0, 2, byte('f'), byte('2'), 155 0, byte(primitive.DataTypeCodeInt & 0xFF), 156 }, 157 nil, 158 }, 159 {"nil udt", nil, nil, errors.New("DataType can not be nil")}, 160 } 161 162 t.Run("versions_with_udt_support", func(t *testing.T) { 163 for _, version := range primitive.SupportedProtocolVersionsGreaterThanOrEqualTo(primitive.ProtocolVersion3) { 164 t.Run(version.String(), func(t *testing.T) { 165 for _, test := range tests { 166 t.Run(test.name, func(t *testing.T) { 167 var dest = &bytes.Buffer{} 168 var err error 169 err = WriteDataType(test.input, dest, version) 170 assert.Equal(t, test.err, err) 171 actual := dest.Bytes() 172 assert.Equal(t, test.expected, actual) 173 }) 174 } 175 }) 176 } 177 }) 178 179 t.Run("versions_without_udt_support", func(t *testing.T) { 180 for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { 181 t.Run(version.String(), func(t *testing.T) { 182 for _, test := range tests { 183 t.Run(test.name, func(t *testing.T) { 184 var dest = &bytes.Buffer{} 185 var err error 186 err = WriteDataType(test.input, dest, version) 187 actual := dest.Bytes() 188 require.NotNil(t, err) 189 if test.err != nil { 190 assert.Equal(t, test.err, err) 191 } else { 192 assert.Contains(t, err.Error(), 193 fmt.Sprintf("invalid data type code for %s: DataTypeCode Udt", version)) 194 } 195 assert.Equal(t, 0, len(actual)) 196 }) 197 } 198 }) 199 } 200 }) 201 } 202 203 func TestLengthOfUserDefinedType(t *testing.T) { 204 for _, version := range primitive.SupportedProtocolVersions() { 205 t.Run(version.String(), func(t *testing.T) { 206 tests := []struct { 207 name string 208 input DataType 209 expected int 210 err error 211 }{ 212 { 213 "simple udt", 214 udt1, 215 primitive.LengthOfString("ks1") + 216 primitive.LengthOfString("udt1") + 217 primitive.LengthOfShort + // field count 218 primitive.LengthOfString("f1") + 219 primitive.LengthOfShort + // varchar 220 primitive.LengthOfString("f2") + 221 primitive.LengthOfShort, // int 222 nil, 223 }, 224 { 225 "complex udt", 226 udt2, 227 primitive.LengthOfString("ks1") + 228 primitive.LengthOfString("udt2") + 229 primitive.LengthOfShort + // field count 230 primitive.LengthOfString("f1") + 231 primitive.LengthOfShort + // UDT 232 primitive.LengthOfString("ks1") + 233 primitive.LengthOfString("udt1") + 234 primitive.LengthOfShort + // field count 235 primitive.LengthOfString("f1") + 236 primitive.LengthOfShort + // varchar 237 primitive.LengthOfString("f2") + 238 primitive.LengthOfShort, // int 239 nil, 240 }, 241 {"nil udt", nil, -1, errors.New("expected *UserDefined, got <nil>")}, 242 } 243 for _, test := range tests { 244 t.Run(test.name, func(t *testing.T) { 245 var actual int 246 var err error 247 actual, err = lengthOfUserDefinedType(test.input, version) 248 assert.Equal(t, test.expected, actual) 249 assert.Equal(t, test.err, err) 250 }) 251 } 252 }) 253 } 254 } 255 256 func TestReadUserDefinedType(t *testing.T) { 257 tests := []struct { 258 name string 259 input []byte 260 expected DataType 261 err error 262 }{ 263 { 264 "simple udt", 265 []byte{ 266 0, byte(primitive.DataTypeCodeUdt & 0xFF), 267 0, 3, byte('k'), byte('s'), byte('1'), 268 0, 4, byte('u'), byte('d'), byte('t'), byte('1'), 269 0, 2, // field count 270 0, 2, byte('f'), byte('1'), 271 0, byte(primitive.DataTypeCodeVarchar & 0xFF), 272 0, 2, byte('f'), byte('2'), 273 0, byte(primitive.DataTypeCodeInt & 0xFF), 274 }, 275 udt1, 276 nil, 277 }, 278 { 279 "complex udt", 280 []byte{ 281 0, byte(primitive.DataTypeCodeUdt & 0xFF), 282 0, 3, byte('k'), byte('s'), byte('1'), 283 0, 4, byte('u'), byte('d'), byte('t'), byte('2'), 284 0, 1, // field count 285 0, 2, byte('f'), byte('1'), 286 0, byte(primitive.DataTypeCodeUdt & 0xFF), 287 0, 3, byte('k'), byte('s'), byte('1'), 288 0, 4, byte('u'), byte('d'), byte('t'), byte('1'), 289 0, 2, // field count 290 0, 2, byte('f'), byte('1'), 291 0, byte(primitive.DataTypeCodeVarchar & 0xFF), 292 0, 2, byte('f'), byte('2'), 293 0, byte(primitive.DataTypeCodeInt & 0xFF), 294 }, 295 udt2, 296 nil, 297 }, 298 { 299 "cannot read udt", 300 []byte{0, byte(primitive.DataTypeCodeUdt & 0xFF)}, 301 nil, 302 fmt.Errorf("cannot read udt keyspace: %w", 303 fmt.Errorf("cannot read [string] length: %w", 304 fmt.Errorf("cannot read [short]: %w", 305 errors.New("EOF")))), 306 }, 307 } 308 309 t.Run("versions_with_udt_support", func(t *testing.T) { 310 for _, version := range primitive.SupportedProtocolVersionsGreaterThanOrEqualTo(primitive.ProtocolVersion3) { 311 t.Run(version.String(), func(t *testing.T) { 312 for _, test := range tests { 313 t.Run(test.name, func(t *testing.T) { 314 var source = bytes.NewBuffer(test.input) 315 var actual DataType 316 var err error 317 actual, err = ReadDataType(source, version) 318 assert.Equal(t, test.err, err) 319 assert.Equal(t, test.expected, actual) 320 }) 321 } 322 }) 323 } 324 }) 325 326 t.Run("versions_without_udt_support", func(t *testing.T) { 327 for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { 328 t.Run(version.String(), func(t *testing.T) { 329 for _, test := range tests { 330 t.Run(test.name, func(t *testing.T) { 331 var source = bytes.NewBuffer(test.input) 332 var actual DataType 333 var err error 334 actual, err = ReadDataType(source, version) 335 require.NotNil(t, err) 336 assert.Contains(t, err.Error(), 337 fmt.Sprintf("invalid data type code for %s: DataTypeCode Udt", version)) 338 assert.Nil(t, actual) 339 }) 340 } 341 }) 342 } 343 }) 344 } 345 346 func Test_userDefinedType_String(t1 *testing.T) { 347 tests := []struct { 348 name string 349 keyspace string 350 udtName string 351 fieldNames []string 352 fieldTypes []DataType 353 want string 354 }{ 355 {"empty", "ks1", "type1", []string{}, []DataType{}, "ks1.type1<>"}, 356 {"simple", "ks1", "type1", []string{"f1", "f2"}, []DataType{Int, Varchar}, "ks1.type1<f1:int,f2:varchar>"}, 357 { 358 "complex", 359 "ks1", 360 "type1", 361 []string{"f1", "f2"}, 362 []DataType{Int, func() DataType { 363 udt2, _ := NewUserDefined("ks1", "type2", []string{"f2a", "f2b"}, []DataType{Varchar, Boolean}) 364 return udt2 365 }()}, 366 "ks1.type1<f1:int,f2:ks1.type2<f2a:varchar,f2b:boolean>>", 367 }, 368 } 369 for _, tt := range tests { 370 t1.Run(tt.name, func(t *testing.T) { 371 udt, err := NewUserDefined(tt.keyspace, tt.udtName, tt.fieldNames, tt.fieldTypes) 372 require.NoError(t, err) 373 got := udt.AsCql() 374 assert.Equal(t, tt.want, got) 375 }) 376 } 377 }