github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/graphql/handler/transport/websocket.go (about) 1 package transport 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "sync" 11 "time" 12 13 "github.com/99designs/gqlgen/graphql" 14 "github.com/99designs/gqlgen/graphql/errcode" 15 "github.com/gorilla/websocket" 16 "github.com/vektah/gqlparser/v2/gqlerror" 17 ) 18 19 const ( 20 connectionInitMsg = "connection_init" // Client -> Server 21 connectionTerminateMsg = "connection_terminate" // Client -> Server 22 startMsg = "start" // Client -> Server 23 stopMsg = "stop" // Client -> Server 24 connectionAckMsg = "connection_ack" // Server -> Client 25 connectionErrorMsg = "connection_error" // Server -> Client 26 dataMsg = "data" // Server -> Client 27 errorMsg = "error" // Server -> Client 28 completeMsg = "complete" // Server -> Client 29 connectionKeepAliveMsg = "ka" // Server -> Client 30 ) 31 32 type ( 33 Websocket struct { 34 Upgrader websocket.Upgrader 35 InitFunc WebsocketInitFunc 36 KeepAlivePingInterval time.Duration 37 PingPongInterval time.Duration 38 } 39 wsConnection struct { 40 Websocket 41 ctx context.Context 42 conn *websocket.Conn 43 active map[string]context.CancelFunc 44 mu sync.Mutex 45 keepAliveTicker *time.Ticker 46 pingPongTicker *time.Ticker 47 exec graphql.GraphExecutor 48 49 initPayload InitPayload 50 } 51 operationMessage struct { 52 Payload json.RawMessage `json:"payload,omitempty"` 53 ID string `json:"id,omitempty"` 54 Type string `json:"type"` 55 } 56 WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) 57 ) 58 59 var _ graphql.Transport = Websocket{} 60 61 func (t Websocket) Supports(r *http.Request) bool { 62 return r.Header.Get("Upgrade") != "" 63 } 64 65 func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 66 ws, err := t.Upgrader.Upgrade(w, r, http.Header{ 67 "Sec-Websocket-Protocol": []string{"graphql-ws"}, 68 }) 69 if err != nil { 70 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 71 SendErrorf(w, http.StatusBadRequest, "unable to upgrade") 72 return 73 } 74 75 conn := wsConnection{ 76 active: map[string]context.CancelFunc{}, 77 conn: ws, 78 ctx: r.Context(), 79 exec: exec, 80 Websocket: t, 81 } 82 83 if !conn.init() { 84 return 85 } 86 87 conn.run() 88 } 89 90 func (c *wsConnection) init() bool { 91 message := c.readOp() 92 if message == nil { 93 c.close(websocket.CloseProtocolError, "decoding error") 94 return false 95 } 96 97 switch message.Type { 98 case connectionInitMsg: 99 if len(message.Payload) > 0 { 100 c.initPayload = make(InitPayload) 101 err := json.Unmarshal(message.Payload, &c.initPayload) 102 if err != nil { 103 return false 104 } 105 } 106 107 if c.InitFunc != nil { 108 ctx, err := c.InitFunc(c.ctx, c.initPayload) 109 if err != nil { 110 c.sendConnectionError(err.Error()) 111 c.close(websocket.CloseNormalClosure, "terminated") 112 return false 113 } 114 c.ctx = ctx 115 } 116 117 c.write(&operationMessage{Type: connectionAckMsg}) 118 c.write(&operationMessage{Type: connectionKeepAliveMsg}) 119 case connectionTerminateMsg: 120 c.close(websocket.CloseNormalClosure, "terminated") 121 return false 122 default: 123 c.sendConnectionError("unexpected message %s", message.Type) 124 c.close(websocket.CloseProtocolError, "unexpected message") 125 return false 126 } 127 128 return true 129 } 130 131 func (c *wsConnection) write(msg *operationMessage) { 132 c.mu.Lock() 133 c.conn.WriteJSON(msg) 134 c.mu.Unlock() 135 } 136 137 func (c *wsConnection) run() { 138 // We create a cancellation that will shutdown the keep-alive when we leave 139 // this function. 140 ctx, cancel := context.WithCancel(c.ctx) 141 defer func() { 142 cancel() 143 }() 144 145 // Create a timer that will fire every interval to keep the connection alive. 146 if c.KeepAlivePingInterval != 0 { 147 c.mu.Lock() 148 c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval) 149 c.mu.Unlock() 150 go c.keepAlive(ctx) 151 } 152 153 // Create a timer that will fire every interval a ping message that should 154 // receive a pong (SetPongHandler in init() function) 155 if c.PingPongInterval != 0 { 156 157 pongWait := 2 * c.PingPongInterval 158 c.conn.SetReadDeadline(time.Now().Add(pongWait)) 159 c.conn.SetPongHandler(func(string) error { 160 return c.conn.SetReadDeadline(time.Now().UTC().Add(pongWait)) 161 }) 162 163 c.mu.Lock() 164 c.pingPongTicker = time.NewTicker(c.PingPongInterval) 165 c.mu.Unlock() 166 167 go c.ping(ctx) 168 } 169 170 for { 171 start := graphql.Now() 172 message := c.readOp() 173 if message == nil { 174 c.close(websocket.CloseAbnormalClosure, "unexpected closure") 175 return 176 } 177 178 switch message.Type { 179 case startMsg: 180 c.subscribe(start, message) 181 case stopMsg: 182 c.mu.Lock() 183 closer := c.active[message.ID] 184 c.mu.Unlock() 185 if closer != nil { 186 closer() 187 } 188 case connectionTerminateMsg: 189 c.close(websocket.CloseNormalClosure, "terminated") 190 return 191 default: 192 c.sendConnectionError("unexpected message %s", message.Type) 193 c.close(websocket.CloseProtocolError, "unexpected message") 194 return 195 } 196 } 197 } 198 199 func (c *wsConnection) keepAlive(ctx context.Context) { 200 for { 201 select { 202 case <-ctx.Done(): 203 c.keepAliveTicker.Stop() 204 return 205 case <-c.keepAliveTicker.C: 206 c.write(&operationMessage{Type: connectionKeepAliveMsg}) 207 } 208 } 209 } 210 211 func (c *wsConnection) ping(ctx context.Context) { 212 for { 213 select { 214 case <-ctx.Done(): 215 c.pingPongTicker.Stop() 216 return 217 case <-c.pingPongTicker.C: 218 c.mu.Lock() 219 c.conn.WriteMessage(websocket.PingMessage, nil) 220 c.mu.Unlock() 221 } 222 } 223 } 224 225 func (c *wsConnection) subscribe(start time.Time, message *operationMessage) { 226 ctx := graphql.StartOperationTrace(c.ctx) 227 var params *graphql.RawParams 228 if err := jsonDecode(bytes.NewReader(message.Payload), ¶ms); err != nil { 229 c.sendError(message.ID, &gqlerror.Error{Message: "invalid json"}) 230 c.complete(message.ID) 231 return 232 } 233 234 params.ReadTime = graphql.TraceTiming{ 235 Start: start, 236 End: graphql.Now(), 237 } 238 239 rc, err := c.exec.CreateOperationContext(ctx, params) 240 if err != nil { 241 resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) 242 switch errcode.GetErrorKind(err) { 243 case errcode.KindProtocol: 244 c.sendError(message.ID, resp.Errors...) 245 default: 246 c.sendResponse(message.ID, &graphql.Response{Errors: err}) 247 } 248 249 c.complete(message.ID) 250 return 251 } 252 253 ctx = graphql.WithOperationContext(ctx, rc) 254 255 if c.initPayload != nil { 256 ctx = withInitPayload(ctx, c.initPayload) 257 } 258 259 ctx, cancel := context.WithCancel(ctx) 260 c.mu.Lock() 261 c.active[message.ID] = cancel 262 c.mu.Unlock() 263 264 go func() { 265 defer func() { 266 if r := recover(); r != nil { 267 userErr := rc.Recover(ctx, r) 268 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()}) 269 } 270 }() 271 responses, ctx := c.exec.DispatchOperation(ctx, rc) 272 for { 273 response := responses(ctx) 274 if response == nil { 275 break 276 } 277 278 c.sendResponse(message.ID, response) 279 } 280 c.complete(message.ID) 281 282 c.mu.Lock() 283 delete(c.active, message.ID) 284 c.mu.Unlock() 285 cancel() 286 }() 287 } 288 289 func (c *wsConnection) sendResponse(id string, response *graphql.Response) { 290 b, err := json.Marshal(response) 291 if err != nil { 292 panic(err) 293 } 294 c.write(&operationMessage{ 295 Payload: b, 296 ID: id, 297 Type: dataMsg, 298 }) 299 } 300 301 func (c *wsConnection) complete(id string) { 302 c.write(&operationMessage{ID: id, Type: completeMsg}) 303 } 304 305 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 306 errs := make([]error, len(errors)) 307 for i, err := range errors { 308 errs[i] = err 309 } 310 b, err := json.Marshal(errs) 311 if err != nil { 312 panic(err) 313 } 314 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b}) 315 } 316 317 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 318 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 319 if err != nil { 320 panic(err) 321 } 322 323 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b}) 324 } 325 326 func (c *wsConnection) readOp() *operationMessage { 327 _, r, err := c.conn.NextReader() 328 if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { 329 return nil 330 } else if err != nil { 331 c.sendConnectionError("invalid json: %T %s", err, err.Error()) 332 return nil 333 } 334 message := operationMessage{} 335 if err := jsonDecode(r, &message); err != nil { 336 c.sendConnectionError("invalid json") 337 return nil 338 } 339 340 return &message 341 } 342 343 func (c *wsConnection) close(closeCode int, message string) { 344 c.mu.Lock() 345 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 346 for _, closer := range c.active { 347 closer() 348 } 349 c.mu.Unlock() 350 _ = c.conn.Close() 351 }