github.com/fortexxx/gqlgen@v0.10.3-0.20191216030626-ca5ea8b21ead/graphql/handler/transport/websocket.go (about) 1 package transport 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "sync" 11 "time" 12 13 "github.com/99designs/gqlgen/graphql" 14 "github.com/gorilla/websocket" 15 "github.com/vektah/gqlparser/gqlerror" 16 ) 17 18 const ( 19 connectionInitMsg = "connection_init" // Client -> Server 20 connectionTerminateMsg = "connection_terminate" // Client -> Server 21 startMsg = "start" // Client -> Server 22 stopMsg = "stop" // Client -> Server 23 connectionAckMsg = "connection_ack" // Server -> Client 24 connectionErrorMsg = "connection_error" // Server -> Client 25 dataMsg = "data" // Server -> Client 26 errorMsg = "error" // Server -> Client 27 completeMsg = "complete" // Server -> Client 28 connectionKeepAliveMsg = "ka" // Server -> Client 29 ) 30 31 type ( 32 Websocket struct { 33 Upgrader websocket.Upgrader 34 InitFunc WebsocketInitFunc 35 KeepAlivePingInterval time.Duration 36 } 37 wsConnection struct { 38 Websocket 39 ctx context.Context 40 conn *websocket.Conn 41 active map[string]context.CancelFunc 42 mu sync.Mutex 43 keepAliveTicker *time.Ticker 44 exec graphql.GraphExecutor 45 46 initPayload InitPayload 47 } 48 operationMessage struct { 49 Payload json.RawMessage `json:"payload,omitempty"` 50 ID string `json:"id,omitempty"` 51 Type string `json:"type"` 52 } 53 WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) 54 ) 55 56 var _ graphql.Transport = Websocket{} 57 58 func (t Websocket) Supports(r *http.Request) bool { 59 return r.Header.Get("Upgrade") != "" 60 } 61 62 func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { 63 ws, err := t.Upgrader.Upgrade(w, r, http.Header{ 64 "Sec-Websocket-Protocol": []string{"graphql-ws"}, 65 }) 66 if err != nil { 67 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) 68 SendErrorf(w, http.StatusBadRequest, "unable to upgrade") 69 return 70 } 71 72 conn := wsConnection{ 73 active: map[string]context.CancelFunc{}, 74 conn: ws, 75 ctx: r.Context(), 76 exec: exec, 77 Websocket: t, 78 } 79 80 if !conn.init() { 81 return 82 } 83 84 conn.run() 85 } 86 87 func (c *wsConnection) init() bool { 88 message := c.readOp() 89 if message == nil { 90 c.close(websocket.CloseProtocolError, "decoding error") 91 return false 92 } 93 94 switch message.Type { 95 case connectionInitMsg: 96 if len(message.Payload) > 0 { 97 c.initPayload = make(InitPayload) 98 err := json.Unmarshal(message.Payload, &c.initPayload) 99 if err != nil { 100 return false 101 } 102 } 103 104 if c.InitFunc != nil { 105 ctx, err := c.InitFunc(c.ctx, c.initPayload) 106 if err != nil { 107 c.sendConnectionError(err.Error()) 108 c.close(websocket.CloseNormalClosure, "terminated") 109 return false 110 } 111 c.ctx = ctx 112 } 113 114 c.write(&operationMessage{Type: connectionAckMsg}) 115 c.write(&operationMessage{Type: connectionKeepAliveMsg}) 116 case connectionTerminateMsg: 117 c.close(websocket.CloseNormalClosure, "terminated") 118 return false 119 default: 120 c.sendConnectionError("unexpected message %s", message.Type) 121 c.close(websocket.CloseProtocolError, "unexpected message") 122 return false 123 } 124 125 return true 126 } 127 128 func (c *wsConnection) write(msg *operationMessage) { 129 c.mu.Lock() 130 c.conn.WriteJSON(msg) 131 c.mu.Unlock() 132 } 133 134 func (c *wsConnection) run() { 135 // We create a cancellation that will shutdown the keep-alive when we leave 136 // this function. 137 ctx, cancel := context.WithCancel(c.ctx) 138 defer cancel() 139 140 // Create a timer that will fire every interval to keep the connection alive. 141 if c.KeepAlivePingInterval != 0 { 142 c.mu.Lock() 143 c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval) 144 c.mu.Unlock() 145 146 go c.keepAlive(ctx) 147 } 148 149 for { 150 message := c.readOp() 151 if message == nil { 152 return 153 } 154 155 switch message.Type { 156 case startMsg: 157 if !c.subscribe(message) { 158 return 159 } 160 case stopMsg: 161 c.mu.Lock() 162 closer := c.active[message.ID] 163 c.mu.Unlock() 164 if closer == nil { 165 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID)) 166 continue 167 } 168 169 closer() 170 case connectionTerminateMsg: 171 c.close(websocket.CloseNormalClosure, "terminated") 172 return 173 default: 174 c.sendConnectionError("unexpected message %s", message.Type) 175 c.close(websocket.CloseProtocolError, "unexpected message") 176 return 177 } 178 } 179 } 180 181 func (c *wsConnection) keepAlive(ctx context.Context) { 182 for { 183 select { 184 case <-ctx.Done(): 185 c.keepAliveTicker.Stop() 186 return 187 case <-c.keepAliveTicker.C: 188 c.write(&operationMessage{Type: connectionKeepAliveMsg}) 189 } 190 } 191 } 192 193 func (c *wsConnection) subscribe(message *operationMessage) bool { 194 ctx := graphql.StartOperationTrace(c.ctx) 195 var params *graphql.RawParams 196 if err := jsonDecode(bytes.NewReader(message.Payload), ¶ms); err != nil { 197 c.sendConnectionError("invalid json") 198 return false 199 } 200 201 rc, err := c.exec.CreateOperationContext(ctx, params) 202 if err != nil { 203 resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) 204 c.sendError(message.ID, resp.Errors...) 205 return false 206 } 207 208 ctx = graphql.WithOperationContext(ctx, rc) 209 210 if c.initPayload != nil { 211 ctx = withInitPayload(ctx, c.initPayload) 212 } 213 214 ctx, cancel := context.WithCancel(ctx) 215 c.mu.Lock() 216 c.active[message.ID] = cancel 217 c.mu.Unlock() 218 219 go func() { 220 defer func() { 221 if r := recover(); r != nil { 222 userErr := rc.Recover(ctx, r) 223 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()}) 224 } 225 }() 226 responses, ctx := c.exec.DispatchOperation(ctx, rc) 227 for { 228 response := responses(ctx) 229 if response == nil { 230 break 231 } 232 233 b, err := json.Marshal(response) 234 if err != nil { 235 panic(err) 236 } 237 c.write(&operationMessage{ 238 Payload: b, 239 ID: message.ID, 240 Type: dataMsg, 241 }) 242 } 243 c.write(&operationMessage{ID: message.ID, Type: completeMsg}) 244 245 c.mu.Lock() 246 delete(c.active, message.ID) 247 c.mu.Unlock() 248 cancel() 249 }() 250 251 return true 252 } 253 254 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { 255 errs := make([]error, len(errors)) 256 for i, err := range errors { 257 errs[i] = err 258 } 259 b, err := json.Marshal(errs) 260 if err != nil { 261 panic(err) 262 } 263 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b}) 264 } 265 266 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { 267 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)}) 268 if err != nil { 269 panic(err) 270 } 271 272 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b}) 273 } 274 275 func (c *wsConnection) readOp() *operationMessage { 276 _, r, err := c.conn.NextReader() 277 if err != nil { 278 c.sendConnectionError("invalid json") 279 return nil 280 } 281 message := operationMessage{} 282 if err := jsonDecode(r, &message); err != nil { 283 c.sendConnectionError("invalid json") 284 return nil 285 } 286 287 return &message 288 } 289 290 func (c *wsConnection) close(closeCode int, message string) { 291 c.mu.Lock() 292 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) 293 c.mu.Unlock() 294 _ = c.conn.Close() 295 }