github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/frame/codec_test.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frame 16 17 import ( 18 "bytes" 19 "testing" 20 21 "github.com/stretchr/testify/assert" 22 "github.com/stretchr/testify/require" 23 24 "github.com/datastax/go-cassandra-native-protocol/compression/lz4" 25 "github.com/datastax/go-cassandra-native-protocol/compression/snappy" 26 "github.com/datastax/go-cassandra-native-protocol/message" 27 "github.com/datastax/go-cassandra-native-protocol/primitive" 28 ) 29 30 // The tests in this file are meant to focus on encoding / decoding of frame headers and other common parts of 31 // the frame body, such as custom payloads, query warnings and tracing ids. They do not focus on encoding / decoding 32 // specific messages. 33 34 func TestFrameEncodeDecode(t *testing.T) { 35 codecs := createCodecs() 36 for _, version := range primitive.SupportedProtocolVersions() { 37 t.Run(version.String(), func(t *testing.T) { 38 request, response := createFrames(version) 39 for algorithm, codec := range codecs { 40 t.Run(algorithm, func(t *testing.T) { 41 tests := []struct { 42 name string 43 frame *Frame 44 }{ 45 {"request", request}, 46 {"response", response}, 47 } 48 for _, test := range tests { 49 t.Run(test.name, func(t *testing.T) { 50 encodedFrame := bytes.Buffer{} 51 err := codec.EncodeFrame(test.frame, &encodedFrame) 52 require.Nil(t, err) 53 decodedFrame, err := codec.DecodeFrame(&encodedFrame) 54 require.Nil(t, err) 55 require.Equal(t, test.frame, decodedFrame) 56 }) 57 } 58 }) 59 } 60 }) 61 } 62 } 63 64 func TestRawFrameEncodeDecode(t *testing.T) { 65 codecs := createCodecs() 66 for _, version := range primitive.SupportedProtocolVersions() { 67 t.Run(version.String(), func(t *testing.T) { 68 request, response := createFrames(version) 69 for algorithm, codec := range codecs { 70 t.Run(algorithm, func(t *testing.T) { 71 tests := []struct { 72 name string 73 frame *Frame 74 }{ 75 {"request", request}, 76 {"response", response}, 77 } 78 for _, test := range tests { 79 t.Run(test.name, func(t *testing.T) { 80 rawFrame, err := codec.ConvertToRawFrame(test.frame) 81 require.Nil(t, err) 82 encodedFrame := &bytes.Buffer{} 83 err = codec.EncodeRawFrame(rawFrame, encodedFrame) 84 require.Nil(t, err) 85 decodedFrame, err := codec.DecodeRawFrame(encodedFrame) 86 require.Nil(t, err) 87 assert.Equal(t, rawFrame, decodedFrame) 88 }) 89 } 90 }) 91 } 92 }) 93 } 94 } 95 96 func TestConvertToRawFrame(t *testing.T) { 97 codec := NewRawCodec() 98 for _, version := range primitive.SupportedProtocolVersions() { 99 t.Run(version.String(), func(t *testing.T) { 100 request, response := createFrames(version) 101 tests := []struct { 102 name string 103 frame *Frame 104 }{ 105 {"request", request}, 106 {"response", response}, 107 } 108 for _, test := range tests { 109 t.Run(test.name, func(t *testing.T) { 110 var rawFrame *RawFrame 111 var err error 112 rawFrame, err = codec.ConvertToRawFrame(test.frame) 113 assert.Nil(t, err) 114 assert.Equal(t, test.frame.Header, rawFrame.Header) 115 assert.Equal(t, test.frame.Body.Message.GetOpCode(), rawFrame.Header.OpCode) 116 assert.Equal(t, test.frame.Body.Message.IsResponse(), rawFrame.Header.IsResponse) 117 118 encodedBody := &bytes.Buffer{} 119 err = codec.EncodeBody(test.frame.Header, test.frame.Body, encodedBody) 120 assert.Nil(t, err) 121 encodedBodyBytes := encodedBody.Bytes() 122 assert.Equal(t, encodedBodyBytes, rawFrame.Body) 123 assert.Equal(t, int32(encodedBody.Len()), rawFrame.Header.BodyLength) 124 125 rawBody, err := codec.DecodeRawBody(test.frame.Header, encodedBody) 126 assert.Nil(t, err) 127 assert.Equal(t, encodedBodyBytes, rawBody) 128 129 fullFrame, err := codec.ConvertFromRawFrame(rawFrame) 130 assert.Nil(t, err) 131 assert.Equal(t, test.frame, fullFrame) 132 }) 133 } 134 }) 135 } 136 } 137 138 func createCodecs() map[string]RawCodec { 139 codecs := map[string]RawCodec{ 140 "NONE": NewRawCodec(), 141 "LZ4": NewRawCodecWithCompression(lz4.Compressor{}), 142 "SNAPPY": NewRawCodecWithCompression(snappy.Compressor{}), 143 } 144 return codecs 145 } 146 147 func createFrames(version primitive.ProtocolVersion) (*Frame, *Frame) { 148 var request = NewFrame(version, 1, message.NewStartup()) 149 var response = NewFrame(version, 1, &message.RowsResult{ 150 Metadata: &message.RowsMetadata{ColumnCount: 1}, 151 Data: [][][]byte{}, 152 }) 153 request.RequestTracingId(true) 154 var uuid = primitive.UUID{0xC0, 0xD1, 0xD2, 0x1E, 0xBB, 0x01, 0x41, 0x96, 0x86, 0xDB, 0xBC, 0x31, 0x7B, 0xC1, 0x79, 0x6A} 155 response.SetTracingId(&uuid) 156 if version >= primitive.ProtocolVersion4 { 157 request.SetCustomPayload(map[string][]byte{"hello": {0xca, 0xfe, 0xba, 0xbe}}) 158 response.SetCustomPayload(map[string][]byte{"hello": {0xca, 0xfe, 0xba, 0xbe}}) 159 response.SetWarnings([]string{"I'm warning you!!"}) 160 } 161 return request, response 162 }