github.com/vipernet-xyz/tm@v0.34.24/rpc/jsonrpc/server/ws_handler.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"reflect"
    10  	"runtime/debug"
    11  	"time"
    12  
    13  	"github.com/gorilla/websocket"
    14  
    15  	"github.com/vipernet-xyz/tm/libs/log"
    16  	"github.com/vipernet-xyz/tm/libs/service"
    17  	types "github.com/vipernet-xyz/tm/rpc/jsonrpc/types"
    18  )
    19  
    20  // WebSocket handler
    21  
    22  const (
    23  	defaultWSWriteChanCapacity = 100
    24  	defaultWSWriteWait         = 10 * time.Second
    25  	defaultWSReadWait          = 30 * time.Second
    26  	defaultWSPingPeriod        = (defaultWSReadWait * 9) / 10
    27  )
    28  
    29  // WebsocketManager provides a WS handler for incoming connections and passes a
    30  // map of functions along with any additional params to new connections.
    31  // NOTE: The websocket path is defined externally, e.g. in node/node.go
    32  type WebsocketManager struct {
    33  	websocket.Upgrader
    34  
    35  	funcMap       map[string]*RPCFunc
    36  	logger        log.Logger
    37  	wsConnOptions []func(*wsConnection)
    38  }
    39  
    40  // NewWebsocketManager returns a new WebsocketManager that passes a map of
    41  // functions, connection options and logger to new WS connections.
    42  func NewWebsocketManager(
    43  	funcMap map[string]*RPCFunc,
    44  	wsConnOptions ...func(*wsConnection),
    45  ) *WebsocketManager {
    46  	return &WebsocketManager{
    47  		funcMap: funcMap,
    48  		Upgrader: websocket.Upgrader{
    49  			CheckOrigin: func(r *http.Request) bool {
    50  				// TODO ???
    51  				//
    52  				// The default behaviour would be relevant to browser-based clients,
    53  				// afaik. I suppose having a pass-through is a workaround for allowing
    54  				// for more complex security schemes, shifting the burden of
    55  				// AuthN/AuthZ outside the Tendermint RPC.
    56  				// I can't think of other uses right now that would warrant a TODO
    57  				// though. The real backstory of this TODO shall remain shrouded in
    58  				// mystery
    59  				return true
    60  			},
    61  		},
    62  		logger:        log.NewNopLogger(),
    63  		wsConnOptions: wsConnOptions,
    64  	}
    65  }
    66  
    67  // SetLogger sets the logger.
    68  func (wm *WebsocketManager) SetLogger(l log.Logger) {
    69  	wm.logger = l
    70  }
    71  
    72  // WebsocketHandler upgrades the request/response (via http.Hijack) and starts
    73  // the wsConnection.
    74  func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Request) {
    75  	wsConn, err := wm.Upgrade(w, r, nil)
    76  	if err != nil {
    77  		// TODO - return http error
    78  		wm.logger.Error("Failed to upgrade connection", "err", err)
    79  		return
    80  	}
    81  	defer func() {
    82  		if err := wsConn.Close(); err != nil {
    83  			wm.logger.Error("Failed to close connection", "err", err)
    84  		}
    85  	}()
    86  
    87  	// register connection
    88  	con := newWSConnection(wsConn, wm.funcMap, wm.wsConnOptions...)
    89  	con.SetLogger(wm.logger.With("remote", wsConn.RemoteAddr()))
    90  	wm.logger.Info("New websocket connection", "remote", con.remoteAddr)
    91  	err = con.Start() // BLOCKING
    92  	if err != nil {
    93  		wm.logger.Error("Failed to start connection", "err", err)
    94  		return
    95  	}
    96  	if err := con.Stop(); err != nil {
    97  		wm.logger.Error("error while stopping connection", "error", err)
    98  	}
    99  }
   100  
   101  // WebSocket connection
   102  
   103  // A single websocket connection contains listener id, underlying ws
   104  // connection, and the event switch for subscribing to events.
   105  //
   106  // In case of an error, the connection is stopped.
   107  type wsConnection struct {
   108  	service.BaseService
   109  
   110  	remoteAddr string
   111  	baseConn   *websocket.Conn
   112  	// writeChan is never closed, to allow WriteRPCResponse() to fail.
   113  	writeChan chan types.RPCResponse
   114  
   115  	// chan, which is closed when/if readRoutine errors
   116  	// used to abort writeRoutine
   117  	readRoutineQuit chan struct{}
   118  
   119  	funcMap map[string]*RPCFunc
   120  
   121  	// write channel capacity
   122  	writeChanCapacity int
   123  
   124  	// each write times out after this.
   125  	writeWait time.Duration
   126  
   127  	// Connection times out if we haven't received *anything* in this long, not even pings.
   128  	readWait time.Duration
   129  
   130  	// Send pings to server with this period. Must be less than readWait, but greater than zero.
   131  	pingPeriod time.Duration
   132  
   133  	// Maximum message size.
   134  	readLimit int64
   135  
   136  	// callback which is called upon disconnect
   137  	onDisconnect func(remoteAddr string)
   138  
   139  	ctx    context.Context
   140  	cancel context.CancelFunc
   141  }
   142  
   143  // NewWSConnection wraps websocket.Conn.
   144  //
   145  // See the commentary on the func(*wsConnection) functions for a detailed
   146  // description of how to configure ping period and pong wait time. NOTE: if the
   147  // write buffer is full, pongs may be dropped, which may cause clients to
   148  // disconnect. see https://github.com/gorilla/websocket/issues/97
   149  func newWSConnection(
   150  	baseConn *websocket.Conn,
   151  	funcMap map[string]*RPCFunc,
   152  	options ...func(*wsConnection),
   153  ) *wsConnection {
   154  	wsc := &wsConnection{
   155  		remoteAddr:        baseConn.RemoteAddr().String(),
   156  		baseConn:          baseConn,
   157  		funcMap:           funcMap,
   158  		writeWait:         defaultWSWriteWait,
   159  		writeChanCapacity: defaultWSWriteChanCapacity,
   160  		readWait:          defaultWSReadWait,
   161  		pingPeriod:        defaultWSPingPeriod,
   162  		readRoutineQuit:   make(chan struct{}),
   163  	}
   164  	for _, option := range options {
   165  		option(wsc)
   166  	}
   167  	wsc.baseConn.SetReadLimit(wsc.readLimit)
   168  	wsc.BaseService = *service.NewBaseService(nil, "wsConnection", wsc)
   169  	return wsc
   170  }
   171  
   172  // OnDisconnect sets a callback which is used upon disconnect - not
   173  // Goroutine-safe. Nop by default.
   174  func OnDisconnect(onDisconnect func(remoteAddr string)) func(*wsConnection) {
   175  	return func(wsc *wsConnection) {
   176  		wsc.onDisconnect = onDisconnect
   177  	}
   178  }
   179  
   180  // WriteWait sets the amount of time to wait before a websocket write times out.
   181  // It should only be used in the constructor - not Goroutine-safe.
   182  func WriteWait(writeWait time.Duration) func(*wsConnection) {
   183  	return func(wsc *wsConnection) {
   184  		wsc.writeWait = writeWait
   185  	}
   186  }
   187  
   188  // WriteChanCapacity sets the capacity of the websocket write channel.
   189  // It should only be used in the constructor - not Goroutine-safe.
   190  func WriteChanCapacity(cap int) func(*wsConnection) {
   191  	return func(wsc *wsConnection) {
   192  		wsc.writeChanCapacity = cap
   193  	}
   194  }
   195  
   196  // ReadWait sets the amount of time to wait before a websocket read times out.
   197  // It should only be used in the constructor - not Goroutine-safe.
   198  func ReadWait(readWait time.Duration) func(*wsConnection) {
   199  	return func(wsc *wsConnection) {
   200  		wsc.readWait = readWait
   201  	}
   202  }
   203  
   204  // PingPeriod sets the duration for sending websocket pings.
   205  // It should only be used in the constructor - not Goroutine-safe.
   206  func PingPeriod(pingPeriod time.Duration) func(*wsConnection) {
   207  	return func(wsc *wsConnection) {
   208  		wsc.pingPeriod = pingPeriod
   209  	}
   210  }
   211  
   212  // ReadLimit sets the maximum size for reading message.
   213  // It should only be used in the constructor - not Goroutine-safe.
   214  func ReadLimit(readLimit int64) func(*wsConnection) {
   215  	return func(wsc *wsConnection) {
   216  		wsc.readLimit = readLimit
   217  	}
   218  }
   219  
   220  // OnStart implements service.Service by starting the read and write routines. It
   221  // blocks until there's some error.
   222  func (wsc *wsConnection) OnStart() error {
   223  	wsc.writeChan = make(chan types.RPCResponse, wsc.writeChanCapacity)
   224  
   225  	// Read subscriptions/unsubscriptions to events
   226  	go wsc.readRoutine()
   227  	// Write responses, BLOCKING.
   228  	wsc.writeRoutine()
   229  
   230  	return nil
   231  }
   232  
   233  // OnStop implements service.Service by unsubscribing remoteAddr from all
   234  // subscriptions.
   235  func (wsc *wsConnection) OnStop() {
   236  	if wsc.onDisconnect != nil {
   237  		wsc.onDisconnect(wsc.remoteAddr)
   238  	}
   239  
   240  	if wsc.ctx != nil {
   241  		wsc.cancel()
   242  	}
   243  }
   244  
   245  // GetRemoteAddr returns the remote address of the underlying connection.
   246  // It implements WSRPCConnection
   247  func (wsc *wsConnection) GetRemoteAddr() string {
   248  	return wsc.remoteAddr
   249  }
   250  
   251  // WriteRPCResponse pushes a response to the writeChan, and blocks until it is
   252  // accepted.
   253  // It implements WSRPCConnection. It is Goroutine-safe.
   254  func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp types.RPCResponse) error {
   255  	select {
   256  	case <-wsc.Quit():
   257  		return errors.New("connection was stopped")
   258  	case <-ctx.Done():
   259  		return ctx.Err()
   260  	case wsc.writeChan <- resp:
   261  		return nil
   262  	}
   263  }
   264  
   265  // TryWriteRPCResponse attempts to push a response to the writeChan, but does
   266  // not block.
   267  // It implements WSRPCConnection. It is Goroutine-safe
   268  func (wsc *wsConnection) TryWriteRPCResponse(resp types.RPCResponse) bool {
   269  	select {
   270  	case <-wsc.Quit():
   271  		return false
   272  	case wsc.writeChan <- resp:
   273  		return true
   274  	default:
   275  		return false
   276  	}
   277  }
   278  
   279  // Context returns the connection's context.
   280  // The context is canceled when the client's connection closes.
   281  func (wsc *wsConnection) Context() context.Context {
   282  	if wsc.ctx != nil {
   283  		return wsc.ctx
   284  	}
   285  	wsc.ctx, wsc.cancel = context.WithCancel(context.Background())
   286  	return wsc.ctx
   287  }
   288  
   289  // Read from the socket and subscribe to or unsubscribe from events
   290  func (wsc *wsConnection) readRoutine() {
   291  	// readRoutine will block until response is written or WS connection is closed
   292  	writeCtx := context.Background()
   293  
   294  	defer func() {
   295  		if r := recover(); r != nil {
   296  			err, ok := r.(error)
   297  			if !ok {
   298  				err = fmt.Errorf("WSJSONRPC: %v", r)
   299  			}
   300  			wsc.Logger.Error("Panic in WSJSONRPC handler", "err", err, "stack", string(debug.Stack()))
   301  			if err := wsc.WriteRPCResponse(writeCtx, types.RPCInternalError(types.JSONRPCIntID(-1), err)); err != nil {
   302  				wsc.Logger.Error("Error writing RPC response", "err", err)
   303  			}
   304  			go wsc.readRoutine()
   305  		}
   306  	}()
   307  
   308  	wsc.baseConn.SetPongHandler(func(m string) error {
   309  		return wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.readWait))
   310  	})
   311  
   312  	for {
   313  		select {
   314  		case <-wsc.Quit():
   315  			return
   316  		default:
   317  			// reset deadline for every type of message (control or data)
   318  			if err := wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.readWait)); err != nil {
   319  				wsc.Logger.Error("failed to set read deadline", "err", err)
   320  			}
   321  
   322  			_, r, err := wsc.baseConn.NextReader()
   323  			if err != nil {
   324  				if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   325  					wsc.Logger.Info("Client closed the connection")
   326  				} else {
   327  					wsc.Logger.Error("Failed to read request", "err", err)
   328  				}
   329  				if err := wsc.Stop(); err != nil {
   330  					wsc.Logger.Error("Error closing websocket connection", "err", err)
   331  				}
   332  				close(wsc.readRoutineQuit)
   333  				return
   334  			}
   335  
   336  			dec := json.NewDecoder(r)
   337  			var request types.RPCRequest
   338  			err = dec.Decode(&request)
   339  			if err != nil {
   340  				if err := wsc.WriteRPCResponse(writeCtx,
   341  					types.RPCParseError(fmt.Errorf("error unmarshaling request: %w", err))); err != nil {
   342  					wsc.Logger.Error("Error writing RPC response", "err", err)
   343  				}
   344  				continue
   345  			}
   346  
   347  			// A Notification is a Request object without an "id" member.
   348  			// The Server MUST NOT reply to a Notification, including those that are within a batch request.
   349  			if request.ID == nil {
   350  				wsc.Logger.Debug(
   351  					"WSJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)",
   352  					"req", request,
   353  				)
   354  				continue
   355  			}
   356  
   357  			// Now, fetch the RPCFunc and execute it.
   358  			rpcFunc := wsc.funcMap[request.Method]
   359  			if rpcFunc == nil {
   360  				if err := wsc.WriteRPCResponse(writeCtx, types.RPCMethodNotFoundError(request.ID)); err != nil {
   361  					wsc.Logger.Error("Error writing RPC response", "err", err)
   362  				}
   363  				continue
   364  			}
   365  
   366  			ctx := &types.Context{JSONReq: &request, WSConn: wsc}
   367  			args := []reflect.Value{reflect.ValueOf(ctx)}
   368  			if len(request.Params) > 0 {
   369  				fnArgs, err := jsonParamsToArgs(rpcFunc, request.Params)
   370  				if err != nil {
   371  					if err := wsc.WriteRPCResponse(writeCtx,
   372  						types.RPCInternalError(request.ID, fmt.Errorf("error converting json params to arguments: %w", err)),
   373  					); err != nil {
   374  						wsc.Logger.Error("Error writing RPC response", "err", err)
   375  					}
   376  					continue
   377  				}
   378  				args = append(args, fnArgs...)
   379  			}
   380  
   381  			returns := rpcFunc.f.Call(args)
   382  
   383  			// TODO: Need to encode args/returns to string if we want to log them
   384  			wsc.Logger.Info("WSJSONRPC", "method", request.Method)
   385  
   386  			result, err := unreflectResult(returns)
   387  			if err != nil {
   388  				if err := wsc.WriteRPCResponse(writeCtx, types.RPCInternalError(request.ID, err)); err != nil {
   389  					wsc.Logger.Error("Error writing RPC response", "err", err)
   390  				}
   391  				continue
   392  			}
   393  
   394  			if err := wsc.WriteRPCResponse(writeCtx, types.NewRPCSuccessResponse(request.ID, result)); err != nil {
   395  				wsc.Logger.Error("Error writing RPC response", "err", err)
   396  			}
   397  		}
   398  	}
   399  }
   400  
   401  // receives on a write channel and writes out on the socket
   402  func (wsc *wsConnection) writeRoutine() {
   403  	pingTicker := time.NewTicker(wsc.pingPeriod)
   404  	defer pingTicker.Stop()
   405  
   406  	// https://github.com/gorilla/websocket/issues/97
   407  	pongs := make(chan string, 1)
   408  	wsc.baseConn.SetPingHandler(func(m string) error {
   409  		select {
   410  		case pongs <- m:
   411  		default:
   412  		}
   413  		return nil
   414  	})
   415  
   416  	for {
   417  		select {
   418  		case <-wsc.Quit():
   419  			return
   420  		case <-wsc.readRoutineQuit: // error in readRoutine
   421  			return
   422  		case m := <-pongs:
   423  			err := wsc.writeMessageWithDeadline(websocket.PongMessage, []byte(m))
   424  			if err != nil {
   425  				wsc.Logger.Info("Failed to write pong (client may disconnect)", "err", err)
   426  			}
   427  		case <-pingTicker.C:
   428  			err := wsc.writeMessageWithDeadline(websocket.PingMessage, []byte{})
   429  			if err != nil {
   430  				wsc.Logger.Error("Failed to write ping", "err", err)
   431  				return
   432  			}
   433  		case msg := <-wsc.writeChan:
   434  			// Use json.MarshalIndent instead of Marshal for pretty output.
   435  			// Pretty output not necessary, since most consumers of WS events are
   436  			// automated processes, not humans.
   437  			jsonBytes, err := json.Marshal(msg)
   438  			if err != nil {
   439  				wsc.Logger.Error("Failed to marshal RPCResponse to JSON", "err", err)
   440  				continue
   441  			}
   442  			if err = wsc.writeMessageWithDeadline(websocket.TextMessage, jsonBytes); err != nil {
   443  				wsc.Logger.Error("Failed to write response", "err", err, "msg", msg)
   444  				return
   445  			}
   446  		}
   447  	}
   448  }
   449  
   450  // All writes to the websocket must (re)set the write deadline.
   451  // If some writes don't set it while others do, they may timeout incorrectly
   452  // (https://github.com/vipernet-xyz/tm/issues/553)
   453  func (wsc *wsConnection) writeMessageWithDeadline(msgType int, msg []byte) error {
   454  	if err := wsc.baseConn.SetWriteDeadline(time.Now().Add(wsc.writeWait)); err != nil {
   455  		return err
   456  	}
   457  	return wsc.baseConn.WriteMessage(msgType, msg)
   458  }