github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/decimal.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"encoding/binary"
    19  	"math/big"
    20  
    21  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    22  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    23  )
    24  
    25  // CqlDecimal is the poor man's representation in Go of a CQL decimal value, since there is no built-in representation
    26  // of arbitrary-precision decimal values in Go's standard library.
    27  // Note that this value is pretty useless as is. It's highly recommended converting this value to some other type using
    28  // a dedicated library. The most popular one is: https://pkg.go.dev/github.com/ericlagergren/decimal/v3.
    29  // The zero value of a CqlDecimal is encoded as zero, with zero scale.
    30  type CqlDecimal struct {
    31  
    32  	// Unscaled is a big.Int representing the unscaled decimal value.
    33  	Unscaled *big.Int
    34  
    35  	// Scale is the decimal value scale.
    36  	Scale int32
    37  }
    38  
    39  // Decimal is a codec for the CQL decimal type. There is no built-in representation of arbitrary-precision
    40  // decimal values in Go's standard library. This is why this codec can only encode from and decode to CqlDecimal.
    41  var Decimal Codec = &decimalCodec{}
    42  
    43  type decimalCodec struct{}
    44  
    45  func (c *decimalCodec) DataType() datatype.DataType {
    46  	return datatype.Decimal
    47  }
    48  
    49  func (c *decimalCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
    50  	var val CqlDecimal
    51  	var wasNil bool
    52  	if val, wasNil, err = convertToDecimal(source); err == nil && !wasNil {
    53  		dest = writeDecimal(val)
    54  	}
    55  	if err != nil {
    56  		err = errCannotEncode(source, c.DataType(), version, err)
    57  	}
    58  	return
    59  }
    60  
    61  func (c *decimalCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
    62  	var val CqlDecimal
    63  	if val, wasNull, err = readDecimal(source); err == nil {
    64  		err = convertFromDecimal(val, wasNull, dest)
    65  	}
    66  	if err != nil {
    67  		err = errCannotDecode(dest, c.DataType(), version, err)
    68  	}
    69  	return
    70  }
    71  
    72  func convertToDecimal(source interface{}) (val CqlDecimal, wasNil bool, err error) {
    73  	switch s := source.(type) {
    74  	case CqlDecimal:
    75  		val = s
    76  	case *CqlDecimal:
    77  		if wasNil = s == nil; !wasNil {
    78  			val = *s
    79  		}
    80  	case nil:
    81  		wasNil = true
    82  	default:
    83  		err = ErrConversionNotSupported
    84  	}
    85  	if err != nil {
    86  		err = errSourceConversionFailed(source, val, err)
    87  	}
    88  	return
    89  }
    90  
    91  func convertFromDecimal(val CqlDecimal, wasNull bool, dest interface{}) (err error) {
    92  	switch d := dest.(type) {
    93  	case *interface{}:
    94  		if d == nil {
    95  			err = ErrNilDestination
    96  		} else if wasNull {
    97  			*d = nil
    98  		} else {
    99  			*d = val
   100  		}
   101  	case *CqlDecimal:
   102  		if d == nil {
   103  			err = ErrNilDestination
   104  		} else if wasNull {
   105  			*d = CqlDecimal{}
   106  		} else {
   107  			*d = val
   108  		}
   109  	default:
   110  		err = errDestinationInvalid(dest)
   111  	}
   112  	if err != nil {
   113  		err = errDestinationConversionFailed(val, dest, err)
   114  	}
   115  	return
   116  }
   117  
   118  func writeDecimal(val CqlDecimal) []byte {
   119  	n := val.Unscaled
   120  	if n == nil {
   121  		n = zeroBigInt
   122  	}
   123  	unscaled := writeBigInt(n)
   124  	dest := make([]byte, primitive.LengthOfInt, primitive.LengthOfInt+len(unscaled))
   125  	binary.BigEndian.PutUint32(dest, uint32(val.Scale))
   126  	return append(dest, unscaled...)
   127  }
   128  
   129  func readDecimal(source []byte) (val CqlDecimal, wasNull bool, err error) {
   130  	length := len(source)
   131  	if length == 0 {
   132  		wasNull = true
   133  	} else if length <= primitive.LengthOfInt {
   134  		err = errWrongMinimumLength(primitive.LengthOfInt, length)
   135  	} else {
   136  		val.Scale = int32(binary.BigEndian.Uint32(source))
   137  		val.Unscaled = readBigInt(source[primitive.LengthOfInt:])
   138  	}
   139  	if err != nil {
   140  		err = errCannotRead(val, err)
   141  	}
   142  	return
   143  }