github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/collection.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datacodec 16 17 import ( 18 "bytes" 19 "fmt" 20 "io" 21 "math" 22 "reflect" 23 24 "github.com/datastax/go-cassandra-native-protocol/datatype" 25 "github.com/datastax/go-cassandra-native-protocol/primitive" 26 ) 27 28 func NewList(dataType *datatype.List) (Codec, error) { 29 if dataType == nil { 30 return nil, ErrNilDataType 31 } 32 codec, err := NewCodec(dataType.ElementType) 33 if err != nil { 34 return nil, fmt.Errorf("cannot create codec for list elements: %w", err) 35 } 36 return &collectionCodec{dataType, codec}, nil 37 } 38 39 func NewSet(dataType *datatype.Set) (Codec, error) { 40 if dataType == nil { 41 return nil, ErrNilDataType 42 } 43 codec, err := NewCodec(dataType.ElementType) 44 if err != nil { 45 return nil, fmt.Errorf("cannot create codec for set elements: %w", err) 46 } 47 return &collectionCodec{dataType, codec}, nil 48 } 49 50 type collectionCodec struct { 51 dataType datatype.DataType 52 elementCodec Codec 53 } 54 55 func (c *collectionCodec) DataType() datatype.DataType { 56 return c.dataType 57 } 58 59 func (c *collectionCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { 60 ext, size, err := c.createExtractor(source) 61 if err == nil && ext != nil { 62 dest, err = writeCollection(ext, c.elementCodec, size, version) 63 } 64 if err != nil { 65 err = errCannotEncode(source, c.DataType(), version, err) 66 } 67 return 68 } 69 70 func (c *collectionCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { 71 wasNull = len(source) == 0 72 var injectorFactory func(int) (injector, error) 73 if injectorFactory, err = c.createInjector(dest, wasNull); err == nil && injectorFactory != nil { 74 err = readCollection(source, injectorFactory, c.elementCodec, version) 75 } 76 if err != nil { 77 err = errCannotDecode(dest, c.DataType(), version, err) 78 } 79 return 80 } 81 82 func (c *collectionCodec) createExtractor(source interface{}) (ext extractor, size int, err error) { 83 sourceValue, sourceType, wasNil := reflectSource(source) 84 if sourceType != nil { 85 switch sourceType.Kind() { 86 case reflect.Slice, reflect.Array: 87 if !wasNil { 88 ext, err = newSliceExtractor(sourceValue) 89 size = sourceValue.Len() 90 } 91 default: 92 err = ErrSourceTypeNotSupported 93 } 94 } 95 return 96 } 97 98 func (c *collectionCodec) createInjector(dest interface{}, wasNull bool) (injectorFactory func(int) (injector, error), err error) { 99 destValue, err := reflectDest(dest, wasNull) 100 if err == nil { 101 switch destValue.Kind() { 102 case reflect.Slice: 103 if !wasNull { 104 injectorFactory = func(size int) (injector, error) { 105 adjustSliceLength(destValue, size) 106 return newSliceInjector(destValue) 107 } 108 } 109 case reflect.Array: 110 if !wasNull { 111 injectorFactory = func(size int) (injector, error) { 112 return newSliceInjector(destValue) 113 } 114 } 115 case reflect.Interface: 116 if !wasNull { 117 var targetType reflect.Type 118 if targetType, err = PreferredGoType(c.DataType()); err == nil { 119 injectorFactory = func(size int) (injector, error) { 120 destValue.Set(reflect.MakeSlice(targetType, size, size)) 121 return newSliceInjector(destValue.Elem()) 122 } 123 } 124 } 125 default: 126 err = ErrDestinationTypeNotSupported 127 } 128 } 129 return 130 } 131 132 func writeCollection(ext extractor, elementCodec Codec, size int, version primitive.ProtocolVersion) ([]byte, error) { 133 buf := &bytes.Buffer{} 134 if err := writeCollectionSize(size, buf, version); err != nil { 135 return nil, err 136 } 137 for i := 0; i < size; i++ { 138 if elem, err := ext.getElem(i, i); err != nil { 139 return nil, errCannotExtractElement(i, err) 140 } else if encodedElem, err := elementCodec.Encode(elem, version); err != nil { 141 return nil, errCannotEncodeElement(i, err) 142 } else { 143 _ = primitive.WriteBytes(encodedElem, buf) 144 } 145 } 146 return buf.Bytes(), nil 147 } 148 149 func readCollection(source []byte, injectorFactory func(int) (injector, error), elementCodec Codec, version primitive.ProtocolVersion) error { 150 reader := bytes.NewReader(source) 151 total := len(source) 152 if size, err := readCollectionSize(reader, version); err != nil { 153 return err 154 } else if inj, err := injectorFactory(size); err != nil { 155 return err 156 } else { 157 for i := 0; i < size; i++ { 158 if encodedElem, err := primitive.ReadBytes(reader); err != nil { 159 return errCannotReadElement(i, err) 160 } else if decodedElem, err := inj.zeroElem(i, i); err != nil { 161 return errCannotCreateElement(i, err) 162 } else if elementWasNull, err := elementCodec.Decode(encodedElem, decodedElem, version); err != nil { 163 return errCannotDecodeElement(i, err) 164 } else if err = inj.setElem(i, i, decodedElem, false, elementWasNull); err != nil { 165 return errCannotInjectElement(i, err) 166 } 167 } 168 if remaining := reader.Len(); remaining != 0 { 169 return errBytesRemaining(total, remaining) 170 } 171 } 172 return nil 173 } 174 175 func writeCollectionSize(size int, dest io.Writer, version primitive.ProtocolVersion) (err error) { 176 if version.Uses4BytesCollectionLength() { 177 if size > math.MaxInt32 { 178 err = collectionSizeTooLarge(size, math.MaxInt32) 179 } else if size < 0 { 180 err = collectionSizeNegative(size) 181 } else { 182 err = primitive.WriteInt(int32(size), dest) 183 } 184 } else { 185 if size > math.MaxUint16 { 186 err = collectionSizeTooLarge(size, math.MaxUint16) 187 } else if size < 0 { 188 err = collectionSizeNegative(size) 189 } else { 190 err = primitive.WriteShort(uint16(size), dest) 191 } 192 } 193 if err != nil { 194 err = cannotWriteCollectionSize(err) 195 } 196 return 197 } 198 199 func readCollectionSize(source io.Reader, version primitive.ProtocolVersion) (size int, err error) { 200 if version.Uses4BytesCollectionLength() { 201 var sizeInt32 int32 202 sizeInt32, err = primitive.ReadInt(source) 203 size = int(sizeInt32) 204 } else { 205 var sizeInt16 uint16 206 sizeInt16, err = primitive.ReadShort(source) 207 size = int(sizeInt16) 208 } 209 if err != nil { 210 err = fmt.Errorf("cannot read collection size: %w", err) 211 } 212 return 213 }