github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/graphql/handler/transport/websocket_test.go (about) 1 package transport_test 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "testing" 11 "time" 12 13 "github.com/geneva/gqlgen/client" 14 "github.com/geneva/gqlgen/graphql" 15 "github.com/geneva/gqlgen/graphql/handler" 16 "github.com/geneva/gqlgen/graphql/handler/testserver" 17 "github.com/geneva/gqlgen/graphql/handler/transport" 18 "github.com/gorilla/websocket" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 "github.com/vektah/gqlparser/v2" 22 "github.com/vektah/gqlparser/v2/ast" 23 ) 24 25 type ckey string 26 27 func TestWebsocket(t *testing.T) { 28 handler := testserver.New() 29 handler.AddTransport(transport.Websocket{}) 30 31 srv := httptest.NewServer(handler) 32 defer srv.Close() 33 34 t.Run("client must send valid json", func(t *testing.T) { 35 c := wsConnect(srv.URL) 36 defer c.Close() 37 38 writeRaw(c, "hello") 39 40 msg := readOp(c) 41 assert.Equal(t, "connection_error", msg.Type) 42 assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload)) 43 }) 44 45 t.Run("client can terminate before init", func(t *testing.T) { 46 c := wsConnect(srv.URL) 47 defer c.Close() 48 49 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 50 51 _, _, err := c.ReadMessage() 52 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 53 }) 54 55 t.Run("client must send init first", func(t *testing.T) { 56 c := wsConnect(srv.URL) 57 defer c.Close() 58 59 require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg})) 60 61 msg := readOp(c) 62 assert.Equal(t, connectionErrorMsg, msg.Type) 63 assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload)) 64 }) 65 66 t.Run("server acks init", func(t *testing.T) { 67 c := wsConnect(srv.URL) 68 defer c.Close() 69 70 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 71 72 assert.Equal(t, connectionAckMsg, readOp(c).Type) 73 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 74 }) 75 76 t.Run("client can terminate before run", func(t *testing.T) { 77 c := wsConnect(srv.URL) 78 defer c.Close() 79 80 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 81 assert.Equal(t, connectionAckMsg, readOp(c).Type) 82 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 83 84 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 85 86 _, _, err := c.ReadMessage() 87 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 88 }) 89 90 t.Run("client gets parse errors", func(t *testing.T) { 91 c := wsConnect(srv.URL) 92 defer c.Close() 93 94 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 95 assert.Equal(t, connectionAckMsg, readOp(c).Type) 96 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 97 98 require.NoError(t, c.WriteJSON(&operationMessage{ 99 Type: startMsg, 100 ID: "test_1", 101 Payload: json.RawMessage(`{"query": "!"}`), 102 })) 103 104 msg := readOp(c) 105 assert.Equal(t, errorMsg, msg.Type) 106 assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload)) 107 }) 108 109 t.Run("client can receive data", func(t *testing.T) { 110 c := wsConnect(srv.URL) 111 defer c.Close() 112 113 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 114 assert.Equal(t, connectionAckMsg, readOp(c).Type) 115 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 116 117 require.NoError(t, c.WriteJSON(&operationMessage{ 118 Type: startMsg, 119 ID: "test_1", 120 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 121 })) 122 123 handler.SendNextSubscriptionMessage() 124 msg := readOp(c) 125 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 126 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 127 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 128 129 handler.SendNextSubscriptionMessage() 130 msg = readOp(c) 131 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 132 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 133 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 134 135 require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"})) 136 137 msg = readOp(c) 138 require.Equal(t, completeMsg, msg.Type) 139 require.Equal(t, "test_1", msg.ID) 140 141 // At this point we should be done and should not receive another message. 142 c.SetReadDeadline(time.Now().UTC().Add(1 * time.Millisecond)) 143 144 err := c.ReadJSON(&msg) 145 if err == nil { 146 // This should not send a second close message for the same id. 147 require.NotEqual(t, completeMsg, msg.Type) 148 require.NotEqual(t, "test_1", msg.ID) 149 } else { 150 assert.Contains(t, err.Error(), "timeout") 151 } 152 }) 153 } 154 155 func TestWebsocketWithKeepAlive(t *testing.T) { 156 h := testserver.New() 157 h.AddTransport(transport.Websocket{ 158 KeepAlivePingInterval: 100 * time.Millisecond, 159 }) 160 161 srv := httptest.NewServer(h) 162 defer srv.Close() 163 164 c := wsConnect(srv.URL) 165 defer c.Close() 166 167 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 168 assert.Equal(t, connectionAckMsg, readOp(c).Type) 169 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 170 171 require.NoError(t, c.WriteJSON(&operationMessage{ 172 Type: startMsg, 173 ID: "test_1", 174 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 175 })) 176 177 // keepalive 178 msg := readOp(c) 179 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 180 181 // server message 182 h.SendNextSubscriptionMessage() 183 msg = readOp(c) 184 assert.Equal(t, dataMsg, msg.Type) 185 186 // keepalive 187 msg = readOp(c) 188 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 189 } 190 191 func TestWebsocketInitFunc(t *testing.T) { 192 t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) { 193 h := testserver.New() 194 h.AddTransport(transport.Websocket{}) 195 srv := httptest.NewServer(h) 196 defer srv.Close() 197 198 c := wsConnect(srv.URL) 199 defer c.Close() 200 201 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 202 203 assert.Equal(t, connectionAckMsg, readOp(c).Type) 204 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 205 }) 206 207 t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 208 h := testserver.New() 209 h.AddTransport(transport.Websocket{ 210 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 211 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil 212 }, 213 }) 214 srv := httptest.NewServer(h) 215 defer srv.Close() 216 217 c := wsConnect(srv.URL) 218 defer c.Close() 219 220 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 221 222 assert.Equal(t, connectionAckMsg, readOp(c).Type) 223 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 224 }) 225 226 t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 227 h := testserver.New() 228 h.AddTransport(transport.Websocket{ 229 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 230 return ctx, errors.New("invalid init payload") 231 }, 232 }) 233 srv := httptest.NewServer(h) 234 defer srv.Close() 235 236 c := wsConnect(srv.URL) 237 defer c.Close() 238 239 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 240 241 msg := readOp(c) 242 assert.Equal(t, connectionErrorMsg, msg.Type) 243 assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload)) 244 }) 245 246 t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) { 247 es := &graphql.ExecutableSchemaMock{ 248 ExecFunc: func(ctx context.Context) graphql.ResponseHandler { 249 assert.Equal(t, "newvalue", ctx.Value(ckey("newkey"))) 250 return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)}) 251 }, 252 SchemaFunc: func() *ast.Schema { 253 return gqlparser.MustLoadSchema(&ast.Source{Input: ` 254 schema { query: Query } 255 type Query { 256 empty: String 257 } 258 `}) 259 }, 260 } 261 h := handler.New(es) 262 263 h.AddTransport(transport.Websocket{ 264 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 265 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil 266 }, 267 }) 268 269 c := client.New(h) 270 271 socket := c.Websocket("{ empty } ") 272 defer socket.Close() 273 var resp struct { 274 Empty string 275 } 276 err := socket.Next(&resp) 277 require.NoError(t, err) 278 assert.Equal(t, "ok", resp.Empty) 279 }) 280 281 t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) { 282 h := testserver.New() 283 var cancel func() 284 h.AddTransport(transport.Websocket{ 285 InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) { 286 newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5) 287 return 288 }, 289 }) 290 srv := httptest.NewServer(h) 291 defer srv.Close() 292 293 c := wsConnect(srv.URL) 294 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 295 assert.Equal(t, connectionAckMsg, readOp(c).Type) 296 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 297 298 // Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy) 299 defer cancel() 300 301 time.Sleep(time.Millisecond * 10) 302 m := readOp(c) 303 assert.Equal(t, m.Type, connectionErrorMsg) 304 assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`) 305 }) 306 } 307 308 func TestWebSocketInitTimeout(t *testing.T) { 309 t.Run("times out if no init message is received within the configured duration", func(t *testing.T) { 310 h := testserver.New() 311 h.AddTransport(transport.Websocket{ 312 InitTimeout: 5 * time.Millisecond, 313 }) 314 srv := httptest.NewServer(h) 315 defer srv.Close() 316 317 c := wsConnect(srv.URL) 318 defer c.Close() 319 320 var msg operationMessage 321 err := c.ReadJSON(&msg) 322 assert.Error(t, err) 323 assert.Contains(t, err.Error(), "timeout") 324 }) 325 326 t.Run("keeps waiting for an init message if no time out is configured", func(t *testing.T) { 327 h := testserver.New() 328 h.AddTransport(transport.Websocket{}) 329 srv := httptest.NewServer(h) 330 defer srv.Close() 331 332 c := wsConnect(srv.URL) 333 defer c.Close() 334 335 done := make(chan interface{}, 1) 336 go func() { 337 var msg operationMessage 338 _ = c.ReadJSON(&msg) 339 done <- 1 340 }() 341 342 select { 343 case <-done: 344 assert.Fail(t, "web socket read operation finished while it shouldn't have") 345 case <-time.After(100 * time.Millisecond): 346 // Success! I guess? Can't really wait forever to see if the read waits forever... 347 } 348 }) 349 } 350 351 func TestWebSocketErrorFunc(t *testing.T) { 352 t.Run("the error handler gets called when an error occurs", func(t *testing.T) { 353 errFuncCalled := make(chan bool, 1) 354 h := testserver.New() 355 h.AddTransport(transport.Websocket{ 356 ErrorFunc: func(_ context.Context, err error) { 357 require.Error(t, err) 358 assert.Equal(t, err.Error(), "websocket read: invalid message received") 359 assert.IsType(t, transport.WebsocketError{}, err) 360 assert.True(t, err.(transport.WebsocketError).IsReadError) 361 errFuncCalled <- true 362 }, 363 }) 364 365 srv := httptest.NewServer(h) 366 defer srv.Close() 367 368 c := wsConnect(srv.URL) 369 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 370 assert.Equal(t, connectionAckMsg, readOp(c).Type) 371 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 372 require.NoError(t, c.WriteMessage(websocket.TextMessage, []byte("mark my words, you will regret this"))) 373 374 select { 375 case res := <-errFuncCalled: 376 assert.True(t, res) 377 case <-time.NewTimer(time.Millisecond * 20).C: 378 assert.Fail(t, "The fail handler was not called in time") 379 } 380 }) 381 382 t.Run("init func errors do not call the error handler", func(t *testing.T) { 383 h := testserver.New() 384 h.AddTransport(transport.Websocket{ 385 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { 386 return ctx, errors.New("this is not what we agreed upon") 387 }, 388 ErrorFunc: func(_ context.Context, err error) { 389 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) 390 }, 391 }) 392 srv := httptest.NewServer(h) 393 defer srv.Close() 394 395 c := wsConnect(srv.URL) 396 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 397 time.Sleep(time.Millisecond * 20) 398 }) 399 400 t.Run("init func context closes do not call the error handler", func(t *testing.T) { 401 h := testserver.New() 402 h.AddTransport(transport.Websocket{ 403 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { 404 newCtx, cancel := context.WithCancel(ctx) 405 time.AfterFunc(time.Millisecond*5, cancel) 406 return newCtx, nil 407 }, 408 ErrorFunc: func(_ context.Context, err error) { 409 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) 410 }, 411 }) 412 srv := httptest.NewServer(h) 413 defer srv.Close() 414 415 c := wsConnect(srv.URL) 416 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 417 assert.Equal(t, connectionAckMsg, readOp(c).Type) 418 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 419 time.Sleep(time.Millisecond * 20) 420 }) 421 422 t.Run("init func context deadlines do not call the error handler", func(t *testing.T) { 423 h := testserver.New() 424 var cancel func() 425 h.AddTransport(transport.Websocket{ 426 InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) { 427 newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5)) 428 return newCtx, nil 429 }, 430 ErrorFunc: func(_ context.Context, err error) { 431 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) 432 }, 433 }) 434 srv := httptest.NewServer(h) 435 defer srv.Close() 436 437 c := wsConnect(srv.URL) 438 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 439 assert.Equal(t, connectionAckMsg, readOp(c).Type) 440 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 441 442 // Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy) 443 defer cancel() 444 445 time.Sleep(time.Millisecond * 20) 446 }) 447 } 448 449 func TestWebSocketCloseFunc(t *testing.T) { 450 t.Run("the on close handler gets called when the websocket is closed", func(t *testing.T) { 451 closeFuncCalled := make(chan bool, 1) 452 h := testserver.New() 453 h.AddTransport(transport.Websocket{ 454 CloseFunc: func(_ context.Context, _closeCode int) { 455 closeFuncCalled <- true 456 }, 457 }) 458 459 srv := httptest.NewServer(h) 460 defer srv.Close() 461 462 c := wsConnect(srv.URL) 463 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 464 assert.Equal(t, connectionAckMsg, readOp(c).Type) 465 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 466 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 467 468 select { 469 case res := <-closeFuncCalled: 470 assert.True(t, res) 471 case <-time.NewTimer(time.Millisecond * 20).C: 472 assert.Fail(t, "The close handler was not called in time") 473 } 474 }) 475 476 t.Run("init func errors call the close handler", func(t *testing.T) { 477 h := testserver.New() 478 closeFuncCalled := make(chan bool, 1) 479 h.AddTransport(transport.Websocket{ 480 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { 481 return ctx, errors.New("error during init") 482 }, 483 CloseFunc: func(_ context.Context, _closeCode int) { 484 closeFuncCalled <- true 485 }, 486 }) 487 srv := httptest.NewServer(h) 488 defer srv.Close() 489 490 c := wsConnect(srv.URL) 491 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 492 select { 493 case res := <-closeFuncCalled: 494 assert.True(t, res) 495 case <-time.NewTimer(time.Millisecond * 20).C: 496 assert.Fail(t, "The close handler was not called in time") 497 } 498 }) 499 } 500 501 func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) { 502 initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) { 503 h := testserver.New() 504 h.AddTransport(ws) 505 return h, httptest.NewServer(h) 506 } 507 508 t.Run("server acks init", func(t *testing.T) { 509 _, srv := initialize(transport.Websocket{}) 510 defer srv.Close() 511 512 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 513 defer c.Close() 514 515 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 516 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 517 }) 518 519 t.Run("client can receive data", func(t *testing.T) { 520 handler, srv := initialize(transport.Websocket{}) 521 defer srv.Close() 522 523 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 524 defer c.Close() 525 526 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 527 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 528 529 require.NoError(t, c.WriteJSON(&operationMessage{ 530 Type: graphqltransportwsSubscribeMsg, 531 ID: "test_1", 532 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 533 })) 534 535 handler.SendNextSubscriptionMessage() 536 msg := readOp(c) 537 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) 538 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 539 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 540 541 handler.SendNextSubscriptionMessage() 542 msg = readOp(c) 543 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) 544 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 545 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 546 547 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"})) 548 549 msg = readOp(c) 550 require.Equal(t, graphqltransportwsCompleteMsg, msg.Type) 551 require.Equal(t, "test_1", msg.ID) 552 }) 553 554 t.Run("receives no graphql-ws keep alive messages", func(t *testing.T) { 555 _, srv := initialize(transport.Websocket{KeepAlivePingInterval: 5 * time.Millisecond}) 556 defer srv.Close() 557 558 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 559 defer c.Close() 560 561 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 562 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 563 564 // If the keep-alives are sent, this deadline will not be used, and no timeout error will be found 565 c.SetReadDeadline(time.Now().UTC().Add(50 * time.Millisecond)) 566 var msg operationMessage 567 err := c.ReadJSON(&msg) 568 require.Error(t, err) 569 assert.Contains(t, err.Error(), "timeout") 570 }) 571 } 572 573 func TestWebsocketWithPingPongInterval(t *testing.T) { 574 initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) { 575 h := testserver.New() 576 h.AddTransport(ws) 577 return h, httptest.NewServer(h) 578 } 579 580 t.Run("client receives ping and responds with pong", func(t *testing.T) { 581 _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond}) 582 defer srv.Close() 583 584 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 585 defer c.Close() 586 587 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 588 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 589 590 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) 591 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg})) 592 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) 593 }) 594 595 t.Run("client sends ping and expects pong", func(t *testing.T) { 596 _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond}) 597 defer srv.Close() 598 599 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 600 defer c.Close() 601 602 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 603 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 604 605 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg})) 606 assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type) 607 }) 608 609 t.Run("ping-pongs are not sent when the graphql-ws sub protocol is used", func(t *testing.T) { 610 // Regression test 611 // --- 612 // Before the refactor, the code would try to convert a ping message to a graphql-ws message type 613 // But since this message type does not exist in the graphql-ws sub protocol, it would fail 614 615 _, srv := initialize(transport.Websocket{ 616 PingPongInterval: 5 * time.Millisecond, 617 KeepAlivePingInterval: 10 * time.Millisecond, 618 }) 619 defer srv.Close() 620 621 // Create connection 622 c := wsConnect(srv.URL) 623 defer c.Close() 624 625 // Initialize connection 626 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 627 assert.Equal(t, connectionAckMsg, readOp(c).Type) 628 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 629 630 // Wait for a few more keep alives to be sure nothing goes wrong 631 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 632 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 633 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 634 }) 635 } 636 637 func wsConnect(url string) *websocket.Conn { 638 return wsConnectWithSubprocotol(url, "") 639 } 640 641 func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn { 642 h := make(http.Header) 643 if subprocotol != "" { 644 h.Add("Sec-WebSocket-Protocol", subprocotol) 645 } 646 647 c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h) 648 if err != nil { 649 panic(err) 650 } 651 _ = resp.Body.Close() 652 653 return c 654 } 655 656 func writeRaw(conn *websocket.Conn, msg string) { 657 if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { 658 panic(err) 659 } 660 } 661 662 func readOp(conn *websocket.Conn) operationMessage { 663 var msg operationMessage 664 if err := conn.ReadJSON(&msg); err != nil { 665 panic(err) 666 } 667 return msg 668 } 669 670 // copied out from websocket_graphqlws.go to keep these private 671 672 const ( 673 connectionInitMsg = "connection_init" // Client -> Server 674 connectionTerminateMsg = "connection_terminate" // Client -> Server 675 startMsg = "start" // Client -> Server 676 stopMsg = "stop" // Client -> Server 677 connectionAckMsg = "connection_ack" // Server -> Client 678 connectionErrorMsg = "connection_error" // Server -> Client 679 dataMsg = "data" // Server -> Client 680 errorMsg = "error" // Server -> Client 681 completeMsg = "complete" // Server -> Client 682 connectionKeepAliveMsg = "ka" // Server -> Client 683 ) 684 685 // copied out from websocket_graphql_transport_ws.go to keep these private 686 687 const ( 688 graphqltransportwsSubprotocol = "graphql-transport-ws" 689 690 graphqltransportwsConnectionInitMsg = "connection_init" 691 graphqltransportwsConnectionAckMsg = "connection_ack" 692 graphqltransportwsSubscribeMsg = "subscribe" 693 graphqltransportwsNextMsg = "next" 694 graphqltransportwsCompleteMsg = "complete" 695 graphqltransportwsPingMsg = "ping" 696 graphqltransportwsPongMsg = "pong" 697 ) 698 699 type operationMessage struct { 700 Payload json.RawMessage `json:"payload,omitempty"` 701 ID string `json:"id,omitempty"` 702 Type string `json:"type"` 703 }