github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/map.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datacodec 16 17 import ( 18 "bytes" 19 "fmt" 20 "reflect" 21 22 "github.com/datastax/go-cassandra-native-protocol/datatype" 23 "github.com/datastax/go-cassandra-native-protocol/primitive" 24 ) 25 26 func NewMap(dataType *datatype.Map) (Codec, error) { 27 if dataType == nil { 28 return nil, ErrNilDataType 29 } 30 keyCodec, err := NewCodec(dataType.KeyType) 31 if err != nil { 32 return nil, fmt.Errorf("cannot create codec for map keys: %w", err) 33 } 34 valueCodec, err := NewCodec(dataType.ValueType) 35 if err != nil { 36 return nil, fmt.Errorf("cannot create codec for map values: %w", err) 37 } 38 return &mapCodec{dataType, keyCodec, valueCodec}, nil 39 } 40 41 type mapCodec struct { 42 dataType *datatype.Map 43 keyCodec Codec 44 valueCodec Codec 45 } 46 47 func (c *mapCodec) DataType() datatype.DataType { 48 return c.dataType 49 } 50 51 func (c *mapCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { 52 ext, size, err := c.createExtractor(source) 53 if err == nil && ext != nil { 54 dest, err = writeMap(ext, size, c.keyCodec, c.valueCodec, version) 55 } 56 if err != nil { 57 err = errCannotEncode(source, c.DataType(), version, err) 58 } 59 return 60 } 61 62 func (c *mapCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { 63 wasNull = len(source) == 0 64 var injectorFactory func(int) (keyValueInjector, error) 65 if injectorFactory, err = c.createInjector(dest, wasNull); err == nil && injectorFactory != nil { 66 err = readMap(source, injectorFactory, c.keyCodec, c.valueCodec, version) 67 } 68 if err != nil { 69 err = errCannotDecode(dest, c.DataType(), version, err) 70 } 71 return 72 } 73 74 func (c *mapCodec) createExtractor(source interface{}) (ext keyValueExtractor, size int, err error) { 75 sourceValue, sourceType, wasNil := reflectSource(source) 76 if sourceType != nil { 77 switch sourceType.Kind() { 78 case reflect.Map: 79 if !wasNil { 80 size = sourceValue.Len() 81 ext, err = newMapExtractor(sourceValue) 82 } 83 case reflect.Struct: 84 if !wasNil { 85 if c.keyCodec.DataType() != datatype.Varchar && c.keyCodec.DataType() != datatype.Ascii { 86 err = errWrongDataType("map key", datatype.Varchar, datatype.Ascii, c.keyCodec.DataType()) 87 } else { 88 size = sourceValue.NumField() 89 ext, err = newStructExtractor(sourceValue) 90 } 91 } 92 default: 93 err = ErrSourceTypeNotSupported 94 } 95 } 96 return 97 } 98 99 func (c *mapCodec) createInjector(dest interface{}, wasNull bool) (injectorFactory func(int) (keyValueInjector, error), err error) { 100 destValue, err := reflectDest(dest, wasNull) 101 if err == nil { 102 switch destValue.Kind() { 103 case reflect.Map: 104 if !wasNull { 105 injectorFactory = func(size int) (keyValueInjector, error) { 106 adjustMapSize(destValue, size) 107 return newMapInjector(destValue) 108 } 109 } 110 case reflect.Struct: 111 if !wasNull { 112 if c.keyCodec.DataType() != datatype.Varchar && c.keyCodec.DataType() != datatype.Ascii { 113 err = errWrongDataType("map key", datatype.Varchar, datatype.Ascii, c.keyCodec.DataType()) 114 } else { 115 injectorFactory = func(size int) (keyValueInjector, error) { 116 return newStructInjector(destValue) 117 } 118 } 119 } 120 case reflect.Interface: 121 if !wasNull { 122 var targetType reflect.Type 123 if targetType, err = PreferredGoType(c.DataType()); err == nil { 124 injectorFactory = func(size int) (keyValueInjector, error) { 125 destValue.Set(reflect.MakeMapWithSize(targetType, size)) 126 return newMapInjector(destValue.Elem()) 127 } 128 } 129 } 130 default: 131 err = ErrDestinationTypeNotSupported 132 } 133 } 134 return 135 } 136 137 func writeMap(ext keyValueExtractor, size int, keyCodec Codec, valueCodec Codec, version primitive.ProtocolVersion) ([]byte, error) { 138 buf := &bytes.Buffer{} 139 if err := writeCollectionSize(size, buf, version); err != nil { 140 return nil, err 141 } 142 for i := 0; i < size; i++ { 143 key := ext.getKey(i) 144 if value, err := ext.getElem(i, key); err != nil { 145 return nil, errCannotExtractMapValue(i, err) 146 } else if encodedKey, err := keyCodec.Encode(key, version); err != nil { 147 return nil, errCannotEncodeMapKey(i, err) 148 } else if encodedValue, err := valueCodec.Encode(value, version); err != nil { 149 return nil, errCannotEncodeMapValue(i, err) 150 } else { 151 _ = primitive.WriteBytes(encodedKey, buf) 152 _ = primitive.WriteBytes(encodedValue, buf) 153 } 154 } 155 return buf.Bytes(), nil 156 } 157 158 func readMap(source []byte, injectorFactory func(int) (keyValueInjector, error), keyCodec Codec, valueCodec Codec, version primitive.ProtocolVersion) error { 159 reader := bytes.NewReader(source) 160 total := len(source) 161 if size, err := readCollectionSize(reader, version); err != nil { 162 return err 163 } else if inj, err := injectorFactory(size); err != nil { 164 return err 165 } else { 166 for i := 0; i < size; i++ { 167 if encodedKey, err := primitive.ReadBytes(reader); err != nil { 168 return errCannotReadMapKey(i, err) 169 } else if encodedValue, err := primitive.ReadBytes(reader); err != nil { 170 return errCannotReadMapValue(i, err) 171 } else if decodedKey, err := inj.zeroKey(i); err != nil { 172 return errCannotCreateMapKey(i, err) 173 } else if keyWasNull, err := keyCodec.Decode(encodedKey, decodedKey, version); err != nil { 174 return errCannotDecodeMapKey(i, err) 175 } else if decodedValue, err := inj.zeroElem(i, decodedKey); err != nil { 176 return errCannotCreateMapValue(i, err) 177 } else if valueWasNull, err := valueCodec.Decode(encodedValue, decodedValue, version); err != nil { 178 return errCannotDecodeMapValue(i, err) 179 } else if err = inj.setElem(i, decodedKey, decodedValue, keyWasNull, valueWasNull); err != nil { 180 return errCannotInjectMapEntry(i, err) 181 } 182 } 183 if remaining := reader.Len(); remaining != 0 { 184 return errBytesRemaining(total, remaining) 185 } 186 } 187 return nil 188 }