github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/duration_test.go (about)

     1  // Copyright 2021 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  
    24  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    25  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    26  )
    27  
    28  var (
    29  	cqlDurationZero = CqlDuration{}
    30  	cqlDurationPos  = CqlDuration{1, 2, 3}
    31  	cqlDurationNeg  = CqlDuration{-1, -2, -3}
    32  	cqlDurationMax  = CqlDuration{math.MaxInt32, math.MaxInt32, math.MaxInt64}
    33  	cqlDurationMin  = CqlDuration{math.MinInt32, math.MinInt32, math.MinInt64}
    34  )
    35  var (
    36  	cqlDurationPosBytes = []byte{2, 4, 6}
    37  	cqlDurationNegBytes = []byte{1, 3, 5}
    38  	cqlDurationMaxBytes = []byte{0xf0, 0xff, 0xff, 0xff, 0xfe, 0xf0, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
    39  	cqlDurationMinBytes = []byte{0xf0, 0xff, 0xff, 0xff, 0xff, 0xf0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
    40  )
    41  
    42  func Test_durationCodec_DataType(t *testing.T) {
    43  	assert.Equal(t, datatype.Duration, Duration.DataType())
    44  }
    45  
    46  func Test_durationCodec_Encode(t *testing.T) {
    47  	for _, version := range primitive.SupportedProtocolVersionsGreaterThanOrEqualTo(primitive.ProtocolVersion5) {
    48  		t.Run(version.String(), func(t *testing.T) {
    49  			tests := []struct {
    50  				name     string
    51  				source   interface{}
    52  				expected []byte
    53  				err      string
    54  			}{
    55  				{"nil", nil, nil, ""},
    56  				{"nil pointer", cqlDurationNilPtr(), nil, ""},
    57  				{"non nil", cqlDurationPos, cqlDurationPosBytes, ""},
    58  				{"non nil pointer", &cqlDurationPos, cqlDurationPosBytes, ""},
    59  				{"conversion failed", 123, nil, fmt.Sprintf("cannot encode int as CQL duration with %v: cannot convert from int to datacodec.CqlDuration: conversion not supported", version)},
    60  			}
    61  			for _, tt := range tests {
    62  				t.Run(tt.name, func(t *testing.T) {
    63  					actual, err := Duration.Encode(tt.source, version)
    64  					assert.Equal(t, tt.expected, actual)
    65  					assertErrorMessage(t, tt.err, err)
    66  				})
    67  			}
    68  		})
    69  	}
    70  	for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion5) {
    71  		t.Run(version.String(), func(t *testing.T) {
    72  			tests := []struct {
    73  				name     string
    74  				source   interface{}
    75  				expected []byte
    76  				err      string
    77  			}{
    78  				{"null", nil, nil, "data type duration not supported"},
    79  				{"non null", cqlDurationPos, nil, "data type duration not supported"},
    80  			}
    81  			for _, tt := range tests {
    82  				t.Run(tt.name, func(t *testing.T) {
    83  					actual, err := Duration.Encode(tt.source, version)
    84  					assert.Equal(t, tt.expected, actual)
    85  					assertErrorMessage(t, tt.err, err)
    86  				})
    87  			}
    88  		})
    89  	}
    90  }
    91  
    92  func Test_durationCodec_Decode(t *testing.T) {
    93  	for _, version := range primitive.SupportedProtocolVersionsGreaterThanOrEqualTo(primitive.ProtocolVersion5) {
    94  		t.Run(version.String(), func(t *testing.T) {
    95  			tests := []struct {
    96  				name     string
    97  				source   []byte
    98  				dest     interface{}
    99  				expected interface{}
   100  				wasNull  bool
   101  				err      string
   102  			}{
   103  				{"null", nil, new(CqlDuration), new(CqlDuration), true, ""},
   104  				{"non null", cqlDurationPosBytes, new(CqlDuration), &cqlDurationPos, false, ""},
   105  				{"non null interface", cqlDurationPosBytes, new(interface{}), interfacePtr(cqlDurationPos), false, ""},
   106  				{"read failed", []byte{1}, new(CqlDuration), new(CqlDuration), false, fmt.Sprintf("cannot decode CQL duration as *datacodec.CqlDuration with %v: cannot read datacodec.CqlDuration: cannot read duration days: cannot read [vint]: cannot read [unsigned vint]: EOF", version)},
   107  				{"conversion failed", cqlDurationPosBytes, new(float64), new(float64), false, fmt.Sprintf("cannot decode CQL duration as *float64 with %v: cannot convert from datacodec.CqlDuration to *float64: conversion not supported", version)},
   108  			}
   109  			for _, tt := range tests {
   110  				t.Run(tt.name, func(t *testing.T) {
   111  					wasNull, err := Duration.Decode(tt.source, tt.dest, version)
   112  					assert.Equal(t, tt.expected, tt.dest)
   113  					assert.Equal(t, tt.wasNull, wasNull)
   114  					assertErrorMessage(t, tt.err, err)
   115  				})
   116  			}
   117  		})
   118  	}
   119  	for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion5) {
   120  		t.Run(version.String(), func(t *testing.T) {
   121  			tests := []struct {
   122  				name     string
   123  				source   []byte
   124  				dest     interface{}
   125  				expected interface{}
   126  				wasNull  bool
   127  				err      string
   128  			}{
   129  				{"null", nil, new(CqlDuration), new(CqlDuration), true, "data type duration not supported"},
   130  				{"non null", cqlDurationPosBytes, new(CqlDuration), new(CqlDuration), false, "data type duration not supported"},
   131  			}
   132  			for _, tt := range tests {
   133  				t.Run(tt.name, func(t *testing.T) {
   134  					wasNull, err := Duration.Decode(tt.source, tt.dest, version)
   135  					assert.Equal(t, tt.expected, tt.dest)
   136  					assert.Equal(t, tt.wasNull, wasNull)
   137  					assertErrorMessage(t, tt.err, err)
   138  				})
   139  			}
   140  		})
   141  	}
   142  }
   143  
   144  func Test_convertToDuration(t *testing.T) {
   145  	tests := []struct {
   146  		name       string
   147  		source     interface{}
   148  		wantDest   CqlDuration
   149  		wantWasNil bool
   150  		wantErr    string
   151  	}{
   152  		{"from CqlDuration", cqlDurationPos, cqlDurationPos, false, ""},
   153  		{"from *CqlDuration", &cqlDurationPos, cqlDurationPos, false, ""},
   154  		{"from *CqlDuration nil", cqlDurationNilPtr(), cqlDurationZero, true, ""},
   155  		{"from untyped nil", nil, cqlDurationZero, true, ""},
   156  		{"from unsupported value type", 123, cqlDurationZero, false, "cannot convert from int to datacodec.CqlDuration: conversion not supported"},
   157  		{"from unsupported pointer type", intPtr(123), cqlDurationZero, false, "cannot convert from *int to datacodec.CqlDuration: conversion not supported"},
   158  	}
   159  	for _, tt := range tests {
   160  		t.Run(tt.name, func(t *testing.T) {
   161  			gotDest, gotWasNil, gotErr := convertToDuration(tt.source)
   162  			assert.Equal(t, tt.wantDest, gotDest)
   163  			assert.Equal(t, tt.wantWasNil, gotWasNil)
   164  			assertErrorMessage(t, tt.wantErr, gotErr)
   165  		})
   166  	}
   167  }
   168  
   169  func Test_convertFromDuration(t *testing.T) {
   170  	tests := []struct {
   171  		name     string
   172  		val      CqlDuration
   173  		wasNull  bool
   174  		dest     interface{}
   175  		expected interface{}
   176  		err      string
   177  	}{
   178  		{"to *interface{} nil dest", cqlDurationPos, false, interfaceNilPtr(), interfaceNilPtr(), "cannot convert from datacodec.CqlDuration to *interface {}: destination is nil"},
   179  		{"to *interface{} nil source", cqlDurationZero, true, new(interface{}), new(interface{}), ""},
   180  		{"to *interface{} non nil", cqlDurationPos, false, new(interface{}), interfacePtr(cqlDurationPos), ""},
   181  		{"to *CqlDuration nil dest", cqlDurationZero, false, cqlDurationNilPtr(), cqlDurationNilPtr(), "cannot convert from datacodec.CqlDuration to *datacodec.CqlDuration: destination is nil"},
   182  		{"to *CqlDuration nil source", cqlDurationZero, true, new(CqlDuration), new(CqlDuration), ""},
   183  		{"to *CqlDuration empty source", cqlDurationZero, false, new(CqlDuration), new(CqlDuration), ""},
   184  		{"to *CqlDuration non nil", cqlDurationPos, false, new(CqlDuration), &cqlDurationPos, ""},
   185  		{"to untyped nil", cqlDurationPos, false, nil, nil, "cannot convert from datacodec.CqlDuration to <nil>: destination is nil"},
   186  		{"to non pointer", cqlDurationPos, false, CqlDuration{}, CqlDuration{}, "cannot convert from datacodec.CqlDuration to datacodec.CqlDuration: destination is not pointer"},
   187  		{"to unsupported pointer type", cqlDurationPos, false, new(float64), new(float64), "cannot convert from datacodec.CqlDuration to *float64: conversion not supported"},
   188  	}
   189  	for _, tt := range tests {
   190  		t.Run(tt.name, func(t *testing.T) {
   191  			gotErr := convertFromDuration(tt.val, tt.wasNull, tt.dest)
   192  			assert.Equal(t, tt.expected, tt.dest)
   193  			assertErrorMessage(t, tt.err, gotErr)
   194  		})
   195  	}
   196  }
   197  
   198  func Test_writeDuration(t *testing.T) {
   199  	tests := []struct {
   200  		name string
   201  		val  CqlDuration
   202  		want []byte
   203  	}{
   204  		{"pos", cqlDurationPos, cqlDurationPosBytes},
   205  		{"neg", cqlDurationNeg, cqlDurationNegBytes},
   206  		{"max", cqlDurationMax, cqlDurationMaxBytes},
   207  		{"min", cqlDurationMin, cqlDurationMinBytes},
   208  	}
   209  	for _, tt := range tests {
   210  		t.Run(tt.name, func(t *testing.T) {
   211  			got := writeDuration(tt.val)
   212  			assert.Equal(t, tt.want, got)
   213  		})
   214  	}
   215  }
   216  
   217  func Test_readDuration(t *testing.T) {
   218  	tests := []struct {
   219  		name        string
   220  		source      []byte
   221  		wantVal     CqlDuration
   222  		wantWasNull bool
   223  		wantErr     string
   224  	}{
   225  		{"nil", nil, CqlDuration{}, true, ""},
   226  		{"empty", []byte{}, CqlDuration{}, true, ""},
   227  		{"pos", cqlDurationPosBytes, cqlDurationPos, false, ""},
   228  		{"neg", cqlDurationNegBytes, cqlDurationNeg, false, ""},
   229  		{"max", cqlDurationMaxBytes, cqlDurationMax, false, ""},
   230  		{"min", cqlDurationMinBytes, cqlDurationMin, false, ""},
   231  		{"wrong months", []byte{255}, CqlDuration{}, false, "cannot read datacodec.CqlDuration: cannot read duration months: cannot read [vint]: cannot read [unsigned vint]: EOF"},
   232  		{"wrong days", []byte{1}, CqlDuration{}, false, "cannot read datacodec.CqlDuration: cannot read duration days: cannot read [vint]: cannot read [unsigned vint]: EOF"},
   233  		{"wrong nanos", []byte{1, 2}, CqlDuration{}, false, "cannot read datacodec.CqlDuration: cannot read duration nanos: cannot read [vint]: cannot read [unsigned vint]: EOF"},
   234  		{"bytes remaining", []byte{2, 4, 6, 8}, CqlDuration{}, false, "cannot read datacodec.CqlDuration: source was not fully read: bytes total: 4, read: 3, remaining: 1"},
   235  	}
   236  	for _, tt := range tests {
   237  		t.Run(tt.name, func(t *testing.T) {
   238  			gotVal, gotWasNull, gotErr := readDuration(tt.source)
   239  			assert.Equal(t, tt.wantVal, gotVal)
   240  			assert.Equal(t, tt.wantWasNull, gotWasNull)
   241  			assertErrorMessage(t, tt.wantErr, gotErr)
   242  		})
   243  	}
   244  }