github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/udt.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"reflect"
    21  
    22  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    23  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    24  )
    25  
    26  func NewUserDefined(dataType *datatype.UserDefined) (Codec, error) {
    27  	if dataType == nil {
    28  		return nil, ErrNilDataType
    29  	}
    30  	fieldCodecs := make([]Codec, len(dataType.FieldTypes))
    31  	for i, fieldType := range dataType.FieldTypes {
    32  		if fieldCodec, err := NewCodec(fieldType); err != nil {
    33  			return nil, fmt.Errorf("cannot create codec for user-defined type field %d (%s): %w", i, dataType.FieldNames[i], err)
    34  		} else {
    35  			fieldCodecs[i] = fieldCodec
    36  		}
    37  	}
    38  	return &udtCodec{dataType, fieldCodecs}, nil
    39  }
    40  
    41  type udtCodec struct {
    42  	dataType    *datatype.UserDefined
    43  	fieldCodecs []Codec
    44  }
    45  
    46  func (c *udtCodec) DataType() datatype.DataType {
    47  	return c.dataType
    48  }
    49  
    50  func (c *udtCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
    51  	if !version.SupportsDataType(c.DataType().Code()) {
    52  		err = errDataTypeNotSupported(c.DataType(), version)
    53  	} else {
    54  		var ext extractor
    55  		if ext, err = c.createExtractor(source); err == nil && ext != nil {
    56  			dest, err = writeUdt(ext, c.dataType.FieldNames, c.fieldCodecs, version)
    57  		}
    58  	}
    59  	if err != nil {
    60  		err = errCannotEncode(source, c.DataType(), version, err)
    61  	}
    62  	return
    63  }
    64  
    65  func (c *udtCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
    66  	wasNull = len(source) == 0
    67  	if !version.SupportsDataType(c.DataType().Code()) {
    68  		err = errDataTypeNotSupported(c.DataType(), version)
    69  	} else {
    70  		var inj injector
    71  		if inj, err = c.createInjector(dest, wasNull); err == nil && inj != nil {
    72  			err = readUdt(source, inj, c.dataType.FieldNames, c.fieldCodecs, version)
    73  		}
    74  	}
    75  	if err != nil {
    76  		err = errCannotDecode(dest, c.DataType(), version, err)
    77  	}
    78  	return
    79  }
    80  
    81  func (c *udtCodec) createExtractor(source interface{}) (ext extractor, err error) {
    82  	sourceValue, sourceType, wasNil := reflectSource(source)
    83  	if sourceType != nil {
    84  		switch sourceType.Kind() {
    85  		case reflect.Struct:
    86  			if !wasNil {
    87  				ext, err = newStructExtractor(sourceValue)
    88  			}
    89  		case reflect.Map:
    90  			if !wasNil {
    91  				keyType := sourceValue.Type().Key()
    92  				if keyType.Kind() != reflect.String {
    93  					err = errWrongElementType("map key", typeOfString, keyType)
    94  				} else {
    95  					ext, err = newMapExtractor(sourceValue)
    96  				}
    97  			}
    98  		case reflect.Slice, reflect.Array:
    99  			if !wasNil {
   100  				ext, err = newSliceExtractor(sourceValue)
   101  			}
   102  		default:
   103  			err = ErrSourceTypeNotSupported
   104  		}
   105  	}
   106  	return
   107  }
   108  
   109  func (c *udtCodec) createInjector(dest interface{}, wasNull bool) (inj injector, err error) {
   110  	destValue, err := reflectDest(dest, wasNull)
   111  	if err == nil {
   112  		switch destValue.Kind() {
   113  		case reflect.Struct:
   114  			if !wasNull {
   115  				inj, err = newStructInjector(destValue)
   116  			}
   117  		case reflect.Map:
   118  			if !wasNull {
   119  				keyType := destValue.Type().Key()
   120  				if keyType.Kind() != reflect.String {
   121  					err = errWrongElementType("map key", typeOfString, keyType)
   122  				} else {
   123  					adjustMapSize(destValue, len(c.fieldCodecs))
   124  					inj, err = newMapInjector(destValue)
   125  				}
   126  			}
   127  		case reflect.Slice:
   128  			if !wasNull {
   129  				adjustSliceLength(destValue, len(c.fieldCodecs))
   130  				inj, err = newSliceInjector(destValue)
   131  			}
   132  		case reflect.Array:
   133  			if !wasNull {
   134  				inj, err = newSliceInjector(destValue)
   135  			}
   136  		case reflect.Interface:
   137  			if !wasNull {
   138  				target := make(map[string]interface{}, len(c.fieldCodecs))
   139  				*dest.(*interface{}) = target
   140  				inj, err = newMapInjector(reflect.ValueOf(target))
   141  			}
   142  		default:
   143  			err = ErrDestinationTypeNotSupported
   144  		}
   145  	}
   146  	return
   147  }
   148  
   149  func writeUdt(ext extractor, fieldNames []string, fieldCodecs []Codec, version primitive.ProtocolVersion) ([]byte, error) {
   150  	buf := &bytes.Buffer{}
   151  	for i, fieldCodec := range fieldCodecs {
   152  		name := fieldNames[i]
   153  		if value, err := ext.getElem(i, name); err != nil {
   154  			return nil, errCannotExtractUdtField(i, name, err)
   155  		} else if encodedField, err := fieldCodec.Encode(value, version); err != nil {
   156  			return nil, errCannotEncodeUdtField(i, name, err)
   157  		} else {
   158  			_ = primitive.WriteBytes(encodedField, buf)
   159  		}
   160  	}
   161  	return buf.Bytes(), nil
   162  }
   163  
   164  func readUdt(source []byte, inj injector, fieldNames []string, fieldCodecs []Codec, version primitive.ProtocolVersion) error {
   165  	reader := bytes.NewReader(source)
   166  	total := reader.Len()
   167  	for i, fieldCodec := range fieldCodecs {
   168  		name := fieldNames[i]
   169  		if encodedField, err := primitive.ReadBytes(reader); err != nil {
   170  			return errCannotReadUdtField(i, name, err)
   171  		} else if decodedField, err := inj.zeroElem(i, name); err != nil {
   172  			return errCannotCreateUdtField(i, name, err)
   173  		} else if fieldWasNull, err := fieldCodec.Decode(encodedField, decodedField, version); err != nil {
   174  			return errCannotDecodeUdtField(i, name, err)
   175  		} else if err = inj.setElem(i, name, decodedField, false, fieldWasNull); err != nil {
   176  			return errCannotInjectUdtField(i, name, err)
   177  		}
   178  	}
   179  	if remaining := reader.Len(); remaining != 0 {
   180  		return errBytesRemaining(total, remaining)
   181  	}
   182  	return nil
   183  }