github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/message/prepare_test.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package message 16 17 import ( 18 "bytes" 19 "errors" 20 "testing" 21 22 "github.com/stretchr/testify/assert" 23 24 "github.com/datastax/go-cassandra-native-protocol/primitive" 25 ) 26 27 func TestPrepare_DeepCopy(t *testing.T) { 28 msg := &Prepare{ 29 Query: "query", 30 Keyspace: "ks1", 31 } 32 33 cloned := msg.DeepCopy() 34 assert.Equal(t, msg, cloned) 35 36 cloned.Query = "query2" 37 cloned.Keyspace = "ks2" 38 39 assert.NotEqual(t, msg, cloned) 40 41 assert.Equal(t, "query", msg.Query) 42 assert.Equal(t, "ks1", msg.Keyspace) 43 44 assert.Equal(t, "query2", cloned.Query) 45 assert.Equal(t, "ks2", cloned.Keyspace) 46 } 47 48 func TestPrepareCodec_Encode(t *testing.T) { 49 codec := &prepareCodec{} 50 // versions <= 4 + DSE v1 51 for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1} { 52 t.Run(version.String(), func(t *testing.T) { 53 tests := []encodeTestCase{ 54 { 55 "prepare simple", 56 &Prepare{"SELECT", ""}, 57 []byte{ 58 0, 0, 0, 6, S, E, L, E, C, T, 59 }, 60 nil, 61 }, 62 { 63 "not a prepare", 64 &Ready{}, 65 nil, 66 errors.New("expected *message.Prepare, got *message.Ready"), 67 }, 68 } 69 for _, tt := range tests { 70 t.Run(tt.name, func(t *testing.T) { 71 dest := &bytes.Buffer{} 72 err := codec.Encode(tt.input, dest, version) 73 assert.Equal(t, tt.expected, dest.Bytes()) 74 assert.Equal(t, tt.err, err) 75 }) 76 } 77 }) 78 } 79 // versions 5, DSE v2 80 for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion5, primitive.ProtocolVersionDse2} { 81 t.Run(version.String(), func(t *testing.T) { 82 tests := []encodeTestCase{ 83 { 84 "prepare simple", 85 &Prepare{"SELECT", ""}, 86 []byte{ 87 0, 0, 0, 6, S, E, L, E, C, T, 88 0, 0, 0, 0, // flags 89 }, 90 nil, 91 }, 92 { 93 "prepare with keyspace", 94 &Prepare{"SELECT", "ks"}, 95 []byte{ 96 0, 0, 0, 6, S, E, L, E, C, T, 97 0, 0, 0, 1, // flags 98 0, 2, k, s, // keyspace 99 }, 100 nil, 101 }, 102 { 103 "not a prepare", 104 &Ready{}, 105 nil, 106 errors.New("expected *message.Prepare, got *message.Ready"), 107 }, 108 } 109 for _, tt := range tests { 110 t.Run(tt.name, func(t *testing.T) { 111 dest := &bytes.Buffer{} 112 err := codec.Encode(tt.input, dest, version) 113 assert.Equal(t, tt.expected, dest.Bytes()) 114 assert.Equal(t, tt.err, err) 115 }) 116 } 117 }) 118 } 119 } 120 121 func TestPrepareCodec_EncodedLength(t *testing.T) { 122 codec := &prepareCodec{} 123 // versions <= 4 + DSE v1 124 for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1} { 125 t.Run(version.String(), func(t *testing.T) { 126 tests := []encodedLengthTestCase{ 127 { 128 "prepare simple", 129 &Prepare{"SELECT", ""}, 130 primitive.LengthOfLongString("SELECT"), 131 nil, 132 }, 133 { 134 "not a prepare", 135 &Ready{}, 136 -1, 137 errors.New("expected *message.Prepare, got *message.Ready"), 138 }, 139 } 140 for _, tt := range tests { 141 t.Run(tt.name, func(t *testing.T) { 142 actual, err := codec.EncodedLength(tt.input, version) 143 assert.Equal(t, tt.expected, actual) 144 assert.Equal(t, tt.err, err) 145 }) 146 } 147 }) 148 } 149 // versions 5, DSE v2 150 for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion5, primitive.ProtocolVersionDse2} { 151 t.Run(version.String(), func(t *testing.T) { 152 tests := []encodedLengthTestCase{ 153 { 154 "prepare simple", 155 &Prepare{"SELECT", ""}, 156 primitive.LengthOfLongString("SELECT") + 157 primitive.LengthOfInt, // flags 158 nil, 159 }, 160 { 161 "prepare with keyspace", 162 &Prepare{"SELECT", "ks"}, 163 primitive.LengthOfLongString("SELECT") + 164 primitive.LengthOfInt + // flags 165 primitive.LengthOfString("ks"), // keyspace 166 nil, 167 }, 168 { 169 "not a prepare", 170 &Ready{}, 171 -1, 172 errors.New("expected *message.Prepare, got *message.Ready"), 173 }, 174 } 175 for _, tt := range tests { 176 t.Run(tt.name, func(t *testing.T) { 177 actual, err := codec.EncodedLength(tt.input, version) 178 assert.Equal(t, tt.expected, actual) 179 assert.Equal(t, tt.err, err) 180 }) 181 } 182 }) 183 } 184 } 185 186 func TestPrepareCodec_Decode(t *testing.T) { 187 codec := &prepareCodec{} 188 // versions <= 4 + DSE v1 189 for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1} { 190 t.Run(version.String(), func(t *testing.T) { 191 tests := []decodeTestCase{ 192 { 193 "prepare simple", 194 []byte{ 195 0, 0, 0, 6, S, E, L, E, C, T, 196 }, 197 &Prepare{"SELECT", ""}, 198 nil, 199 }, 200 } 201 for _, tt := range tests { 202 t.Run(tt.name, func(t *testing.T) { 203 source := bytes.NewBuffer(tt.input) 204 actual, err := codec.Decode(source, version) 205 assert.Equal(t, tt.expected, actual) 206 assert.Equal(t, tt.err, err) 207 }) 208 } 209 }) 210 } 211 // versions 5, DSE v2 212 for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion5, primitive.ProtocolVersionDse2} { 213 t.Run(version.String(), func(t *testing.T) { 214 tests := []decodeTestCase{ 215 { 216 "prepare simple", 217 []byte{ 218 0, 0, 0, 6, S, E, L, E, C, T, 219 0, 0, 0, 0, // flags 220 }, 221 &Prepare{"SELECT", ""}, 222 nil, 223 }, 224 { 225 "prepare with keyspace", 226 []byte{ 227 0, 0, 0, 6, S, E, L, E, C, T, 228 0, 0, 0, 1, // flags 229 0, 2, k, s, // keyspace 230 }, 231 &Prepare{"SELECT", "ks"}, 232 nil, 233 }, 234 } 235 for _, tt := range tests { 236 t.Run(tt.name, func(t *testing.T) { 237 source := bytes.NewBuffer(tt.input) 238 actual, err := codec.Decode(source, version) 239 assert.Equal(t, tt.expected, actual) 240 assert.Equal(t, tt.err, err) 241 }) 242 } 243 }) 244 } 245 }