github.com/apipluspower/gqlgen@v0.15.2/graphql/handler/transport/websocket.go (about) 1 package transport 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "log" 10 "net/http" 11 "sync" 12 "time" 13 14 "github.com/apipluspower/gqlgen/graphql" 15 "github.com/apipluspower/gqlgen/graphql/errcode" 16 "github.com/gorilla/websocket" 17 "github.com/vektah/gqlparser/v2/gqlerror" 18 ) 19 20 type ( 21 Websocket struct { 22 Upgrader websocket.Upgrader 23 InitFunc WebsocketInitFunc 24 KeepAlivePingInterval time.Duration 25 PingPongInterval time.Duration 26 27 didInjectSubprotocols bool 28 } 29 wsConnection struct { 30 Websocket 31 ctx context.Context 32 conn *websocket.Conn 33 me messageExchanger 34 active map[string]context.CancelFunc 35 mu sync.Mutex 36 keepAliveTicker *time.Ticker 37 pingPongTicker *time.Ticker 38 exec graphql.GraphExecutor 39 40 initPayload InitPayload 41 } 42 43 WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) 44 ) 45 46 var _ graphql.Transport = Websocket{} 47 48 func (t Websocket) Supports(r *http.Request) bool { 49 return r.Header.Get("Upgrade") != "" 50 } 51 52 func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 53 t.injectGraphQLWSSubprotocols() 54 ws, err := t.Upgrader.Upgrade(w, r, http.Header{}) 55 if err != nil { 56 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 57 SendErrorf(w, http.StatusBadRequest, "unable to upgrade") 58 return 59 } 60 61 var me messageExchanger 62 switch ws.Subprotocol() { 63 default: 64 msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol())) 65 ws.WriteMessage(websocket.CloseMessage, msg) 66 return 67 case graphqlwsSubprotocol, "": 68 // clients are required to send a subprotocol, to be backward compatible with the previous implementation we select 69 // "graphql-ws" by default 70 me = graphqlwsMessageExchanger{c: ws} 71 case graphqltransportwsSubprotocol: 72 me = graphqltransportwsMessageExchanger{c: ws} 73 } 74 75 conn := wsConnection{ 76 active: map[string]context.CancelFunc{}, 77 conn: ws, 78 ctx: r.Context(), 79 exec: exec, 80 me: me, 81 Websocket: t, 82 } 83 84 if !conn.init() { 85 return 86 } 87 88 conn.run() 89 } 90 91 func (c *wsConnection) init() bool { 92 m, err := c.me.NextMessage() 93 if err != nil { 94 if err == errInvalidMsg { 95 c.sendConnectionError("invalid json") 96 } 97 98 c.close(websocket.CloseProtocolError, "decoding error") 99 return false 100 } 101 102 switch m.t { 103 case initMessageType: 104 if len(m.payload) > 0 { 105 c.initPayload = make(InitPayload) 106 err := json.Unmarshal(m.payload, &c.initPayload) 107 if err != nil { 108 return false 109 } 110 } 111 112 if c.InitFunc != nil { 113 ctx, err := c.InitFunc(c.ctx, c.initPayload) 114 if err != nil { 115 c.sendConnectionError(err.Error()) 116 c.close(websocket.CloseNormalClosure, "terminated") 117 return false 118 } 119 c.ctx = ctx 120 } 121 122 c.write(&message{t: connectionAckMessageType}) 123 c.write(&message{t: keepAliveMessageType}) 124 case connectionCloseMessageType: 125 c.close(websocket.CloseNormalClosure, "terminated") 126 return false 127 default: 128 c.sendConnectionError("unexpected message %s", m.t) 129 c.close(websocket.CloseProtocolError, "unexpected message") 130 return false 131 } 132 133 return true 134 } 135 136 func (c *wsConnection) write(msg *message) { 137 c.mu.Lock() 138 // TODO: missing error handling here, err from previous implementation 139 // was ignored 140 _ = c.me.Send(msg) 141 c.mu.Unlock() 142 } 143 144 func (c *wsConnection) run() { 145 // We create a cancellation that will shutdown the keep-alive when we leave 146 // this function. 147 ctx, cancel := context.WithCancel(c.ctx) 148 defer func() { 149 cancel() 150 c.close(websocket.CloseAbnormalClosure, "unexpected closure") 151 }() 152 153 // Create a timer that will fire every interval to keep the connection alive. 154 if c.KeepAlivePingInterval != 0 { 155 c.mu.Lock() 156 c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval) 157 c.mu.Unlock() 158 159 go c.keepAlive(ctx) 160 } 161 162 // Create a timer that will fire every interval a ping message that should 163 // receive a pong (SetPongHandler in init() function) 164 if c.PingPongInterval != 0 { 165 c.mu.Lock() 166 c.pingPongTicker = time.NewTicker(c.PingPongInterval) 167 c.mu.Unlock() 168 169 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) 170 go c.ping(ctx) 171 } 172 173 // Close the connection when the context is cancelled. 174 // Will optionally send a "close reason" that is retrieved from the context. 175 go c.closeOnCancel(ctx) 176 177 for { 178 start := graphql.Now() 179 m, err := c.me.NextMessage() 180 if err != nil { 181 // TODO: better error handling here 182 return 183 } 184 185 switch m.t { 186 case startMessageType: 187 c.subscribe(start, &m) 188 case stopMessageType: 189 c.mu.Lock() 190 closer := c.active[m.id] 191 c.mu.Unlock() 192 if closer != nil { 193 closer() 194 } 195 case connectionCloseMessageType: 196 c.close(websocket.CloseNormalClosure, "terminated") 197 return 198 case pingMesageType: 199 c.write(&message{t: pongMessageType, payload: m.payload}) 200 case pongMessageType: 201 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) 202 default: 203 c.sendConnectionError("unexpected message %s", m.t) 204 c.close(websocket.CloseProtocolError, "unexpected message") 205 return 206 } 207 } 208 } 209 210 func (c *wsConnection) keepAlive(ctx context.Context) { 211 for { 212 select { 213 case <-ctx.Done(): 214 c.keepAliveTicker.Stop() 215 return 216 case <-c.keepAliveTicker.C: 217 c.write(&message{t: keepAliveMessageType}) 218 } 219 } 220 } 221 222 func (c *wsConnection) ping(ctx context.Context) { 223 for { 224 select { 225 case <-ctx.Done(): 226 c.pingPongTicker.Stop() 227 return 228 case <-c.pingPongTicker.C: 229 c.write(&message{t: pingMesageType, payload: json.RawMessage{}}) 230 } 231 } 232 } 233 234 func (c *wsConnection) closeOnCancel(ctx context.Context) { 235 <-ctx.Done() 236 237 if r := closeReasonForContext(ctx); r != "" { 238 c.sendConnectionError(r) 239 } 240 c.close(websocket.CloseNormalClosure, "terminated") 241 } 242 243 func (c *wsConnection) subscribe(start time.Time, msg *message) { 244 ctx := graphql.StartOperationTrace(c.ctx) 245 var params *graphql.RawParams 246 if err := jsonDecode(bytes.NewReader(msg.payload), ¶ms); err != nil { 247 c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"}) 248 c.complete(msg.id) 249 return 250 } 251 252 params.ReadTime = graphql.TraceTiming{ 253 Start: start, 254 End: graphql.Now(), 255 } 256 257 rc, err := c.exec.CreateOperationContext(ctx, params) 258 if err != nil { 259 resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) 260 switch errcode.GetErrorKind(err) { 261 case errcode.KindProtocol: 262 c.sendError(msg.id, resp.Errors...) 263 default: 264 c.sendResponse(msg.id, &graphql.Response{Errors: err}) 265 } 266 267 c.complete(msg.id) 268 return 269 } 270 271 ctx = graphql.WithOperationContext(ctx, rc) 272 273 if c.initPayload != nil { 274 ctx = withInitPayload(ctx, c.initPayload) 275 } 276 277 ctx, cancel := context.WithCancel(ctx) 278 c.mu.Lock() 279 c.active[msg.id] = cancel 280 c.mu.Unlock() 281 282 go func() { 283 defer func() { 284 if r := recover(); r != nil { 285 err := rc.Recover(ctx, r) 286 var gqlerr *gqlerror.Error 287 if !errors.As(err, &gqlerr) { 288 gqlerr = &gqlerror.Error{} 289 if err != nil { 290 gqlerr.Message = err.Error() 291 } 292 } 293 c.sendError(msg.id, gqlerr) 294 } 295 c.complete(msg.id) 296 c.mu.Lock() 297 delete(c.active, msg.id) 298 c.mu.Unlock() 299 cancel() 300 }() 301 302 responses, ctx := c.exec.DispatchOperation(ctx, rc) 303 for { 304 response := responses(ctx) 305 if response == nil { 306 break 307 } 308 309 c.sendResponse(msg.id, response) 310 } 311 c.complete(msg.id) 312 313 c.mu.Lock() 314 delete(c.active, msg.id) 315 c.mu.Unlock() 316 cancel() 317 }() 318 } 319 320 func (c *wsConnection) sendResponse(id string, response *graphql.Response) { 321 b, err := json.Marshal(response) 322 if err != nil { 323 panic(err) 324 } 325 c.write(&message{ 326 payload: b, 327 id: id, 328 t: dataMessageType, 329 }) 330 } 331 332 func (c *wsConnection) complete(id string) { 333 c.write(&message{id: id, t: completeMessageType}) 334 } 335 336 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 337 errs := make([]error, len(errors)) 338 for i, err := range errors { 339 errs[i] = err 340 } 341 b, err := json.Marshal(errs) 342 if err != nil { 343 panic(err) 344 } 345 c.write(&message{t: errorMessageType, id: id, payload: b}) 346 } 347 348 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 349 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 350 if err != nil { 351 panic(err) 352 } 353 354 c.write(&message{t: connectionErrorMessageType, payload: b}) 355 } 356 357 func (c *wsConnection) close(closeCode int, message string) { 358 c.mu.Lock() 359 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 360 for _, closer := range c.active { 361 closer() 362 } 363 c.mu.Unlock() 364 _ = c.conn.Close() 365 }