github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/collection.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io"
    21  	"math"
    22  	"reflect"
    23  
    24  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    25  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    26  )
    27  
    28  func NewList(dataType *datatype.List) (Codec, error) {
    29  	if dataType == nil {
    30  		return nil, ErrNilDataType
    31  	}
    32  	codec, err := NewCodec(dataType.ElementType)
    33  	if err != nil {
    34  		return nil, fmt.Errorf("cannot create codec for list elements: %w", err)
    35  	}
    36  	return &collectionCodec{dataType, codec}, nil
    37  }
    38  
    39  func NewSet(dataType *datatype.Set) (Codec, error) {
    40  	if dataType == nil {
    41  		return nil, ErrNilDataType
    42  	}
    43  	codec, err := NewCodec(dataType.ElementType)
    44  	if err != nil {
    45  		return nil, fmt.Errorf("cannot create codec for set elements: %w", err)
    46  	}
    47  	return &collectionCodec{dataType, codec}, nil
    48  }
    49  
    50  type collectionCodec struct {
    51  	dataType     datatype.DataType
    52  	elementCodec Codec
    53  }
    54  
    55  func (c *collectionCodec) DataType() datatype.DataType {
    56  	return c.dataType
    57  }
    58  
    59  func (c *collectionCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
    60  	ext, size, err := c.createExtractor(source)
    61  	if err == nil && ext != nil {
    62  		dest, err = writeCollection(ext, c.elementCodec, size, version)
    63  	}
    64  	if err != nil {
    65  		err = errCannotEncode(source, c.DataType(), version, err)
    66  	}
    67  	return
    68  }
    69  
    70  func (c *collectionCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
    71  	wasNull = len(source) == 0
    72  	var injectorFactory func(int) (injector, error)
    73  	if injectorFactory, err = c.createInjector(dest, wasNull); err == nil && injectorFactory != nil {
    74  		err = readCollection(source, injectorFactory, c.elementCodec, version)
    75  	}
    76  	if err != nil {
    77  		err = errCannotDecode(dest, c.DataType(), version, err)
    78  	}
    79  	return
    80  }
    81  
    82  func (c *collectionCodec) createExtractor(source interface{}) (ext extractor, size int, err error) {
    83  	sourceValue, sourceType, wasNil := reflectSource(source)
    84  	if sourceType != nil {
    85  		switch sourceType.Kind() {
    86  		case reflect.Slice, reflect.Array:
    87  			if !wasNil {
    88  				ext, err = newSliceExtractor(sourceValue)
    89  				size = sourceValue.Len()
    90  			}
    91  		default:
    92  			err = ErrSourceTypeNotSupported
    93  		}
    94  	}
    95  	return
    96  }
    97  
    98  func (c *collectionCodec) createInjector(dest interface{}, wasNull bool) (injectorFactory func(int) (injector, error), err error) {
    99  	destValue, err := reflectDest(dest, wasNull)
   100  	if err == nil {
   101  		switch destValue.Kind() {
   102  		case reflect.Slice:
   103  			if !wasNull {
   104  				injectorFactory = func(size int) (injector, error) {
   105  					adjustSliceLength(destValue, size)
   106  					return newSliceInjector(destValue)
   107  				}
   108  			}
   109  		case reflect.Array:
   110  			if !wasNull {
   111  				injectorFactory = func(size int) (injector, error) {
   112  					return newSliceInjector(destValue)
   113  				}
   114  			}
   115  		case reflect.Interface:
   116  			if !wasNull {
   117  				var targetType reflect.Type
   118  				if targetType, err = PreferredGoType(c.DataType()); err == nil {
   119  					injectorFactory = func(size int) (injector, error) {
   120  						destValue.Set(reflect.MakeSlice(targetType, size, size))
   121  						return newSliceInjector(destValue.Elem())
   122  					}
   123  				}
   124  			}
   125  		default:
   126  			err = ErrDestinationTypeNotSupported
   127  		}
   128  	}
   129  	return
   130  }
   131  
   132  func writeCollection(ext extractor, elementCodec Codec, size int, version primitive.ProtocolVersion) ([]byte, error) {
   133  	buf := &bytes.Buffer{}
   134  	if err := writeCollectionSize(size, buf, version); err != nil {
   135  		return nil, err
   136  	}
   137  	for i := 0; i < size; i++ {
   138  		if elem, err := ext.getElem(i, i); err != nil {
   139  			return nil, errCannotExtractElement(i, err)
   140  		} else if encodedElem, err := elementCodec.Encode(elem, version); err != nil {
   141  			return nil, errCannotEncodeElement(i, err)
   142  		} else {
   143  			_ = primitive.WriteBytes(encodedElem, buf)
   144  		}
   145  	}
   146  	return buf.Bytes(), nil
   147  }
   148  
   149  func readCollection(source []byte, injectorFactory func(int) (injector, error), elementCodec Codec, version primitive.ProtocolVersion) error {
   150  	reader := bytes.NewReader(source)
   151  	total := len(source)
   152  	if size, err := readCollectionSize(reader, version); err != nil {
   153  		return err
   154  	} else if inj, err := injectorFactory(size); err != nil {
   155  		return err
   156  	} else {
   157  		for i := 0; i < size; i++ {
   158  			if encodedElem, err := primitive.ReadBytes(reader); err != nil {
   159  				return errCannotReadElement(i, err)
   160  			} else if decodedElem, err := inj.zeroElem(i, i); err != nil {
   161  				return errCannotCreateElement(i, err)
   162  			} else if elementWasNull, err := elementCodec.Decode(encodedElem, decodedElem, version); err != nil {
   163  				return errCannotDecodeElement(i, err)
   164  			} else if err = inj.setElem(i, i, decodedElem, false, elementWasNull); err != nil {
   165  				return errCannotInjectElement(i, err)
   166  			}
   167  		}
   168  		if remaining := reader.Len(); remaining != 0 {
   169  			return errBytesRemaining(total, remaining)
   170  		}
   171  	}
   172  	return nil
   173  }
   174  
   175  func writeCollectionSize(size int, dest io.Writer, version primitive.ProtocolVersion) (err error) {
   176  	if version.Uses4BytesCollectionLength() {
   177  		if size > math.MaxInt32 {
   178  			err = collectionSizeTooLarge(size, math.MaxInt32)
   179  		} else if size < 0 {
   180  			err = collectionSizeNegative(size)
   181  		} else {
   182  			err = primitive.WriteInt(int32(size), dest)
   183  		}
   184  	} else {
   185  		if size > math.MaxUint16 {
   186  			err = collectionSizeTooLarge(size, math.MaxUint16)
   187  		} else if size < 0 {
   188  			err = collectionSizeNegative(size)
   189  		} else {
   190  			err = primitive.WriteShort(uint16(size), dest)
   191  		}
   192  	}
   193  	if err != nil {
   194  		err = cannotWriteCollectionSize(err)
   195  	}
   196  	return
   197  }
   198  
   199  func readCollectionSize(source io.Reader, version primitive.ProtocolVersion) (size int, err error) {
   200  	if version.Uses4BytesCollectionLength() {
   201  		var sizeInt32 int32
   202  		sizeInt32, err = primitive.ReadInt(source)
   203  		size = int(sizeInt32)
   204  	} else {
   205  		var sizeInt16 uint16
   206  		sizeInt16, err = primitive.ReadShort(source)
   207  		size = int(sizeInt16)
   208  	}
   209  	if err != nil {
   210  		err = fmt.Errorf("cannot read collection size: %w", err)
   211  	}
   212  	return
   213  }