github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/prepare_test.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package client_test
    16  
    17  import (
    18  	"github.com/datastax/go-cassandra-native-protocol/client"
    19  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    20  	"github.com/datastax/go-cassandra-native-protocol/frame"
    21  	"github.com/datastax/go-cassandra-native-protocol/message"
    22  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    23  	"github.com/stretchr/testify/require"
    24  	"testing"
    25  )
    26  
    27  func TestNewPreparedStatementHandler(t *testing.T) {
    28  
    29  	query := "SELECT v FROM ks.t1 WHERE pk = ?"
    30  
    31  	// bound variables in the prepared statement (pk)
    32  	var variables = &message.VariablesMetadata{
    33  		PkIndices: []uint16{0},
    34  		Columns: []*message.ColumnMetadata{{
    35  			Keyspace: "ks",
    36  			Table:    "t1",
    37  			Name:     "pk",
    38  			Index:    0,
    39  			Type:     datatype.Varchar,
    40  		}},
    41  	}
    42  
    43  	// columns in each row returned by the statement execution (v)
    44  	var columns = &message.RowsMetadata{
    45  		ColumnCount: 1,
    46  		Columns: []*message.ColumnMetadata{{
    47  			Keyspace: "ks",
    48  			Table:    "t1",
    49  			Name:     "v",
    50  			Index:    0,
    51  			Type:     datatype.Varchar,
    52  		}},
    53  	}
    54  
    55  	var pk1 = primitive.NewValue([]byte("pk1"))
    56  	var pk2 = primitive.NewValue([]byte("pk2"))
    57  
    58  	var v1 = message.Row{message.Column("v1")}
    59  	var v2 = message.Row{message.Column("v2")}
    60  
    61  	// if bound variable pk = pk1 then EXECUTE should return v1, otherwise EXECUTE should return v2
    62  	rows := func(options *message.QueryOptions) message.RowSet {
    63  		value := options.PositionalValues[0]
    64  		if string(value.Contents) == "pk1" {
    65  			return message.RowSet{v1}
    66  		} else {
    67  			return message.RowSet{v2}
    68  		}
    69  	}
    70  
    71  	handler := client.NewPreparedStatementHandler(query, variables, columns, rows)
    72  
    73  	server, clientConn, cancelFn := createServerAndClient(t, []client.RequestHandler{handler}, nil)
    74  
    75  	testUnprepared(t, clientConn, query)
    76  	testPrepare(t, clientConn, query, variables, columns)
    77  	testExecute(t, clientConn, query, columns, pk1, v1)
    78  	testExecute(t, clientConn, query, columns, pk2, v2)
    79  
    80  	cancelFn()
    81  	checkClosed(t, clientConn, server)
    82  
    83  }
    84  
    85  func testUnprepared(
    86  	t *testing.T,
    87  	clientConn *client.CqlClientConnection,
    88  	query string,
    89  ) {
    90  	execute := frame.NewFrame(
    91  		primitive.ProtocolVersion4,
    92  		client.ManagedStreamId,
    93  		&message.Execute{QueryId: []byte(query)},
    94  	)
    95  	response, err := clientConn.SendAndReceive(execute)
    96  	require.NotNil(t, response)
    97  	require.NoError(t, err)
    98  	require.Equal(t, primitive.OpCodeError, response.Header.OpCode)
    99  	require.IsType(t, &message.Unprepared{}, response.Body.Message)
   100  	result := response.Body.Message.(*message.Unprepared)
   101  	require.Equal(t, []byte(query), result.Id)
   102  }
   103  
   104  func testPrepare(
   105  	t *testing.T,
   106  	clientConn *client.CqlClientConnection,
   107  	query string,
   108  	variables *message.VariablesMetadata,
   109  	columns *message.RowsMetadata,
   110  ) {
   111  	prepare := frame.NewFrame(
   112  		primitive.ProtocolVersion4,
   113  		client.ManagedStreamId,
   114  		&message.Prepare{Query: query},
   115  	)
   116  	response, err := clientConn.SendAndReceive(prepare)
   117  	require.NotNil(t, response)
   118  	require.NoError(t, err)
   119  	require.Equal(t, primitive.OpCodeResult, response.Header.OpCode)
   120  	require.IsType(t, &message.PreparedResult{}, response.Body.Message)
   121  	result := response.Body.Message.(*message.PreparedResult)
   122  	require.Equal(t, []byte(query), result.PreparedQueryId)
   123  	require.Equal(t, variables, result.VariablesMetadata)
   124  	require.Equal(t, columns, result.ResultMetadata)
   125  
   126  }
   127  
   128  func testExecute(
   129  	t *testing.T,
   130  	clientConn *client.CqlClientConnection,
   131  	query string,
   132  	columns *message.RowsMetadata,
   133  	pk *primitive.Value,
   134  	row message.Row,
   135  ) {
   136  	execute := frame.NewFrame(
   137  		primitive.ProtocolVersion4,
   138  		client.ManagedStreamId,
   139  		&message.Execute{
   140  			QueryId: []byte(query),
   141  			Options: &message.QueryOptions{PositionalValues: []*primitive.Value{pk}},
   142  		},
   143  	)
   144  	response, err := clientConn.SendAndReceive(execute)
   145  	require.NotNil(t, response)
   146  	require.NoError(t, err)
   147  	require.Equal(t, primitive.OpCodeResult, response.Header.OpCode)
   148  	require.IsType(t, &message.RowsResult{}, response.Body.Message)
   149  	result := response.Body.Message.(*message.RowsResult)
   150  	require.Equal(t, message.RowSet{row}, result.Data)
   151  	require.Equal(t, columns, result.Metadata)
   152  }