github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/message/prepare_test.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package message
    16  
    17  import (
    18  	"bytes"
    19  	"errors"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  
    24  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    25  )
    26  
    27  func TestPrepare_DeepCopy(t *testing.T) {
    28  	msg := &Prepare{
    29  		Query:    "query",
    30  		Keyspace: "ks1",
    31  	}
    32  
    33  	cloned := msg.DeepCopy()
    34  	assert.Equal(t, msg, cloned)
    35  
    36  	cloned.Query = "query2"
    37  	cloned.Keyspace = "ks2"
    38  
    39  	assert.NotEqual(t, msg, cloned)
    40  
    41  	assert.Equal(t, "query", msg.Query)
    42  	assert.Equal(t, "ks1", msg.Keyspace)
    43  
    44  	assert.Equal(t, "query2", cloned.Query)
    45  	assert.Equal(t, "ks2", cloned.Keyspace)
    46  }
    47  
    48  func TestPrepareCodec_Encode(t *testing.T) {
    49  	codec := &prepareCodec{}
    50  	// versions <= 4 + DSE v1
    51  	for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1} {
    52  		t.Run(version.String(), func(t *testing.T) {
    53  			tests := []encodeTestCase{
    54  				{
    55  					"prepare simple",
    56  					&Prepare{"SELECT", ""},
    57  					[]byte{
    58  						0, 0, 0, 6, S, E, L, E, C, T,
    59  					},
    60  					nil,
    61  				},
    62  				{
    63  					"not a prepare",
    64  					&Ready{},
    65  					nil,
    66  					errors.New("expected *message.Prepare, got *message.Ready"),
    67  				},
    68  			}
    69  			for _, tt := range tests {
    70  				t.Run(tt.name, func(t *testing.T) {
    71  					dest := &bytes.Buffer{}
    72  					err := codec.Encode(tt.input, dest, version)
    73  					assert.Equal(t, tt.expected, dest.Bytes())
    74  					assert.Equal(t, tt.err, err)
    75  				})
    76  			}
    77  		})
    78  	}
    79  	// versions 5, DSE v2
    80  	for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion5, primitive.ProtocolVersionDse2} {
    81  		t.Run(version.String(), func(t *testing.T) {
    82  			tests := []encodeTestCase{
    83  				{
    84  					"prepare simple",
    85  					&Prepare{"SELECT", ""},
    86  					[]byte{
    87  						0, 0, 0, 6, S, E, L, E, C, T,
    88  						0, 0, 0, 0, // flags
    89  					},
    90  					nil,
    91  				},
    92  				{
    93  					"prepare with keyspace",
    94  					&Prepare{"SELECT", "ks"},
    95  					[]byte{
    96  						0, 0, 0, 6, S, E, L, E, C, T,
    97  						0, 0, 0, 1, // flags
    98  						0, 2, k, s, // keyspace
    99  					},
   100  					nil,
   101  				},
   102  				{
   103  					"not a prepare",
   104  					&Ready{},
   105  					nil,
   106  					errors.New("expected *message.Prepare, got *message.Ready"),
   107  				},
   108  			}
   109  			for _, tt := range tests {
   110  				t.Run(tt.name, func(t *testing.T) {
   111  					dest := &bytes.Buffer{}
   112  					err := codec.Encode(tt.input, dest, version)
   113  					assert.Equal(t, tt.expected, dest.Bytes())
   114  					assert.Equal(t, tt.err, err)
   115  				})
   116  			}
   117  		})
   118  	}
   119  }
   120  
   121  func TestPrepareCodec_EncodedLength(t *testing.T) {
   122  	codec := &prepareCodec{}
   123  	// versions <= 4 + DSE v1
   124  	for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1} {
   125  		t.Run(version.String(), func(t *testing.T) {
   126  			tests := []encodedLengthTestCase{
   127  				{
   128  					"prepare simple",
   129  					&Prepare{"SELECT", ""},
   130  					primitive.LengthOfLongString("SELECT"),
   131  					nil,
   132  				},
   133  				{
   134  					"not a prepare",
   135  					&Ready{},
   136  					-1,
   137  					errors.New("expected *message.Prepare, got *message.Ready"),
   138  				},
   139  			}
   140  			for _, tt := range tests {
   141  				t.Run(tt.name, func(t *testing.T) {
   142  					actual, err := codec.EncodedLength(tt.input, version)
   143  					assert.Equal(t, tt.expected, actual)
   144  					assert.Equal(t, tt.err, err)
   145  				})
   146  			}
   147  		})
   148  	}
   149  	// versions 5, DSE v2
   150  	for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion5, primitive.ProtocolVersionDse2} {
   151  		t.Run(version.String(), func(t *testing.T) {
   152  			tests := []encodedLengthTestCase{
   153  				{
   154  					"prepare simple",
   155  					&Prepare{"SELECT", ""},
   156  					primitive.LengthOfLongString("SELECT") +
   157  						primitive.LengthOfInt, // flags
   158  					nil,
   159  				},
   160  				{
   161  					"prepare with keyspace",
   162  					&Prepare{"SELECT", "ks"},
   163  					primitive.LengthOfLongString("SELECT") +
   164  						primitive.LengthOfInt + // flags
   165  						primitive.LengthOfString("ks"), // keyspace
   166  					nil,
   167  				},
   168  				{
   169  					"not a prepare",
   170  					&Ready{},
   171  					-1,
   172  					errors.New("expected *message.Prepare, got *message.Ready"),
   173  				},
   174  			}
   175  			for _, tt := range tests {
   176  				t.Run(tt.name, func(t *testing.T) {
   177  					actual, err := codec.EncodedLength(tt.input, version)
   178  					assert.Equal(t, tt.expected, actual)
   179  					assert.Equal(t, tt.err, err)
   180  				})
   181  			}
   182  		})
   183  	}
   184  }
   185  
   186  func TestPrepareCodec_Decode(t *testing.T) {
   187  	codec := &prepareCodec{}
   188  	// versions <= 4 + DSE v1
   189  	for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion2, primitive.ProtocolVersion3, primitive.ProtocolVersion4, primitive.ProtocolVersionDse1} {
   190  		t.Run(version.String(), func(t *testing.T) {
   191  			tests := []decodeTestCase{
   192  				{
   193  					"prepare simple",
   194  					[]byte{
   195  						0, 0, 0, 6, S, E, L, E, C, T,
   196  					},
   197  					&Prepare{"SELECT", ""},
   198  					nil,
   199  				},
   200  			}
   201  			for _, tt := range tests {
   202  				t.Run(tt.name, func(t *testing.T) {
   203  					source := bytes.NewBuffer(tt.input)
   204  					actual, err := codec.Decode(source, version)
   205  					assert.Equal(t, tt.expected, actual)
   206  					assert.Equal(t, tt.err, err)
   207  				})
   208  			}
   209  		})
   210  	}
   211  	// versions 5, DSE v2
   212  	for _, version := range []primitive.ProtocolVersion{primitive.ProtocolVersion5, primitive.ProtocolVersionDse2} {
   213  		t.Run(version.String(), func(t *testing.T) {
   214  			tests := []decodeTestCase{
   215  				{
   216  					"prepare simple",
   217  					[]byte{
   218  						0, 0, 0, 6, S, E, L, E, C, T,
   219  						0, 0, 0, 0, // flags
   220  					},
   221  					&Prepare{"SELECT", ""},
   222  					nil,
   223  				},
   224  				{
   225  					"prepare with keyspace",
   226  					[]byte{
   227  						0, 0, 0, 6, S, E, L, E, C, T,
   228  						0, 0, 0, 1, // flags
   229  						0, 2, k, s, // keyspace
   230  					},
   231  					&Prepare{"SELECT", "ks"},
   232  					nil,
   233  				},
   234  			}
   235  			for _, tt := range tests {
   236  				t.Run(tt.name, func(t *testing.T) {
   237  					source := bytes.NewBuffer(tt.input)
   238  					actual, err := codec.Decode(source, version)
   239  					assert.Equal(t, tt.expected, actual)
   240  					assert.Equal(t, tt.err, err)
   241  				})
   242  			}
   243  		})
   244  	}
   245  }