github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/access/rest/routes/websocket_handler.go (about) 1 package routes 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net/http" 8 "time" 9 10 "github.com/gorilla/websocket" 11 "github.com/rs/zerolog" 12 "go.uber.org/atomic" 13 14 "github.com/onflow/flow-go/engine/access/rest/models" 15 "github.com/onflow/flow-go/engine/access/rest/request" 16 "github.com/onflow/flow-go/engine/access/state_stream" 17 "github.com/onflow/flow-go/engine/access/state_stream/backend" 18 "github.com/onflow/flow-go/engine/access/subscription" 19 "github.com/onflow/flow-go/engine/common/rpc/convert" 20 "github.com/onflow/flow-go/model/flow" 21 ) 22 23 const ( 24 // Time allowed to read the next pong message from the peer. 25 pongWait = 10 * time.Second 26 27 // Send pings to peer with this period. Must be less than pongWait. 28 pingPeriod = (pongWait * 9) / 10 29 30 // Time allowed to write a message to the peer. 31 writeWait = 10 * time.Second 32 ) 33 34 // WebsocketController holds the necessary components and parameters for handling a WebSocket subscription. 35 // It manages the communication between the server and the WebSocket client for subscribing. 36 type WebsocketController struct { 37 logger zerolog.Logger 38 conn *websocket.Conn // the WebSocket connection for communication with the client 39 api state_stream.API // the state_stream.API instance for managing event subscriptions 40 eventFilterConfig state_stream.EventFilterConfig // the configuration for filtering events 41 maxStreams int32 // the maximum number of streams allowed 42 activeStreamCount *atomic.Int32 // the current number of active streams 43 readChannel chan error // channel which notify closing connection by the client and provide errors to the client 44 heartbeatInterval uint64 // the interval to deliver heartbeat messages to client[IN BLOCKS] 45 } 46 47 // SetWebsocketConf used to set read and write deadlines for WebSocket connections and establishes a Pong handler to 48 // manage incoming Pong messages. These methods allow to specify a time limit for reading from or writing to a WebSocket 49 // connection. If the operation (reading or writing) takes longer than the specified deadline, the connection will be closed. 50 func (wsController *WebsocketController) SetWebsocketConf() error { 51 err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message 52 if err != nil { 53 return models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err) 54 } 55 err = wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message 56 if err != nil { 57 return models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err) 58 } 59 // Establish a Pong handler 60 wsController.conn.SetPongHandler(func(string) error { 61 err := wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) 62 if err != nil { 63 return err 64 } 65 return nil 66 }) 67 return nil 68 } 69 70 // wsErrorHandler handles WebSocket errors by sending an appropriate close message 71 // to the client WebSocket connection. 72 // 73 // If the error is an instance of models.StatusError, the function extracts the 74 // relevant information like status code and user message to construct the WebSocket 75 // close code and message. If the error is not a models.StatusError, a default 76 // internal server error close code and the error's message are used. 77 // The connection is then closed using WriteControl to send a CloseMessage with the 78 // constructed close code and message. Any errors that occur during the closing 79 // process are logged using the provided logger. 80 func (wsController *WebsocketController) wsErrorHandler(err error) { 81 // rest status type error should be returned with status and user message provided 82 var statusErr models.StatusError 83 var wsCode int 84 var wsMsg string 85 86 if errors.As(err, &statusErr) { 87 if statusErr.Status() == http.StatusBadRequest { 88 wsCode = websocket.CloseUnsupportedData 89 } 90 if statusErr.Status() == http.StatusServiceUnavailable { 91 wsCode = websocket.CloseTryAgainLater 92 } 93 if statusErr.Status() == http.StatusRequestTimeout { 94 wsCode = websocket.CloseGoingAway 95 } 96 wsMsg = statusErr.UserMessage() 97 98 } else { 99 wsCode = websocket.CloseInternalServerErr 100 wsMsg = err.Error() 101 } 102 103 // Close the connection with the CloseError message 104 err = wsController.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(wsCode, wsMsg), time.Now().Add(time.Second)) 105 if err != nil { 106 wsController.logger.Error().Err(err).Msg(fmt.Sprintf("error sending WebSocket error: %v", err)) 107 } 108 } 109 110 // writeEvents is used for writing events and pings to the WebSocket connection for a given subscription. 111 // It listens to the subscription's channel for events and writes them to the WebSocket connection. 112 // If an error occurs or the subscription channel is closed, it handles the error or termination accordingly. 113 // The function uses a ticker to periodically send ping messages to the client to maintain the connection. 114 func (wsController *WebsocketController) writeEvents(sub subscription.Subscription) { 115 ticker := time.NewTicker(pingPeriod) 116 defer ticker.Stop() 117 118 blocksSinceLastMessage := uint64(0) 119 for { 120 select { 121 case err := <-wsController.readChannel: 122 // we use `readChannel` 123 // 1) as indicator of client's status, when `readChannel` closes it means that client 124 // connection has been terminated and we need to stop this goroutine to avoid memory leak. 125 // 2) as error receiver for any errors that occur during the reading process 126 if err != nil { 127 wsController.wsErrorHandler(err) 128 } 129 return 130 case event, ok := <-sub.Channel(): 131 if !ok { 132 if sub.Err() != nil { 133 err := fmt.Errorf("stream encountered an error: %v", sub.Err()) 134 wsController.wsErrorHandler(err) 135 return 136 } 137 err := fmt.Errorf("subscription channel closed, no error occurred") 138 wsController.wsErrorHandler(models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err)) 139 return 140 } 141 err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) 142 if err != nil { 143 wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err)) 144 return 145 } 146 147 resp, ok := event.(*backend.EventsResponse) 148 if !ok { 149 err = fmt.Errorf("unexpected response type: %s", event) 150 wsController.wsErrorHandler(err) 151 return 152 } 153 // responses with empty events increase heartbeat interval counter, when threshold is met a heartbeat 154 // message will be emitted. 155 if len(resp.Events) == 0 { 156 blocksSinceLastMessage++ 157 if blocksSinceLastMessage < wsController.heartbeatInterval { 158 continue 159 } 160 blocksSinceLastMessage = 0 161 } 162 163 // EventsResponse contains CCF encoded events, and this API returns JSON-CDC events. 164 // convert event payload formats. 165 for i, e := range resp.Events { 166 payload, err := convert.CcfPayloadToJsonPayload(e.Payload) 167 if err != nil { 168 err = fmt.Errorf("could not convert event payload from CCF to Json: %w", err) 169 wsController.wsErrorHandler(err) 170 return 171 } 172 resp.Events[i].Payload = payload 173 } 174 175 // Write the response to the WebSocket connection 176 err = wsController.conn.WriteJSON(event) 177 if err != nil { 178 wsController.wsErrorHandler(err) 179 return 180 } 181 case <-ticker.C: 182 err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) 183 if err != nil { 184 wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err)) 185 return 186 } 187 if err := wsController.conn.WriteMessage(websocket.PingMessage, nil); err != nil { 188 wsController.wsErrorHandler(err) 189 return 190 } 191 } 192 } 193 } 194 195 // read function handles WebSocket messages from the client. 196 // It continuously reads messages from the WebSocket connection and closes 197 // the associated read channel when the connection is closed by client or when an 198 // any additional message is received from the client. 199 // 200 // This method should be called after establishing the WebSocket connection 201 // to handle incoming messages asynchronously. 202 func (wsController *WebsocketController) read() { 203 // Start a goroutine to handle the WebSocket connection 204 defer close(wsController.readChannel) // notify websocket about closed connection 205 206 for { 207 // reads messages from the WebSocket connection when 208 // 1) the connection is closed by client 209 // 2) a message is received from the client 210 _, msg, err := wsController.conn.ReadMessage() 211 if err != nil { 212 if _, ok := err.(*websocket.CloseError); !ok { 213 wsController.readChannel <- err 214 } 215 return 216 } 217 218 // Check the message from the client, if is any just close the connection 219 if len(msg) > 0 { 220 err := fmt.Errorf("the client sent an unexpected message, connection closed") 221 wsController.logger.Debug().Msg(err.Error()) 222 wsController.readChannel <- err 223 return 224 } 225 } 226 } 227 228 // SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, fetches necessary resources 229 type SubscribeHandlerFunc func( 230 ctx context.Context, 231 request *request.Request, 232 wsController *WebsocketController, 233 ) (subscription.Subscription, error) 234 235 // WSHandler is websocket handler implementing custom websocket handler function and allows easier handling of errors and 236 // responses as it wraps functionality for handling error and responses outside of endpoint handling. 237 type WSHandler struct { 238 *HttpHandler 239 subscribeFunc SubscribeHandlerFunc 240 241 api state_stream.API 242 eventFilterConfig state_stream.EventFilterConfig 243 maxStreams int32 244 defaultHeartbeatInterval uint64 245 activeStreamCount *atomic.Int32 246 } 247 248 var _ http.Handler = (*WSHandler)(nil) 249 250 func NewWSHandler( 251 logger zerolog.Logger, 252 api state_stream.API, 253 subscribeFunc SubscribeHandlerFunc, 254 chain flow.Chain, 255 stateStreamConfig backend.Config, 256 ) *WSHandler { 257 handler := &WSHandler{ 258 subscribeFunc: subscribeFunc, 259 api: api, 260 eventFilterConfig: stateStreamConfig.EventFilterConfig, 261 maxStreams: int32(stateStreamConfig.MaxGlobalStreams), 262 defaultHeartbeatInterval: stateStreamConfig.HeartbeatInterval, 263 activeStreamCount: atomic.NewInt32(0), 264 HttpHandler: NewHttpHandler(logger, chain), 265 } 266 267 return handler 268 } 269 270 // ServeHTTP function acts as a wrapper to each request providing common handling functionality 271 // such as logging, error handling, request decorators 272 func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 273 // create a logger 274 logger := h.Logger.With().Str("subscribe_url", r.URL.String()).Logger() 275 276 err := h.VerifyRequest(w, r) 277 if err != nil { 278 // VerifyRequest sets the response error before returning 279 return 280 } 281 282 // Upgrade the HTTP connection to a WebSocket connection 283 upgrader := websocket.Upgrader{ 284 // allow all origins by default, operators can override using a proxy 285 CheckOrigin: func(r *http.Request) bool { 286 return true 287 }, 288 } 289 conn, err := upgrader.Upgrade(w, r, nil) 290 if err != nil { 291 h.errorHandler(w, models.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) 292 return 293 } 294 defer conn.Close() 295 296 wsController := &WebsocketController{ 297 logger: logger, 298 conn: conn, 299 api: h.api, 300 eventFilterConfig: h.eventFilterConfig, 301 maxStreams: h.maxStreams, 302 activeStreamCount: h.activeStreamCount, 303 readChannel: make(chan error), 304 heartbeatInterval: h.defaultHeartbeatInterval, // set default heartbeat interval from state stream config 305 } 306 307 err = wsController.SetWebsocketConf() 308 if err != nil { 309 wsController.wsErrorHandler(err) 310 return 311 } 312 313 if wsController.activeStreamCount.Load() >= wsController.maxStreams { 314 err := fmt.Errorf("maximum number of streams reached") 315 wsController.wsErrorHandler(models.NewRestError(http.StatusServiceUnavailable, err.Error(), err)) 316 return 317 } 318 wsController.activeStreamCount.Add(1) 319 defer wsController.activeStreamCount.Add(-1) 320 321 // cancelling the context passed into the `subscribeFunc` to ensure that when the client disconnects, 322 // gorountines setup by the backend are cleaned up. 323 ctx, cancel := context.WithCancel(context.Background()) 324 defer cancel() 325 326 sub, err := h.subscribeFunc(ctx, request.Decorate(r, h.HttpHandler.Chain), wsController) 327 if err != nil { 328 wsController.wsErrorHandler(err) 329 return 330 } 331 332 go wsController.read() 333 wsController.writeEvents(sub) 334 }