git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/graphql/handler/transport/websocket_test.go (about) 1 package transport_test 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net/http/httptest" 8 "strings" 9 "testing" 10 "time" 11 12 "git.sr.ht/~sircmpwn/gqlgen/client" 13 "git.sr.ht/~sircmpwn/gqlgen/graphql" 14 "git.sr.ht/~sircmpwn/gqlgen/graphql/handler" 15 "git.sr.ht/~sircmpwn/gqlgen/graphql/handler/testserver" 16 "git.sr.ht/~sircmpwn/gqlgen/graphql/handler/transport" 17 "github.com/gorilla/websocket" 18 "github.com/stretchr/testify/assert" 19 "github.com/stretchr/testify/require" 20 "github.com/vektah/gqlparser/v2" 21 "github.com/vektah/gqlparser/v2/ast" 22 ) 23 24 func TestWebsocket(t *testing.T) { 25 handler := testserver.New() 26 handler.AddTransport(transport.Websocket{}) 27 28 srv := httptest.NewServer(handler) 29 defer srv.Close() 30 31 t.Run("client must send valid json", func(t *testing.T) { 32 c := wsConnect(srv.URL) 33 defer c.Close() 34 35 writeRaw(c, "hello") 36 37 msg := readOp(c) 38 assert.Equal(t, "connection_error", msg.Type) 39 assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload)) 40 }) 41 42 t.Run("client can terminate before init", func(t *testing.T) { 43 c := wsConnect(srv.URL) 44 defer c.Close() 45 46 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 47 48 _, _, err := c.ReadMessage() 49 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 50 }) 51 52 t.Run("client must send init first", func(t *testing.T) { 53 c := wsConnect(srv.URL) 54 defer c.Close() 55 56 require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg})) 57 58 msg := readOp(c) 59 assert.Equal(t, connectionErrorMsg, msg.Type) 60 assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload)) 61 }) 62 63 t.Run("server acks init", func(t *testing.T) { 64 c := wsConnect(srv.URL) 65 defer c.Close() 66 67 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 68 69 assert.Equal(t, connectionAckMsg, readOp(c).Type) 70 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 71 }) 72 73 t.Run("client can terminate before run", func(t *testing.T) { 74 c := wsConnect(srv.URL) 75 defer c.Close() 76 77 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 78 assert.Equal(t, connectionAckMsg, readOp(c).Type) 79 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 80 81 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) 82 83 _, _, err := c.ReadMessage() 84 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code) 85 }) 86 87 t.Run("client gets parse errors", func(t *testing.T) { 88 c := wsConnect(srv.URL) 89 defer c.Close() 90 91 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 92 assert.Equal(t, connectionAckMsg, readOp(c).Type) 93 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 94 95 require.NoError(t, c.WriteJSON(&operationMessage{ 96 Type: startMsg, 97 ID: "test_1", 98 Payload: json.RawMessage(`{"query": "!"}`), 99 })) 100 101 msg := readOp(c) 102 assert.Equal(t, errorMsg, msg.Type) 103 assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload)) 104 }) 105 106 t.Run("client can receive data", func(t *testing.T) { 107 c := wsConnect(srv.URL) 108 defer c.Close() 109 110 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 111 assert.Equal(t, connectionAckMsg, readOp(c).Type) 112 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 113 114 require.NoError(t, c.WriteJSON(&operationMessage{ 115 Type: startMsg, 116 ID: "test_1", 117 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 118 })) 119 120 handler.SendNextSubscriptionMessage() 121 msg := readOp(c) 122 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 123 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 124 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 125 126 handler.SendNextSubscriptionMessage() 127 msg = readOp(c) 128 require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) 129 require.Equal(t, "test_1", msg.ID, string(msg.Payload)) 130 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) 131 132 require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"})) 133 134 msg = readOp(c) 135 require.Equal(t, completeMsg, msg.Type) 136 require.Equal(t, "test_1", msg.ID) 137 }) 138 } 139 140 func TestWebsocketWithKeepAlive(t *testing.T) { 141 142 h := testserver.New() 143 h.AddTransport(transport.Websocket{ 144 KeepAlivePingInterval: 100 * time.Millisecond, 145 }) 146 147 srv := httptest.NewServer(h) 148 defer srv.Close() 149 150 c := wsConnect(srv.URL) 151 defer c.Close() 152 153 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 154 assert.Equal(t, connectionAckMsg, readOp(c).Type) 155 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 156 157 require.NoError(t, c.WriteJSON(&operationMessage{ 158 Type: startMsg, 159 ID: "test_1", 160 Payload: json.RawMessage(`{"query": "subscription { name }"}`), 161 })) 162 163 // keepalive 164 msg := readOp(c) 165 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 166 167 // server message 168 h.SendNextSubscriptionMessage() 169 msg = readOp(c) 170 assert.Equal(t, dataMsg, msg.Type) 171 172 // keepalive 173 msg = readOp(c) 174 assert.Equal(t, connectionKeepAliveMsg, msg.Type) 175 } 176 177 func TestWebsocketInitFunc(t *testing.T) { 178 t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) { 179 h := testserver.New() 180 h.AddTransport(transport.Websocket{}) 181 srv := httptest.NewServer(h) 182 defer srv.Close() 183 184 c := wsConnect(srv.URL) 185 defer c.Close() 186 187 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 188 189 assert.Equal(t, connectionAckMsg, readOp(c).Type) 190 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 191 }) 192 193 t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 194 h := testserver.New() 195 h.AddTransport(transport.Websocket{ 196 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 197 return context.WithValue(ctx, "newkey", "newvalue"), nil 198 }, 199 }) 200 srv := httptest.NewServer(h) 201 defer srv.Close() 202 203 c := wsConnect(srv.URL) 204 defer c.Close() 205 206 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 207 208 assert.Equal(t, connectionAckMsg, readOp(c).Type) 209 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) 210 }) 211 212 t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { 213 h := testserver.New() 214 h.AddTransport(transport.Websocket{ 215 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 216 return ctx, errors.New("invalid init payload") 217 }, 218 }) 219 srv := httptest.NewServer(h) 220 defer srv.Close() 221 222 c := wsConnect(srv.URL) 223 defer c.Close() 224 225 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) 226 227 msg := readOp(c) 228 assert.Equal(t, connectionErrorMsg, msg.Type) 229 assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload)) 230 }) 231 232 t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) { 233 es := &graphql.ExecutableSchemaMock{ 234 ExecFunc: func(ctx context.Context) graphql.ResponseHandler { 235 assert.Equal(t, "newvalue", ctx.Value("newkey")) 236 return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)}) 237 }, 238 SchemaFunc: func() *ast.Schema { 239 return gqlparser.MustLoadSchema(&ast.Source{Input: ` 240 schema { query: Query } 241 type Query { 242 empty: String 243 } 244 `}) 245 }, 246 } 247 h := handler.New(es) 248 249 h.AddTransport(transport.Websocket{ 250 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { 251 return context.WithValue(ctx, "newkey", "newvalue"), nil 252 }, 253 }) 254 255 c := client.New(h) 256 257 socket := c.Websocket("{ empty } ") 258 defer socket.Close() 259 var resp struct { 260 Empty string 261 } 262 err := socket.Next(&resp) 263 require.NoError(t, err) 264 assert.Equal(t, "ok", resp.Empty) 265 }) 266 } 267 268 func wsConnect(url string) *websocket.Conn { 269 c, resp, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil) 270 if err != nil { 271 panic(err) 272 } 273 _ = resp.Body.Close() 274 275 return c 276 } 277 278 func writeRaw(conn *websocket.Conn, msg string) { 279 if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { 280 panic(err) 281 } 282 } 283 284 func readOp(conn *websocket.Conn) operationMessage { 285 var msg operationMessage 286 if err := conn.ReadJSON(&msg); err != nil { 287 panic(err) 288 } 289 return msg 290 } 291 292 // copied out from weboscket.go to keep these private 293 294 const ( 295 connectionInitMsg = "connection_init" // Client -> Server 296 connectionTerminateMsg = "connection_terminate" // Client -> Server 297 startMsg = "start" // Client -> Server 298 stopMsg = "stop" // Client -> Server 299 connectionAckMsg = "connection_ack" // Server -> Client 300 connectionErrorMsg = "connection_error" // Server -> Client 301 dataMsg = "data" // Server -> Client 302 errorMsg = "error" // Server -> Client 303 completeMsg = "complete" // Server -> Client 304 connectionKeepAliveMsg = "ka" // Server -> Client 305 ) 306 307 type operationMessage struct { 308 Payload json.RawMessage `json:"payload,omitempty"` 309 ID string `json:"id,omitempty"` 310 Type string `json:"type"` 311 }