github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datatype/udt.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datatype 16 17 import ( 18 "bytes" 19 "fmt" 20 "io" 21 22 "github.com/datastax/go-cassandra-native-protocol/primitive" 23 ) 24 25 // UserDefined is a data type that represents a CQL user-defined type. 26 // +k8s:deepcopy-gen=true 27 // +k8s:deepcopy-gen:interfaces=github.com/datastax/go-cassandra-native-protocol/datatype.DataType 28 type UserDefined struct { 29 Keyspace string 30 Name string 31 FieldNames []string 32 FieldTypes []DataType 33 // Note: field names and field types are not modeled as a map because iteration order matters. 34 } 35 36 func NewUserDefined(keyspace string, name string, fieldNames []string, fieldTypes []DataType) (*UserDefined, error) { 37 fieldNamesLength := len(fieldNames) 38 fieldTypesLength := len(fieldTypes) 39 if fieldNamesLength != fieldTypesLength { 40 return nil, fmt.Errorf("field names and field types length mismatch: %d != %d", fieldNamesLength, fieldTypesLength) 41 } 42 return &UserDefined{Keyspace: keyspace, Name: name, FieldNames: fieldNames, FieldTypes: fieldTypes}, nil 43 } 44 45 func (t *UserDefined) Code() primitive.DataTypeCode { 46 return primitive.DataTypeCodeUdt 47 } 48 49 func (t *UserDefined) String() string { 50 return t.AsCql() 51 } 52 53 func (t *UserDefined) AsCql() string { 54 buf := &bytes.Buffer{} 55 buf.WriteString(t.Keyspace) 56 buf.WriteString(".") 57 buf.WriteString(t.Name) 58 buf.WriteString("<") 59 for i, fieldType := range t.FieldTypes { 60 if i > 0 { 61 buf.WriteString(",") 62 } 63 buf.WriteString(t.FieldNames[i]) 64 buf.WriteString(":") 65 buf.WriteString(fieldType.AsCql()) 66 } 67 buf.WriteString(">") 68 return buf.String() 69 } 70 71 func writeUserDefinedType(t DataType, dest io.Writer, version primitive.ProtocolVersion) (err error) { 72 userDefinedType, ok := t.(*UserDefined) 73 if !ok { 74 return fmt.Errorf("expected *UserDefined, got %T", t) 75 } else if err = primitive.WriteString(userDefinedType.Keyspace, dest); err != nil { 76 return fmt.Errorf("cannot write udt keyspace: %w", err) 77 } else if err = primitive.WriteString(userDefinedType.Name, dest); err != nil { 78 return fmt.Errorf("cannot write udt name: %w", err) 79 } else if err = primitive.WriteShort(uint16(len(userDefinedType.FieldTypes)), dest); err != nil { 80 return fmt.Errorf("cannot write udt field count: %w", err) 81 } 82 if len(userDefinedType.FieldNames) != len(userDefinedType.FieldTypes) { 83 return fmt.Errorf("invalid user-defined type: length of field names is not equal to length of field types") 84 } 85 for i, fieldName := range userDefinedType.FieldNames { 86 fieldType := userDefinedType.FieldTypes[i] 87 if err = primitive.WriteString(fieldName, dest); err != nil { 88 return fmt.Errorf("cannot write udt field %v name: %w", fieldName, err) 89 } else if err = WriteDataType(fieldType, dest, version); err != nil { 90 return fmt.Errorf("cannot write udt field %v: %w", fieldName, err) 91 } 92 } 93 return nil 94 } 95 96 func lengthOfUserDefinedType(t DataType, version primitive.ProtocolVersion) (length int, err error) { 97 userDefinedType, ok := t.(*UserDefined) 98 if !ok { 99 return -1, fmt.Errorf("expected *UserDefined, got %T", t) 100 } 101 length += primitive.LengthOfString(userDefinedType.Keyspace) 102 length += primitive.LengthOfString(userDefinedType.Name) 103 length += primitive.LengthOfShort // field count 104 if len(userDefinedType.FieldNames) != len(userDefinedType.FieldTypes) { 105 return -1, fmt.Errorf("invalid user-defined type: length of field names is not equal to length of field types") 106 } 107 for i, fieldName := range userDefinedType.FieldNames { 108 fieldType := userDefinedType.FieldTypes[i] 109 length += primitive.LengthOfString(fieldName) 110 if fieldLength, err := LengthOfDataType(fieldType, version); err != nil { 111 return -1, fmt.Errorf("cannot compute length of udt field %v: %w", fieldName, err) 112 } else { 113 length += fieldLength 114 } 115 } 116 return length, nil 117 } 118 119 func readUserDefinedType(source io.Reader, version primitive.ProtocolVersion) (decoded DataType, err error) { 120 userDefinedType := &UserDefined{} 121 if userDefinedType.Keyspace, err = primitive.ReadString(source); err != nil { 122 return nil, fmt.Errorf("cannot read udt keyspace: %w", err) 123 } else if userDefinedType.Name, err = primitive.ReadString(source); err != nil { 124 return nil, fmt.Errorf("cannot read udt name: %w", err) 125 } else if fieldCount, err := primitive.ReadShort(source); err != nil { 126 return nil, fmt.Errorf("cannot read udt field count: %w", err) 127 } else { 128 userDefinedType.FieldNames = make([]string, fieldCount) 129 userDefinedType.FieldTypes = make([]DataType, fieldCount) 130 for i := 0; i < int(fieldCount); i++ { 131 if userDefinedType.FieldNames[i], err = primitive.ReadString(source); err != nil { 132 return nil, fmt.Errorf("cannot read udt field %d name: %w", i, err) 133 } else if userDefinedType.FieldTypes[i], err = ReadDataType(source, version); err != nil { 134 return nil, fmt.Errorf("cannot read udt field %d: %w", i, err) 135 } 136 } 137 return userDefinedType, nil 138 } 139 }