github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/udt.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datacodec 16 17 import ( 18 "bytes" 19 "fmt" 20 "reflect" 21 22 "github.com/datastax/go-cassandra-native-protocol/datatype" 23 "github.com/datastax/go-cassandra-native-protocol/primitive" 24 ) 25 26 func NewUserDefined(dataType *datatype.UserDefined) (Codec, error) { 27 if dataType == nil { 28 return nil, ErrNilDataType 29 } 30 fieldCodecs := make([]Codec, len(dataType.FieldTypes)) 31 for i, fieldType := range dataType.FieldTypes { 32 if fieldCodec, err := NewCodec(fieldType); err != nil { 33 return nil, fmt.Errorf("cannot create codec for user-defined type field %d (%s): %w", i, dataType.FieldNames[i], err) 34 } else { 35 fieldCodecs[i] = fieldCodec 36 } 37 } 38 return &udtCodec{dataType, fieldCodecs}, nil 39 } 40 41 type udtCodec struct { 42 dataType *datatype.UserDefined 43 fieldCodecs []Codec 44 } 45 46 func (c *udtCodec) DataType() datatype.DataType { 47 return c.dataType 48 } 49 50 func (c *udtCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { 51 if !version.SupportsDataType(c.DataType().Code()) { 52 err = errDataTypeNotSupported(c.DataType(), version) 53 } else { 54 var ext extractor 55 if ext, err = c.createExtractor(source); err == nil && ext != nil { 56 dest, err = writeUdt(ext, c.dataType.FieldNames, c.fieldCodecs, version) 57 } 58 } 59 if err != nil { 60 err = errCannotEncode(source, c.DataType(), version, err) 61 } 62 return 63 } 64 65 func (c *udtCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { 66 wasNull = len(source) == 0 67 if !version.SupportsDataType(c.DataType().Code()) { 68 err = errDataTypeNotSupported(c.DataType(), version) 69 } else { 70 var inj injector 71 if inj, err = c.createInjector(dest, wasNull); err == nil && inj != nil { 72 err = readUdt(source, inj, c.dataType.FieldNames, c.fieldCodecs, version) 73 } 74 } 75 if err != nil { 76 err = errCannotDecode(dest, c.DataType(), version, err) 77 } 78 return 79 } 80 81 func (c *udtCodec) createExtractor(source interface{}) (ext extractor, err error) { 82 sourceValue, sourceType, wasNil := reflectSource(source) 83 if sourceType != nil { 84 switch sourceType.Kind() { 85 case reflect.Struct: 86 if !wasNil { 87 ext, err = newStructExtractor(sourceValue) 88 } 89 case reflect.Map: 90 if !wasNil { 91 keyType := sourceValue.Type().Key() 92 if keyType.Kind() != reflect.String { 93 err = errWrongElementType("map key", typeOfString, keyType) 94 } else { 95 ext, err = newMapExtractor(sourceValue) 96 } 97 } 98 case reflect.Slice, reflect.Array: 99 if !wasNil { 100 ext, err = newSliceExtractor(sourceValue) 101 } 102 default: 103 err = ErrSourceTypeNotSupported 104 } 105 } 106 return 107 } 108 109 func (c *udtCodec) createInjector(dest interface{}, wasNull bool) (inj injector, err error) { 110 destValue, err := reflectDest(dest, wasNull) 111 if err == nil { 112 switch destValue.Kind() { 113 case reflect.Struct: 114 if !wasNull { 115 inj, err = newStructInjector(destValue) 116 } 117 case reflect.Map: 118 if !wasNull { 119 keyType := destValue.Type().Key() 120 if keyType.Kind() != reflect.String { 121 err = errWrongElementType("map key", typeOfString, keyType) 122 } else { 123 adjustMapSize(destValue, len(c.fieldCodecs)) 124 inj, err = newMapInjector(destValue) 125 } 126 } 127 case reflect.Slice: 128 if !wasNull { 129 adjustSliceLength(destValue, len(c.fieldCodecs)) 130 inj, err = newSliceInjector(destValue) 131 } 132 case reflect.Array: 133 if !wasNull { 134 inj, err = newSliceInjector(destValue) 135 } 136 case reflect.Interface: 137 if !wasNull { 138 target := make(map[string]interface{}, len(c.fieldCodecs)) 139 *dest.(*interface{}) = target 140 inj, err = newMapInjector(reflect.ValueOf(target)) 141 } 142 default: 143 err = ErrDestinationTypeNotSupported 144 } 145 } 146 return 147 } 148 149 func writeUdt(ext extractor, fieldNames []string, fieldCodecs []Codec, version primitive.ProtocolVersion) ([]byte, error) { 150 buf := &bytes.Buffer{} 151 for i, fieldCodec := range fieldCodecs { 152 name := fieldNames[i] 153 if value, err := ext.getElem(i, name); err != nil { 154 return nil, errCannotExtractUdtField(i, name, err) 155 } else if encodedField, err := fieldCodec.Encode(value, version); err != nil { 156 return nil, errCannotEncodeUdtField(i, name, err) 157 } else { 158 _ = primitive.WriteBytes(encodedField, buf) 159 } 160 } 161 return buf.Bytes(), nil 162 } 163 164 func readUdt(source []byte, inj injector, fieldNames []string, fieldCodecs []Codec, version primitive.ProtocolVersion) error { 165 reader := bytes.NewReader(source) 166 total := reader.Len() 167 for i, fieldCodec := range fieldCodecs { 168 name := fieldNames[i] 169 if encodedField, err := primitive.ReadBytes(reader); err != nil { 170 return errCannotReadUdtField(i, name, err) 171 } else if decodedField, err := inj.zeroElem(i, name); err != nil { 172 return errCannotCreateUdtField(i, name, err) 173 } else if fieldWasNull, err := fieldCodec.Decode(encodedField, decodedField, version); err != nil { 174 return errCannotDecodeUdtField(i, name, err) 175 } else if err = inj.setElem(i, name, decodedField, false, fieldWasNull); err != nil { 176 return errCannotInjectUdtField(i, name, err) 177 } 178 } 179 if remaining := reader.Len(); remaining != 0 { 180 return errBytesRemaining(total, remaining) 181 } 182 return nil 183 }