github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datatype/set_test.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datatype 16 17 import ( 18 "bytes" 19 "errors" 20 "fmt" 21 "testing" 22 23 "github.com/stretchr/testify/assert" 24 25 "github.com/datastax/go-cassandra-native-protocol/primitive" 26 ) 27 28 func TestSetType(t *testing.T) { 29 setType := NewSet(Varchar) 30 assert.Equal(t, primitive.DataTypeCodeSet, setType.Code()) 31 assert.Equal(t, Varchar, setType.ElementType) 32 } 33 34 func TestSetTypeDeepCopy(t *testing.T) { 35 st := NewSet(Varchar) 36 cloned := st.DeepCopy() 37 assert.Equal(t, st, cloned) 38 cloned.ElementType = Int 39 assert.Equal(t, primitive.DataTypeCodeSet, st.Code()) 40 assert.Equal(t, Varchar, st.ElementType) 41 assert.Equal(t, primitive.DataTypeCodeSet, cloned.Code()) 42 assert.Equal(t, Int, cloned.ElementType) 43 } 44 45 func TestWriteSetType(t *testing.T) { 46 for _, version := range primitive.SupportedProtocolVersions() { 47 t.Run(version.String(), func(t *testing.T) { 48 tests := []struct { 49 name string 50 input DataType 51 expected []byte 52 err error 53 }{ 54 { 55 "simple set", 56 NewSet(Varchar), 57 []byte{0, byte(primitive.DataTypeCodeVarchar & 0xFF)}, 58 nil, 59 }, 60 { 61 "complex set", 62 NewSet(NewSet(Varchar)), 63 []byte{ 64 0, byte(primitive.DataTypeCodeSet & 0xFF), 65 0, byte(primitive.DataTypeCodeVarchar & 0xFF)}, 66 nil, 67 }, 68 {"nil set", nil, nil, errors.New("expected *Set, got <nil>")}, 69 } 70 for _, test := range tests { 71 t.Run(test.name, func(t *testing.T) { 72 var dest = &bytes.Buffer{} 73 var err error 74 err = writeSetType(test.input, dest, version) 75 actual := dest.Bytes() 76 assert.Equal(t, test.expected, actual) 77 assert.Equal(t, test.err, err) 78 }) 79 } 80 }) 81 } 82 } 83 84 func TestLengthOfSetType(t *testing.T) { 85 for _, version := range primitive.SupportedProtocolVersions() { 86 t.Run(version.String(), func(t *testing.T) { 87 tests := []struct { 88 name string 89 input DataType 90 expected int 91 err error 92 }{ 93 {"simple set", NewSet(Varchar), primitive.LengthOfShort, nil}, 94 {"complex set", NewSet(NewSet(Varchar)), primitive.LengthOfShort + primitive.LengthOfShort, nil}, 95 {"nil set", nil, -1, errors.New("expected *Set, got <nil>")}, 96 } 97 for _, test := range tests { 98 t.Run(test.name, func(t *testing.T) { 99 var actual int 100 var err error 101 actual, err = lengthOfSetType(test.input, version) 102 assert.Equal(t, test.expected, actual) 103 assert.Equal(t, test.err, err) 104 }) 105 } 106 }) 107 } 108 } 109 110 func TestReadSetType(t *testing.T) { 111 for _, version := range primitive.SupportedProtocolVersions() { 112 t.Run(version.String(), func(t *testing.T) { 113 tests := []struct { 114 name string 115 input []byte 116 expected DataType 117 err error 118 }{ 119 { 120 "simple set", 121 []byte{0, byte(primitive.DataTypeCodeVarchar & 0xff)}, 122 NewSet(Varchar), 123 nil, 124 }, 125 { 126 "complex set", 127 []byte{ 128 0, byte(primitive.DataTypeCodeSet & 0xff), 129 0, byte(primitive.DataTypeCodeVarchar & 0xff)}, 130 NewSet(NewSet(Varchar)), 131 nil, 132 }, 133 { 134 "cannot read set", 135 []byte{}, 136 nil, 137 fmt.Errorf("cannot read set element type: %w", 138 fmt.Errorf("cannot read data type code: %w", 139 fmt.Errorf("cannot read [short]: %w", 140 errors.New("EOF")))), 141 }, 142 } 143 for _, test := range tests { 144 t.Run(test.name, func(t *testing.T) { 145 var source = bytes.NewBuffer(test.input) 146 var actual DataType 147 var err error 148 actual, err = readSetType(source, version) 149 assert.Equal(t, test.expected, actual) 150 assert.Equal(t, test.err, err) 151 }) 152 } 153 }) 154 } 155 }