github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/tuple.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"reflect"
    21  
    22  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    23  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    24  )
    25  
    26  func NewTuple(tupleType *datatype.Tuple) (Codec, error) {
    27  	if tupleType == nil {
    28  		return nil, ErrNilDataType
    29  	}
    30  	elementCodecs := make([]Codec, len(tupleType.FieldTypes))
    31  	for i, elementType := range tupleType.FieldTypes {
    32  		if elementCodec, err := NewCodec(elementType); err != nil {
    33  			return nil, fmt.Errorf("cannot create codec for tuple element %d: %w", i, err)
    34  		} else {
    35  			elementCodecs[i] = elementCodec
    36  		}
    37  	}
    38  	return &tupleCodec{tupleType, elementCodecs}, nil
    39  }
    40  
    41  type tupleCodec struct {
    42  	dataType      *datatype.Tuple
    43  	elementCodecs []Codec
    44  }
    45  
    46  func (c *tupleCodec) DataType() datatype.DataType {
    47  	return c.dataType
    48  }
    49  
    50  func (c *tupleCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
    51  	if !version.SupportsDataType(c.DataType().Code()) {
    52  		err = errDataTypeNotSupported(c.DataType(), version)
    53  	} else {
    54  		var ext extractor
    55  		if ext, err = c.createExtractor(source); err == nil && ext != nil {
    56  			dest, err = writeTuple(ext, c.elementCodecs, version)
    57  		}
    58  	}
    59  	if err != nil {
    60  		err = errCannotEncode(source, c.DataType(), version, err)
    61  	}
    62  	return
    63  }
    64  
    65  func (c *tupleCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
    66  	wasNull = len(source) == 0
    67  	if !version.SupportsDataType(c.DataType().Code()) {
    68  		err = errDataTypeNotSupported(c.DataType(), version)
    69  	} else {
    70  		var inj injector
    71  		if inj, err = c.createInjector(dest, wasNull); err == nil && inj != nil {
    72  			err = readTuple(source, inj, c.elementCodecs, version)
    73  		}
    74  	}
    75  	if err != nil {
    76  		err = errCannotDecode(dest, c.DataType(), version, err)
    77  	}
    78  	return
    79  }
    80  
    81  func (c *tupleCodec) createExtractor(source interface{}) (ext extractor, err error) {
    82  	sourceValue, sourceType, wasNil := reflectSource(source)
    83  	if sourceType != nil {
    84  		switch sourceType.Kind() {
    85  		case reflect.Struct:
    86  			if !wasNil {
    87  				ext, err = newStructExtractor(sourceValue)
    88  			}
    89  		case reflect.Slice, reflect.Array:
    90  			if !wasNil {
    91  				ext, err = newSliceExtractor(sourceValue)
    92  			}
    93  		default:
    94  			err = ErrSourceTypeNotSupported
    95  		}
    96  	}
    97  	return
    98  }
    99  
   100  func (c *tupleCodec) createInjector(dest interface{}, wasNull bool) (inj injector, err error) {
   101  	destValue, err := reflectDest(dest, wasNull)
   102  	if err == nil {
   103  		switch destValue.Kind() {
   104  		case reflect.Struct:
   105  			if !wasNull {
   106  				inj, err = newStructInjector(destValue)
   107  			}
   108  		case reflect.Slice:
   109  			if !wasNull {
   110  				adjustSliceLength(destValue, len(c.elementCodecs))
   111  				inj, err = newSliceInjector(destValue)
   112  			}
   113  		case reflect.Array:
   114  			if !wasNull {
   115  				inj, err = newSliceInjector(destValue)
   116  			}
   117  		case reflect.Interface:
   118  			if !wasNull {
   119  				target := make([]interface{}, len(c.elementCodecs))
   120  				*dest.(*interface{}) = target
   121  				inj, err = newSliceInjector(reflect.ValueOf(target))
   122  			}
   123  		default:
   124  			err = ErrDestinationTypeNotSupported
   125  		}
   126  	}
   127  	return
   128  }
   129  
   130  func writeTuple(ext extractor, elementCodecs []Codec, version primitive.ProtocolVersion) ([]byte, error) {
   131  	buf := &bytes.Buffer{}
   132  	for i, elementCodec := range elementCodecs {
   133  		if value, err := ext.getElem(i, i); err != nil {
   134  			return nil, errCannotExtractElement(i, err)
   135  		} else if encodedElement, err := elementCodec.Encode(value, version); err != nil {
   136  			return nil, errCannotEncodeElement(i, err)
   137  		} else {
   138  			_ = primitive.WriteBytes(encodedElement, buf)
   139  		}
   140  	}
   141  	return buf.Bytes(), nil
   142  }
   143  
   144  func readTuple(source []byte, inj injector, elementCodecs []Codec, version primitive.ProtocolVersion) error {
   145  	reader := bytes.NewReader(source)
   146  	total := reader.Len()
   147  	for i, elementCodec := range elementCodecs {
   148  		if encodedElement, err := primitive.ReadBytes(reader); err != nil {
   149  			return errCannotReadElement(i, err)
   150  		} else if decodedElement, err := inj.zeroElem(i, i); err != nil {
   151  			return errCannotCreateElement(i, err)
   152  		} else if elementWasNull, err := elementCodec.Decode(encodedElement, decodedElement, version); err != nil {
   153  			return errCannotDecodeElement(i, err)
   154  		} else if err = inj.setElem(i, i, decodedElement, false, elementWasNull); err != nil {
   155  			return errCannotInjectElement(i, err)
   156  		}
   157  	}
   158  	if remaining := reader.Len(); remaining != 0 {
   159  		return errBytesRemaining(total, remaining)
   160  	}
   161  	return nil
   162  }