github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/graphql/handler/transport/websocket.go (about) 1 package transport 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "log" 10 "net" 11 "net/http" 12 "sync" 13 "time" 14 15 "github.com/geneva/gqlgen/graphql" 16 "github.com/geneva/gqlgen/graphql/errcode" 17 "github.com/gorilla/websocket" 18 "github.com/vektah/gqlparser/v2/gqlerror" 19 ) 20 21 type ( 22 Websocket struct { 23 Upgrader websocket.Upgrader 24 InitFunc WebsocketInitFunc 25 InitTimeout time.Duration 26 ErrorFunc WebsocketErrorFunc 27 CloseFunc WebsocketCloseFunc 28 KeepAlivePingInterval time.Duration 29 PingPongInterval time.Duration 30 31 didInjectSubprotocols bool 32 } 33 wsConnection struct { 34 Websocket 35 ctx context.Context 36 conn *websocket.Conn 37 me messageExchanger 38 active map[string]context.CancelFunc 39 mu sync.Mutex 40 keepAliveTicker *time.Ticker 41 pingPongTicker *time.Ticker 42 exec graphql.GraphExecutor 43 44 initPayload InitPayload 45 } 46 47 WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) 48 WebsocketErrorFunc func(ctx context.Context, err error) 49 50 // Callback called when websocket is closed. 51 WebsocketCloseFunc func(ctx context.Context, closeCode int) 52 ) 53 54 var errReadTimeout = errors.New("read timeout") 55 56 type WebsocketError struct { 57 Err error 58 59 // IsReadError flags whether the error occurred on read or write to the websocket 60 IsReadError bool 61 } 62 63 func (e WebsocketError) Error() string { 64 if e.IsReadError { 65 return fmt.Sprintf("websocket read: %v", e.Err) 66 } 67 return fmt.Sprintf("websocket write: %v", e.Err) 68 } 69 70 var ( 71 _ graphql.Transport = Websocket{} 72 _ error = WebsocketError{} 73 ) 74 75 func (t Websocket) Supports(r *http.Request) bool { 76 return r.Header.Get("Upgrade") != "" 77 } 78 79 func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 80 t.injectGraphQLWSSubprotocols() 81 ws, err := t.Upgrader.Upgrade(w, r, http.Header{}) 82 if err != nil { 83 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 84 SendErrorf(w, http.StatusBadRequest, "unable to upgrade") 85 return 86 } 87 88 var me messageExchanger 89 switch ws.Subprotocol() { 90 default: 91 msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol())) 92 ws.WriteMessage(websocket.CloseMessage, msg) 93 return 94 case graphqlwsSubprotocol, "": 95 // clients are required to send a subprotocol, to be backward compatible with the previous implementation we select 96 // "graphql-ws" by default 97 me = graphqlwsMessageExchanger{c: ws} 98 case graphqltransportwsSubprotocol: 99 me = graphqltransportwsMessageExchanger{c: ws} 100 } 101 102 conn := wsConnection{ 103 active: map[string]context.CancelFunc{}, 104 conn: ws, 105 ctx: r.Context(), 106 exec: exec, 107 me: me, 108 Websocket: t, 109 } 110 111 if !conn.init() { 112 return 113 } 114 115 conn.run() 116 } 117 118 func (c *wsConnection) handlePossibleError(err error, isReadError bool) { 119 if c.ErrorFunc != nil && err != nil { 120 c.ErrorFunc(c.ctx, WebsocketError{ 121 Err: err, 122 IsReadError: isReadError, 123 }) 124 } 125 } 126 127 func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) { 128 messages, errs := make(chan message, 1), make(chan error, 1) 129 130 go func() { 131 if m, err := c.me.NextMessage(); err != nil { 132 errs <- err 133 } else { 134 messages <- m 135 } 136 }() 137 138 select { 139 case m := <-messages: 140 return m, nil 141 case err := <-errs: 142 return message{}, err 143 case <-time.After(timeout): 144 return message{}, errReadTimeout 145 } 146 } 147 148 func (c *wsConnection) init() bool { 149 var m message 150 var err error 151 152 if c.InitTimeout != 0 { 153 m, err = c.nextMessageWithTimeout(c.InitTimeout) 154 } else { 155 m, err = c.me.NextMessage() 156 } 157 158 if err != nil { 159 if err == errReadTimeout { 160 c.close(websocket.CloseProtocolError, "connection initialisation timeout") 161 return false 162 } 163 164 if err == errInvalidMsg { 165 c.sendConnectionError("invalid json") 166 } 167 168 c.close(websocket.CloseProtocolError, "decoding error") 169 return false 170 } 171 172 switch m.t { 173 case initMessageType: 174 if len(m.payload) > 0 { 175 c.initPayload = make(InitPayload) 176 err := json.Unmarshal(m.payload, &c.initPayload) 177 if err != nil { 178 return false 179 } 180 } 181 182 if c.InitFunc != nil { 183 ctx, err := c.InitFunc(c.ctx, c.initPayload) 184 if err != nil { 185 c.sendConnectionError(err.Error()) 186 c.close(websocket.CloseNormalClosure, "terminated") 187 return false 188 } 189 c.ctx = ctx 190 } 191 192 c.write(&message{t: connectionAckMessageType}) 193 c.write(&message{t: keepAliveMessageType}) 194 case connectionCloseMessageType: 195 c.close(websocket.CloseNormalClosure, "terminated") 196 return false 197 default: 198 c.sendConnectionError("unexpected message %s", m.t) 199 c.close(websocket.CloseProtocolError, "unexpected message") 200 return false 201 } 202 203 return true 204 } 205 206 func (c *wsConnection) write(msg *message) { 207 c.mu.Lock() 208 c.handlePossibleError(c.me.Send(msg), false) 209 c.mu.Unlock() 210 } 211 212 func (c *wsConnection) run() { 213 // We create a cancellation that will shutdown the keep-alive when we leave 214 // this function. 215 ctx, cancel := context.WithCancel(c.ctx) 216 defer func() { 217 cancel() 218 c.close(websocket.CloseAbnormalClosure, "unexpected closure") 219 }() 220 221 // If we're running in graphql-ws mode, create a timer that will trigger a 222 // keep alive message every interval 223 if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 { 224 c.mu.Lock() 225 c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval) 226 c.mu.Unlock() 227 228 go c.keepAlive(ctx) 229 } 230 231 // If we're running in graphql-transport-ws mode, create a timer that will 232 // trigger a ping message every interval 233 if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 { 234 c.mu.Lock() 235 c.pingPongTicker = time.NewTicker(c.PingPongInterval) 236 c.mu.Unlock() 237 238 // Note: when the connection is closed by this deadline, the client 239 // will receive an "invalid close code" 240 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) 241 go c.ping(ctx) 242 } 243 244 // Close the connection when the context is cancelled. 245 // Will optionally send a "close reason" that is retrieved from the context. 246 go c.closeOnCancel(ctx) 247 248 for { 249 start := graphql.Now() 250 m, err := c.me.NextMessage() 251 if err != nil { 252 // If the connection got closed by us, don't report the error 253 if !errors.Is(err, net.ErrClosed) { 254 c.handlePossibleError(err, true) 255 } 256 return 257 } 258 259 switch m.t { 260 case startMessageType: 261 c.subscribe(start, &m) 262 case stopMessageType: 263 c.mu.Lock() 264 closer := c.active[m.id] 265 c.mu.Unlock() 266 if closer != nil { 267 closer() 268 } 269 case connectionCloseMessageType: 270 c.close(websocket.CloseNormalClosure, "terminated") 271 return 272 case pingMessageType: 273 c.write(&message{t: pongMessageType, payload: m.payload}) 274 case pongMessageType: 275 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) 276 default: 277 c.sendConnectionError("unexpected message %s", m.t) 278 c.close(websocket.CloseProtocolError, "unexpected message") 279 return 280 } 281 } 282 } 283 284 func (c *wsConnection) keepAlive(ctx context.Context) { 285 for { 286 select { 287 case <-ctx.Done(): 288 c.keepAliveTicker.Stop() 289 return 290 case <-c.keepAliveTicker.C: 291 c.write(&message{t: keepAliveMessageType}) 292 } 293 } 294 } 295 296 func (c *wsConnection) ping(ctx context.Context) { 297 for { 298 select { 299 case <-ctx.Done(): 300 c.pingPongTicker.Stop() 301 return 302 case <-c.pingPongTicker.C: 303 c.write(&message{t: pingMessageType, payload: json.RawMessage{}}) 304 } 305 } 306 } 307 308 func (c *wsConnection) closeOnCancel(ctx context.Context) { 309 <-ctx.Done() 310 311 if r := closeReasonForContext(ctx); r != "" { 312 c.sendConnectionError(r) 313 } 314 c.close(websocket.CloseNormalClosure, "terminated") 315 } 316 317 func (c *wsConnection) subscribe(start time.Time, msg *message) { 318 ctx := graphql.StartOperationTrace(c.ctx) 319 var params *graphql.RawParams 320 if err := jsonDecode(bytes.NewReader(msg.payload), ¶ms); err != nil { 321 c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"}) 322 c.complete(msg.id) 323 return 324 } 325 326 params.ReadTime = graphql.TraceTiming{ 327 Start: start, 328 End: graphql.Now(), 329 } 330 331 rc, err := c.exec.CreateOperationContext(ctx, params) 332 if err != nil { 333 resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) 334 switch errcode.GetErrorKind(err) { 335 case errcode.KindProtocol: 336 c.sendError(msg.id, resp.Errors...) 337 default: 338 c.sendResponse(msg.id, &graphql.Response{Errors: err}) 339 } 340 341 c.complete(msg.id) 342 return 343 } 344 345 ctx = graphql.WithOperationContext(ctx, rc) 346 347 if c.initPayload != nil { 348 ctx = withInitPayload(ctx, c.initPayload) 349 } 350 351 ctx, cancel := context.WithCancel(ctx) 352 c.mu.Lock() 353 c.active[msg.id] = cancel 354 c.mu.Unlock() 355 356 go func() { 357 ctx = withSubscriptionErrorContext(ctx) 358 defer func() { 359 if r := recover(); r != nil { 360 err := rc.Recover(ctx, r) 361 var gqlerr *gqlerror.Error 362 if !errors.As(err, &gqlerr) { 363 gqlerr = &gqlerror.Error{} 364 if err != nil { 365 gqlerr.Message = err.Error() 366 } 367 } 368 c.sendError(msg.id, gqlerr) 369 } 370 if errs := getSubscriptionError(ctx); len(errs) != 0 { 371 c.sendError(msg.id, errs...) 372 } else { 373 c.complete(msg.id) 374 } 375 c.mu.Lock() 376 delete(c.active, msg.id) 377 c.mu.Unlock() 378 cancel() 379 }() 380 381 responses, ctx := c.exec.DispatchOperation(ctx, rc) 382 for { 383 response := responses(ctx) 384 if response == nil { 385 break 386 } 387 388 c.sendResponse(msg.id, response) 389 } 390 391 // complete and context cancel comes from the defer 392 }() 393 } 394 395 func (c *wsConnection) sendResponse(id string, response *graphql.Response) { 396 b, err := json.Marshal(response) 397 if err != nil { 398 panic(err) 399 } 400 c.write(&message{ 401 payload: b, 402 id: id, 403 t: dataMessageType, 404 }) 405 } 406 407 func (c *wsConnection) complete(id string) { 408 c.write(&message{id: id, t: completeMessageType}) 409 } 410 411 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 412 errs := make([]error, len(errors)) 413 for i, err := range errors { 414 errs[i] = err 415 } 416 b, err := json.Marshal(errs) 417 if err != nil { 418 panic(err) 419 } 420 c.write(&message{t: errorMessageType, id: id, payload: b}) 421 } 422 423 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 424 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 425 if err != nil { 426 panic(err) 427 } 428 429 c.write(&message{t: connectionErrorMessageType, payload: b}) 430 } 431 432 func (c *wsConnection) close(closeCode int, message string) { 433 c.mu.Lock() 434 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 435 for _, closer := range c.active { 436 closer() 437 } 438 c.mu.Unlock() 439 _ = c.conn.Close() 440 441 if c.CloseFunc != nil { 442 c.CloseFunc(c.ctx, closeCode) 443 } 444 }