github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/rpc/jsonrpc/server/ws_handler.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"runtime/debug"
     9  	"time"
    10  
    11  	"github.com/gorilla/websocket"
    12  
    13  	"github.com/ari-anchor/sei-tendermint/libs/log"
    14  	rpctypes "github.com/ari-anchor/sei-tendermint/rpc/jsonrpc/types"
    15  )
    16  
    17  // WebSocket handler
    18  
    19  const (
    20  	defaultWSWriteChanCapacity = 100
    21  	defaultWSWriteWait         = 10 * time.Second
    22  	defaultWSReadWait          = 30 * time.Second
    23  	defaultWSPingPeriod        = (defaultWSReadWait * 9) / 10
    24  )
    25  
    26  // WebsocketManager provides a WS handler for incoming connections and passes a
    27  // map of functions along with any additional params to new connections.
    28  // NOTE: The websocket path is defined externally, e.g. in node/node.go
    29  type WebsocketManager struct {
    30  	websocket.Upgrader
    31  
    32  	funcMap       map[string]*RPCFunc
    33  	logger        log.Logger
    34  	wsConnOptions []func(*wsConnection)
    35  }
    36  
    37  // NewWebsocketManager returns a new WebsocketManager that passes a map of
    38  // functions, connection options and logger to new WS connections.
    39  func NewWebsocketManager(logger log.Logger, funcMap map[string]*RPCFunc, wsConnOptions ...func(*wsConnection)) *WebsocketManager {
    40  	return &WebsocketManager{
    41  		funcMap: funcMap,
    42  		Upgrader: websocket.Upgrader{
    43  			CheckOrigin: func(r *http.Request) bool {
    44  				// TODO ???
    45  				//
    46  				// The default behavior would be relevant to browser-based clients,
    47  				// afaik. I suppose having a pass-through is a workaround for allowing
    48  				// for more complex security schemes, shifting the burden of
    49  				// AuthN/AuthZ outside the Tendermint RPC.
    50  				// I can't think of other uses right now that would warrant a TODO
    51  				// though. The real backstory of this TODO shall remain shrouded in
    52  				// mystery
    53  				return true
    54  			},
    55  		},
    56  		logger:        logger,
    57  		wsConnOptions: wsConnOptions,
    58  	}
    59  }
    60  
    61  // WebsocketHandler upgrades the request/response (via http.Hijack) and starts
    62  // the wsConnection.
    63  func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Request) {
    64  	wsConn, err := wm.Upgrade(w, r, nil)
    65  	if err != nil {
    66  		// The upgrader has already reported an HTTP error to the client, so we
    67  		// need only log it.
    68  		wm.logger.Error("Failed to upgrade connection", "err", err)
    69  		return
    70  	}
    71  	defer func() {
    72  		if err := wsConn.Close(); err != nil {
    73  			wm.logger.Error("Failed to close connection", "err", err)
    74  		}
    75  	}()
    76  
    77  	// register connection
    78  	logger := wm.logger.With("remote", wsConn.RemoteAddr())
    79  	conn := newWSConnection(wsConn, wm.funcMap, logger, wm.wsConnOptions...)
    80  	wm.logger.Info("New websocket connection", "remote", conn.remoteAddr)
    81  
    82  	// starting the conn is blocking
    83  	if err = conn.Start(r.Context()); err != nil {
    84  		wm.logger.Error("Failed to start connection", "err", err)
    85  		writeInternalError(w, err)
    86  		return
    87  	}
    88  
    89  	if err := conn.Stop(); err != nil {
    90  		wm.logger.Error("error while stopping connection", "error", err)
    91  	}
    92  }
    93  
    94  // WebSocket connection
    95  
    96  // A single websocket connection contains listener id, underlying ws
    97  // connection, and the event switch for subscribing to events.
    98  //
    99  // In case of an error, the connection is stopped.
   100  type wsConnection struct {
   101  	Logger log.Logger
   102  
   103  	remoteAddr string
   104  	baseConn   *websocket.Conn
   105  	// writeChan is never closed, to allow WriteRPCResponse() to fail.
   106  	writeChan chan rpctypes.RPCResponse
   107  
   108  	// chan, which is closed when/if readRoutine errors
   109  	// used to abort writeRoutine
   110  	readRoutineQuit chan struct{}
   111  
   112  	funcMap map[string]*RPCFunc
   113  
   114  	// Connection times out if we haven't received *anything* in this long, not even pings.
   115  	readWait time.Duration
   116  
   117  	// Send pings to server with this period. Must be less than readWait, but greater than zero.
   118  	pingPeriod time.Duration
   119  
   120  	// Maximum message size.
   121  	readLimit int64
   122  
   123  	// callback which is called upon disconnect
   124  	onDisconnect func(remoteAddr string)
   125  
   126  	ctx    context.Context
   127  	cancel context.CancelFunc
   128  }
   129  
   130  // NewWSConnection wraps websocket.Conn.
   131  //
   132  // See the commentary on the func(*wsConnection) functions for a detailed
   133  // description of how to configure ping period and pong wait time. NOTE: if the
   134  // write buffer is full, pongs may be dropped, which may cause clients to
   135  // disconnect. see https://github.com/gorilla/websocket/issues/97
   136  func newWSConnection(baseConn *websocket.Conn, funcMap map[string]*RPCFunc, logger log.Logger, options ...func(*wsConnection)) *wsConnection {
   137  	wsc := &wsConnection{
   138  		Logger:          logger,
   139  		remoteAddr:      baseConn.RemoteAddr().String(),
   140  		baseConn:        baseConn,
   141  		funcMap:         funcMap,
   142  		readWait:        defaultWSReadWait,
   143  		pingPeriod:      defaultWSPingPeriod,
   144  		readRoutineQuit: make(chan struct{}),
   145  	}
   146  	for _, option := range options {
   147  		option(wsc)
   148  	}
   149  	wsc.baseConn.SetReadLimit(wsc.readLimit)
   150  	return wsc
   151  }
   152  
   153  // OnDisconnect sets a callback which is used upon disconnect - not
   154  // Goroutine-safe. Nop by default.
   155  func OnDisconnect(onDisconnect func(remoteAddr string)) func(*wsConnection) {
   156  	return func(wsc *wsConnection) {
   157  		wsc.onDisconnect = onDisconnect
   158  	}
   159  }
   160  
   161  // ReadWait sets the amount of time to wait before a websocket read times out.
   162  // It should only be used in the constructor - not Goroutine-safe.
   163  func ReadWait(readWait time.Duration) func(*wsConnection) {
   164  	return func(wsc *wsConnection) {
   165  		wsc.readWait = readWait
   166  	}
   167  }
   168  
   169  // PingPeriod sets the duration for sending websocket pings.
   170  // It should only be used in the constructor - not Goroutine-safe.
   171  func PingPeriod(pingPeriod time.Duration) func(*wsConnection) {
   172  	return func(wsc *wsConnection) {
   173  		wsc.pingPeriod = pingPeriod
   174  	}
   175  }
   176  
   177  // ReadLimit sets the maximum size for reading message.
   178  // It should only be used in the constructor - not Goroutine-safe.
   179  func ReadLimit(readLimit int64) func(*wsConnection) {
   180  	return func(wsc *wsConnection) {
   181  		wsc.readLimit = readLimit
   182  	}
   183  }
   184  
   185  // Start starts the client service routines and blocks until there is an error.
   186  func (wsc *wsConnection) Start(ctx context.Context) error {
   187  	wsc.writeChan = make(chan rpctypes.RPCResponse, defaultWSWriteChanCapacity)
   188  
   189  	// Read subscriptions/unsubscriptions to events
   190  	go wsc.readRoutine(ctx)
   191  	// Write responses, BLOCKING.
   192  	wsc.writeRoutine(ctx)
   193  
   194  	return nil
   195  }
   196  
   197  // Stop unsubscribes the remote from all subscriptions.
   198  func (wsc *wsConnection) Stop() error {
   199  	if wsc.onDisconnect != nil {
   200  		wsc.onDisconnect(wsc.remoteAddr)
   201  	}
   202  	if wsc.ctx != nil {
   203  		wsc.cancel()
   204  	}
   205  	return nil
   206  }
   207  
   208  // GetRemoteAddr returns the remote address of the underlying connection.
   209  // It implements WSRPCConnection
   210  func (wsc *wsConnection) GetRemoteAddr() string {
   211  	return wsc.remoteAddr
   212  }
   213  
   214  // WriteRPCResponse pushes a response to the writeChan, and blocks until it is
   215  // accepted.
   216  // It implements WSRPCConnection. It is Goroutine-safe.
   217  func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp rpctypes.RPCResponse) error {
   218  	select {
   219  	case <-ctx.Done():
   220  		return ctx.Err()
   221  	case wsc.writeChan <- resp:
   222  		return nil
   223  	}
   224  }
   225  
   226  // TryWriteRPCResponse attempts to push a response to the writeChan, but does
   227  // not block.
   228  // It implements WSRPCConnection. It is Goroutine-safe
   229  func (wsc *wsConnection) TryWriteRPCResponse(ctx context.Context, resp rpctypes.RPCResponse) bool {
   230  	select {
   231  	case <-ctx.Done():
   232  		return false
   233  	case wsc.writeChan <- resp:
   234  		return true
   235  	default:
   236  		return false
   237  	}
   238  }
   239  
   240  // Context returns the connection's context.
   241  // The context is canceled when the client's connection closes.
   242  func (wsc *wsConnection) Context() context.Context {
   243  	if wsc.ctx != nil {
   244  		return wsc.ctx
   245  	}
   246  	wsc.ctx, wsc.cancel = context.WithCancel(context.Background())
   247  	return wsc.ctx
   248  }
   249  
   250  // Read from the socket and subscribe to or unsubscribe from events
   251  func (wsc *wsConnection) readRoutine(ctx context.Context) {
   252  	// readRoutine will block until response is written or WS connection is closed
   253  	writeCtx := context.Background()
   254  
   255  	defer func() {
   256  		if r := recover(); r != nil {
   257  			err, ok := r.(error)
   258  			if !ok {
   259  				err = fmt.Errorf("WSJSONRPC: %v", r)
   260  			}
   261  			req := rpctypes.NewRequest(uriReqID)
   262  			wsc.Logger.Error("Panic in WSJSONRPC handler", "err", err, "stack", string(debug.Stack()))
   263  			if err := wsc.WriteRPCResponse(writeCtx,
   264  				req.MakeErrorf(rpctypes.CodeInternalError, "Panic in handler: %v", err)); err != nil {
   265  				wsc.Logger.Error("error writing RPC response", "err", err)
   266  			}
   267  			go wsc.readRoutine(ctx)
   268  		}
   269  	}()
   270  
   271  	wsc.baseConn.SetPongHandler(func(m string) error {
   272  		return wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.readWait))
   273  	})
   274  
   275  	for {
   276  		select {
   277  		case <-ctx.Done():
   278  			return
   279  		default:
   280  			// reset deadline for every type of message (control or data)
   281  			if err := wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.readWait)); err != nil {
   282  				wsc.Logger.Error("failed to set read deadline", "err", err)
   283  			}
   284  
   285  			_, r, err := wsc.baseConn.NextReader()
   286  			if err != nil {
   287  				if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   288  					wsc.Logger.Info("Client closed the connection")
   289  				} else {
   290  					wsc.Logger.Error("Failed to read request", "err", err)
   291  				}
   292  				if err := wsc.Stop(); err != nil {
   293  					wsc.Logger.Error("error closing websocket connection", "err", err)
   294  				}
   295  				close(wsc.readRoutineQuit)
   296  				return
   297  			}
   298  
   299  			dec := json.NewDecoder(r)
   300  			var request rpctypes.RPCRequest
   301  			err = dec.Decode(&request)
   302  			if err != nil {
   303  				if err := wsc.WriteRPCResponse(writeCtx,
   304  					request.MakeErrorf(rpctypes.CodeParseError, "unmarshaling request: %v", err)); err != nil {
   305  					wsc.Logger.Error("error writing RPC response", "err", err)
   306  				}
   307  				continue
   308  			}
   309  
   310  			// A Notification is a Request object without an "id" member.
   311  			// The Server MUST NOT reply to a Notification, including those that are within a batch request.
   312  			if request.IsNotification() {
   313  				wsc.Logger.Debug(
   314  					"WSJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)",
   315  					"req", request,
   316  				)
   317  				continue
   318  			}
   319  
   320  			// Now, fetch the RPCFunc and execute it.
   321  			rpcFunc := wsc.funcMap[request.Method]
   322  			if rpcFunc == nil {
   323  				if err := wsc.WriteRPCResponse(writeCtx,
   324  					request.MakeErrorf(rpctypes.CodeMethodNotFound, request.Method)); err != nil {
   325  					wsc.Logger.Error("error writing RPC response", "err", err)
   326  				}
   327  				continue
   328  			}
   329  
   330  			fctx := rpctypes.WithCallInfo(wsc.Context(), &rpctypes.CallInfo{
   331  				RPCRequest: &request,
   332  				WSConn:     wsc,
   333  			})
   334  			var resp rpctypes.RPCResponse
   335  			result, err := rpcFunc.Call(fctx, request.Params)
   336  			if err == nil {
   337  				resp = request.MakeResponse(result)
   338  			} else {
   339  				resp = request.MakeError(err)
   340  			}
   341  			if err := wsc.WriteRPCResponse(writeCtx, resp); err != nil {
   342  				wsc.Logger.Error("error writing RPC response", "err", err)
   343  			}
   344  		}
   345  	}
   346  }
   347  
   348  // receives on a write channel and writes out on the socket
   349  func (wsc *wsConnection) writeRoutine(ctx context.Context) {
   350  	pingTicker := time.NewTicker(wsc.pingPeriod)
   351  	defer pingTicker.Stop()
   352  
   353  	// https://github.com/gorilla/websocket/issues/97
   354  	pongs := make(chan string, 1)
   355  	wsc.baseConn.SetPingHandler(func(m string) error {
   356  		select {
   357  		case pongs <- m:
   358  		default:
   359  		}
   360  		return nil
   361  	})
   362  
   363  	for {
   364  		select {
   365  		case <-ctx.Done():
   366  			return
   367  		case <-wsc.readRoutineQuit: // error in readRoutine
   368  			return
   369  		case m := <-pongs:
   370  			err := wsc.writeMessageWithDeadline(websocket.PongMessage, []byte(m))
   371  			if err != nil {
   372  				wsc.Logger.Info("Failed to write pong (client may disconnect)", "err", err)
   373  			}
   374  		case <-pingTicker.C:
   375  			err := wsc.writeMessageWithDeadline(websocket.PingMessage, []byte{})
   376  			if err != nil {
   377  				wsc.Logger.Error("Failed to write ping", "err", err)
   378  				return
   379  			}
   380  		case msg := <-wsc.writeChan:
   381  			data, err := json.Marshal(msg)
   382  			if err != nil {
   383  				wsc.Logger.Error("Failed to marshal RPCResponse to JSON", "msg", msg, "err", err)
   384  				continue
   385  			}
   386  			if err = wsc.writeMessageWithDeadline(websocket.TextMessage, data); err != nil {
   387  				wsc.Logger.Error("Failed to write response", "msg", msg, "err", err)
   388  				return
   389  			}
   390  		}
   391  	}
   392  }
   393  
   394  // All writes to the websocket must (re)set the write deadline.
   395  // If some writes don't set it while others do, they may timeout incorrectly
   396  // (https://github.com/ari-anchor/sei-tendermint/issues/553)
   397  func (wsc *wsConnection) writeMessageWithDeadline(msgType int, msg []byte) error {
   398  	if err := wsc.baseConn.SetWriteDeadline(time.Now().Add(defaultWSWriteWait)); err != nil {
   399  		return err
   400  	}
   401  	return wsc.baseConn.WriteMessage(msgType, msg)
   402  }