github.com/99designs/gqlgen@v0.17.45/graphql/handler/transport/websocket.go (about) 1 package transport 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "log" 10 "net" 11 "net/http" 12 "sync" 13 "time" 14 15 "github.com/gorilla/websocket" 16 "github.com/vektah/gqlparser/v2/gqlerror" 17 18 "github.com/99designs/gqlgen/graphql" 19 "github.com/99designs/gqlgen/graphql/errcode" 20 ) 21 22 type ( 23 Websocket struct { 24 Upgrader websocket.Upgrader 25 InitFunc WebsocketInitFunc 26 InitTimeout time.Duration 27 ErrorFunc WebsocketErrorFunc 28 CloseFunc WebsocketCloseFunc 29 KeepAlivePingInterval time.Duration 30 PongOnlyInterval time.Duration 31 PingPongInterval time.Duration 32 /* If PingPongInterval has a non-0 duration, then when the server sends a ping 33 * it sets a ReadDeadline of PingPongInterval*2 and if the client doesn't respond 34 * with pong before that deadline is reached then the connection will die with a 35 * 1006 error code. 36 * 37 * MissingPongOk if true, tells the server to not use a ReadDeadline such that a 38 * missing/slow pong response from the client doesn't kill the connection. 39 */ 40 MissingPongOk bool 41 42 didInjectSubprotocols bool 43 } 44 wsConnection struct { 45 Websocket 46 ctx context.Context 47 conn *websocket.Conn 48 me messageExchanger 49 active map[string]context.CancelFunc 50 mu sync.Mutex 51 keepAliveTicker *time.Ticker 52 pongOnlyTicker *time.Ticker 53 pingPongTicker *time.Ticker 54 receivedPong bool 55 exec graphql.GraphExecutor 56 closed bool 57 58 initPayload InitPayload 59 } 60 61 WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, *InitPayload, error) 62 WebsocketErrorFunc func(ctx context.Context, err error) 63 64 // Callback called when websocket is closed. 65 WebsocketCloseFunc func(ctx context.Context, closeCode int) 66 ) 67 68 var errReadTimeout = errors.New("read timeout") 69 70 type WebsocketError struct { 71 Err error 72 73 // IsReadError flags whether the error occurred on read or write to the websocket 74 IsReadError bool 75 } 76 77 func (e WebsocketError) Error() string { 78 if e.IsReadError { 79 return fmt.Sprintf("websocket read: %v", e.Err) 80 } 81 return fmt.Sprintf("websocket write: %v", e.Err) 82 } 83 84 var ( 85 _ graphql.Transport = Websocket{} 86 _ error = WebsocketError{} 87 ) 88 89 func (t Websocket) Supports(r *http.Request) bool { 90 return r.Header.Get("Upgrade") != "" 91 } 92 93 func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 94 t.injectGraphQLWSSubprotocols() 95 ws, err := t.Upgrader.Upgrade(w, r, http.Header{}) 96 if err != nil { 97 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 98 SendErrorf(w, http.StatusBadRequest, "unable to upgrade") 99 return 100 } 101 102 var me messageExchanger 103 switch ws.Subprotocol() { 104 default: 105 msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol())) 106 ws.WriteMessage(websocket.CloseMessage, msg) 107 return 108 case graphqlwsSubprotocol, "": 109 // clients are required to send a subprotocol, to be backward compatible with the previous implementation we select 110 // "graphql-ws" by default 111 me = graphqlwsMessageExchanger{c: ws} 112 case graphqltransportwsSubprotocol: 113 me = graphqltransportwsMessageExchanger{c: ws} 114 } 115 116 conn := wsConnection{ 117 active: map[string]context.CancelFunc{}, 118 conn: ws, 119 ctx: r.Context(), 120 exec: exec, 121 me: me, 122 Websocket: t, 123 } 124 125 if !conn.init() { 126 return 127 } 128 129 conn.run() 130 } 131 132 func (c *wsConnection) handlePossibleError(err error, isReadError bool) { 133 if c.ErrorFunc != nil && err != nil { 134 c.ErrorFunc(c.ctx, WebsocketError{ 135 Err: err, 136 IsReadError: isReadError, 137 }) 138 } 139 } 140 141 func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) { 142 messages, errs := make(chan message, 1), make(chan error, 1) 143 144 go func() { 145 if m, err := c.me.NextMessage(); err != nil { 146 errs <- err 147 } else { 148 messages <- m 149 } 150 }() 151 152 select { 153 case m := <-messages: 154 return m, nil 155 case err := <-errs: 156 return message{}, err 157 case <-time.After(timeout): 158 return message{}, errReadTimeout 159 } 160 } 161 162 func (c *wsConnection) init() bool { 163 var m message 164 var err error 165 166 if c.InitTimeout != 0 { 167 m, err = c.nextMessageWithTimeout(c.InitTimeout) 168 } else { 169 m, err = c.me.NextMessage() 170 } 171 172 if err != nil { 173 if err == errReadTimeout { 174 c.close(websocket.CloseProtocolError, "connection initialisation timeout") 175 return false 176 } 177 178 if err == errInvalidMsg { 179 c.sendConnectionError("invalid json") 180 } 181 182 c.close(websocket.CloseProtocolError, "decoding error") 183 return false 184 } 185 186 switch m.t { 187 case initMessageType: 188 if len(m.payload) > 0 { 189 c.initPayload = make(InitPayload) 190 err := json.Unmarshal(m.payload, &c.initPayload) 191 if err != nil { 192 return false 193 } 194 } 195 196 var initAckPayload *InitPayload = nil 197 if c.InitFunc != nil { 198 var ctx context.Context 199 ctx, initAckPayload, err = c.InitFunc(c.ctx, c.initPayload) 200 if err != nil { 201 c.sendConnectionError(err.Error()) 202 c.close(websocket.CloseNormalClosure, "terminated") 203 return false 204 } 205 c.ctx = ctx 206 } 207 208 if initAckPayload != nil { 209 initJsonAckPayload, err := json.Marshal(*initAckPayload) 210 if err != nil { 211 panic(err) 212 } 213 c.write(&message{t: connectionAckMessageType, payload: initJsonAckPayload}) 214 } else { 215 c.write(&message{t: connectionAckMessageType}) 216 } 217 c.write(&message{t: keepAliveMessageType}) 218 case connectionCloseMessageType: 219 c.close(websocket.CloseNormalClosure, "terminated") 220 return false 221 default: 222 c.sendConnectionError("unexpected message %s", m.t) 223 c.close(websocket.CloseProtocolError, "unexpected message") 224 return false 225 } 226 227 return true 228 } 229 230 func (c *wsConnection) write(msg *message) { 231 c.mu.Lock() 232 c.handlePossibleError(c.me.Send(msg), false) 233 c.mu.Unlock() 234 } 235 236 func (c *wsConnection) run() { 237 // We create a cancellation that will shutdown the keep-alive when we leave 238 // this function. 239 ctx, cancel := context.WithCancel(c.ctx) 240 defer func() { 241 cancel() 242 c.close(websocket.CloseAbnormalClosure, "unexpected closure") 243 }() 244 245 // If we're running in graphql-ws mode, create a timer that will trigger a 246 // keep alive message every interval 247 if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 { 248 c.mu.Lock() 249 c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval) 250 c.mu.Unlock() 251 252 go c.keepAlive(ctx) 253 } 254 255 // If we're running in graphql-transport-ws mode, create a timer that will trigger a 256 // just a pong message every interval 257 if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PongOnlyInterval != 0 { 258 c.mu.Lock() 259 c.pongOnlyTicker = time.NewTicker(c.PongOnlyInterval) 260 c.mu.Unlock() 261 262 go c.keepAlivePongOnly(ctx) 263 } 264 265 // If we're running in graphql-transport-ws mode, create a timer that will 266 // trigger a ping message every interval and expect a pong! 267 if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 { 268 c.mu.Lock() 269 c.pingPongTicker = time.NewTicker(c.PingPongInterval) 270 c.mu.Unlock() 271 272 if !c.MissingPongOk { 273 // Note: when the connection is closed by this deadline, the client 274 // will receive an "invalid close code" 275 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) 276 } 277 go c.ping(ctx) 278 } 279 280 // Close the connection when the context is cancelled. 281 // Will optionally send a "close reason" that is retrieved from the context. 282 go c.closeOnCancel(ctx) 283 284 for { 285 start := graphql.Now() 286 m, err := c.me.NextMessage() 287 if err != nil { 288 // If the connection got closed by us, don't report the error 289 if !errors.Is(err, net.ErrClosed) { 290 c.handlePossibleError(err, true) 291 } 292 return 293 } 294 295 switch m.t { 296 case startMessageType: 297 c.subscribe(start, &m) 298 case stopMessageType: 299 c.mu.Lock() 300 closer := c.active[m.id] 301 c.mu.Unlock() 302 if closer != nil { 303 closer() 304 } 305 case connectionCloseMessageType: 306 c.close(websocket.CloseNormalClosure, "terminated") 307 return 308 case pingMessageType: 309 c.write(&message{t: pongMessageType, payload: m.payload}) 310 case pongMessageType: 311 c.mu.Lock() 312 c.receivedPong = true 313 c.mu.Unlock() 314 // Clear ReadTimeout -- 0 time val clears. 315 c.conn.SetReadDeadline(time.Time{}) 316 default: 317 c.sendConnectionError("unexpected message %s", m.t) 318 c.close(websocket.CloseProtocolError, "unexpected message") 319 return 320 } 321 } 322 } 323 324 func (c *wsConnection) keepAlivePongOnly(ctx context.Context) { 325 for { 326 select { 327 case <-ctx.Done(): 328 c.pongOnlyTicker.Stop() 329 return 330 case <-c.pongOnlyTicker.C: 331 c.write(&message{t: pongMessageType, payload: json.RawMessage{}}) 332 } 333 } 334 } 335 336 func (c *wsConnection) keepAlive(ctx context.Context) { 337 for { 338 select { 339 case <-ctx.Done(): 340 c.keepAliveTicker.Stop() 341 return 342 case <-c.keepAliveTicker.C: 343 c.write(&message{t: keepAliveMessageType}) 344 } 345 } 346 } 347 348 func (c *wsConnection) ping(ctx context.Context) { 349 for { 350 select { 351 case <-ctx.Done(): 352 c.pingPongTicker.Stop() 353 return 354 case <-c.pingPongTicker.C: 355 c.write(&message{t: pingMessageType, payload: json.RawMessage{}}) 356 // The initial deadline for this method is set in run() 357 // if we have not yet received a pong, don't reset the deadline. 358 c.mu.Lock() 359 if !c.MissingPongOk && c.receivedPong { 360 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) 361 } 362 c.receivedPong = false 363 c.mu.Unlock() 364 } 365 } 366 } 367 368 func (c *wsConnection) closeOnCancel(ctx context.Context) { 369 <-ctx.Done() 370 371 if r := closeReasonForContext(ctx); r != "" { 372 c.sendConnectionError(r) 373 } 374 c.close(websocket.CloseNormalClosure, "terminated") 375 } 376 377 func (c *wsConnection) subscribe(start time.Time, msg *message) { 378 ctx := graphql.StartOperationTrace(c.ctx) 379 var params *graphql.RawParams 380 if err := jsonDecode(bytes.NewReader(msg.payload), ¶ms); err != nil { 381 c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"}) 382 c.complete(msg.id) 383 return 384 } 385 386 params.ReadTime = graphql.TraceTiming{ 387 Start: start, 388 End: graphql.Now(), 389 } 390 391 rc, err := c.exec.CreateOperationContext(ctx, params) 392 if err != nil { 393 resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) 394 switch errcode.GetErrorKind(err) { 395 case errcode.KindProtocol: 396 c.sendError(msg.id, resp.Errors...) 397 default: 398 c.sendResponse(msg.id, &graphql.Response{Errors: err}) 399 } 400 401 c.complete(msg.id) 402 return 403 } 404 405 ctx = graphql.WithOperationContext(ctx, rc) 406 407 if c.initPayload != nil { 408 ctx = withInitPayload(ctx, c.initPayload) 409 } 410 411 ctx, cancel := context.WithCancel(ctx) 412 c.mu.Lock() 413 c.active[msg.id] = cancel 414 c.mu.Unlock() 415 416 go func() { 417 ctx = withSubscriptionErrorContext(ctx) 418 defer func() { 419 if r := recover(); r != nil { 420 err := rc.Recover(ctx, r) 421 var gqlerr *gqlerror.Error 422 if !errors.As(err, &gqlerr) { 423 gqlerr = &gqlerror.Error{} 424 if err != nil { 425 gqlerr.Message = err.Error() 426 } 427 } 428 c.sendError(msg.id, gqlerr) 429 } 430 if errs := getSubscriptionError(ctx); len(errs) != 0 { 431 c.sendError(msg.id, errs...) 432 } else { 433 c.complete(msg.id) 434 } 435 c.mu.Lock() 436 delete(c.active, msg.id) 437 c.mu.Unlock() 438 cancel() 439 }() 440 441 responses, ctx := c.exec.DispatchOperation(ctx, rc) 442 for { 443 response := responses(ctx) 444 if response == nil { 445 break 446 } 447 448 c.sendResponse(msg.id, response) 449 } 450 451 // complete and context cancel comes from the defer 452 }() 453 } 454 455 func (c *wsConnection) sendResponse(id string, response *graphql.Response) { 456 b, err := json.Marshal(response) 457 if err != nil { 458 panic(err) 459 } 460 c.write(&message{ 461 payload: b, 462 id: id, 463 t: dataMessageType, 464 }) 465 } 466 467 func (c *wsConnection) complete(id string) { 468 c.write(&message{id: id, t: completeMessageType}) 469 } 470 471 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 472 errs := make([]error, len(errors)) 473 for i, err := range errors { 474 errs[i] = err 475 } 476 b, err := json.Marshal(errs) 477 if err != nil { 478 panic(err) 479 } 480 c.write(&message{t: errorMessageType, id: id, payload: b}) 481 } 482 483 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 484 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 485 if err != nil { 486 panic(err) 487 } 488 489 c.write(&message{t: connectionErrorMessageType, payload: b}) 490 } 491 492 func (c *wsConnection) close(closeCode int, message string) { 493 c.mu.Lock() 494 if c.closed { 495 c.mu.Unlock() 496 return 497 } 498 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 499 for _, closer := range c.active { 500 closer() 501 } 502 c.closed = true 503 c.mu.Unlock() 504 _ = c.conn.Close() 505 506 if c.CloseFunc != nil { 507 c.CloseFunc(c.ctx, closeCode) 508 } 509 }