github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/map.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"reflect"
    21  
    22  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    23  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    24  )
    25  
    26  func NewMap(dataType *datatype.Map) (Codec, error) {
    27  	if dataType == nil {
    28  		return nil, ErrNilDataType
    29  	}
    30  	keyCodec, err := NewCodec(dataType.KeyType)
    31  	if err != nil {
    32  		return nil, fmt.Errorf("cannot create codec for map keys: %w", err)
    33  	}
    34  	valueCodec, err := NewCodec(dataType.ValueType)
    35  	if err != nil {
    36  		return nil, fmt.Errorf("cannot create codec for map values: %w", err)
    37  	}
    38  	return &mapCodec{dataType, keyCodec, valueCodec}, nil
    39  }
    40  
    41  type mapCodec struct {
    42  	dataType   *datatype.Map
    43  	keyCodec   Codec
    44  	valueCodec Codec
    45  }
    46  
    47  func (c *mapCodec) DataType() datatype.DataType {
    48  	return c.dataType
    49  }
    50  
    51  func (c *mapCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
    52  	ext, size, err := c.createExtractor(source)
    53  	if err == nil && ext != nil {
    54  		dest, err = writeMap(ext, size, c.keyCodec, c.valueCodec, version)
    55  	}
    56  	if err != nil {
    57  		err = errCannotEncode(source, c.DataType(), version, err)
    58  	}
    59  	return
    60  }
    61  
    62  func (c *mapCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
    63  	wasNull = len(source) == 0
    64  	var injectorFactory func(int) (keyValueInjector, error)
    65  	if injectorFactory, err = c.createInjector(dest, wasNull); err == nil && injectorFactory != nil {
    66  		err = readMap(source, injectorFactory, c.keyCodec, c.valueCodec, version)
    67  	}
    68  	if err != nil {
    69  		err = errCannotDecode(dest, c.DataType(), version, err)
    70  	}
    71  	return
    72  }
    73  
    74  func (c *mapCodec) createExtractor(source interface{}) (ext keyValueExtractor, size int, err error) {
    75  	sourceValue, sourceType, wasNil := reflectSource(source)
    76  	if sourceType != nil {
    77  		switch sourceType.Kind() {
    78  		case reflect.Map:
    79  			if !wasNil {
    80  				size = sourceValue.Len()
    81  				ext, err = newMapExtractor(sourceValue)
    82  			}
    83  		case reflect.Struct:
    84  			if !wasNil {
    85  				if c.keyCodec.DataType() != datatype.Varchar && c.keyCodec.DataType() != datatype.Ascii {
    86  					err = errWrongDataType("map key", datatype.Varchar, datatype.Ascii, c.keyCodec.DataType())
    87  				} else {
    88  					size = sourceValue.NumField()
    89  					ext, err = newStructExtractor(sourceValue)
    90  				}
    91  			}
    92  		default:
    93  			err = ErrSourceTypeNotSupported
    94  		}
    95  	}
    96  	return
    97  }
    98  
    99  func (c *mapCodec) createInjector(dest interface{}, wasNull bool) (injectorFactory func(int) (keyValueInjector, error), err error) {
   100  	destValue, err := reflectDest(dest, wasNull)
   101  	if err == nil {
   102  		switch destValue.Kind() {
   103  		case reflect.Map:
   104  			if !wasNull {
   105  				injectorFactory = func(size int) (keyValueInjector, error) {
   106  					adjustMapSize(destValue, size)
   107  					return newMapInjector(destValue)
   108  				}
   109  			}
   110  		case reflect.Struct:
   111  			if !wasNull {
   112  				if c.keyCodec.DataType() != datatype.Varchar && c.keyCodec.DataType() != datatype.Ascii {
   113  					err = errWrongDataType("map key", datatype.Varchar, datatype.Ascii, c.keyCodec.DataType())
   114  				} else {
   115  					injectorFactory = func(size int) (keyValueInjector, error) {
   116  						return newStructInjector(destValue)
   117  					}
   118  				}
   119  			}
   120  		case reflect.Interface:
   121  			if !wasNull {
   122  				var targetType reflect.Type
   123  				if targetType, err = PreferredGoType(c.DataType()); err == nil {
   124  					injectorFactory = func(size int) (keyValueInjector, error) {
   125  						destValue.Set(reflect.MakeMapWithSize(targetType, size))
   126  						return newMapInjector(destValue.Elem())
   127  					}
   128  				}
   129  			}
   130  		default:
   131  			err = ErrDestinationTypeNotSupported
   132  		}
   133  	}
   134  	return
   135  }
   136  
   137  func writeMap(ext keyValueExtractor, size int, keyCodec Codec, valueCodec Codec, version primitive.ProtocolVersion) ([]byte, error) {
   138  	buf := &bytes.Buffer{}
   139  	if err := writeCollectionSize(size, buf, version); err != nil {
   140  		return nil, err
   141  	}
   142  	for i := 0; i < size; i++ {
   143  		key := ext.getKey(i)
   144  		if value, err := ext.getElem(i, key); err != nil {
   145  			return nil, errCannotExtractMapValue(i, err)
   146  		} else if encodedKey, err := keyCodec.Encode(key, version); err != nil {
   147  			return nil, errCannotEncodeMapKey(i, err)
   148  		} else if encodedValue, err := valueCodec.Encode(value, version); err != nil {
   149  			return nil, errCannotEncodeMapValue(i, err)
   150  		} else {
   151  			_ = primitive.WriteBytes(encodedKey, buf)
   152  			_ = primitive.WriteBytes(encodedValue, buf)
   153  		}
   154  	}
   155  	return buf.Bytes(), nil
   156  }
   157  
   158  func readMap(source []byte, injectorFactory func(int) (keyValueInjector, error), keyCodec Codec, valueCodec Codec, version primitive.ProtocolVersion) error {
   159  	reader := bytes.NewReader(source)
   160  	total := len(source)
   161  	if size, err := readCollectionSize(reader, version); err != nil {
   162  		return err
   163  	} else if inj, err := injectorFactory(size); err != nil {
   164  		return err
   165  	} else {
   166  		for i := 0; i < size; i++ {
   167  			if encodedKey, err := primitive.ReadBytes(reader); err != nil {
   168  				return errCannotReadMapKey(i, err)
   169  			} else if encodedValue, err := primitive.ReadBytes(reader); err != nil {
   170  				return errCannotReadMapValue(i, err)
   171  			} else if decodedKey, err := inj.zeroKey(i); err != nil {
   172  				return errCannotCreateMapKey(i, err)
   173  			} else if keyWasNull, err := keyCodec.Decode(encodedKey, decodedKey, version); err != nil {
   174  				return errCannotDecodeMapKey(i, err)
   175  			} else if decodedValue, err := inj.zeroElem(i, decodedKey); err != nil {
   176  				return errCannotCreateMapValue(i, err)
   177  			} else if valueWasNull, err := valueCodec.Decode(encodedValue, decodedValue, version); err != nil {
   178  				return errCannotDecodeMapValue(i, err)
   179  			} else if err = inj.setElem(i, decodedKey, decodedValue, keyWasNull, valueWasNull); err != nil {
   180  				return errCannotInjectMapEntry(i, err)
   181  			}
   182  		}
   183  		if remaining := reader.Len(); remaining != 0 {
   184  			return errBytesRemaining(total, remaining)
   185  		}
   186  	}
   187  	return nil
   188  }