github.com/99designs/gqlgen@v0.17.45/graphql/handler/transport/websocket_test.go (about) 1 package transport_test 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "testing" 11 "time" 12 13 "github.com/gorilla/websocket" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 "github.com/vektah/gqlparser/v2" 17 "github.com/vektah/gqlparser/v2/ast" 18 19 "github.com/99designs/gqlgen/client" 20 "github.com/99designs/gqlgen/graphql" 21 "github.com/99designs/gqlgen/graphql/handler" 22 "github.com/99designs/gqlgen/graphql/handler/testserver" 23 "github.com/99designs/gqlgen/graphql/handler/transport" 24 ) 25 26 type ckey string 27 28 func TestWebsocket(t *testing.T) { 29 handler := testserver.New() 30 handler.AddTransport(transport.Websocket{}) 31 32 srv := httptest.NewServer(handler) 33 defer srv.Close() 34 35 t.Run("client must send valid json", func(t *testing.T) { 36 c := wsConnect(srv.URL) 37 defer c.Close() 38 39 writeRaw(c, "hello") 40 41 msg := readOp(c) 42 assert.Equal(t, "connection_error", msg.Type) 43 assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload)) 44 }) 45 46 t.Run("client can terminate before init", func(t *testing.T) { 47 c := wsConnect(srv.URL) 48 defer c.Close() 49 50 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 51 52 _, _, err := c.ReadMessage() 53 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 54 }) 55 56 t.Run("client must send init first", func(t *testing.T) { 57 c := wsConnect(srv.URL) 58 defer c.Close() 59 60 require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg})) 61 62 msg := readOp(c) 63 assert.Equal(t, connectionErrorMsg, msg.Type) 64 assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload)) 65 }) 66 67 t.Run("server acks init", func(t *testing.T) { 68 c := wsConnect(srv.URL) 69 defer c.Close() 70 71 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 72 73 assert.Equal(t, connectionAckMsg, readOp(c).Type) 74 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 75 }) 76 77 t.Run("client can terminate before run", func(t *testing.T) { 78 c := wsConnect(srv.URL) 79 defer c.Close() 80 81 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 82 assert.Equal(t, connectionAckMsg, readOp(c).Type) 83 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 84 85 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 86 87 _, _, err := c.ReadMessage() 88 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 89 }) 90 91 t.Run("client gets parse errors", func(t *testing.T) { 92 c := wsConnect(srv.URL) 93 defer c.Close() 94 95 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 96 assert.Equal(t, connectionAckMsg, readOp(c).Type) 97 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 98 99 require.NoError(t, c.WriteJSON(&operationMessage{ 100 Type: startMsg, 101 ID: "test_1", 102 Payload: json.RawMessage(`{"query": "!"}`), 103 })) 104 105 msg := readOp(c) 106 assert.Equal(t, errorMsg, msg.Type) 107 assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload)) 108 }) 109 110 t.Run("client can receive data", func(t *testing.T) { 111 c := wsConnect(srv.URL) 112 defer c.Close() 113 114 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 115 assert.Equal(t, connectionAckMsg, readOp(c).Type) 116 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 117 118 require.NoError(t, c.WriteJSON(&operationMessage{ 119 Type: startMsg, 120 ID: "test_1", 121 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 122 })) 123 124 handler.SendNextSubscriptionMessage() 125 msg := readOp(c) 126 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 127 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 128 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 129 130 handler.SendNextSubscriptionMessage() 131 msg = readOp(c) 132 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 133 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 134 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 135 136 require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"})) 137 138 msg = readOp(c) 139 require.Equal(t, completeMsg, msg.Type) 140 require.Equal(t, "test_1", msg.ID) 141 142 // At this point we should be done and should not receive another message. 143 c.SetReadDeadline(time.Now().UTC().Add(1 * time.Millisecond)) 144 145 err := c.ReadJSON(&msg) 146 if err == nil { 147 // This should not send a second close message for the same id. 148 require.NotEqual(t, completeMsg, msg.Type) 149 require.NotEqual(t, "test_1", msg.ID) 150 } else { 151 assert.Contains(t, err.Error(), "timeout") 152 } 153 }) 154 } 155 156 func TestWebsocketWithKeepAlive(t *testing.T) { 157 h := testserver.New() 158 h.AddTransport(transport.Websocket{ 159 KeepAlivePingInterval: 100 * time.Millisecond, 160 }) 161 162 srv := httptest.NewServer(h) 163 defer srv.Close() 164 165 c := wsConnect(srv.URL) 166 defer c.Close() 167 168 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 169 assert.Equal(t, connectionAckMsg, readOp(c).Type) 170 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 171 172 require.NoError(t, c.WriteJSON(&operationMessage{ 173 Type: startMsg, 174 ID: "test_1", 175 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 176 })) 177 178 // keepalive 179 msg := readOp(c) 180 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 181 182 // server message 183 h.SendNextSubscriptionMessage() 184 msg = readOp(c) 185 assert.Equal(t, dataMsg, msg.Type) 186 187 // keepalive 188 msg = readOp(c) 189 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 190 } 191 192 func TestWebsocketInitFunc(t *testing.T) { 193 t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) { 194 h := testserver.New() 195 h.AddTransport(transport.Websocket{}) 196 srv := httptest.NewServer(h) 197 defer srv.Close() 198 199 c := wsConnect(srv.URL) 200 defer c.Close() 201 202 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 203 204 assert.Equal(t, connectionAckMsg, readOp(c).Type) 205 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 206 }) 207 208 t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 209 h := testserver.New() 210 h.AddTransport(transport.Websocket{ 211 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { 212 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil 213 }, 214 }) 215 srv := httptest.NewServer(h) 216 defer srv.Close() 217 218 c := wsConnect(srv.URL) 219 defer c.Close() 220 221 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 222 223 assert.Equal(t, connectionAckMsg, readOp(c).Type) 224 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 225 }) 226 227 t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 228 h := testserver.New() 229 h.AddTransport(transport.Websocket{ 230 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { 231 return ctx, nil, errors.New("invalid init payload") 232 }, 233 }) 234 srv := httptest.NewServer(h) 235 defer srv.Close() 236 237 c := wsConnect(srv.URL) 238 defer c.Close() 239 240 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 241 242 msg := readOp(c) 243 assert.Equal(t, connectionErrorMsg, msg.Type) 244 assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload)) 245 }) 246 247 t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) { 248 es := &graphql.ExecutableSchemaMock{ 249 ExecFunc: func(ctx context.Context) graphql.ResponseHandler { 250 assert.Equal(t, "newvalue", ctx.Value(ckey("newkey"))) 251 return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)}) 252 }, 253 SchemaFunc: func() *ast.Schema { 254 return gqlparser.MustLoadSchema(&ast.Source{Input: ` 255 schema { query: Query } 256 type Query { 257 empty: String 258 } 259 `}) 260 }, 261 } 262 h := handler.New(es) 263 264 h.AddTransport(transport.Websocket{ 265 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { 266 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil 267 }, 268 }) 269 270 c := client.New(h) 271 272 socket := c.Websocket("{ empty } ") 273 defer socket.Close() 274 var resp struct { 275 Empty string 276 } 277 err := socket.Next(&resp) 278 require.NoError(t, err) 279 assert.Equal(t, "ok", resp.Empty) 280 }) 281 282 t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) { 283 h := testserver.New() 284 var cancel func() 285 h.AddTransport(transport.Websocket{ 286 InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) { 287 newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5) 288 return 289 }, 290 }) 291 srv := httptest.NewServer(h) 292 defer srv.Close() 293 294 c := wsConnect(srv.URL) 295 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 296 assert.Equal(t, connectionAckMsg, readOp(c).Type) 297 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 298 299 // Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy) 300 defer cancel() 301 302 time.Sleep(time.Millisecond * 10) 303 m := readOp(c) 304 assert.Equal(t, m.Type, connectionErrorMsg) 305 assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`) 306 }) 307 t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 308 h := testserver.New() 309 h.AddTransport(transport.Websocket{ 310 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { 311 initResponsePayload := transport.InitPayload{"trackingId": "123-456"} 312 return context.WithValue(ctx, ckey("newkey"), "newvalue"), &initResponsePayload, nil 313 }, 314 }) 315 srv := httptest.NewServer(h) 316 defer srv.Close() 317 318 c := wsConnect(srv.URL) 319 defer c.Close() 320 321 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 322 323 connAck := readOp(c) 324 assert.Equal(t, connectionAckMsg, connAck.Type) 325 326 var payload map[string]interface{} 327 err := json.Unmarshal(connAck.Payload, &payload) 328 if err != nil { 329 t.Fatal("Unexpected Error", err) 330 } 331 assert.EqualValues(t, "123-456", payload["trackingId"]) 332 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 333 }) 334 } 335 336 func TestWebSocketInitTimeout(t *testing.T) { 337 t.Run("times out if no init message is received within the configured duration", func(t *testing.T) { 338 h := testserver.New() 339 h.AddTransport(transport.Websocket{ 340 InitTimeout: 5 * time.Millisecond, 341 }) 342 srv := httptest.NewServer(h) 343 defer srv.Close() 344 345 c := wsConnect(srv.URL) 346 defer c.Close() 347 348 var msg operationMessage 349 err := c.ReadJSON(&msg) 350 assert.Error(t, err) 351 assert.Contains(t, err.Error(), "timeout") 352 }) 353 354 t.Run("keeps waiting for an init message if no time out is configured", func(t *testing.T) { 355 h := testserver.New() 356 h.AddTransport(transport.Websocket{}) 357 srv := httptest.NewServer(h) 358 defer srv.Close() 359 360 c := wsConnect(srv.URL) 361 defer c.Close() 362 363 done := make(chan interface{}, 1) 364 go func() { 365 var msg operationMessage 366 _ = c.ReadJSON(&msg) 367 done <- 1 368 }() 369 370 select { 371 case <-done: 372 assert.Fail(t, "web socket read operation finished while it shouldn't have") 373 case <-time.After(100 * time.Millisecond): 374 // Success! I guess? Can't really wait forever to see if the read waits forever... 375 } 376 }) 377 } 378 379 func TestWebSocketErrorFunc(t *testing.T) { 380 t.Run("the error handler gets called when an error occurs", func(t *testing.T) { 381 errFuncCalled := make(chan bool, 1) 382 h := testserver.New() 383 h.AddTransport(transport.Websocket{ 384 ErrorFunc: func(_ context.Context, err error) { 385 require.Error(t, err) 386 assert.Equal(t, err.Error(), "websocket read: invalid message received") 387 assert.IsType(t, transport.WebsocketError{}, err) 388 assert.True(t, err.(transport.WebsocketError).IsReadError) 389 errFuncCalled <- true 390 }, 391 }) 392 393 srv := httptest.NewServer(h) 394 defer srv.Close() 395 396 c := wsConnect(srv.URL) 397 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 398 assert.Equal(t, connectionAckMsg, readOp(c).Type) 399 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 400 require.NoError(t, c.WriteMessage(websocket.TextMessage, []byte("mark my words, you will regret this"))) 401 402 select { 403 case res := <-errFuncCalled: 404 assert.True(t, res) 405 case <-time.NewTimer(time.Millisecond * 20).C: 406 assert.Fail(t, "The fail handler was not called in time") 407 } 408 }) 409 410 t.Run("init func errors do not call the error handler", func(t *testing.T) { 411 h := testserver.New() 412 h.AddTransport(transport.Websocket{ 413 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) { 414 return ctx, nil, errors.New("this is not what we agreed upon") 415 }, 416 ErrorFunc: func(_ context.Context, err error) { 417 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) 418 }, 419 }) 420 srv := httptest.NewServer(h) 421 defer srv.Close() 422 423 c := wsConnect(srv.URL) 424 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 425 time.Sleep(time.Millisecond * 20) 426 }) 427 428 t.Run("init func context closes do not call the error handler", func(t *testing.T) { 429 h := testserver.New() 430 h.AddTransport(transport.Websocket{ 431 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) { 432 newCtx, cancel := context.WithCancel(ctx) 433 time.AfterFunc(time.Millisecond*5, cancel) 434 return newCtx, nil, nil 435 }, 436 ErrorFunc: func(_ context.Context, err error) { 437 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) 438 }, 439 }) 440 srv := httptest.NewServer(h) 441 defer srv.Close() 442 443 c := wsConnect(srv.URL) 444 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 445 assert.Equal(t, connectionAckMsg, readOp(c).Type) 446 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 447 time.Sleep(time.Millisecond * 20) 448 }) 449 450 t.Run("init func context deadlines do not call the error handler", func(t *testing.T) { 451 h := testserver.New() 452 var cancel func() 453 h.AddTransport(transport.Websocket{ 454 InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) { 455 newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5)) 456 return newCtx, nil, nil 457 }, 458 ErrorFunc: func(_ context.Context, err error) { 459 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) 460 }, 461 }) 462 srv := httptest.NewServer(h) 463 defer srv.Close() 464 465 c := wsConnect(srv.URL) 466 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 467 assert.Equal(t, connectionAckMsg, readOp(c).Type) 468 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 469 470 // Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy) 471 defer cancel() 472 473 time.Sleep(time.Millisecond * 20) 474 }) 475 } 476 477 func TestWebSocketCloseFunc(t *testing.T) { 478 t.Run("the on close handler gets called when the websocket is closed", func(t *testing.T) { 479 closeFuncCalled := make(chan bool, 1) 480 h := testserver.New() 481 h.AddTransport(transport.Websocket{ 482 CloseFunc: func(_ context.Context, _closeCode int) { 483 closeFuncCalled <- true 484 }, 485 }) 486 487 srv := httptest.NewServer(h) 488 defer srv.Close() 489 490 c := wsConnect(srv.URL) 491 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 492 assert.Equal(t, connectionAckMsg, readOp(c).Type) 493 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 494 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 495 496 select { 497 case res := <-closeFuncCalled: 498 assert.True(t, res) 499 case <-time.NewTimer(time.Millisecond * 20).C: 500 assert.Fail(t, "The close handler was not called in time") 501 } 502 }) 503 504 t.Run("the on close handler gets called only once when the websocket is closed", func(t *testing.T) { 505 closeFuncCalled := make(chan bool, 1) 506 h := testserver.New() 507 h.AddTransport(transport.Websocket{ 508 CloseFunc: func(_ context.Context, _closeCode int) { 509 closeFuncCalled <- true 510 }, 511 }) 512 513 srv := httptest.NewServer(h) 514 defer srv.Close() 515 516 c := wsConnect(srv.URL) 517 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 518 assert.Equal(t, connectionAckMsg, readOp(c).Type) 519 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 520 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 521 522 select { 523 case res := <-closeFuncCalled: 524 assert.True(t, res) 525 case <-time.NewTimer(time.Millisecond * 20).C: 526 assert.Fail(t, "The close handler was not called in time") 527 } 528 529 select { 530 case <-closeFuncCalled: 531 assert.Fail(t, "The close handler was called more than once") 532 case <-time.NewTimer(time.Millisecond * 20).C: 533 // ok 534 } 535 }) 536 537 t.Run("init func errors call the close handler", func(t *testing.T) { 538 h := testserver.New() 539 closeFuncCalled := make(chan bool, 1) 540 h.AddTransport(transport.Websocket{ 541 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) { 542 return ctx, nil, errors.New("error during init") 543 }, 544 CloseFunc: func(_ context.Context, _closeCode int) { 545 closeFuncCalled <- true 546 }, 547 }) 548 srv := httptest.NewServer(h) 549 defer srv.Close() 550 551 c := wsConnect(srv.URL) 552 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 553 select { 554 case res := <-closeFuncCalled: 555 assert.True(t, res) 556 case <-time.NewTimer(time.Millisecond * 20).C: 557 assert.Fail(t, "The close handler was not called in time") 558 } 559 }) 560 } 561 562 func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) { 563 initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) { 564 h := testserver.New() 565 h.AddTransport(ws) 566 return h, httptest.NewServer(h) 567 } 568 569 t.Run("server acks init", func(t *testing.T) { 570 _, srv := initialize(transport.Websocket{}) 571 defer srv.Close() 572 573 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 574 defer c.Close() 575 576 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 577 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 578 }) 579 580 t.Run("client can receive data", func(t *testing.T) { 581 handler, srv := initialize(transport.Websocket{}) 582 defer srv.Close() 583 584 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 585 defer c.Close() 586 587 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 588 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 589 590 require.NoError(t, c.WriteJSON(&operationMessage{ 591 Type: graphqltransportwsSubscribeMsg, 592 ID: "test_1", 593 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 594 })) 595 596 handler.SendNextSubscriptionMessage() 597 msg := readOp(c) 598 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) 599 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 600 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 601 602 handler.SendNextSubscriptionMessage() 603 msg = readOp(c) 604 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) 605 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 606 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 607 608 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"})) 609 610 msg = readOp(c) 611 require.Equal(t, graphqltransportwsCompleteMsg, msg.Type) 612 require.Equal(t, "test_1", msg.ID) 613 }) 614 615 t.Run("receives no graphql-ws keep alive messages", func(t *testing.T) { 616 _, srv := initialize(transport.Websocket{KeepAlivePingInterval: 5 * time.Millisecond}) 617 defer srv.Close() 618 619 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 620 defer c.Close() 621 622 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 623 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 624 625 // If the keep-alives are sent, this deadline will not be used, and no timeout error will be found 626 c.SetReadDeadline(time.Now().UTC().Add(50 * time.Millisecond)) 627 var msg operationMessage 628 err := c.ReadJSON(&msg) 629 require.Error(t, err) 630 assert.Contains(t, err.Error(), "timeout") 631 }) 632 } 633 634 func TestWebsocketWithPingPongInterval(t *testing.T) { 635 initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) { 636 h := testserver.New() 637 h.AddTransport(ws) 638 return h, httptest.NewServer(h) 639 } 640 641 t.Run("client receives ping and responds with pong", func(t *testing.T) { 642 _, srv := initialize(transport.Websocket{PingPongInterval: 20 * time.Millisecond}) 643 defer srv.Close() 644 645 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 646 defer c.Close() 647 648 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 649 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 650 651 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) 652 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg})) 653 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) 654 }) 655 656 t.Run("client sends ping and expects pong", func(t *testing.T) { 657 _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond}) 658 defer srv.Close() 659 }) 660 661 t.Run("client sends ping and expects pong", func(t *testing.T) { 662 _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond}) 663 defer srv.Close() 664 665 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 666 defer c.Close() 667 668 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 669 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 670 671 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg})) 672 assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type) 673 }) 674 675 t.Run("server closes with error if client does not pong and !MissingPongOk", func(t *testing.T) { 676 h := testserver.New() 677 closeFuncCalled := make(chan bool, 1) 678 h.AddTransport(transport.Websocket{ 679 MissingPongOk: false, // default value but beign explicit for test clarity. 680 PingPongInterval: 5 * time.Millisecond, 681 CloseFunc: func(_ context.Context, _closeCode int) { 682 closeFuncCalled <- true 683 }, 684 }) 685 686 srv := httptest.NewServer(h) 687 defer srv.Close() 688 689 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 690 defer c.Close() 691 692 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 693 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 694 695 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) 696 697 select { 698 case res := <-closeFuncCalled: 699 assert.True(t, res) 700 case <-time.NewTimer(time.Millisecond * 20).C: 701 // with a 5ms interval 10ms should be the timeout, double that to make the test less likely to flake under load 702 assert.Fail(t, "The close handler was not called in time") 703 } 704 }) 705 706 t.Run("server does not close with error if client does not pong and MissingPongOk", func(t *testing.T) { 707 h := testserver.New() 708 closeFuncCalled := make(chan bool, 1) 709 h.AddTransport(transport.Websocket{ 710 MissingPongOk: true, 711 PingPongInterval: 10 * time.Millisecond, 712 CloseFunc: func(_ context.Context, _closeCode int) { 713 closeFuncCalled <- true 714 }, 715 }) 716 717 srv := httptest.NewServer(h) 718 defer srv.Close() 719 720 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 721 defer c.Close() 722 723 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 724 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 725 726 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) 727 728 select { 729 case <-closeFuncCalled: 730 assert.Fail(t, "The close handler was called even with MissingPongOk = true") 731 case _, ok := <-time.NewTimer(time.Millisecond * 20).C: 732 assert.True(t, ok) 733 } 734 }) 735 736 t.Run("ping-pongs are not sent when the graphql-ws sub protocol is used", func(t *testing.T) { 737 // Regression test 738 // --- 739 // Before the refactor, the code would try to convert a ping message to a graphql-ws message type 740 // But since this message type does not exist in the graphql-ws sub protocol, it would fail 741 742 _, srv := initialize(transport.Websocket{ 743 PingPongInterval: 5 * time.Millisecond, 744 KeepAlivePingInterval: 10 * time.Millisecond, 745 }) 746 defer srv.Close() 747 748 // Create connection 749 c := wsConnect(srv.URL) 750 defer c.Close() 751 752 // Initialize connection 753 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 754 assert.Equal(t, connectionAckMsg, readOp(c).Type) 755 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 756 757 // Wait for a few more keep alives to be sure nothing goes wrong 758 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 759 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 760 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 761 }) 762 t.Run("pong only messages are sent when configured with graphql-transport-ws", func(t *testing.T) { 763 764 h, srv := initialize(transport.Websocket{PongOnlyInterval: 10 * time.Millisecond}) 765 defer srv.Close() 766 767 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 768 defer c.Close() 769 770 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 771 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 772 773 assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type) 774 775 require.NoError(t, c.WriteJSON(&operationMessage{ 776 Type: graphqltransportwsSubscribeMsg, 777 ID: "test_1", 778 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 779 })) 780 781 // pong 782 msg := readOp(c) 783 assert.Equal(t, graphqltransportwsPongMsg, msg.Type) 784 785 // server message 786 h.SendNextSubscriptionMessage() 787 msg = readOp(c) 788 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) 789 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 790 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 791 792 // keepalive 793 msg = readOp(c) 794 assert.Equal(t, graphqltransportwsPongMsg, msg.Type) 795 }) 796 797 } 798 799 func wsConnect(url string) *websocket.Conn { 800 return wsConnectWithSubprocotol(url, "") 801 } 802 803 func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn { 804 h := make(http.Header) 805 if subprocotol != "" { 806 h.Add("Sec-WebSocket-Protocol", subprocotol) 807 } 808 809 c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h) 810 if err != nil { 811 panic(err) 812 } 813 _ = resp.Body.Close() 814 815 return c 816 } 817 818 func writeRaw(conn *websocket.Conn, msg string) { 819 if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { 820 panic(err) 821 } 822 } 823 824 func readOp(conn *websocket.Conn) operationMessage { 825 var msg operationMessage 826 if err := conn.ReadJSON(&msg); err != nil { 827 panic(err) 828 } 829 return msg 830 } 831 832 // copied out from websocket_graphqlws.go to keep these private 833 834 const ( 835 connectionInitMsg = "connection_init" // Client -> Server 836 connectionTerminateMsg = "connection_terminate" // Client -> Server 837 startMsg = "start" // Client -> Server 838 stopMsg = "stop" // Client -> Server 839 connectionAckMsg = "connection_ack" // Server -> Client 840 connectionErrorMsg = "connection_error" // Server -> Client 841 dataMsg = "data" // Server -> Client 842 errorMsg = "error" // Server -> Client 843 completeMsg = "complete" // Server -> Client 844 connectionKeepAliveMsg = "ka" // Server -> Client 845 ) 846 847 // copied out from websocket_graphql_transport_ws.go to keep these private 848 849 const ( 850 graphqltransportwsSubprotocol = "graphql-transport-ws" 851 852 graphqltransportwsConnectionInitMsg = "connection_init" 853 graphqltransportwsConnectionAckMsg = "connection_ack" 854 graphqltransportwsSubscribeMsg = "subscribe" 855 graphqltransportwsNextMsg = "next" 856 graphqltransportwsCompleteMsg = "complete" 857 graphqltransportwsPingMsg = "ping" 858 graphqltransportwsPongMsg = "pong" 859 ) 860 861 type operationMessage struct { 862 Payload json.RawMessage `json:"payload,omitempty"` 863 ID string `json:"id,omitempty"` 864 Type string `json:"type"` 865 }