github.com/apipluspower/gqlgen@v0.15.2/graphql/handler/transport/websocket_test.go (about) 1 package transport_test 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "testing" 11 "time" 12 13 "github.com/apipluspower/gqlgen/client" 14 "github.com/apipluspower/gqlgen/graphql" 15 "github.com/apipluspower/gqlgen/graphql/handler" 16 "github.com/apipluspower/gqlgen/graphql/handler/testserver" 17 "github.com/apipluspower/gqlgen/graphql/handler/transport" 18 "github.com/gorilla/websocket" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 "github.com/vektah/gqlparser/v2" 22 "github.com/vektah/gqlparser/v2/ast" 23 ) 24 25 type ckey string 26 27 func TestWebsocket(t *testing.T) { 28 handler := testserver.New() 29 handler.AddTransport(transport.Websocket{}) 30 31 srv := httptest.NewServer(handler) 32 defer srv.Close() 33 34 t.Run("client must send valid json", func(t *testing.T) { 35 c := wsConnect(srv.URL) 36 defer c.Close() 37 38 writeRaw(c, "hello") 39 40 msg := readOp(c) 41 assert.Equal(t, "connection_error", msg.Type) 42 assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload)) 43 }) 44 45 t.Run("client can terminate before init", func(t *testing.T) { 46 c := wsConnect(srv.URL) 47 defer c.Close() 48 49 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 50 51 _, _, err := c.ReadMessage() 52 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 53 }) 54 55 t.Run("client must send init first", func(t *testing.T) { 56 c := wsConnect(srv.URL) 57 defer c.Close() 58 59 require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg})) 60 61 msg := readOp(c) 62 assert.Equal(t, connectionErrorMsg, msg.Type) 63 assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload)) 64 }) 65 66 t.Run("server acks init", func(t *testing.T) { 67 c := wsConnect(srv.URL) 68 defer c.Close() 69 70 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 71 72 assert.Equal(t, connectionAckMsg, readOp(c).Type) 73 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 74 }) 75 76 t.Run("client can terminate before run", func(t *testing.T) { 77 c := wsConnect(srv.URL) 78 defer c.Close() 79 80 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 81 assert.Equal(t, connectionAckMsg, readOp(c).Type) 82 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 83 84 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 85 86 _, _, err := c.ReadMessage() 87 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 88 }) 89 90 t.Run("client gets parse errors", func(t *testing.T) { 91 c := wsConnect(srv.URL) 92 defer c.Close() 93 94 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 95 assert.Equal(t, connectionAckMsg, readOp(c).Type) 96 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 97 98 require.NoError(t, c.WriteJSON(&operationMessage{ 99 Type: startMsg, 100 ID: "test_1", 101 Payload: json.RawMessage(`{"query": "!"}`), 102 })) 103 104 msg := readOp(c) 105 assert.Equal(t, errorMsg, msg.Type) 106 assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload)) 107 }) 108 109 t.Run("client can receive data", func(t *testing.T) { 110 c := wsConnect(srv.URL) 111 defer c.Close() 112 113 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 114 assert.Equal(t, connectionAckMsg, readOp(c).Type) 115 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 116 117 require.NoError(t, c.WriteJSON(&operationMessage{ 118 Type: startMsg, 119 ID: "test_1", 120 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 121 })) 122 123 handler.SendNextSubscriptionMessage() 124 msg := readOp(c) 125 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 126 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 127 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 128 129 handler.SendNextSubscriptionMessage() 130 msg = readOp(c) 131 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 132 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 133 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 134 135 require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"})) 136 137 msg = readOp(c) 138 require.Equal(t, completeMsg, msg.Type) 139 require.Equal(t, "test_1", msg.ID) 140 }) 141 } 142 143 func TestWebsocketWithKeepAlive(t *testing.T) { 144 h := testserver.New() 145 h.AddTransport(transport.Websocket{ 146 KeepAlivePingInterval: 100 * time.Millisecond, 147 }) 148 149 srv := httptest.NewServer(h) 150 defer srv.Close() 151 152 c := wsConnect(srv.URL) 153 defer c.Close() 154 155 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 156 assert.Equal(t, connectionAckMsg, readOp(c).Type) 157 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 158 159 require.NoError(t, c.WriteJSON(&operationMessage{ 160 Type: startMsg, 161 ID: "test_1", 162 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 163 })) 164 165 // keepalive 166 msg := readOp(c) 167 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 168 169 // server message 170 h.SendNextSubscriptionMessage() 171 msg = readOp(c) 172 assert.Equal(t, dataMsg, msg.Type) 173 174 // keepalive 175 msg = readOp(c) 176 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 177 } 178 179 func TestWebsocketInitFunc(t *testing.T) { 180 t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) { 181 h := testserver.New() 182 h.AddTransport(transport.Websocket{}) 183 srv := httptest.NewServer(h) 184 defer srv.Close() 185 186 c := wsConnect(srv.URL) 187 defer c.Close() 188 189 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 190 191 assert.Equal(t, connectionAckMsg, readOp(c).Type) 192 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 193 }) 194 195 t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 196 h := testserver.New() 197 h.AddTransport(transport.Websocket{ 198 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 199 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil 200 }, 201 }) 202 srv := httptest.NewServer(h) 203 defer srv.Close() 204 205 c := wsConnect(srv.URL) 206 defer c.Close() 207 208 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 209 210 assert.Equal(t, connectionAckMsg, readOp(c).Type) 211 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 212 }) 213 214 t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 215 h := testserver.New() 216 h.AddTransport(transport.Websocket{ 217 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 218 return ctx, errors.New("invalid init payload") 219 }, 220 }) 221 srv := httptest.NewServer(h) 222 defer srv.Close() 223 224 c := wsConnect(srv.URL) 225 defer c.Close() 226 227 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 228 229 msg := readOp(c) 230 assert.Equal(t, connectionErrorMsg, msg.Type) 231 assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload)) 232 }) 233 234 t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) { 235 es := &graphql.ExecutableSchemaMock{ 236 ExecFunc: func(ctx context.Context) graphql.ResponseHandler { 237 assert.Equal(t, "newvalue", ctx.Value(ckey("newkey"))) 238 return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)}) 239 }, 240 SchemaFunc: func() *ast.Schema { 241 return gqlparser.MustLoadSchema(&ast.Source{Input: ` 242 schema { query: Query } 243 type Query { 244 empty: String 245 } 246 `}) 247 }, 248 } 249 h := handler.New(es) 250 251 h.AddTransport(transport.Websocket{ 252 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 253 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil 254 }, 255 }) 256 257 c := client.New(h) 258 259 socket := c.Websocket("{ empty } ") 260 defer socket.Close() 261 var resp struct { 262 Empty string 263 } 264 err := socket.Next(&resp) 265 require.NoError(t, err) 266 assert.Equal(t, "ok", resp.Empty) 267 }) 268 269 t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) { 270 h := testserver.New() 271 var cancel func() 272 h.AddTransport(transport.Websocket{ 273 InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) { 274 newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5) 275 return 276 }, 277 }) 278 srv := httptest.NewServer(h) 279 defer srv.Close() 280 281 c := wsConnect(srv.URL) 282 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 283 assert.Equal(t, connectionAckMsg, readOp(c).Type) 284 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 285 286 // Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy) 287 defer cancel() 288 289 time.Sleep(time.Millisecond * 10) 290 m := readOp(c) 291 assert.Equal(t, m.Type, connectionErrorMsg) 292 assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`) 293 }) 294 } 295 296 func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) { 297 handler := testserver.New() 298 handler.AddTransport(transport.Websocket{}) 299 300 srv := httptest.NewServer(handler) 301 defer srv.Close() 302 303 t.Run("server acks init", func(t *testing.T) { 304 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 305 defer c.Close() 306 307 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 308 309 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 310 }) 311 312 t.Run("client can receive data", func(t *testing.T) { 313 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 314 defer c.Close() 315 316 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 317 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 318 319 require.NoError(t, c.WriteJSON(&operationMessage{ 320 Type: graphqltransportwsSubscribeMsg, 321 ID: "test_1", 322 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 323 })) 324 325 handler.SendNextSubscriptionMessage() 326 msg := readOp(c) 327 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) 328 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 329 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 330 331 handler.SendNextSubscriptionMessage() 332 msg = readOp(c) 333 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) 334 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 335 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 336 337 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"})) 338 339 msg = readOp(c) 340 require.Equal(t, graphqltransportwsCompleteMsg, msg.Type) 341 require.Equal(t, "test_1", msg.ID) 342 }) 343 } 344 345 func TestWebsocketWithPingPongInterval(t *testing.T) { 346 handler := testserver.New() 347 handler.AddTransport(transport.Websocket{ 348 PingPongInterval: time.Second * 1, 349 }) 350 351 srv := httptest.NewServer(handler) 352 defer srv.Close() 353 354 t.Run("client receives ping and responds with pong", func(t *testing.T) { 355 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 356 defer c.Close() 357 358 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 359 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 360 361 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) 362 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg})) 363 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) 364 }) 365 366 t.Run("client sends ping and expects pong", func(t *testing.T) { 367 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) 368 defer c.Close() 369 370 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) 371 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) 372 373 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg})) 374 assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type) 375 }) 376 } 377 378 func wsConnect(url string) *websocket.Conn { 379 return wsConnectWithSubprocotol(url, "") 380 } 381 382 func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn { 383 h := make(http.Header) 384 if subprocotol != "" { 385 h.Add("Sec-WebSocket-Protocol", subprocotol) 386 } 387 388 c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h) 389 if err != nil { 390 panic(err) 391 } 392 _ = resp.Body.Close() 393 394 return c 395 } 396 397 func writeRaw(conn *websocket.Conn, msg string) { 398 if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { 399 panic(err) 400 } 401 } 402 403 func readOp(conn *websocket.Conn) operationMessage { 404 var msg operationMessage 405 if err := conn.ReadJSON(&msg); err != nil { 406 panic(err) 407 } 408 return msg 409 } 410 411 // copied out from websocket_graphqlws.go to keep these private 412 413 const ( 414 connectionInitMsg = "connection_init" // Client -> Server 415 connectionTerminateMsg = "connection_terminate" // Client -> Server 416 startMsg = "start" // Client -> Server 417 stopMsg = "stop" // Client -> Server 418 connectionAckMsg = "connection_ack" // Server -> Client 419 connectionErrorMsg = "connection_error" // Server -> Client 420 dataMsg = "data" // Server -> Client 421 errorMsg = "error" // Server -> Client 422 completeMsg = "complete" // Server -> Client 423 connectionKeepAliveMsg = "ka" // Server -> Client 424 ) 425 426 // copied out from websocket_graphql_transport_ws.go to keep these private 427 428 const ( 429 graphqltransportwsSubprotocol = "graphql-transport-ws" 430 431 graphqltransportwsConnectionInitMsg = "connection_init" 432 graphqltransportwsConnectionAckMsg = "connection_ack" 433 graphqltransportwsSubscribeMsg = "subscribe" 434 graphqltransportwsNextMsg = "next" 435 graphqltransportwsCompleteMsg = "complete" 436 graphqltransportwsPingMsg = "ping" 437 graphqltransportwsPongMsg = "pong" 438 ) 439 440 type operationMessage struct { 441 Payload json.RawMessage `json:"payload,omitempty"` 442 ID string `json:"id,omitempty"` 443 Type string `json:"type"` 444 }