github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/access/rest/routes/websocket_handler.go (about)

     1  package routes
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"time"
     9  
    10  	"github.com/gorilla/websocket"
    11  	"github.com/rs/zerolog"
    12  	"go.uber.org/atomic"
    13  
    14  	"github.com/onflow/flow-go/engine/access/rest/models"
    15  	"github.com/onflow/flow-go/engine/access/rest/request"
    16  	"github.com/onflow/flow-go/engine/access/state_stream"
    17  	"github.com/onflow/flow-go/engine/access/state_stream/backend"
    18  	"github.com/onflow/flow-go/engine/access/subscription"
    19  	"github.com/onflow/flow-go/engine/common/rpc/convert"
    20  	"github.com/onflow/flow-go/model/flow"
    21  )
    22  
    23  const (
    24  	// Time allowed to read the next pong message from the peer.
    25  	pongWait = 10 * time.Second
    26  
    27  	// Send pings to peer with this period. Must be less than pongWait.
    28  	pingPeriod = (pongWait * 9) / 10
    29  
    30  	// Time allowed to write a message to the peer.
    31  	writeWait = 10 * time.Second
    32  )
    33  
    34  // WebsocketController holds the necessary components and parameters for handling a WebSocket subscription.
    35  // It manages the communication between the server and the WebSocket client for subscribing.
    36  type WebsocketController struct {
    37  	logger            zerolog.Logger
    38  	conn              *websocket.Conn                // the WebSocket connection for communication with the client
    39  	api               state_stream.API               // the state_stream.API instance for managing event subscriptions
    40  	eventFilterConfig state_stream.EventFilterConfig // the configuration for filtering events
    41  	maxStreams        int32                          // the maximum number of streams allowed
    42  	activeStreamCount *atomic.Int32                  // the current number of active streams
    43  	readChannel       chan error                     // channel which notify closing connection by the client and provide errors to the client
    44  	heartbeatInterval uint64                         // the interval to deliver heartbeat messages to client[IN BLOCKS]
    45  }
    46  
    47  // SetWebsocketConf used to set read and write deadlines for WebSocket connections and establishes a Pong handler to
    48  // manage incoming Pong messages. These methods allow to specify a time limit for reading from or writing to a WebSocket
    49  // connection. If the operation (reading or writing) takes longer than the specified deadline, the connection will be closed.
    50  func (wsController *WebsocketController) SetWebsocketConf() error {
    51  	err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message
    52  	if err != nil {
    53  		return models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err)
    54  	}
    55  	err = wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message
    56  	if err != nil {
    57  		return models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err)
    58  	}
    59  	// Establish a Pong handler
    60  	wsController.conn.SetPongHandler(func(string) error {
    61  		err := wsController.conn.SetReadDeadline(time.Now().Add(pongWait))
    62  		if err != nil {
    63  			return err
    64  		}
    65  		return nil
    66  	})
    67  	return nil
    68  }
    69  
    70  // wsErrorHandler handles WebSocket errors by sending an appropriate close message
    71  // to the client WebSocket connection.
    72  //
    73  // If the error is an instance of models.StatusError, the function extracts the
    74  // relevant information like status code and user message to construct the WebSocket
    75  // close code and message. If the error is not a models.StatusError, a default
    76  // internal server error close code and the error's message are used.
    77  // The connection is then closed using WriteControl to send a CloseMessage with the
    78  // constructed close code and message. Any errors that occur during the closing
    79  // process are logged using the provided logger.
    80  func (wsController *WebsocketController) wsErrorHandler(err error) {
    81  	// rest status type error should be returned with status and user message provided
    82  	var statusErr models.StatusError
    83  	var wsCode int
    84  	var wsMsg string
    85  
    86  	if errors.As(err, &statusErr) {
    87  		if statusErr.Status() == http.StatusBadRequest {
    88  			wsCode = websocket.CloseUnsupportedData
    89  		}
    90  		if statusErr.Status() == http.StatusServiceUnavailable {
    91  			wsCode = websocket.CloseTryAgainLater
    92  		}
    93  		if statusErr.Status() == http.StatusRequestTimeout {
    94  			wsCode = websocket.CloseGoingAway
    95  		}
    96  		wsMsg = statusErr.UserMessage()
    97  
    98  	} else {
    99  		wsCode = websocket.CloseInternalServerErr
   100  		wsMsg = err.Error()
   101  	}
   102  
   103  	// Close the connection with the CloseError message
   104  	err = wsController.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(wsCode, wsMsg), time.Now().Add(time.Second))
   105  	if err != nil {
   106  		wsController.logger.Error().Err(err).Msg(fmt.Sprintf("error sending WebSocket error: %v", err))
   107  	}
   108  }
   109  
   110  // writeEvents is used for writing events and pings to the WebSocket connection for a given subscription.
   111  // It listens to the subscription's channel for events and writes them to the WebSocket connection.
   112  // If an error occurs or the subscription channel is closed, it handles the error or termination accordingly.
   113  // The function uses a ticker to periodically send ping messages to the client to maintain the connection.
   114  func (wsController *WebsocketController) writeEvents(sub subscription.Subscription) {
   115  	ticker := time.NewTicker(pingPeriod)
   116  	defer ticker.Stop()
   117  
   118  	blocksSinceLastMessage := uint64(0)
   119  	for {
   120  		select {
   121  		case err := <-wsController.readChannel:
   122  			// we use `readChannel`
   123  			// 1) as indicator of client's status, when `readChannel` closes it means that client
   124  			// connection has been terminated and we need to stop this goroutine to avoid memory leak.
   125  			// 2) as error receiver for any errors that occur during the reading process
   126  			if err != nil {
   127  				wsController.wsErrorHandler(err)
   128  			}
   129  			return
   130  		case event, ok := <-sub.Channel():
   131  			if !ok {
   132  				if sub.Err() != nil {
   133  					err := fmt.Errorf("stream encountered an error: %v", sub.Err())
   134  					wsController.wsErrorHandler(err)
   135  					return
   136  				}
   137  				err := fmt.Errorf("subscription channel closed, no error occurred")
   138  				wsController.wsErrorHandler(models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err))
   139  				return
   140  			}
   141  			err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait))
   142  			if err != nil {
   143  				wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err))
   144  				return
   145  			}
   146  
   147  			resp, ok := event.(*backend.EventsResponse)
   148  			if !ok {
   149  				err = fmt.Errorf("unexpected response type: %s", event)
   150  				wsController.wsErrorHandler(err)
   151  				return
   152  			}
   153  			// responses with empty events increase heartbeat interval counter, when threshold is met a heartbeat
   154  			// message will be emitted.
   155  			if len(resp.Events) == 0 {
   156  				blocksSinceLastMessage++
   157  				if blocksSinceLastMessage < wsController.heartbeatInterval {
   158  					continue
   159  				}
   160  				blocksSinceLastMessage = 0
   161  			}
   162  
   163  			// EventsResponse contains CCF encoded events, and this API returns JSON-CDC events.
   164  			// convert event payload formats.
   165  			for i, e := range resp.Events {
   166  				payload, err := convert.CcfPayloadToJsonPayload(e.Payload)
   167  				if err != nil {
   168  					err = fmt.Errorf("could not convert event payload from CCF to Json: %w", err)
   169  					wsController.wsErrorHandler(err)
   170  					return
   171  				}
   172  				resp.Events[i].Payload = payload
   173  			}
   174  
   175  			// Write the response to the WebSocket connection
   176  			err = wsController.conn.WriteJSON(event)
   177  			if err != nil {
   178  				wsController.wsErrorHandler(err)
   179  				return
   180  			}
   181  		case <-ticker.C:
   182  			err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait))
   183  			if err != nil {
   184  				wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err))
   185  				return
   186  			}
   187  			if err := wsController.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
   188  				wsController.wsErrorHandler(err)
   189  				return
   190  			}
   191  		}
   192  	}
   193  }
   194  
   195  // read function handles WebSocket messages from the client.
   196  // It continuously reads messages from the WebSocket connection and closes
   197  // the associated read channel when the connection is closed by client or when an
   198  // any additional message is received from the client.
   199  //
   200  // This method should be called after establishing the WebSocket connection
   201  // to handle incoming messages asynchronously.
   202  func (wsController *WebsocketController) read() {
   203  	// Start a goroutine to handle the WebSocket connection
   204  	defer close(wsController.readChannel) // notify websocket about closed connection
   205  
   206  	for {
   207  		// reads messages from the WebSocket connection when
   208  		// 1) the connection is closed by client
   209  		// 2) a message is received from the client
   210  		_, msg, err := wsController.conn.ReadMessage()
   211  		if err != nil {
   212  			if _, ok := err.(*websocket.CloseError); !ok {
   213  				wsController.readChannel <- err
   214  			}
   215  			return
   216  		}
   217  
   218  		// Check the message from the client, if is any just close the connection
   219  		if len(msg) > 0 {
   220  			err := fmt.Errorf("the client sent an unexpected message, connection closed")
   221  			wsController.logger.Debug().Msg(err.Error())
   222  			wsController.readChannel <- err
   223  			return
   224  		}
   225  	}
   226  }
   227  
   228  // SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, fetches necessary resources
   229  type SubscribeHandlerFunc func(
   230  	ctx context.Context,
   231  	request *request.Request,
   232  	wsController *WebsocketController,
   233  ) (subscription.Subscription, error)
   234  
   235  // WSHandler is websocket handler implementing custom websocket handler function and allows easier handling of errors and
   236  // responses as it wraps functionality for handling error and responses outside of endpoint handling.
   237  type WSHandler struct {
   238  	*HttpHandler
   239  	subscribeFunc SubscribeHandlerFunc
   240  
   241  	api                      state_stream.API
   242  	eventFilterConfig        state_stream.EventFilterConfig
   243  	maxStreams               int32
   244  	defaultHeartbeatInterval uint64
   245  	activeStreamCount        *atomic.Int32
   246  }
   247  
   248  var _ http.Handler = (*WSHandler)(nil)
   249  
   250  func NewWSHandler(
   251  	logger zerolog.Logger,
   252  	api state_stream.API,
   253  	subscribeFunc SubscribeHandlerFunc,
   254  	chain flow.Chain,
   255  	stateStreamConfig backend.Config,
   256  ) *WSHandler {
   257  	handler := &WSHandler{
   258  		subscribeFunc:            subscribeFunc,
   259  		api:                      api,
   260  		eventFilterConfig:        stateStreamConfig.EventFilterConfig,
   261  		maxStreams:               int32(stateStreamConfig.MaxGlobalStreams),
   262  		defaultHeartbeatInterval: stateStreamConfig.HeartbeatInterval,
   263  		activeStreamCount:        atomic.NewInt32(0),
   264  		HttpHandler:              NewHttpHandler(logger, chain),
   265  	}
   266  
   267  	return handler
   268  }
   269  
   270  // ServeHTTP function acts as a wrapper to each request providing common handling functionality
   271  // such as logging, error handling, request decorators
   272  func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   273  	// create a logger
   274  	logger := h.Logger.With().Str("subscribe_url", r.URL.String()).Logger()
   275  
   276  	err := h.VerifyRequest(w, r)
   277  	if err != nil {
   278  		// VerifyRequest sets the response error before returning
   279  		return
   280  	}
   281  
   282  	// Upgrade the HTTP connection to a WebSocket connection
   283  	upgrader := websocket.Upgrader{
   284  		// allow all origins by default, operators can override using a proxy
   285  		CheckOrigin: func(r *http.Request) bool {
   286  			return true
   287  		},
   288  	}
   289  	conn, err := upgrader.Upgrade(w, r, nil)
   290  	if err != nil {
   291  		h.errorHandler(w, models.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger)
   292  		return
   293  	}
   294  	defer conn.Close()
   295  
   296  	wsController := &WebsocketController{
   297  		logger:            logger,
   298  		conn:              conn,
   299  		api:               h.api,
   300  		eventFilterConfig: h.eventFilterConfig,
   301  		maxStreams:        h.maxStreams,
   302  		activeStreamCount: h.activeStreamCount,
   303  		readChannel:       make(chan error),
   304  		heartbeatInterval: h.defaultHeartbeatInterval, // set default heartbeat interval from state stream config
   305  	}
   306  
   307  	err = wsController.SetWebsocketConf()
   308  	if err != nil {
   309  		wsController.wsErrorHandler(err)
   310  		return
   311  	}
   312  
   313  	if wsController.activeStreamCount.Load() >= wsController.maxStreams {
   314  		err := fmt.Errorf("maximum number of streams reached")
   315  		wsController.wsErrorHandler(models.NewRestError(http.StatusServiceUnavailable, err.Error(), err))
   316  		return
   317  	}
   318  	wsController.activeStreamCount.Add(1)
   319  	defer wsController.activeStreamCount.Add(-1)
   320  
   321  	// cancelling the context passed into the `subscribeFunc` to ensure that when the client disconnects,
   322  	// gorountines setup by the backend are cleaned up.
   323  	ctx, cancel := context.WithCancel(context.Background())
   324  	defer cancel()
   325  
   326  	sub, err := h.subscribeFunc(ctx, request.Decorate(r, h.HttpHandler.Chain), wsController)
   327  	if err != nil {
   328  		wsController.wsErrorHandler(err)
   329  		return
   330  	}
   331  
   332  	go wsController.read()
   333  	wsController.writeEvents(sub)
   334  }