github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datatype/datatype.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datatype
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  
    21  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    22  )
    23  
    24  type DataType interface {
    25  	Code() primitive.DataTypeCode
    26  	AsCql() string
    27  	DeepCopyDataType() DataType
    28  }
    29  
    30  func WriteDataType(t DataType, dest io.Writer, version primitive.ProtocolVersion) (err error) {
    31  	if t == nil {
    32  		return fmt.Errorf("DataType can not be nil")
    33  	} else if err = primitive.CheckValidDataTypeCode(t.Code(), version); err != nil {
    34  		return err
    35  	} else if err = primitive.WriteShort(uint16(t.Code()), dest); err != nil {
    36  		return fmt.Errorf("cannot write data type code %v: %w", t.Code(), err)
    37  	} else {
    38  		switch t.Code() {
    39  		case primitive.DataTypeCodeCustom:
    40  			return writeCustomType(t, dest, version)
    41  		case primitive.DataTypeCodeList:
    42  			return writeListType(t, dest, version)
    43  		case primitive.DataTypeCodeMap:
    44  			return writeMapType(t, dest, version)
    45  		case primitive.DataTypeCodeSet:
    46  			return writeSetType(t, dest, version)
    47  		case primitive.DataTypeCodeUdt:
    48  			return writeUserDefinedType(t, dest, version)
    49  		case primitive.DataTypeCodeTuple:
    50  			return writeTupleType(t, dest, version)
    51  		}
    52  		return
    53  	}
    54  }
    55  
    56  func LengthOfDataType(t DataType, version primitive.ProtocolVersion) (length int, err error) {
    57  	length += primitive.LengthOfShort // type code
    58  	dataTypeLength := 0
    59  	switch t.Code() {
    60  	case primitive.DataTypeCodeCustom:
    61  		dataTypeLength, err = lengthOfCustomType(t, version)
    62  	case primitive.DataTypeCodeList:
    63  		dataTypeLength, err = lengthOfListType(t, version)
    64  	case primitive.DataTypeCodeMap:
    65  		dataTypeLength, err = lengthOfMapType(t, version)
    66  	case primitive.DataTypeCodeSet:
    67  		dataTypeLength, err = lengthOfSetType(t, version)
    68  	case primitive.DataTypeCodeUdt:
    69  		dataTypeLength, err = lengthOfUserDefinedType(t, version)
    70  	case primitive.DataTypeCodeTuple:
    71  		dataTypeLength, err = lengthOfTupleType(t, version)
    72  	}
    73  	if err != nil {
    74  		return -1, fmt.Errorf("cannot compute length of data type %v: %w", t, err)
    75  	}
    76  	return length + dataTypeLength, nil
    77  }
    78  
    79  func ReadDataType(source io.Reader, version primitive.ProtocolVersion) (decoded DataType, err error) {
    80  	var typeCode uint16
    81  	if typeCode, err = primitive.ReadShort(source); err != nil {
    82  		return nil, fmt.Errorf("cannot read data type code: %w", err)
    83  	} else if err := primitive.CheckValidDataTypeCode(primitive.DataTypeCode(typeCode), version); err != nil {
    84  		return nil, err
    85  	} else {
    86  		switch primitive.DataTypeCode(typeCode) {
    87  		case primitive.DataTypeCodeAscii:
    88  			return Ascii, nil
    89  		case primitive.DataTypeCodeBigint:
    90  			return Bigint, nil
    91  		case primitive.DataTypeCodeBlob:
    92  			return Blob, nil
    93  		case primitive.DataTypeCodeBoolean:
    94  			return Boolean, nil
    95  		case primitive.DataTypeCodeCounter:
    96  			return Counter, nil
    97  		case primitive.DataTypeCodeDecimal:
    98  			return Decimal, nil
    99  		case primitive.DataTypeCodeDouble:
   100  			return Double, nil
   101  		case primitive.DataTypeCodeFloat:
   102  			return Float, nil
   103  		case primitive.DataTypeCodeInt:
   104  			return Int, nil
   105  		case primitive.DataTypeCodeTimestamp:
   106  			return Timestamp, nil
   107  		case primitive.DataTypeCodeUuid:
   108  			return Uuid, nil
   109  		case primitive.DataTypeCodeVarchar:
   110  			return Varchar, nil
   111  		case primitive.DataTypeCodeVarint:
   112  			return Varint, nil
   113  		case primitive.DataTypeCodeTimeuuid:
   114  			return Timeuuid, nil
   115  		case primitive.DataTypeCodeInet:
   116  			return Inet, nil
   117  		case primitive.DataTypeCodeDate:
   118  			return Date, nil
   119  		case primitive.DataTypeCodeTime:
   120  			return Time, nil
   121  		case primitive.DataTypeCodeSmallint:
   122  			return Smallint, nil
   123  		case primitive.DataTypeCodeTinyint:
   124  			return Tinyint, nil
   125  		case primitive.DataTypeCodeDuration:
   126  			return Duration, nil
   127  		case primitive.DataTypeCodeCustom:
   128  			return readCustomType(source, version)
   129  		case primitive.DataTypeCodeList:
   130  			return readListType(source, version)
   131  		case primitive.DataTypeCodeMap:
   132  			return readMapType(source, version)
   133  		case primitive.DataTypeCodeSet:
   134  			return readSetType(source, version)
   135  		case primitive.DataTypeCodeUdt:
   136  			return readUserDefinedType(source, version)
   137  		case primitive.DataTypeCodeTuple:
   138  			return readTupleType(source, version)
   139  		}
   140  		return nil, fmt.Errorf("unknown type code: %w", err)
   141  	}
   142  }