github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/events/websockets/websockets_test.go (about) 1 // Copyright © 2021 Kaleido, Inc. 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 package websockets 18 19 import ( 20 "context" 21 "encoding/json" 22 "fmt" 23 "net/http" 24 "net/http/httptest" 25 "net/url" 26 "strings" 27 "testing" 28 29 "github.com/kaleido-io/firefly/internal/config" 30 "github.com/kaleido-io/firefly/internal/log" 31 "github.com/kaleido-io/firefly/internal/restclient" 32 "github.com/kaleido-io/firefly/internal/wsclient" 33 "github.com/kaleido-io/firefly/mocks/eventsmocks" 34 "github.com/kaleido-io/firefly/pkg/events" 35 "github.com/kaleido-io/firefly/pkg/fftypes" 36 "github.com/stretchr/testify/assert" 37 "github.com/stretchr/testify/mock" 38 ) 39 40 func newTestWebsockets(t *testing.T, cbs *eventsmocks.Callbacks, queryParams ...string) (ws *WebSockets, wsc wsclient.WSClient, cancel func()) { 41 config.Reset() 42 43 ws = &WebSockets{} 44 ctx, cancelCtx := context.WithCancel(context.Background()) 45 svrPrefix := config.NewPluginConfig("ut.websockets") 46 ws.InitPrefix(svrPrefix) 47 ws.Init(ctx, svrPrefix, cbs) 48 assert.Equal(t, "websockets", ws.Name()) 49 assert.NotNil(t, ws.Capabilities()) 50 cbs.On("ConnnectionClosed", mock.Anything).Return(nil).Maybe() 51 52 svr := httptest.NewServer(ws) 53 54 clientPrefix := config.NewPluginConfig("ut.wsclient") 55 wsclient.InitPrefix(clientPrefix) 56 qs := "" 57 if len(queryParams) > 0 { 58 qs = fmt.Sprintf("?%s", strings.Join(queryParams, "&")) 59 } 60 clientPrefix.Set(restclient.HTTPConfigURL, fmt.Sprintf("http://%s%s", svr.Listener.Addr(), qs)) 61 wsc, err := wsclient.New(ctx, clientPrefix, nil) 62 assert.NoError(t, err) 63 err = wsc.Connect() 64 assert.NoError(t, err) 65 66 return ws, wsc, func() { 67 cancelCtx() 68 wsc.Close() 69 ws.WaitClosed() 70 svr.Close() 71 } 72 } 73 74 func TestSendBadData(t *testing.T) { 75 cbs := &eventsmocks.Callbacks{} 76 _, wsc, cancel := newTestWebsockets(t, cbs) 77 defer cancel() 78 79 cbs.On("ConnnectionClosed", mock.Anything).Return(nil) 80 81 err := wsc.Send(context.Background(), []byte(`!json`)) 82 assert.NoError(t, err) 83 b := <-wsc.Receive() 84 var res fftypes.WSProtocolErrorPayload 85 err = json.Unmarshal(b, &res) 86 assert.NoError(t, err) 87 assert.Equal(t, fftypes.WSProtocolErrorEventType, res.Type) 88 assert.Regexp(t, "FF10176", res.Error) 89 } 90 91 func TestSendBadAction(t *testing.T) { 92 cbs := &eventsmocks.Callbacks{} 93 _, wsc, cancel := newTestWebsockets(t, cbs) 94 defer cancel() 95 cbs.On("ConnnectionClosed", mock.Anything).Return(nil) 96 97 err := wsc.Send(context.Background(), []byte(`{"type":"lobster"}`)) 98 assert.NoError(t, err) 99 b := <-wsc.Receive() 100 var res fftypes.WSProtocolErrorPayload 101 err = json.Unmarshal(b, &res) 102 assert.NoError(t, err) 103 assert.Equal(t, fftypes.WSProtocolErrorEventType, res.Type) 104 assert.Regexp(t, "FF10176", res.Error) 105 } 106 107 func TestSendEmptyStartAction(t *testing.T) { 108 cbs := &eventsmocks.Callbacks{} 109 _, wsc, cancel := newTestWebsockets(t, cbs) 110 defer cancel() 111 cbs.On("ConnnectionClosed", mock.Anything).Return(nil) 112 113 err := wsc.Send(context.Background(), []byte(`{"type":"start"}`)) 114 assert.NoError(t, err) 115 b := <-wsc.Receive() 116 var res fftypes.WSProtocolErrorPayload 117 err = json.Unmarshal(b, &res) 118 assert.NoError(t, err) 119 assert.Equal(t, fftypes.WSProtocolErrorEventType, res.Type) 120 assert.Regexp(t, "FF10176", res.Error) 121 } 122 123 func TestStartReceiveAckEphemeral(t *testing.T) { 124 log.SetLevel("trace") 125 126 cbs := &eventsmocks.Callbacks{} 127 ws, wsc, cancel := newTestWebsockets(t, cbs) 128 defer cancel() 129 var connID string 130 sub := cbs.On("EphemeralSubscription", 131 mock.MatchedBy(func(s string) bool { connID = s; return true }), 132 "ns1", mock.Anything, mock.Anything).Return(nil) 133 ack := cbs.On("DeliveryResponse", 134 mock.MatchedBy(func(s string) bool { return s == connID }), 135 mock.Anything).Return(nil) 136 137 waitSubscribed := make(chan struct{}) 138 sub.RunFn = func(a mock.Arguments) { 139 close(waitSubscribed) 140 } 141 142 waitAcked := make(chan struct{}) 143 ack.RunFn = func(a mock.Arguments) { 144 close(waitAcked) 145 } 146 147 err := wsc.Send(context.Background(), []byte(`{"type":"start","namespace":"ns1","ephemeral":true}`)) 148 assert.NoError(t, err) 149 150 <-waitSubscribed 151 ws.DeliveryRequest(connID, &fftypes.EventDelivery{ 152 Event: fftypes.Event{ID: fftypes.NewUUID()}, 153 Subscription: fftypes.SubscriptionRef{ID: fftypes.NewUUID()}, 154 }) 155 156 b := <-wsc.Receive() 157 var res fftypes.EventDelivery 158 err = json.Unmarshal(b, &res) 159 assert.NoError(t, err) 160 161 err = wsc.Send(context.Background(), []byte(`{"type":"ack"}`)) 162 assert.NoError(t, err) 163 164 <-waitAcked 165 cbs.AssertExpectations(t) 166 } 167 168 func TestStartReceiveDurable(t *testing.T) { 169 cbs := &eventsmocks.Callbacks{} 170 ws, wsc, cancel := newTestWebsockets(t, cbs) 171 defer cancel() 172 var connID string 173 sub := cbs.On("RegisterConnection", 174 mock.MatchedBy(func(s string) bool { connID = s; return true }), 175 mock.MatchedBy(func(subMatch events.SubscriptionMatcher) bool { 176 return subMatch(fftypes.SubscriptionRef{Namespace: "ns1", Name: "sub1"}) && 177 !subMatch(fftypes.SubscriptionRef{Namespace: "ns2", Name: "sub1"}) && 178 !subMatch(fftypes.SubscriptionRef{Namespace: "ns1", Name: "sub2"}) 179 }), 180 ).Return(nil) 181 ack := cbs.On("DeliveryResponse", 182 mock.MatchedBy(func(s string) bool { return s == connID }), 183 mock.Anything).Return(nil) 184 185 waitSubscribed := make(chan struct{}) 186 sub.RunFn = func(a mock.Arguments) { 187 close(waitSubscribed) 188 } 189 190 waitAcked := make(chan struct{}) 191 ack.RunFn = func(a mock.Arguments) { 192 close(waitAcked) 193 } 194 195 err := wsc.Send(context.Background(), []byte(`{"type":"start","namespace":"ns1","name":"sub1"}`)) 196 assert.NoError(t, err) 197 198 <-waitSubscribed 199 ws.DeliveryRequest(connID, &fftypes.EventDelivery{ 200 Event: fftypes.Event{ID: fftypes.NewUUID()}, 201 Subscription: fftypes.SubscriptionRef{ 202 ID: fftypes.NewUUID(), 203 Namespace: "ns1", 204 Name: "sub1", 205 }, 206 }) 207 // Put a second in flight 208 ws.DeliveryRequest(connID, &fftypes.EventDelivery{ 209 Event: fftypes.Event{ID: fftypes.NewUUID()}, 210 Subscription: fftypes.SubscriptionRef{ 211 ID: fftypes.NewUUID(), 212 Namespace: "ns1", 213 Name: "sub2", 214 }, 215 }) 216 217 b := <-wsc.Receive() 218 var res fftypes.EventDelivery 219 err = json.Unmarshal(b, &res) 220 assert.NoError(t, err) 221 222 assert.Equal(t, "ns1", res.Subscription.Namespace) 223 assert.Equal(t, "sub1", res.Subscription.Name) 224 err = wsc.Send(context.Background(), []byte(fmt.Sprintf(`{ 225 "type":"ack", 226 "id": "%s", 227 "subscription": { 228 "namespace": "ns1", 229 "name": "sub1" 230 } 231 }`, res.ID))) 232 assert.NoError(t, err) 233 234 <-waitAcked 235 236 // Check we left the right one behind 237 conn := ws.connections[connID] 238 assert.Equal(t, 1, len(conn.inflight)) 239 assert.Equal(t, "sub2", conn.inflight[0].Subscription.Name) 240 241 cbs.AssertExpectations(t) 242 } 243 244 func TestAutoStartReceiveAckEphemeral(t *testing.T) { 245 var connID string 246 cbs := &eventsmocks.Callbacks{} 247 sub := cbs.On("EphemeralSubscription", 248 mock.MatchedBy(func(s string) bool { connID = s; return true }), 249 "ns1", mock.Anything, mock.Anything).Return(nil) 250 ack := cbs.On("DeliveryResponse", 251 mock.MatchedBy(func(s string) bool { return s == connID }), 252 mock.Anything).Return(nil) 253 254 waitSubscribed := make(chan struct{}) 255 sub.RunFn = func(a mock.Arguments) { 256 close(waitSubscribed) 257 } 258 259 waitAcked := make(chan struct{}) 260 ack.RunFn = func(a mock.Arguments) { 261 close(waitAcked) 262 } 263 264 ws, wsc, cancel := newTestWebsockets(t, cbs, "ephemeral", "namespace=ns1") 265 defer cancel() 266 267 <-waitSubscribed 268 ws.DeliveryRequest(connID, &fftypes.EventDelivery{ 269 Event: fftypes.Event{ID: fftypes.NewUUID()}, 270 Subscription: fftypes.SubscriptionRef{ID: fftypes.NewUUID()}, 271 }) 272 273 b := <-wsc.Receive() 274 var res fftypes.EventDelivery 275 err := json.Unmarshal(b, &res) 276 assert.NoError(t, err) 277 278 err = wsc.Send(context.Background(), []byte(`{"type":"ack"}`)) 279 assert.NoError(t, err) 280 281 <-waitAcked 282 cbs.AssertExpectations(t) 283 } 284 285 func TestAutoStartBadOptions(t *testing.T) { 286 cbs := &eventsmocks.Callbacks{} 287 _, wsc, cancel := newTestWebsockets(t, cbs, "name=missingnamespace") 288 defer cancel() 289 290 b := <-wsc.Receive() 291 var res fftypes.WSProtocolErrorPayload 292 err := json.Unmarshal(b, &res) 293 assert.NoError(t, err) 294 assert.Regexp(t, "FF10178", res.Error) 295 cbs.AssertExpectations(t) 296 } 297 298 func TestHandleAckWithAutoAck(t *testing.T) { 299 eventUUID := fftypes.NewUUID() 300 wsc := &websocketConnection{ 301 ctx: context.Background(), 302 startedCount: 1, 303 sendMessages: make(chan interface{}, 1), 304 inflight: []*fftypes.EventDeliveryResponse{ 305 {ID: eventUUID}, 306 }, 307 autoAck: true, 308 } 309 err := wsc.handleAck(&fftypes.WSClientActionAckPayload{ 310 ID: eventUUID, 311 }) 312 assert.Regexp(t, "FF10180", err) 313 } 314 315 func TestHandleStartFlippingAutoAck(t *testing.T) { 316 eventUUID := fftypes.NewUUID() 317 wsc := &websocketConnection{ 318 ctx: context.Background(), 319 startedCount: 1, 320 sendMessages: make(chan interface{}, 1), 321 inflight: []*fftypes.EventDeliveryResponse{ 322 {ID: eventUUID}, 323 }, 324 autoAck: true, 325 } 326 no := false 327 err := wsc.handleStart(&fftypes.WSClientActionStartPayload{ 328 AutoAck: &no, 329 }) 330 assert.Regexp(t, "FF10179", err) 331 } 332 333 func TestHandleAckMultipleStartedMissingSub(t *testing.T) { 334 eventUUID := fftypes.NewUUID() 335 wsc := &websocketConnection{ 336 ctx: context.Background(), 337 startedCount: 3, 338 sendMessages: make(chan interface{}, 1), 339 inflight: []*fftypes.EventDeliveryResponse{ 340 {ID: eventUUID}, 341 }, 342 } 343 err := wsc.handleAck(&fftypes.WSClientActionAckPayload{ 344 ID: eventUUID, 345 }) 346 assert.Regexp(t, "FF10175", err) 347 348 } 349 350 func TestHandleAckMultipleStartedNoSubSingleMatch(t *testing.T) { 351 cbs := &eventsmocks.Callbacks{} 352 cbs.On("DeliveryResponse", mock.Anything, mock.Anything).Return(nil) 353 eventUUID := fftypes.NewUUID() 354 wsc := &websocketConnection{ 355 ctx: context.Background(), 356 ws: &WebSockets{ 357 ctx: context.Background(), 358 callbacks: cbs, 359 }, 360 startedCount: 1, 361 sendMessages: make(chan interface{}, 1), 362 inflight: []*fftypes.EventDeliveryResponse{ 363 {ID: eventUUID}, 364 }, 365 } 366 err := wsc.handleAck(&fftypes.WSClientActionAckPayload{ 367 ID: eventUUID, 368 }) 369 assert.NoError(t, err) 370 371 } 372 373 func TestHandleAckNoneInflight(t *testing.T) { 374 wsc := &websocketConnection{ 375 ctx: context.Background(), 376 sendMessages: make(chan interface{}, 1), 377 inflight: []*fftypes.EventDeliveryResponse{}, 378 } 379 err := wsc.handleAck(&fftypes.WSClientActionAckPayload{}) 380 assert.Regexp(t, "FF10175", err) 381 } 382 383 func TestProtocolErrorSwallowsSendError(t *testing.T) { 384 ctx, cancel := context.WithCancel(context.Background()) 385 cancel() 386 wsc := &websocketConnection{ 387 ctx: ctx, 388 sendMessages: make(chan interface{}), 389 } 390 wsc.protocolError(fmt.Errorf("pop")) 391 392 } 393 394 func TestSendLoopBadData(t *testing.T) { 395 cbs := &eventsmocks.Callbacks{} 396 ws, wsc, cancel := newTestWebsockets(t, cbs) 397 defer cancel() 398 399 subscribedConn := make(chan string, 1) 400 cbs.On("EphemeralSubscription", 401 mock.MatchedBy(func(s string) bool { 402 subscribedConn <- s 403 return true 404 }), 405 "ns1", mock.Anything, mock.Anything).Return(nil) 406 407 err := wsc.Send(context.Background(), []byte(`{"type":"start","namespace":"ns1","ephemeral":true}`)) 408 assert.NoError(t, err) 409 410 connID := <-subscribedConn 411 connection := ws.connections[connID] 412 connection.sendMessages <- map[bool]bool{false: true} // no JSON representation 413 414 // Connection should close on its own with that bad data 415 <-connection.senderDone 416 417 } 418 419 func TestUpgradeFail(t *testing.T) { 420 cbs := &eventsmocks.Callbacks{} 421 _, wsc, cancel := newTestWebsockets(t, cbs) 422 defer cancel() 423 424 u, _ := url.Parse(wsc.URL()) 425 u.Scheme = "http" 426 res, err := http.Get(u.String()) 427 assert.NoError(t, err) 428 assert.Equal(t, 400, res.StatusCode) 429 430 } 431 func TestConnectionDispatchAfterClose(t *testing.T) { 432 ctx, cancel := context.WithCancel(context.Background()) 433 cancel() 434 wsc := &websocketConnection{ 435 ctx: ctx, 436 } 437 err := wsc.dispatch(&fftypes.EventDelivery{}) 438 assert.Regexp(t, "FF10160", err) 439 } 440 441 func TestWebsocketDispatchAfterClose(t *testing.T) { 442 ws := &WebSockets{ 443 ctx: context.Background(), 444 connections: make(map[string]*websocketConnection), 445 } 446 err := ws.DeliveryRequest("gone", &fftypes.EventDelivery{}) 447 assert.Regexp(t, "FF10173", err) 448 } 449 450 func TestDispatchAutoAck(t *testing.T) { 451 cbs := &eventsmocks.Callbacks{} 452 cbs.On("DeliveryResponse", mock.Anything, mock.Anything).Return(nil) 453 wsc := &websocketConnection{ 454 ctx: context.Background(), 455 connID: fftypes.NewUUID().String(), 456 ws: &WebSockets{ 457 ctx: context.Background(), 458 callbacks: cbs, 459 connections: make(map[string]*websocketConnection), 460 }, 461 startedCount: 1, 462 sendMessages: make(chan interface{}, 1), 463 autoAck: true, 464 } 465 wsc.ws.connections[wsc.connID] = wsc 466 err := wsc.ws.DeliveryRequest(wsc.connID, &fftypes.EventDelivery{ 467 Event: fftypes.Event{ID: fftypes.NewUUID()}, 468 Subscription: fftypes.SubscriptionRef{ID: fftypes.NewUUID(), Namespace: "ns1", Name: "sub1"}, 469 }) 470 assert.NoError(t, err) 471 cbs.AssertExpectations(t) 472 }