github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/handlers_test.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package client_test 16 17 import ( 18 "bytes" 19 "context" 20 "github.com/datastax/go-cassandra-native-protocol/client" 21 "github.com/datastax/go-cassandra-native-protocol/frame" 22 "github.com/datastax/go-cassandra-native-protocol/message" 23 "github.com/datastax/go-cassandra-native-protocol/primitive" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 "testing" 27 "time" 28 ) 29 30 func TestHeartbeatHandler(t *testing.T) { 31 32 server, clientConn, cancelFn := createServerAndClient(t, []client.RequestHandler{client.HeartbeatHandler}, nil) 33 defer cancelFn() 34 35 testHeartbeat(t, clientConn) 36 37 cancelFn() 38 checkClosed(t, clientConn, server) 39 40 } 41 42 func TestSetKeyspaceHandler(t *testing.T) { 43 44 onKeyspaceSetCalled := false 45 var onKeyspaceSet = func(keyspace string) { 46 require.Equal(t, "ks1", keyspace) 47 onKeyspaceSetCalled = true 48 } 49 server, clientConn, cancelFn := createServerAndClient(t, []client.RequestHandler{client.NewSetKeyspaceHandler(onKeyspaceSet)}, nil) 50 defer cancelFn() 51 52 testUseQuery(t, clientConn) 53 require.True(t, onKeyspaceSetCalled) 54 55 cancelFn() 56 checkClosed(t, clientConn, server) 57 58 } 59 60 func TestRegisterHandler(t *testing.T) { 61 62 server, clientConn, cancelFn := createServerAndClient(t, []client.RequestHandler{client.RegisterHandler}, nil) 63 defer cancelFn() 64 65 testRegister(t, clientConn) 66 67 cancelFn() 68 checkClosed(t, clientConn, server) 69 70 } 71 72 func TestNewCompositeRequestHandler(t *testing.T) { 73 74 handler := client.NewCompositeRequestHandler(client.HeartbeatHandler, client.RegisterHandler) 75 server, clientConn, cancelFn := createServerAndClient(t, []client.RequestHandler{handler}, nil) 76 defer cancelFn() 77 78 testHeartbeat(t, clientConn) 79 testRegister(t, clientConn) 80 81 cancelFn() 82 checkClosed(t, clientConn, server) 83 84 } 85 86 func TestNewDriverConnectionInitializationHandler(t *testing.T) { 87 88 var onKeyspaceSet = func(keyspace string) { 89 require.Equal(t, "ks1", keyspace) 90 } 91 handler := client.NewDriverConnectionInitializationHandler("cluster_test", "datacenter_test", onKeyspaceSet) 92 server, clientConn, cancelFn := createServerAndClient(t, []client.RequestHandler{handler}, nil) 93 defer cancelFn() 94 95 err := clientConn.InitiateHandshake(primitive.ProtocolVersion4, client.ManagedStreamId) 96 require.NoError(t, err) 97 98 testHeartbeat(t, clientConn) 99 testRegister(t, clientConn) 100 testUseQuery(t, clientConn) 101 testClusterName(t, clientConn) 102 testSchemaVersion(t, clientConn) 103 testFullSystemLocal(t, clientConn) 104 testSystemPeers(t, clientConn) 105 106 cancelFn() 107 checkClosed(t, clientConn, server) 108 109 } 110 111 func TestRawHandler(t *testing.T) { 112 var rawHandler client.RawRequestHandler 113 rawHandler = func(request *frame.Frame, conn *client.CqlServerConnection, ctx client.RequestHandlerContext) (rawResponse []byte) { 114 bytesBuf := bytes.Buffer{} 115 err := frame.NewCodec().EncodeFrame(frame.NewFrame(primitive.ProtocolVersion4, 1, &message.Ready{}), &bytesBuf) 116 if err == nil { 117 return bytesBuf.Bytes() 118 } else { 119 return nil 120 } 121 } 122 server, clientConn, cancelFn := createServerAndClient( 123 t, []client.RequestHandler{client.HeartbeatHandler}, []client.RawRequestHandler{rawHandler}) 124 defer cancelFn() 125 126 testRawRequestHandler(t, clientConn) 127 128 cancelFn() 129 checkClosed(t, clientConn, server) 130 131 } 132 133 func createServerAndClient(t *testing.T, handlers []client.RequestHandler, rawHandlers []client.RawRequestHandler) (*client.CqlServer, *client.CqlClientConnection, context.CancelFunc) { 134 server := client.NewCqlServer("127.0.0.1:9043", nil) 135 server.RequestHandlers = handlers 136 server.RequestRawHandlers = rawHandlers 137 clt := client.NewCqlClient("127.0.0.1:9043", nil) 138 ctx, cancelFn := context.WithCancel(context.Background()) 139 err := server.Start(ctx) 140 require.NoError(t, err) 141 clientConn, err := clt.Connect(ctx) 142 require.NoError(t, err) 143 require.NotNil(t, clientConn) 144 return server, clientConn, cancelFn 145 } 146 147 func checkClosed(t *testing.T, clientConn *client.CqlClientConnection, server *client.CqlServer) { 148 assert.Eventually(t, clientConn.IsClosed, time.Second*10, time.Millisecond*10) 149 assert.Eventually(t, server.IsClosed, time.Second*10, time.Millisecond*10) 150 } 151 152 func testRawRequestHandler(t *testing.T, clientConn *client.CqlClientConnection) { 153 heartbeat := frame.NewFrame( 154 primitive.ProtocolVersion4, 155 client.ManagedStreamId, 156 &message.Options{}, 157 ) 158 for i := 0; i < 100; i++ { 159 response, err := clientConn.SendAndReceive(heartbeat) 160 require.NoError(t, err) 161 require.NotNil(t, response) 162 require.Equal(t, primitive.OpCodeReady, response.Header.OpCode) 163 require.IsType(t, &message.Ready{}, response.Body.Message) 164 } 165 } 166 167 func testHeartbeat(t *testing.T, clientConn *client.CqlClientConnection) { 168 heartbeat := frame.NewFrame( 169 primitive.ProtocolVersion4, 170 client.ManagedStreamId, 171 &message.Options{}, 172 ) 173 for i := 0; i < 100; i++ { 174 response, err := clientConn.SendAndReceive(heartbeat) 175 require.NoError(t, err) 176 require.NotNil(t, response) 177 require.Equal(t, primitive.OpCodeSupported, response.Header.OpCode) 178 require.IsType(t, &message.Supported{}, response.Body.Message) 179 } 180 } 181 182 func testUseQuery(t *testing.T, clientConn *client.CqlClientConnection) { 183 useQuery := frame.NewFrame( 184 primitive.ProtocolVersion4, 185 client.ManagedStreamId, 186 &message.Query{Query: " USE \n ks1 "}, 187 ) 188 response, err := clientConn.SendAndReceive(useQuery) 189 require.NoError(t, err) 190 require.NotNil(t, response) 191 require.Equal(t, primitive.OpCodeResult, response.Header.OpCode) 192 require.IsType(t, &message.SetKeyspaceResult{}, response.Body.Message) 193 result := response.Body.Message.(*message.SetKeyspaceResult) 194 require.Equal(t, "ks1", result.Keyspace) 195 } 196 197 func testRegister(t *testing.T, clientConn *client.CqlClientConnection) { 198 register := frame.NewFrame( 199 primitive.ProtocolVersion4, 200 client.ManagedStreamId, 201 &message.Register{EventTypes: []primitive.EventType{primitive.EventTypeSchemaChange, primitive.EventTypeTopologyChange}}, 202 ) 203 response, err := clientConn.SendAndReceive(register) 204 require.NoError(t, err) 205 require.NotNil(t, response) 206 require.Equal(t, primitive.OpCodeReady, response.Header.OpCode) 207 require.IsType(t, &message.Ready{}, response.Body.Message) 208 }