github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/decimal_test.go (about)

     1  // Copyright 2021 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"math/big"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/assert"
    24  
    25  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    26  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    27  )
    28  
    29  var (
    30  	decimalZero      = CqlDecimal{}
    31  	decimalOne       = CqlDecimal{Unscaled: big.NewInt(1), Scale: 0}
    32  	decimalMaxUint64 = CqlDecimal{Unscaled: new(big.Int).SetUint64(math.MaxUint64), Scale: 0}
    33  	decimalSimple    = CqlDecimal{big.NewInt(123), -1}
    34  )
    35  
    36  var (
    37  	decimalZeroBytes = []byte{
    38  		0, 0, 0, 0,
    39  		0,
    40  	}
    41  	decimalOneBytes = []byte{
    42  		0, 0, 0, 0,
    43  		1,
    44  	}
    45  	decimalMaxUint64Bytes = []byte{
    46  		0, 0, 0, 0, 0,
    47  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    48  	}
    49  	decimalSimpleBytes = []byte{
    50  		0xff, 0xff, 0xff, 0xff,
    51  		0x7b,
    52  	}
    53  )
    54  
    55  func Test_decimalCodec_DataType(t *testing.T) {
    56  	assert.Equal(t, datatype.Decimal, Decimal.DataType())
    57  }
    58  
    59  func Test_decimalCodec_Encode(t *testing.T) {
    60  	for _, version := range primitive.SupportedProtocolVersions() {
    61  		t.Run(version.String(), func(t *testing.T) {
    62  			tests := []struct {
    63  				name     string
    64  				source   interface{}
    65  				expected []byte
    66  				err      string
    67  			}{
    68  				{"nil", nil, nil, ""},
    69  				{"nil pointer", cqlDecimalNilPtr(), nil, ""},
    70  				{"non nil", decimalSimple, decimalSimpleBytes, ""},
    71  				{"non nil pointer", &decimalSimple, decimalSimpleBytes, ""},
    72  				{"conversion failed", 123, nil, fmt.Sprintf("cannot encode int as CQL decimal with %v: cannot convert from int to datacodec.CqlDecimal: conversion not supported", version)},
    73  			}
    74  			for _, tt := range tests {
    75  				t.Run(tt.name, func(t *testing.T) {
    76  					actual, err := Decimal.Encode(tt.source, version)
    77  					assert.Equal(t, tt.expected, actual)
    78  					assertErrorMessage(t, tt.err, err)
    79  				})
    80  			}
    81  		})
    82  	}
    83  }
    84  
    85  func Test_decimalCodec_Decode(t *testing.T) {
    86  	for _, version := range primitive.SupportedProtocolVersions() {
    87  		t.Run(version.String(), func(t *testing.T) {
    88  			tests := []struct {
    89  				name     string
    90  				source   []byte
    91  				dest     interface{}
    92  				expected interface{}
    93  				wasNull  bool
    94  				err      string
    95  			}{
    96  				{"null", nil, new(CqlDecimal), new(CqlDecimal), true, ""},
    97  				{"non null", decimalSimpleBytes, new(CqlDecimal), &decimalSimple, false, ""},
    98  				{"non null interface", decimalSimpleBytes, new(interface{}), interfacePtr(decimalSimple), false, ""},
    99  				{"read failed", []byte{1, 2, 3}, new(CqlDecimal), new(CqlDecimal), false, fmt.Sprintf("cannot decode CQL decimal as *datacodec.CqlDecimal with %v: cannot read datacodec.CqlDecimal: expected at least 4 bytes but got: 3", version)},
   100  				{"conversion failed", decimalSimpleBytes, new(float64), new(float64), false, fmt.Sprintf("cannot decode CQL decimal as *float64 with %v: cannot convert from datacodec.CqlDecimal to *float64: conversion not supported", version)},
   101  			}
   102  			for _, tt := range tests {
   103  				t.Run(tt.name, func(t *testing.T) {
   104  					wasNull, err := Decimal.Decode(tt.source, tt.dest, version)
   105  					assert.Equal(t, tt.expected, tt.dest)
   106  					assert.Equal(t, tt.wasNull, wasNull)
   107  					assertErrorMessage(t, tt.err, err)
   108  				})
   109  			}
   110  		})
   111  	}
   112  }
   113  
   114  func Test_convertToDecimal(t *testing.T) {
   115  	tests := []struct {
   116  		name       string
   117  		source     interface{}
   118  		wantDest   CqlDecimal
   119  		wantWasNil bool
   120  		wantErr    string
   121  	}{
   122  		{"from CqlDecimal", decimalSimple, decimalSimple, false, ""},
   123  		{"from *CqlDecimal", &decimalSimple, decimalSimple, false, ""},
   124  		{"from *CqlDecimal nil", cqlDecimalNilPtr(), decimalZero, true, ""},
   125  		{"from untyped nil", nil, decimalZero, true, ""},
   126  		{"from unsupported value type", 123, decimalZero, false, "cannot convert from int to datacodec.CqlDecimal: conversion not supported"},
   127  		{"from unsupported pointer type", intPtr(123), decimalZero, false, "cannot convert from *int to datacodec.CqlDecimal: conversion not supported"},
   128  	}
   129  	for _, tt := range tests {
   130  		t.Run(tt.name, func(t *testing.T) {
   131  			gotDest, gotWasNil, gotErr := convertToDecimal(tt.source)
   132  			assert.Equal(t, tt.wantDest, gotDest)
   133  			assert.Equal(t, tt.wantWasNil, gotWasNil)
   134  			assertErrorMessage(t, tt.wantErr, gotErr)
   135  		})
   136  	}
   137  }
   138  
   139  func Test_convertFromDecimal(t *testing.T) {
   140  	tests := []struct {
   141  		name     string
   142  		val      CqlDecimal
   143  		wasNull  bool
   144  		dest     interface{}
   145  		expected interface{}
   146  		err      string
   147  	}{
   148  		{"to *interface{} nil dest", decimalSimple, false, interfaceNilPtr(), interfaceNilPtr(), "cannot convert from datacodec.CqlDecimal to *interface {}: destination is nil"},
   149  		{"to *interface{} nil source", decimalZero, true, new(interface{}), new(interface{}), ""},
   150  		{"to *interface{} non nil", decimalSimple, false, new(interface{}), interfacePtr(decimalSimple), ""},
   151  		{"to *CqlDecimal nil dest", decimalZero, false, cqlDecimalNilPtr(), cqlDecimalNilPtr(), "cannot convert from datacodec.CqlDecimal to *datacodec.CqlDecimal: destination is nil"},
   152  		{"to *CqlDecimal nil source", decimalZero, true, new(CqlDecimal), new(CqlDecimal), ""},
   153  		{"to *CqlDecimal empty source", decimalZero, false, new(CqlDecimal), new(CqlDecimal), ""},
   154  		{"to *CqlDecimal non nil", decimalSimple, false, new(CqlDecimal), &decimalSimple, ""},
   155  		{"to untyped nil", decimalSimple, false, nil, nil, "cannot convert from datacodec.CqlDecimal to <nil>: destination is nil"},
   156  		{"to non pointer", decimalSimple, false, CqlDecimal{}, CqlDecimal{}, "cannot convert from datacodec.CqlDecimal to datacodec.CqlDecimal: destination is not pointer"},
   157  		{"to unsupported pointer type", decimalSimple, false, new(float64), new(float64), "cannot convert from datacodec.CqlDecimal to *float64: conversion not supported"},
   158  	}
   159  	for _, tt := range tests {
   160  		t.Run(tt.name, func(t *testing.T) {
   161  			gotErr := convertFromDecimal(tt.val, tt.wasNull, tt.dest)
   162  			assert.Equal(t, tt.expected, tt.dest)
   163  			assertErrorMessage(t, tt.err, gotErr)
   164  		})
   165  	}
   166  }
   167  
   168  func Test_writeDecimal(t *testing.T) {
   169  	tests := []struct {
   170  		name     string
   171  		val      CqlDecimal
   172  		expected []byte
   173  	}{
   174  		{"zero", decimalZero, decimalZeroBytes},
   175  		{"one", decimalOne, decimalOneBytes},
   176  		{"simple", decimalSimple, decimalSimpleBytes},
   177  		{"max", decimalMaxUint64, decimalMaxUint64Bytes},
   178  	}
   179  	for _, tt := range tests {
   180  		t.Run(tt.name, func(t *testing.T) {
   181  			actual := writeDecimal(tt.val)
   182  			assert.Equal(t, tt.expected, actual)
   183  		})
   184  	}
   185  }
   186  
   187  func Test_readDecimal(t *testing.T) {
   188  	tests := []struct {
   189  		name     string
   190  		source   []byte
   191  		expected CqlDecimal
   192  		wasNull  bool
   193  		err      string
   194  	}{
   195  		{"nil", nil, decimalZero, true, ""},
   196  		{"empty", []byte{}, decimalZero, true, ""},
   197  		{"wrong length", []byte{1}, decimalZero, false, "cannot read datacodec.CqlDecimal: expected at least 4 bytes but got: 1"},
   198  		{"zero", decimalZeroBytes, CqlDecimal{zeroBigInt, 0}, false, ""},
   199  		{"simple", decimalSimpleBytes, decimalSimple, false, ""},
   200  		{"max", decimalMaxUint64Bytes, decimalMaxUint64, false, ""},
   201  	}
   202  	for _, tt := range tests {
   203  		t.Run(tt.name, func(t *testing.T) {
   204  			actual, wasNull, err := readDecimal(tt.source)
   205  			assert.Zero(t, tt.expected.Unscaled.Cmp(actual.Unscaled))
   206  			assert.Equal(t, tt.expected.Scale, actual.Scale)
   207  			assert.Equal(t, tt.wasNull, wasNull)
   208  			assertErrorMessage(t, tt.err, err)
   209  		})
   210  	}
   211  }