github.com/animeshon/gqlgen@v0.13.1-0.20210304133704-3a770431bb6d/graphql/handler/transport/websocket_test.go (about) 1 package transport_test 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net/http/httptest" 8 "strings" 9 "testing" 10 "time" 11 12 "github.com/animeshon/gqlgen/client" 13 "github.com/animeshon/gqlgen/graphql" 14 "github.com/animeshon/gqlgen/graphql/handler" 15 "github.com/animeshon/gqlgen/graphql/handler/testserver" 16 "github.com/animeshon/gqlgen/graphql/handler/transport" 17 "github.com/gorilla/websocket" 18 "github.com/stretchr/testify/assert" 19 "github.com/stretchr/testify/require" 20 "github.com/vektah/gqlparser/v2" 21 "github.com/vektah/gqlparser/v2/ast" 22 ) 23 24 type ckey string 25 26 func TestWebsocket(t *testing.T) { 27 handler := testserver.New() 28 handler.AddTransport(transport.Websocket{}) 29 30 srv := httptest.NewServer(handler) 31 defer srv.Close() 32 33 t.Run("client must send valid json", func(t *testing.T) { 34 c := wsConnect(srv.URL) 35 defer c.Close() 36 37 writeRaw(c, "hello") 38 39 msg := readOp(c) 40 assert.Equal(t, "connection_error", msg.Type) 41 assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload)) 42 }) 43 44 t.Run("client can terminate before init", func(t *testing.T) { 45 c := wsConnect(srv.URL) 46 defer c.Close() 47 48 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 49 50 _, _, err := c.ReadMessage() 51 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 52 }) 53 54 t.Run("client must send init first", func(t *testing.T) { 55 c := wsConnect(srv.URL) 56 defer c.Close() 57 58 require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg})) 59 60 msg := readOp(c) 61 assert.Equal(t, connectionErrorMsg, msg.Type) 62 assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload)) 63 }) 64 65 t.Run("server acks init", func(t *testing.T) { 66 c := wsConnect(srv.URL) 67 defer c.Close() 68 69 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 70 71 assert.Equal(t, connectionAckMsg, readOp(c).Type) 72 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 73 }) 74 75 t.Run("client can terminate before run", func(t *testing.T) { 76 c := wsConnect(srv.URL) 77 defer c.Close() 78 79 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 80 assert.Equal(t, connectionAckMsg, readOp(c).Type) 81 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 82 83 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 84 85 _, _, err := c.ReadMessage() 86 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 87 }) 88 89 t.Run("client gets parse errors", func(t *testing.T) { 90 c := wsConnect(srv.URL) 91 defer c.Close() 92 93 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 94 assert.Equal(t, connectionAckMsg, readOp(c).Type) 95 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 96 97 require.NoError(t, c.WriteJSON(&operationMessage{ 98 Type: startMsg, 99 ID: "test_1", 100 Payload: json.RawMessage(`{"query": "!"}`), 101 })) 102 103 msg := readOp(c) 104 assert.Equal(t, errorMsg, msg.Type) 105 assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload)) 106 }) 107 108 t.Run("client can receive data", func(t *testing.T) { 109 c := wsConnect(srv.URL) 110 defer c.Close() 111 112 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 113 assert.Equal(t, connectionAckMsg, readOp(c).Type) 114 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 115 116 require.NoError(t, c.WriteJSON(&operationMessage{ 117 Type: startMsg, 118 ID: "test_1", 119 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 120 })) 121 122 handler.SendNextSubscriptionMessage() 123 msg := readOp(c) 124 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 125 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 126 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 127 128 handler.SendNextSubscriptionMessage() 129 msg = readOp(c) 130 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 131 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 132 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 133 134 require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"})) 135 136 msg = readOp(c) 137 require.Equal(t, completeMsg, msg.Type) 138 require.Equal(t, "test_1", msg.ID) 139 }) 140 } 141 142 func TestWebsocketWithKeepAlive(t *testing.T) { 143 144 h := testserver.New() 145 h.AddTransport(transport.Websocket{ 146 KeepAlivePingInterval: 100 * time.Millisecond, 147 }) 148 149 srv := httptest.NewServer(h) 150 defer srv.Close() 151 152 c := wsConnect(srv.URL) 153 defer c.Close() 154 155 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 156 assert.Equal(t, connectionAckMsg, readOp(c).Type) 157 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 158 159 require.NoError(t, c.WriteJSON(&operationMessage{ 160 Type: startMsg, 161 ID: "test_1", 162 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 163 })) 164 165 // keepalive 166 msg := readOp(c) 167 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 168 169 // server message 170 h.SendNextSubscriptionMessage() 171 msg = readOp(c) 172 assert.Equal(t, dataMsg, msg.Type) 173 174 // keepalive 175 msg = readOp(c) 176 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 177 } 178 179 func TestWebsocketInitFunc(t *testing.T) { 180 t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) { 181 h := testserver.New() 182 h.AddTransport(transport.Websocket{}) 183 srv := httptest.NewServer(h) 184 defer srv.Close() 185 186 c := wsConnect(srv.URL) 187 defer c.Close() 188 189 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 190 191 assert.Equal(t, connectionAckMsg, readOp(c).Type) 192 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 193 }) 194 195 t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 196 h := testserver.New() 197 h.AddTransport(transport.Websocket{ 198 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 199 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil 200 }, 201 }) 202 srv := httptest.NewServer(h) 203 defer srv.Close() 204 205 c := wsConnect(srv.URL) 206 defer c.Close() 207 208 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 209 210 assert.Equal(t, connectionAckMsg, readOp(c).Type) 211 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 212 }) 213 214 t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 215 h := testserver.New() 216 h.AddTransport(transport.Websocket{ 217 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 218 return ctx, errors.New("invalid init payload") 219 }, 220 }) 221 srv := httptest.NewServer(h) 222 defer srv.Close() 223 224 c := wsConnect(srv.URL) 225 defer c.Close() 226 227 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 228 229 msg := readOp(c) 230 assert.Equal(t, connectionErrorMsg, msg.Type) 231 assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload)) 232 }) 233 234 t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) { 235 es := &graphql.ExecutableSchemaMock{ 236 ExecFunc: func(ctx context.Context) graphql.ResponseHandler { 237 assert.Equal(t, "newvalue", ctx.Value(ckey("newkey"))) 238 return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)}) 239 }, 240 SchemaFunc: func() *ast.Schema { 241 return gqlparser.MustLoadSchema(&ast.Source{Input: ` 242 schema { query: Query } 243 type Query { 244 empty: String 245 } 246 `}) 247 }, 248 } 249 h := handler.New(es) 250 251 h.AddTransport(transport.Websocket{ 252 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 253 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil 254 }, 255 }) 256 257 c := client.New(h) 258 259 socket := c.Websocket("{ empty } ") 260 defer socket.Close() 261 var resp struct { 262 Empty string 263 } 264 err := socket.Next(&resp) 265 require.NoError(t, err) 266 assert.Equal(t, "ok", resp.Empty) 267 }) 268 } 269 270 func wsConnect(url string) *websocket.Conn { 271 c, resp, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil) 272 if err != nil { 273 panic(err) 274 } 275 _ = resp.Body.Close() 276 277 return c 278 } 279 280 func writeRaw(conn *websocket.Conn, msg string) { 281 if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { 282 panic(err) 283 } 284 } 285 286 func readOp(conn *websocket.Conn) operationMessage { 287 var msg operationMessage 288 if err := conn.ReadJSON(&msg); err != nil { 289 panic(err) 290 } 291 return msg 292 } 293 294 // copied out from weboscket.go to keep these private 295 296 const ( 297 connectionInitMsg = "connection_init" // Client -> Server 298 connectionTerminateMsg = "connection_terminate" // Client -> Server 299 startMsg = "start" // Client -> Server 300 stopMsg = "stop" // Client -> Server 301 connectionAckMsg = "connection_ack" // Server -> Client 302 connectionErrorMsg = "connection_error" // Server -> Client 303 dataMsg = "data" // Server -> Client 304 errorMsg = "error" // Server -> Client 305 completeMsg = "complete" // Server -> Client 306 connectionKeepAliveMsg = "ka" // Server -> Client 307 ) 308 309 type operationMessage struct { 310 Payload json.RawMessage `json:"payload,omitempty"` 311 ID string `json:"id,omitempty"` 312 Type string `json:"type"` 313 }