github.com/number571/tendermint@v0.34.11-gost/rpc/jsonrpc/server/ws_handler.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"reflect"
    10  	"runtime/debug"
    11  	"time"
    12  
    13  	"github.com/gorilla/websocket"
    14  
    15  	"github.com/number571/tendermint/libs/log"
    16  	"github.com/number571/tendermint/libs/service"
    17  	ctypes "github.com/number571/tendermint/rpc/core/types"
    18  	types "github.com/number571/tendermint/rpc/jsonrpc/types"
    19  )
    20  
    21  // WebSocket handler
    22  
    23  const (
    24  	defaultWSWriteChanCapacity = 100
    25  	defaultWSWriteWait         = 10 * time.Second
    26  	defaultWSReadWait          = 30 * time.Second
    27  	defaultWSPingPeriod        = (defaultWSReadWait * 9) / 10
    28  )
    29  
    30  // WebsocketManager provides a WS handler for incoming connections and passes a
    31  // map of functions along with any additional params to new connections.
    32  // NOTE: The websocket path is defined externally, e.g. in node/node.go
    33  type WebsocketManager struct {
    34  	websocket.Upgrader
    35  
    36  	funcMap       map[string]*RPCFunc
    37  	logger        log.Logger
    38  	wsConnOptions []func(*wsConnection)
    39  }
    40  
    41  // NewWebsocketManager returns a new WebsocketManager that passes a map of
    42  // functions, connection options and logger to new WS connections.
    43  func NewWebsocketManager(
    44  	funcMap map[string]*RPCFunc,
    45  	wsConnOptions ...func(*wsConnection),
    46  ) *WebsocketManager {
    47  	return &WebsocketManager{
    48  		funcMap: funcMap,
    49  		Upgrader: websocket.Upgrader{
    50  			CheckOrigin: func(r *http.Request) bool {
    51  				// TODO ???
    52  				//
    53  				// The default behavior would be relevant to browser-based clients,
    54  				// afaik. I suppose having a pass-through is a workaround for allowing
    55  				// for more complex security schemes, shifting the burden of
    56  				// AuthN/AuthZ outside the Tendermint RPC.
    57  				// I can't think of other uses right now that would warrant a TODO
    58  				// though. The real backstory of this TODO shall remain shrouded in
    59  				// mystery
    60  				return true
    61  			},
    62  		},
    63  		logger:        log.NewNopLogger(),
    64  		wsConnOptions: wsConnOptions,
    65  	}
    66  }
    67  
    68  // SetLogger sets the logger.
    69  func (wm *WebsocketManager) SetLogger(l log.Logger) {
    70  	wm.logger = l
    71  }
    72  
    73  // WebsocketHandler upgrades the request/response (via http.Hijack) and starts
    74  // the wsConnection.
    75  func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Request) {
    76  	wsConn, err := wm.Upgrade(w, r, nil)
    77  	if err != nil {
    78  		// TODO - return http error
    79  		wm.logger.Error("Failed to upgrade connection", "err", err)
    80  		return
    81  	}
    82  	defer func() {
    83  		if err := wsConn.Close(); err != nil {
    84  			wm.logger.Error("Failed to close connection", "err", err)
    85  		}
    86  	}()
    87  
    88  	// register connection
    89  	con := newWSConnection(wsConn, wm.funcMap, wm.wsConnOptions...)
    90  	con.SetLogger(wm.logger.With("remote", wsConn.RemoteAddr()))
    91  	wm.logger.Info("New websocket connection", "remote", con.remoteAddr)
    92  	err = con.Start() // BLOCKING
    93  	if err != nil {
    94  		wm.logger.Error("Failed to start connection", "err", err)
    95  		return
    96  	}
    97  	if err := con.Stop(); err != nil {
    98  		wm.logger.Error("error while stopping connection", "error", err)
    99  	}
   100  }
   101  
   102  // WebSocket connection
   103  
   104  // A single websocket connection contains listener id, underlying ws
   105  // connection, and the event switch for subscribing to events.
   106  //
   107  // In case of an error, the connection is stopped.
   108  type wsConnection struct {
   109  	service.BaseService
   110  
   111  	remoteAddr string
   112  	baseConn   *websocket.Conn
   113  	// writeChan is never closed, to allow WriteRPCResponse() to fail.
   114  	writeChan chan types.RPCResponse
   115  
   116  	// chan, which is closed when/if readRoutine errors
   117  	// used to abort writeRoutine
   118  	readRoutineQuit chan struct{}
   119  
   120  	funcMap map[string]*RPCFunc
   121  
   122  	// write channel capacity
   123  	writeChanCapacity int
   124  
   125  	// each write times out after this.
   126  	writeWait time.Duration
   127  
   128  	// Connection times out if we haven't received *anything* in this long, not even pings.
   129  	readWait time.Duration
   130  
   131  	// Send pings to server with this period. Must be less than readWait, but greater than zero.
   132  	pingPeriod time.Duration
   133  
   134  	// Maximum message size.
   135  	readLimit int64
   136  
   137  	// callback which is called upon disconnect
   138  	onDisconnect func(remoteAddr string)
   139  
   140  	ctx    context.Context
   141  	cancel context.CancelFunc
   142  }
   143  
   144  // NewWSConnection wraps websocket.Conn.
   145  //
   146  // See the commentary on the func(*wsConnection) functions for a detailed
   147  // description of how to configure ping period and pong wait time. NOTE: if the
   148  // write buffer is full, pongs may be dropped, which may cause clients to
   149  // disconnect. see https://github.com/gorilla/websocket/issues/97
   150  func newWSConnection(
   151  	baseConn *websocket.Conn,
   152  	funcMap map[string]*RPCFunc,
   153  	options ...func(*wsConnection),
   154  ) *wsConnection {
   155  	wsc := &wsConnection{
   156  		remoteAddr:        baseConn.RemoteAddr().String(),
   157  		baseConn:          baseConn,
   158  		funcMap:           funcMap,
   159  		writeWait:         defaultWSWriteWait,
   160  		writeChanCapacity: defaultWSWriteChanCapacity,
   161  		readWait:          defaultWSReadWait,
   162  		pingPeriod:        defaultWSPingPeriod,
   163  		readRoutineQuit:   make(chan struct{}),
   164  	}
   165  	for _, option := range options {
   166  		option(wsc)
   167  	}
   168  	wsc.baseConn.SetReadLimit(wsc.readLimit)
   169  	wsc.BaseService = *service.NewBaseService(nil, "wsConnection", wsc)
   170  	return wsc
   171  }
   172  
   173  // OnDisconnect sets a callback which is used upon disconnect - not
   174  // Goroutine-safe. Nop by default.
   175  func OnDisconnect(onDisconnect func(remoteAddr string)) func(*wsConnection) {
   176  	return func(wsc *wsConnection) {
   177  		wsc.onDisconnect = onDisconnect
   178  	}
   179  }
   180  
   181  // WriteWait sets the amount of time to wait before a websocket write times out.
   182  // It should only be used in the constructor - not Goroutine-safe.
   183  func WriteWait(writeWait time.Duration) func(*wsConnection) {
   184  	return func(wsc *wsConnection) {
   185  		wsc.writeWait = writeWait
   186  	}
   187  }
   188  
   189  // WriteChanCapacity sets the capacity of the websocket write channel.
   190  // It should only be used in the constructor - not Goroutine-safe.
   191  func WriteChanCapacity(cap int) func(*wsConnection) {
   192  	return func(wsc *wsConnection) {
   193  		wsc.writeChanCapacity = cap
   194  	}
   195  }
   196  
   197  // ReadWait sets the amount of time to wait before a websocket read times out.
   198  // It should only be used in the constructor - not Goroutine-safe.
   199  func ReadWait(readWait time.Duration) func(*wsConnection) {
   200  	return func(wsc *wsConnection) {
   201  		wsc.readWait = readWait
   202  	}
   203  }
   204  
   205  // PingPeriod sets the duration for sending websocket pings.
   206  // It should only be used in the constructor - not Goroutine-safe.
   207  func PingPeriod(pingPeriod time.Duration) func(*wsConnection) {
   208  	return func(wsc *wsConnection) {
   209  		wsc.pingPeriod = pingPeriod
   210  	}
   211  }
   212  
   213  // ReadLimit sets the maximum size for reading message.
   214  // It should only be used in the constructor - not Goroutine-safe.
   215  func ReadLimit(readLimit int64) func(*wsConnection) {
   216  	return func(wsc *wsConnection) {
   217  		wsc.readLimit = readLimit
   218  	}
   219  }
   220  
   221  // OnStart implements service.Service by starting the read and write routines. It
   222  // blocks until there's some error.
   223  func (wsc *wsConnection) OnStart() error {
   224  	wsc.writeChan = make(chan types.RPCResponse, wsc.writeChanCapacity)
   225  
   226  	// Read subscriptions/unsubscriptions to events
   227  	go wsc.readRoutine()
   228  	// Write responses, BLOCKING.
   229  	wsc.writeRoutine()
   230  
   231  	return nil
   232  }
   233  
   234  // OnStop implements service.Service by unsubscribing remoteAddr from all
   235  // subscriptions.
   236  func (wsc *wsConnection) OnStop() {
   237  	if wsc.onDisconnect != nil {
   238  		wsc.onDisconnect(wsc.remoteAddr)
   239  	}
   240  
   241  	if wsc.ctx != nil {
   242  		wsc.cancel()
   243  	}
   244  }
   245  
   246  // GetRemoteAddr returns the remote address of the underlying connection.
   247  // It implements WSRPCConnection
   248  func (wsc *wsConnection) GetRemoteAddr() string {
   249  	return wsc.remoteAddr
   250  }
   251  
   252  // WriteRPCResponse pushes a response to the writeChan, and blocks until it is
   253  // accepted.
   254  // It implements WSRPCConnection. It is Goroutine-safe.
   255  func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp types.RPCResponse) error {
   256  	select {
   257  	case <-wsc.Quit():
   258  		return errors.New("connection was stopped")
   259  	case <-ctx.Done():
   260  		return ctx.Err()
   261  	case wsc.writeChan <- resp:
   262  		return nil
   263  	}
   264  }
   265  
   266  // TryWriteRPCResponse attempts to push a response to the writeChan, but does
   267  // not block.
   268  // It implements WSRPCConnection. It is Goroutine-safe
   269  func (wsc *wsConnection) TryWriteRPCResponse(resp types.RPCResponse) bool {
   270  	select {
   271  	case <-wsc.Quit():
   272  		return false
   273  	case wsc.writeChan <- resp:
   274  		return true
   275  	default:
   276  		return false
   277  	}
   278  }
   279  
   280  // Context returns the connection's context.
   281  // The context is canceled when the client's connection closes.
   282  func (wsc *wsConnection) Context() context.Context {
   283  	if wsc.ctx != nil {
   284  		return wsc.ctx
   285  	}
   286  	wsc.ctx, wsc.cancel = context.WithCancel(context.Background())
   287  	return wsc.ctx
   288  }
   289  
   290  // Read from the socket and subscribe to or unsubscribe from events
   291  func (wsc *wsConnection) readRoutine() {
   292  	// readRoutine will block until response is written or WS connection is closed
   293  	writeCtx := context.Background()
   294  
   295  	defer func() {
   296  		if r := recover(); r != nil {
   297  			err, ok := r.(error)
   298  			if !ok {
   299  				err = fmt.Errorf("WSJSONRPC: %v", r)
   300  			}
   301  			wsc.Logger.Error("Panic in WSJSONRPC handler", "err", err, "stack", string(debug.Stack()))
   302  			if err := wsc.WriteRPCResponse(writeCtx, types.RPCInternalError(types.JSONRPCIntID(-1), err)); err != nil {
   303  				wsc.Logger.Error("Error writing RPC response", "err", err)
   304  			}
   305  			go wsc.readRoutine()
   306  		}
   307  	}()
   308  
   309  	wsc.baseConn.SetPongHandler(func(m string) error {
   310  		return wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.readWait))
   311  	})
   312  
   313  	for {
   314  		select {
   315  		case <-wsc.Quit():
   316  			return
   317  		default:
   318  			// reset deadline for every type of message (control or data)
   319  			if err := wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.readWait)); err != nil {
   320  				wsc.Logger.Error("failed to set read deadline", "err", err)
   321  			}
   322  
   323  			_, r, err := wsc.baseConn.NextReader()
   324  			if err != nil {
   325  				if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   326  					wsc.Logger.Info("Client closed the connection")
   327  				} else {
   328  					wsc.Logger.Error("Failed to read request", "err", err)
   329  				}
   330  				if err := wsc.Stop(); err != nil {
   331  					wsc.Logger.Error("Error closing websocket connection", "err", err)
   332  				}
   333  				close(wsc.readRoutineQuit)
   334  				return
   335  			}
   336  
   337  			dec := json.NewDecoder(r)
   338  			var request types.RPCRequest
   339  			err = dec.Decode(&request)
   340  			if err != nil {
   341  				if err := wsc.WriteRPCResponse(writeCtx,
   342  					types.RPCParseError(fmt.Errorf("error unmarshaling request: %w", err))); err != nil {
   343  					wsc.Logger.Error("Error writing RPC response", "err", err)
   344  				}
   345  				continue
   346  			}
   347  
   348  			// A Notification is a Request object without an "id" member.
   349  			// The Server MUST NOT reply to a Notification, including those that are within a batch request.
   350  			if request.ID == nil {
   351  				wsc.Logger.Debug(
   352  					"WSJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)",
   353  					"req", request,
   354  				)
   355  				continue
   356  			}
   357  
   358  			// Now, fetch the RPCFunc and execute it.
   359  			rpcFunc := wsc.funcMap[request.Method]
   360  			if rpcFunc == nil {
   361  				if err := wsc.WriteRPCResponse(writeCtx, types.RPCMethodNotFoundError(request.ID)); err != nil {
   362  					wsc.Logger.Error("Error writing RPC response", "err", err)
   363  				}
   364  				continue
   365  			}
   366  
   367  			ctx := &types.Context{JSONReq: &request, WSConn: wsc}
   368  			args := []reflect.Value{reflect.ValueOf(ctx)}
   369  			if len(request.Params) > 0 {
   370  				fnArgs, err := jsonParamsToArgs(rpcFunc, request.Params)
   371  				if err != nil {
   372  					if err := wsc.WriteRPCResponse(writeCtx,
   373  						types.RPCInvalidParamsError(request.ID, fmt.Errorf("error converting json params to arguments: %w", err)),
   374  					); err != nil {
   375  						wsc.Logger.Error("Error writing RPC response", "err", err)
   376  					}
   377  					continue
   378  				}
   379  				args = append(args, fnArgs...)
   380  			}
   381  
   382  			returns := rpcFunc.f.Call(args)
   383  
   384  			// TODO: Need to encode args/returns to string if we want to log them
   385  			wsc.Logger.Info("WSJSONRPC", "method", request.Method)
   386  
   387  			var resp types.RPCResponse
   388  			result, err := unreflectResult(returns)
   389  			switch e := err.(type) {
   390  			// if no error then return a success response
   391  			case nil:
   392  				resp = types.NewRPCSuccessResponse(request.ID, result)
   393  
   394  			// if this already of type RPC error then forward that error
   395  			case *types.RPCError:
   396  				resp = types.NewRPCErrorResponse(request.ID, e.Code, e.Message, e.Data)
   397  
   398  			default: // we need to unwrap the error and parse it accordingly
   399  				switch errors.Unwrap(err) {
   400  				// check if the error was due to an invald request
   401  				case ctypes.ErrZeroOrNegativeHeight, ctypes.ErrZeroOrNegativePerPage,
   402  					ctypes.ErrPageOutOfRange, ctypes.ErrInvalidRequest:
   403  					resp = types.RPCInvalidRequestError(request.ID, err)
   404  
   405  				// lastly default all remaining errors as internal errors
   406  				default: // includes ctypes.ErrHeightNotAvailable and ctypes.ErrHeightExceedsChainHead
   407  					resp = types.RPCInternalError(request.ID, err)
   408  				}
   409  			}
   410  
   411  			if err := wsc.WriteRPCResponse(writeCtx, resp); err != nil {
   412  				wsc.Logger.Error("Error writing RPC response", "err", err)
   413  			}
   414  
   415  		}
   416  	}
   417  }
   418  
   419  // receives on a write channel and writes out on the socket
   420  func (wsc *wsConnection) writeRoutine() {
   421  	pingTicker := time.NewTicker(wsc.pingPeriod)
   422  	defer pingTicker.Stop()
   423  
   424  	// https://github.com/gorilla/websocket/issues/97
   425  	pongs := make(chan string, 1)
   426  	wsc.baseConn.SetPingHandler(func(m string) error {
   427  		select {
   428  		case pongs <- m:
   429  		default:
   430  		}
   431  		return nil
   432  	})
   433  
   434  	for {
   435  		select {
   436  		case <-wsc.Quit():
   437  			return
   438  		case <-wsc.readRoutineQuit: // error in readRoutine
   439  			return
   440  		case m := <-pongs:
   441  			err := wsc.writeMessageWithDeadline(websocket.PongMessage, []byte(m))
   442  			if err != nil {
   443  				wsc.Logger.Info("Failed to write pong (client may disconnect)", "err", err)
   444  			}
   445  		case <-pingTicker.C:
   446  			err := wsc.writeMessageWithDeadline(websocket.PingMessage, []byte{})
   447  			if err != nil {
   448  				wsc.Logger.Error("Failed to write ping", "err", err)
   449  				return
   450  			}
   451  		case msg := <-wsc.writeChan:
   452  			jsonBytes, err := json.MarshalIndent(msg, "", "  ")
   453  			if err != nil {
   454  				wsc.Logger.Error("Failed to marshal RPCResponse to JSON", "err", err)
   455  				continue
   456  			}
   457  			if err = wsc.writeMessageWithDeadline(websocket.TextMessage, jsonBytes); err != nil {
   458  				wsc.Logger.Error("Failed to write response", "err", err, "msg", msg)
   459  				return
   460  			}
   461  		}
   462  	}
   463  }
   464  
   465  // All writes to the websocket must (re)set the write deadline.
   466  // If some writes don't set it while others do, they may timeout incorrectly
   467  // (https://github.com/number571/tendermint/issues/553)
   468  func (wsc *wsConnection) writeMessageWithDeadline(msgType int, msg []byte) error {
   469  	if err := wsc.baseConn.SetWriteDeadline(time.Now().Add(wsc.writeWait)); err != nil {
   470  		return err
   471  	}
   472  	return wsc.baseConn.WriteMessage(msgType, msg)
   473  }