github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/handler/websocket.go (about) 1 package handler 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "sync" 11 "time" 12 13 "github.com/HaswinVidanage/gqlgen/graphql" 14 "github.com/gorilla/websocket" 15 "github.com/hashicorp/golang-lru" 16 "github.com/vektah/gqlparser" 17 "github.com/vektah/gqlparser/ast" 18 "github.com/vektah/gqlparser/gqlerror" 19 "github.com/vektah/gqlparser/validator" 20 ) 21 22 const ( 23 connectionInitMsg = "connection_init" // Client -> Server 24 connectionTerminateMsg = "connection_terminate" // Client -> Server 25 startMsg = "start" // Client -> Server 26 stopMsg = "stop" // Client -> Server 27 connectionAckMsg = "connection_ack" // Server -> Client 28 connectionErrorMsg = "connection_error" // Server -> Client 29 dataMsg = "data" // Server -> Client 30 errorMsg = "error" // Server -> Client 31 completeMsg = "complete" // Server -> Client 32 connectionKeepAliveMsg = "ka" // Server -> Client 33 ) 34 35 type operationMessage struct { 36 Payload json.RawMessage `json:"payload,omitempty"` 37 ID string `json:"id,omitempty"` 38 Type string `json:"type"` 39 } 40 41 type wsConnection struct { 42 ctx context.Context 43 conn *websocket.Conn 44 exec graphql.ExecutableSchema 45 active map[string]context.CancelFunc 46 mu sync.Mutex 47 cfg *Config 48 cache *lru.Cache 49 keepAliveTicker *time.Ticker 50 51 initPayload InitPayload 52 } 53 54 func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) { 55 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{ 56 "Sec-Websocket-Protocol": []string{"graphql-ws"}, 57 }) 58 if err != nil { 59 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 60 sendErrorf(w, http.StatusBadRequest, "unable to upgrade") 61 return 62 } 63 64 conn := wsConnection{ 65 active: map[string]context.CancelFunc{}, 66 exec: exec, 67 conn: ws, 68 ctx: r.Context(), 69 cfg: cfg, 70 cache: cache, 71 } 72 73 if !conn.init() { 74 return 75 } 76 77 conn.run() 78 } 79 80 func (c *wsConnection) init() bool { 81 message := c.readOp() 82 if message == nil { 83 c.close(websocket.CloseProtocolError, "decoding error") 84 return false 85 } 86 87 switch message.Type { 88 case connectionInitMsg: 89 if len(message.Payload) > 0 { 90 c.initPayload = make(InitPayload) 91 err := json.Unmarshal(message.Payload, &c.initPayload) 92 if err != nil { 93 return false 94 } 95 } 96 97 c.write(&operationMessage{Type: connectionAckMsg}) 98 case connectionTerminateMsg: 99 c.close(websocket.CloseNormalClosure, "terminated") 100 return false 101 default: 102 c.sendConnectionError("unexpected message %s", message.Type) 103 c.close(websocket.CloseProtocolError, "unexpected message") 104 return false 105 } 106 107 return true 108 } 109 110 func (c *wsConnection) write(msg *operationMessage) { 111 c.mu.Lock() 112 c.conn.WriteJSON(msg) 113 c.mu.Unlock() 114 } 115 116 func (c *wsConnection) run() { 117 // We create a cancellation that will shutdown the keep-alive when we leave 118 // this function. 119 ctx, cancel := context.WithCancel(c.ctx) 120 defer cancel() 121 122 // Create a timer that will fire every interval to keep the connection alive. 123 if c.cfg.connectionKeepAlivePingInterval != 0 { 124 c.mu.Lock() 125 c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval) 126 c.mu.Unlock() 127 128 go c.keepAlive(ctx) 129 } 130 131 for { 132 message := c.readOp() 133 if message == nil { 134 return 135 } 136 137 switch message.Type { 138 case startMsg: 139 if !c.subscribe(message) { 140 return 141 } 142 case stopMsg: 143 c.mu.Lock() 144 closer := c.active[message.ID] 145 c.mu.Unlock() 146 if closer == nil { 147 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID)) 148 continue 149 } 150 151 closer() 152 case connectionTerminateMsg: 153 c.close(websocket.CloseNormalClosure, "terminated") 154 return 155 default: 156 c.sendConnectionError("unexpected message %s", message.Type) 157 c.close(websocket.CloseProtocolError, "unexpected message") 158 return 159 } 160 } 161 } 162 163 func (c *wsConnection) keepAlive(ctx context.Context) { 164 for { 165 select { 166 case <-ctx.Done(): 167 c.keepAliveTicker.Stop() 168 return 169 case <-c.keepAliveTicker.C: 170 c.write(&operationMessage{Type: connectionKeepAliveMsg}) 171 } 172 } 173 } 174 175 func (c *wsConnection) subscribe(message *operationMessage) bool { 176 var reqParams params 177 if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil { 178 c.sendConnectionError("invalid json") 179 return false 180 } 181 182 var ( 183 doc *ast.QueryDocument 184 cacheHit bool 185 ) 186 if c.cache != nil { 187 val, ok := c.cache.Get(reqParams.Query) 188 if ok { 189 doc = val.(*ast.QueryDocument) 190 cacheHit = true 191 } 192 } 193 if !cacheHit { 194 var qErr gqlerror.List 195 doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query) 196 if qErr != nil { 197 c.sendError(message.ID, qErr...) 198 return true 199 } 200 if c.cache != nil { 201 c.cache.Add(reqParams.Query, doc) 202 } 203 } 204 205 op := doc.Operations.ForName(reqParams.OperationName) 206 if op == nil { 207 c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName)) 208 return true 209 } 210 211 vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables) 212 if err != nil { 213 c.sendError(message.ID, err) 214 return true 215 } 216 reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars) 217 ctx := graphql.WithRequestContext(c.ctx, reqCtx) 218 219 if c.initPayload != nil { 220 ctx = withInitPayload(ctx, c.initPayload) 221 } 222 223 if op.Operation != ast.Subscription { 224 var result *graphql.Response 225 if op.Operation == ast.Query { 226 result = c.exec.Query(ctx, op) 227 } else { 228 result = c.exec.Mutation(ctx, op) 229 } 230 231 c.sendData(message.ID, result) 232 c.write(&operationMessage{ID: message.ID, Type: completeMsg}) 233 return true 234 } 235 236 ctx, cancel := context.WithCancel(ctx) 237 c.mu.Lock() 238 c.active[message.ID] = cancel 239 c.mu.Unlock() 240 go func() { 241 defer func() { 242 if r := recover(); r != nil { 243 userErr := reqCtx.Recover(ctx, r) 244 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()}) 245 } 246 }() 247 next := c.exec.Subscription(ctx, op) 248 for result := next(); result != nil; result = next() { 249 c.sendData(message.ID, result) 250 } 251 252 c.write(&operationMessage{ID: message.ID, Type: completeMsg}) 253 254 c.mu.Lock() 255 delete(c.active, message.ID) 256 c.mu.Unlock() 257 cancel() 258 }() 259 260 return true 261 } 262 263 func (c *wsConnection) sendData(id string, response *graphql.Response) { 264 b, err := json.Marshal(response) 265 if err != nil { 266 c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error())) 267 return 268 } 269 270 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b}) 271 } 272 273 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 274 var errs []error 275 for _, err := range errors { 276 errs = append(errs, err) 277 } 278 b, err := json.Marshal(errs) 279 if err != nil { 280 panic(err) 281 } 282 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b}) 283 } 284 285 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 286 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 287 if err != nil { 288 panic(err) 289 } 290 291 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b}) 292 } 293 294 func (c *wsConnection) readOp() *operationMessage { 295 _, r, err := c.conn.NextReader() 296 if err != nil { 297 c.sendConnectionError("invalid json") 298 return nil 299 } 300 message := operationMessage{} 301 if err := jsonDecode(r, &message); err != nil { 302 c.sendConnectionError("invalid json") 303 return nil 304 } 305 306 return &message 307 } 308 309 func (c *wsConnection) close(closeCode int, message string) { 310 c.mu.Lock() 311 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 312 c.mu.Unlock() 313 _ = c.conn.Close() 314 }