github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/graphql/handler/transport/websocket.go (about) 1 package transport 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "log" 10 "net" 11 "net/http" 12 "sync" 13 "time" 14 15 "github.com/gorilla/websocket" 16 "github.com/mstephano/gqlgen-schemagen/graphql" 17 "github.com/mstephano/gqlgen-schemagen/graphql/errcode" 18 "github.com/vektah/gqlparser/v2/gqlerror" 19 ) 20 21 type ( 22 Websocket struct { 23 Upgrader websocket.Upgrader 24 InitFunc WebsocketInitFunc 25 InitTimeout time.Duration 26 ErrorFunc WebsocketErrorFunc 27 KeepAlivePingInterval time.Duration 28 PingPongInterval time.Duration 29 30 didInjectSubprotocols bool 31 } 32 wsConnection struct { 33 Websocket 34 ctx context.Context 35 conn *websocket.Conn 36 me messageExchanger 37 active map[string]context.CancelFunc 38 mu sync.Mutex 39 keepAliveTicker *time.Ticker 40 pingPongTicker *time.Ticker 41 exec graphql.GraphExecutor 42 43 initPayload InitPayload 44 } 45 46 WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) 47 WebsocketErrorFunc func(ctx context.Context, err error) 48 ) 49 50 var errReadTimeout = errors.New("read timeout") 51 52 type WebsocketError struct { 53 Err error 54 55 // IsReadError flags whether the error occurred on read or write to the websocket 56 IsReadError bool 57 } 58 59 func (e WebsocketError) Error() string { 60 if e.IsReadError { 61 return fmt.Sprintf("websocket read: %v", e.Err) 62 } 63 return fmt.Sprintf("websocket write: %v", e.Err) 64 } 65 66 var ( 67 _ graphql.Transport = Websocket{} 68 _ error = WebsocketError{} 69 ) 70 71 func (t Websocket) Supports(r *http.Request) bool { 72 return r.Header.Get("Upgrade") != "" 73 } 74 75 func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 76 t.injectGraphQLWSSubprotocols() 77 ws, err := t.Upgrader.Upgrade(w, r, http.Header{}) 78 if err != nil { 79 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 80 SendErrorf(w, http.StatusBadRequest, "unable to upgrade") 81 return 82 } 83 84 var me messageExchanger 85 switch ws.Subprotocol() { 86 default: 87 msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol())) 88 ws.WriteMessage(websocket.CloseMessage, msg) 89 return 90 case graphqlwsSubprotocol, "": 91 // clients are required to send a subprotocol, to be backward compatible with the previous implementation we select 92 // "graphql-ws" by default 93 me = graphqlwsMessageExchanger{c: ws} 94 case graphqltransportwsSubprotocol: 95 me = graphqltransportwsMessageExchanger{c: ws} 96 } 97 98 conn := wsConnection{ 99 active: map[string]context.CancelFunc{}, 100 conn: ws, 101 ctx: r.Context(), 102 exec: exec, 103 me: me, 104 Websocket: t, 105 } 106 107 if !conn.init() { 108 return 109 } 110 111 conn.run() 112 } 113 114 func (c *wsConnection) handlePossibleError(err error, isReadError bool) { 115 if c.ErrorFunc != nil && err != nil { 116 c.ErrorFunc(c.ctx, WebsocketError{ 117 Err: err, 118 IsReadError: isReadError, 119 }) 120 } 121 } 122 123 func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) { 124 messages, errs := make(chan message, 1), make(chan error, 1) 125 126 go func() { 127 if m, err := c.me.NextMessage(); err != nil { 128 errs <- err 129 } else { 130 messages <- m 131 } 132 }() 133 134 select { 135 case m := <-messages: 136 return m, nil 137 case err := <-errs: 138 return message{}, err 139 case <-time.After(timeout): 140 return message{}, errReadTimeout 141 } 142 } 143 144 func (c *wsConnection) init() bool { 145 var m message 146 var err error 147 148 if c.InitTimeout != 0 { 149 m, err = c.nextMessageWithTimeout(c.InitTimeout) 150 } else { 151 m, err = c.me.NextMessage() 152 } 153 154 if err != nil { 155 if err == errReadTimeout { 156 c.close(websocket.CloseProtocolError, "connection initialisation timeout") 157 return false 158 } 159 160 if err == errInvalidMsg { 161 c.sendConnectionError("invalid json") 162 } 163 164 c.close(websocket.CloseProtocolError, "decoding error") 165 return false 166 } 167 168 switch m.t { 169 case initMessageType: 170 if len(m.payload) > 0 { 171 c.initPayload = make(InitPayload) 172 err := json.Unmarshal(m.payload, &c.initPayload) 173 if err != nil { 174 return false 175 } 176 } 177 178 if c.InitFunc != nil { 179 ctx, err := c.InitFunc(c.ctx, c.initPayload) 180 if err != nil { 181 c.sendConnectionError(err.Error()) 182 c.close(websocket.CloseNormalClosure, "terminated") 183 return false 184 } 185 c.ctx = ctx 186 } 187 188 c.write(&message{t: connectionAckMessageType}) 189 c.write(&message{t: keepAliveMessageType}) 190 case connectionCloseMessageType: 191 c.close(websocket.CloseNormalClosure, "terminated") 192 return false 193 default: 194 c.sendConnectionError("unexpected message %s", m.t) 195 c.close(websocket.CloseProtocolError, "unexpected message") 196 return false 197 } 198 199 return true 200 } 201 202 func (c *wsConnection) write(msg *message) { 203 c.mu.Lock() 204 c.handlePossibleError(c.me.Send(msg), false) 205 c.mu.Unlock() 206 } 207 208 func (c *wsConnection) run() { 209 // We create a cancellation that will shutdown the keep-alive when we leave 210 // this function. 211 ctx, cancel := context.WithCancel(c.ctx) 212 defer func() { 213 cancel() 214 c.close(websocket.CloseAbnormalClosure, "unexpected closure") 215 }() 216 217 // If we're running in graphql-ws mode, create a timer that will trigger a 218 // keep alive message every interval 219 if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 { 220 c.mu.Lock() 221 c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval) 222 c.mu.Unlock() 223 224 go c.keepAlive(ctx) 225 } 226 227 // If we're running in graphql-transport-ws mode, create a timer that will 228 // trigger a ping message every interval 229 if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 { 230 c.mu.Lock() 231 c.pingPongTicker = time.NewTicker(c.PingPongInterval) 232 c.mu.Unlock() 233 234 // Note: when the connection is closed by this deadline, the client 235 // will receive an "invalid close code" 236 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) 237 go c.ping(ctx) 238 } 239 240 // Close the connection when the context is cancelled. 241 // Will optionally send a "close reason" that is retrieved from the context. 242 go c.closeOnCancel(ctx) 243 244 for { 245 start := graphql.Now() 246 m, err := c.me.NextMessage() 247 if err != nil { 248 // If the connection got closed by us, don't report the error 249 if !errors.Is(err, net.ErrClosed) { 250 c.handlePossibleError(err, true) 251 } 252 return 253 } 254 255 switch m.t { 256 case startMessageType: 257 c.subscribe(start, &m) 258 case stopMessageType: 259 c.mu.Lock() 260 closer := c.active[m.id] 261 c.mu.Unlock() 262 if closer != nil { 263 closer() 264 } 265 case connectionCloseMessageType: 266 c.close(websocket.CloseNormalClosure, "terminated") 267 return 268 case pingMessageType: 269 c.write(&message{t: pongMessageType, payload: m.payload}) 270 case pongMessageType: 271 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) 272 default: 273 c.sendConnectionError("unexpected message %s", m.t) 274 c.close(websocket.CloseProtocolError, "unexpected message") 275 return 276 } 277 } 278 } 279 280 func (c *wsConnection) keepAlive(ctx context.Context) { 281 for { 282 select { 283 case <-ctx.Done(): 284 c.keepAliveTicker.Stop() 285 return 286 case <-c.keepAliveTicker.C: 287 c.write(&message{t: keepAliveMessageType}) 288 } 289 } 290 } 291 292 func (c *wsConnection) ping(ctx context.Context) { 293 for { 294 select { 295 case <-ctx.Done(): 296 c.pingPongTicker.Stop() 297 return 298 case <-c.pingPongTicker.C: 299 c.write(&message{t: pingMessageType, payload: json.RawMessage{}}) 300 } 301 } 302 } 303 304 func (c *wsConnection) closeOnCancel(ctx context.Context) { 305 <-ctx.Done() 306 307 if r := closeReasonForContext(ctx); r != "" { 308 c.sendConnectionError(r) 309 } 310 c.close(websocket.CloseNormalClosure, "terminated") 311 } 312 313 func (c *wsConnection) subscribe(start time.Time, msg *message) { 314 ctx := graphql.StartOperationTrace(c.ctx) 315 var params *graphql.RawParams 316 if err := jsonDecode(bytes.NewReader(msg.payload), ¶ms); err != nil { 317 c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"}) 318 c.complete(msg.id) 319 return 320 } 321 322 params.ReadTime = graphql.TraceTiming{ 323 Start: start, 324 End: graphql.Now(), 325 } 326 327 rc, err := c.exec.CreateOperationContext(ctx, params) 328 if err != nil { 329 resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) 330 switch errcode.GetErrorKind(err) { 331 case errcode.KindProtocol: 332 c.sendError(msg.id, resp.Errors...) 333 default: 334 c.sendResponse(msg.id, &graphql.Response{Errors: err}) 335 } 336 337 c.complete(msg.id) 338 return 339 } 340 341 ctx = graphql.WithOperationContext(ctx, rc) 342 343 if c.initPayload != nil { 344 ctx = withInitPayload(ctx, c.initPayload) 345 } 346 347 ctx, cancel := context.WithCancel(ctx) 348 c.mu.Lock() 349 c.active[msg.id] = cancel 350 c.mu.Unlock() 351 352 go func() { 353 defer func() { 354 if r := recover(); r != nil { 355 err := rc.Recover(ctx, r) 356 var gqlerr *gqlerror.Error 357 if !errors.As(err, &gqlerr) { 358 gqlerr = &gqlerror.Error{} 359 if err != nil { 360 gqlerr.Message = err.Error() 361 } 362 } 363 c.sendError(msg.id, gqlerr) 364 } 365 c.complete(msg.id) 366 c.mu.Lock() 367 delete(c.active, msg.id) 368 c.mu.Unlock() 369 cancel() 370 }() 371 372 responses, ctx := c.exec.DispatchOperation(ctx, rc) 373 for { 374 response := responses(ctx) 375 if response == nil { 376 break 377 } 378 379 c.sendResponse(msg.id, response) 380 } 381 382 // complete and context cancel comes from the defer 383 }() 384 } 385 386 func (c *wsConnection) sendResponse(id string, response *graphql.Response) { 387 b, err := json.Marshal(response) 388 if err != nil { 389 panic(err) 390 } 391 c.write(&message{ 392 payload: b, 393 id: id, 394 t: dataMessageType, 395 }) 396 } 397 398 func (c *wsConnection) complete(id string) { 399 c.write(&message{id: id, t: completeMessageType}) 400 } 401 402 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 403 errs := make([]error, len(errors)) 404 for i, err := range errors { 405 errs[i] = err 406 } 407 b, err := json.Marshal(errs) 408 if err != nil { 409 panic(err) 410 } 411 c.write(&message{t: errorMessageType, id: id, payload: b}) 412 } 413 414 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 415 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 416 if err != nil { 417 panic(err) 418 } 419 420 c.write(&message{t: connectionErrorMessageType, payload: b}) 421 } 422 423 func (c *wsConnection) close(closeCode int, message string) { 424 c.mu.Lock() 425 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 426 for _, closer := range c.active { 427 closer() 428 } 429 c.mu.Unlock() 430 _ = c.conn.Close() 431 }