github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/frame/codec_test.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package frame
    16  
    17  import (
    18  	"bytes"
    19  	"testing"
    20  
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/datastax/go-cassandra-native-protocol/compression/lz4"
    25  	"github.com/datastax/go-cassandra-native-protocol/compression/snappy"
    26  	"github.com/datastax/go-cassandra-native-protocol/message"
    27  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    28  )
    29  
    30  // The tests in this file are meant to focus on encoding / decoding of frame headers and other common parts of
    31  // the frame body, such as custom payloads, query warnings and tracing ids. They do not focus on encoding / decoding
    32  // specific messages.
    33  
    34  func TestFrameEncodeDecode(t *testing.T) {
    35  	codecs := createCodecs()
    36  	for _, version := range primitive.SupportedProtocolVersions() {
    37  		t.Run(version.String(), func(t *testing.T) {
    38  			request, response := createFrames(version)
    39  			for algorithm, codec := range codecs {
    40  				t.Run(algorithm, func(t *testing.T) {
    41  					tests := []struct {
    42  						name  string
    43  						frame *Frame
    44  					}{
    45  						{"request", request},
    46  						{"response", response},
    47  					}
    48  					for _, test := range tests {
    49  						t.Run(test.name, func(t *testing.T) {
    50  							encodedFrame := bytes.Buffer{}
    51  							err := codec.EncodeFrame(test.frame, &encodedFrame)
    52  							require.Nil(t, err)
    53  							decodedFrame, err := codec.DecodeFrame(&encodedFrame)
    54  							require.Nil(t, err)
    55  							require.Equal(t, test.frame, decodedFrame)
    56  						})
    57  					}
    58  				})
    59  			}
    60  		})
    61  	}
    62  }
    63  
    64  func TestRawFrameEncodeDecode(t *testing.T) {
    65  	codecs := createCodecs()
    66  	for _, version := range primitive.SupportedProtocolVersions() {
    67  		t.Run(version.String(), func(t *testing.T) {
    68  			request, response := createFrames(version)
    69  			for algorithm, codec := range codecs {
    70  				t.Run(algorithm, func(t *testing.T) {
    71  					tests := []struct {
    72  						name  string
    73  						frame *Frame
    74  					}{
    75  						{"request", request},
    76  						{"response", response},
    77  					}
    78  					for _, test := range tests {
    79  						t.Run(test.name, func(t *testing.T) {
    80  							rawFrame, err := codec.ConvertToRawFrame(test.frame)
    81  							require.Nil(t, err)
    82  							encodedFrame := &bytes.Buffer{}
    83  							err = codec.EncodeRawFrame(rawFrame, encodedFrame)
    84  							require.Nil(t, err)
    85  							decodedFrame, err := codec.DecodeRawFrame(encodedFrame)
    86  							require.Nil(t, err)
    87  							assert.Equal(t, rawFrame, decodedFrame)
    88  						})
    89  					}
    90  				})
    91  			}
    92  		})
    93  	}
    94  }
    95  
    96  func TestConvertToRawFrame(t *testing.T) {
    97  	codec := NewRawCodec()
    98  	for _, version := range primitive.SupportedProtocolVersions() {
    99  		t.Run(version.String(), func(t *testing.T) {
   100  			request, response := createFrames(version)
   101  			tests := []struct {
   102  				name  string
   103  				frame *Frame
   104  			}{
   105  				{"request", request},
   106  				{"response", response},
   107  			}
   108  			for _, test := range tests {
   109  				t.Run(test.name, func(t *testing.T) {
   110  					var rawFrame *RawFrame
   111  					var err error
   112  					rawFrame, err = codec.ConvertToRawFrame(test.frame)
   113  					assert.Nil(t, err)
   114  					assert.Equal(t, test.frame.Header, rawFrame.Header)
   115  					assert.Equal(t, test.frame.Body.Message.GetOpCode(), rawFrame.Header.OpCode)
   116  					assert.Equal(t, test.frame.Body.Message.IsResponse(), rawFrame.Header.IsResponse)
   117  
   118  					encodedBody := &bytes.Buffer{}
   119  					err = codec.EncodeBody(test.frame.Header, test.frame.Body, encodedBody)
   120  					assert.Nil(t, err)
   121  					encodedBodyBytes := encodedBody.Bytes()
   122  					assert.Equal(t, encodedBodyBytes, rawFrame.Body)
   123  					assert.Equal(t, int32(encodedBody.Len()), rawFrame.Header.BodyLength)
   124  
   125  					rawBody, err := codec.DecodeRawBody(test.frame.Header, encodedBody)
   126  					assert.Nil(t, err)
   127  					assert.Equal(t, encodedBodyBytes, rawBody)
   128  
   129  					fullFrame, err := codec.ConvertFromRawFrame(rawFrame)
   130  					assert.Nil(t, err)
   131  					assert.Equal(t, test.frame, fullFrame)
   132  				})
   133  			}
   134  		})
   135  	}
   136  }
   137  
   138  func createCodecs() map[string]RawCodec {
   139  	codecs := map[string]RawCodec{
   140  		"NONE":   NewRawCodec(),
   141  		"LZ4":    NewRawCodecWithCompression(lz4.Compressor{}),
   142  		"SNAPPY": NewRawCodecWithCompression(snappy.Compressor{}),
   143  	}
   144  	return codecs
   145  }
   146  
   147  func createFrames(version primitive.ProtocolVersion) (*Frame, *Frame) {
   148  	var request = NewFrame(version, 1, message.NewStartup())
   149  	var response = NewFrame(version, 1, &message.RowsResult{
   150  		Metadata: &message.RowsMetadata{ColumnCount: 1},
   151  		Data:     [][][]byte{},
   152  	})
   153  	request.RequestTracingId(true)
   154  	var uuid = primitive.UUID{0xC0, 0xD1, 0xD2, 0x1E, 0xBB, 0x01, 0x41, 0x96, 0x86, 0xDB, 0xBC, 0x31, 0x7B, 0xC1, 0x79, 0x6A}
   155  	response.SetTracingId(&uuid)
   156  	if version >= primitive.ProtocolVersion4 {
   157  		request.SetCustomPayload(map[string][]byte{"hello": {0xca, 0xfe, 0xba, 0xbe}})
   158  		response.SetCustomPayload(map[string][]byte{"hello": {0xca, 0xfe, 0xba, 0xbe}})
   159  		response.SetWarnings([]string{"I'm warning you!!"})
   160  	}
   161  	return request, response
   162  }