github.com/humans-group/gqlgen@v0.7.2/handler/websocket.go (about) 1 package handler 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "sync" 11 "time" 12 13 "github.com/99designs/gqlgen/graphql" 14 "github.com/gorilla/websocket" 15 "github.com/vektah/gqlparser" 16 "github.com/vektah/gqlparser/ast" 17 "github.com/vektah/gqlparser/gqlerror" 18 "github.com/vektah/gqlparser/validator" 19 ) 20 21 const ( 22 connectionInitMsg = "connection_init" // Client -> Server 23 connectionTerminateMsg = "connection_terminate" // Client -> Server 24 startMsg = "start" // Client -> Server 25 stopMsg = "stop" // Client -> Server 26 connectionAckMsg = "connection_ack" // Server -> Client 27 connectionErrorMsg = "connection_error" // Server -> Client 28 dataMsg = "data" // Server -> Client 29 errorMsg = "error" // Server -> Client 30 completeMsg = "complete" // Server -> Client 31 connectionKeepAliveMsg = "ka" // Server -> Client 32 ) 33 34 type operationMessage struct { 35 Payload json.RawMessage `json:"payload,omitempty"` 36 ID string `json:"id,omitempty"` 37 Type string `json:"type"` 38 } 39 40 type wsConnection struct { 41 ctx context.Context 42 conn *websocket.Conn 43 exec graphql.ExecutableSchema 44 active map[string]context.CancelFunc 45 mu sync.Mutex 46 cfg *Config 47 keepAliveTicker *time.Ticker 48 49 initPayload InitPayload 50 } 51 52 func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) { 53 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{ 54 "Sec-Websocket-Protocol": []string{"graphql-ws"}, 55 }) 56 if err != nil { 57 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 58 sendErrorf(w, http.StatusBadRequest, "unable to upgrade") 59 return 60 } 61 62 conn := wsConnection{ 63 active: map[string]context.CancelFunc{}, 64 exec: exec, 65 conn: ws, 66 ctx: r.Context(), 67 cfg: cfg, 68 } 69 70 if !conn.init() { 71 return 72 } 73 74 conn.run() 75 } 76 77 func (c *wsConnection) init() bool { 78 message := c.readOp() 79 if message == nil { 80 c.close(websocket.CloseProtocolError, "decoding error") 81 return false 82 } 83 84 switch message.Type { 85 case connectionInitMsg: 86 if len(message.Payload) > 0 { 87 c.initPayload = make(InitPayload) 88 err := json.Unmarshal(message.Payload, &c.initPayload) 89 if err != nil { 90 return false 91 } 92 } 93 94 c.write(&operationMessage{Type: connectionAckMsg}) 95 case connectionTerminateMsg: 96 c.close(websocket.CloseNormalClosure, "terminated") 97 return false 98 default: 99 c.sendConnectionError("unexpected message %s", message.Type) 100 c.close(websocket.CloseProtocolError, "unexpected message") 101 return false 102 } 103 104 return true 105 } 106 107 func (c *wsConnection) write(msg *operationMessage) { 108 c.mu.Lock() 109 c.conn.WriteJSON(msg) 110 c.mu.Unlock() 111 } 112 113 func (c *wsConnection) run() { 114 // We create a cancellation that will shutdown the keep-alive when we leave 115 // this function. 116 ctx, cancel := context.WithCancel(c.ctx) 117 defer cancel() 118 119 // Create a timer that will fire every interval to keep the connection alive. 120 if c.cfg.connectionKeepAlivePingInterval != 0 { 121 c.mu.Lock() 122 c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval) 123 c.mu.Unlock() 124 125 go c.keepAlive(ctx) 126 } 127 128 for { 129 message := c.readOp() 130 if message == nil { 131 return 132 } 133 134 switch message.Type { 135 case startMsg: 136 if !c.subscribe(message) { 137 return 138 } 139 case stopMsg: 140 c.mu.Lock() 141 closer := c.active[message.ID] 142 c.mu.Unlock() 143 if closer == nil { 144 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID)) 145 continue 146 } 147 148 closer() 149 case connectionTerminateMsg: 150 c.close(websocket.CloseNormalClosure, "terminated") 151 return 152 default: 153 c.sendConnectionError("unexpected message %s", message.Type) 154 c.close(websocket.CloseProtocolError, "unexpected message") 155 return 156 } 157 } 158 } 159 160 func (c *wsConnection) keepAlive(ctx context.Context) { 161 for { 162 select { 163 case <-ctx.Done(): 164 c.keepAliveTicker.Stop() 165 return 166 case <-c.keepAliveTicker.C: 167 c.write(&operationMessage{Type: connectionKeepAliveMsg}) 168 } 169 } 170 } 171 172 func (c *wsConnection) subscribe(message *operationMessage) bool { 173 var reqParams params 174 if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil { 175 c.sendConnectionError("invalid json") 176 return false 177 } 178 179 doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query) 180 if qErr != nil { 181 c.sendError(message.ID, qErr...) 182 return true 183 } 184 185 op := doc.Operations.ForName(reqParams.OperationName) 186 if op == nil { 187 c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName)) 188 return true 189 } 190 191 vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables) 192 if err != nil { 193 c.sendError(message.ID, err) 194 return true 195 } 196 reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars) 197 ctx := graphql.WithRequestContext(c.ctx, reqCtx) 198 199 if c.initPayload != nil { 200 ctx = withInitPayload(ctx, c.initPayload) 201 } 202 203 if op.Operation != ast.Subscription { 204 var result *graphql.Response 205 if op.Operation == ast.Query { 206 result = c.exec.Query(ctx, op) 207 } else { 208 result = c.exec.Mutation(ctx, op) 209 } 210 211 c.sendData(message.ID, result) 212 c.write(&operationMessage{ID: message.ID, Type: completeMsg}) 213 return true 214 } 215 216 ctx, cancel := context.WithCancel(ctx) 217 c.mu.Lock() 218 c.active[message.ID] = cancel 219 c.mu.Unlock() 220 go func() { 221 defer func() { 222 if r := recover(); r != nil { 223 userErr := reqCtx.Recover(ctx, r) 224 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()}) 225 } 226 }() 227 next := c.exec.Subscription(ctx, op) 228 for result := next(); result != nil; result = next() { 229 c.sendData(message.ID, result) 230 } 231 232 c.write(&operationMessage{ID: message.ID, Type: completeMsg}) 233 234 c.mu.Lock() 235 delete(c.active, message.ID) 236 c.mu.Unlock() 237 cancel() 238 }() 239 240 return true 241 } 242 243 func (c *wsConnection) sendData(id string, response *graphql.Response) { 244 b, err := json.Marshal(response) 245 if err != nil { 246 c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error())) 247 return 248 } 249 250 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b}) 251 } 252 253 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 254 var errs []error 255 for _, err := range errors { 256 errs = append(errs, err) 257 } 258 b, err := json.Marshal(errs) 259 if err != nil { 260 panic(err) 261 } 262 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b}) 263 } 264 265 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 266 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 267 if err != nil { 268 panic(err) 269 } 270 271 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b}) 272 } 273 274 func (c *wsConnection) readOp() *operationMessage { 275 _, r, err := c.conn.NextReader() 276 if err != nil { 277 c.sendConnectionError("invalid json") 278 return nil 279 } 280 message := operationMessage{} 281 if err := jsonDecode(r, &message); err != nil { 282 c.sendConnectionError("invalid json") 283 return nil 284 } 285 286 return &message 287 } 288 289 func (c *wsConnection) close(closeCode int, message string) { 290 c.mu.Lock() 291 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 292 c.mu.Unlock() 293 _ = c.conn.Close() 294 }