github.com/animeshon/gqlgen@v0.13.1-0.20210304133704-3a770431bb6d/graphql/handler/transport/websocket.go (about) 1 package transport 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "sync" 11 "time" 12 13 "github.com/animeshon/gqlgen/graphql" 14 "github.com/animeshon/gqlgen/graphql/errcode" 15 "github.com/gorilla/websocket" 16 "github.com/vektah/gqlparser/v2/gqlerror" 17 ) 18 19 const ( 20 connectionInitMsg = "connection_init" // Client -> Server 21 connectionTerminateMsg = "connection_terminate" // Client -> Server 22 startMsg = "start" // Client -> Server 23 stopMsg = "stop" // Client -> Server 24 connectionAckMsg = "connection_ack" // Server -> Client 25 connectionErrorMsg = "connection_error" // Server -> Client 26 dataMsg = "data" // Server -> Client 27 errorMsg = "error" // Server -> Client 28 completeMsg = "complete" // Server -> Client 29 connectionKeepAliveMsg = "ka" // Server -> Client 30 ) 31 32 type ( 33 Websocket struct { 34 Upgrader websocket.Upgrader 35 InitFunc WebsocketInitFunc 36 KeepAlivePingInterval time.Duration 37 } 38 wsConnection struct { 39 Websocket 40 ctx context.Context 41 conn *websocket.Conn 42 active map[string]context.CancelFunc 43 mu sync.Mutex 44 keepAliveTicker *time.Ticker 45 exec graphql.GraphExecutor 46 47 initPayload InitPayload 48 } 49 operationMessage struct { 50 Payload json.RawMessage `json:"payload,omitempty"` 51 ID string `json:"id,omitempty"` 52 Type string `json:"type"` 53 } 54 WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) 55 ) 56 57 var _ graphql.Transport = Websocket{} 58 59 func (t Websocket) Supports(r *http.Request) bool { 60 return r.Header.Get("Upgrade") != "" 61 } 62 63 func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 64 ws, err := t.Upgrader.Upgrade(w, r, http.Header{ 65 "Sec-Websocket-Protocol": []string{"graphql-ws"}, 66 }) 67 if err != nil { 68 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 69 SendErrorf(w, http.StatusBadRequest, "unable to upgrade") 70 return 71 } 72 73 conn := wsConnection{ 74 active: map[string]context.CancelFunc{}, 75 conn: ws, 76 ctx: r.Context(), 77 exec: exec, 78 Websocket: t, 79 } 80 81 if !conn.init() { 82 return 83 } 84 85 conn.run() 86 } 87 88 func (c *wsConnection) init() bool { 89 message := c.readOp() 90 if message == nil { 91 c.close(websocket.CloseProtocolError, "decoding error") 92 return false 93 } 94 95 switch message.Type { 96 case connectionInitMsg: 97 if len(message.Payload) > 0 { 98 c.initPayload = make(InitPayload) 99 err := json.Unmarshal(message.Payload, &c.initPayload) 100 if err != nil { 101 return false 102 } 103 } 104 105 if c.InitFunc != nil { 106 ctx, err := c.InitFunc(c.ctx, c.initPayload) 107 if err != nil { 108 c.sendConnectionError(err.Error()) 109 c.close(websocket.CloseNormalClosure, "terminated") 110 return false 111 } 112 c.ctx = ctx 113 } 114 115 c.write(&operationMessage{Type: connectionAckMsg}) 116 c.write(&operationMessage{Type: connectionKeepAliveMsg}) 117 case connectionTerminateMsg: 118 c.close(websocket.CloseNormalClosure, "terminated") 119 return false 120 default: 121 c.sendConnectionError("unexpected message %s", message.Type) 122 c.close(websocket.CloseProtocolError, "unexpected message") 123 return false 124 } 125 126 return true 127 } 128 129 func (c *wsConnection) write(msg *operationMessage) { 130 c.mu.Lock() 131 c.conn.WriteJSON(msg) 132 c.mu.Unlock() 133 } 134 135 func (c *wsConnection) run() { 136 // We create a cancellation that will shutdown the keep-alive when we leave 137 // this function. 138 ctx, cancel := context.WithCancel(c.ctx) 139 defer func() { 140 cancel() 141 c.close(websocket.CloseAbnormalClosure, "unexpected closure") 142 }() 143 144 // Create a timer that will fire every interval to keep the connection alive. 145 if c.KeepAlivePingInterval != 0 { 146 c.mu.Lock() 147 c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval) 148 c.mu.Unlock() 149 150 go c.keepAlive(ctx) 151 } 152 153 for { 154 start := graphql.Now() 155 message := c.readOp() 156 if message == nil { 157 return 158 } 159 160 switch message.Type { 161 case startMsg: 162 c.subscribe(start, message) 163 case stopMsg: 164 c.mu.Lock() 165 closer := c.active[message.ID] 166 c.mu.Unlock() 167 if closer != nil { 168 closer() 169 } 170 case connectionTerminateMsg: 171 c.close(websocket.CloseNormalClosure, "terminated") 172 return 173 default: 174 c.sendConnectionError("unexpected message %s", message.Type) 175 c.close(websocket.CloseProtocolError, "unexpected message") 176 return 177 } 178 } 179 } 180 181 func (c *wsConnection) keepAlive(ctx context.Context) { 182 for { 183 select { 184 case <-ctx.Done(): 185 c.keepAliveTicker.Stop() 186 return 187 case <-c.keepAliveTicker.C: 188 c.write(&operationMessage{Type: connectionKeepAliveMsg}) 189 } 190 } 191 } 192 193 func (c *wsConnection) subscribe(start time.Time, message *operationMessage) { 194 ctx := graphql.StartOperationTrace(c.ctx) 195 var params *graphql.RawParams 196 if err := jsonDecode(bytes.NewReader(message.Payload), ¶ms); err != nil { 197 c.sendError(message.ID, &gqlerror.Error{Message: "invalid json"}) 198 c.complete(message.ID) 199 return 200 } 201 202 params.ReadTime = graphql.TraceTiming{ 203 Start: start, 204 End: graphql.Now(), 205 } 206 207 rc, err := c.exec.CreateOperationContext(ctx, params) 208 if err != nil { 209 resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) 210 switch errcode.GetErrorKind(err) { 211 case errcode.KindProtocol: 212 c.sendError(message.ID, resp.Errors...) 213 default: 214 c.sendResponse(message.ID, &graphql.Response{Errors: err}) 215 } 216 217 c.complete(message.ID) 218 return 219 } 220 221 ctx = graphql.WithOperationContext(ctx, rc) 222 223 if c.initPayload != nil { 224 ctx = withInitPayload(ctx, c.initPayload) 225 } 226 227 ctx, cancel := context.WithCancel(ctx) 228 c.mu.Lock() 229 c.active[message.ID] = cancel 230 c.mu.Unlock() 231 232 go func() { 233 defer func() { 234 if r := recover(); r != nil { 235 userErr := rc.Recover(ctx, r) 236 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()}) 237 } 238 }() 239 responses, ctx := c.exec.DispatchOperation(ctx, rc) 240 for { 241 response := responses(ctx) 242 if response == nil { 243 break 244 } 245 246 c.sendResponse(message.ID, response) 247 } 248 c.complete(message.ID) 249 250 c.mu.Lock() 251 delete(c.active, message.ID) 252 c.mu.Unlock() 253 cancel() 254 }() 255 } 256 257 func (c *wsConnection) sendResponse(id string, response *graphql.Response) { 258 b, err := json.Marshal(response) 259 if err != nil { 260 panic(err) 261 } 262 c.write(&operationMessage{ 263 Payload: b, 264 ID: id, 265 Type: dataMsg, 266 }) 267 } 268 269 func (c *wsConnection) complete(id string) { 270 c.write(&operationMessage{ID: id, Type: completeMsg}) 271 } 272 273 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 274 errs := make([]error, len(errors)) 275 for i, err := range errors { 276 errs[i] = err 277 } 278 b, err := json.Marshal(errs) 279 if err != nil { 280 panic(err) 281 } 282 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b}) 283 } 284 285 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 286 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 287 if err != nil { 288 panic(err) 289 } 290 291 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b}) 292 } 293 294 func (c *wsConnection) readOp() *operationMessage { 295 _, r, err := c.conn.NextReader() 296 if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { 297 return nil 298 } else if err != nil { 299 c.sendConnectionError("invalid json: %T %s", err, err.Error()) 300 return nil 301 } 302 message := operationMessage{} 303 if err := jsonDecode(r, &message); err != nil { 304 c.sendConnectionError("invalid json") 305 return nil 306 } 307 308 return &message 309 } 310 311 func (c *wsConnection) close(closeCode int, message string) { 312 c.mu.Lock() 313 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 314 c.mu.Unlock() 315 _ = c.conn.Close() 316 }