github.com/okex/exchain@v1.8.0/libs/tendermint/rpc/jsonrpc/server/ws_handler.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"reflect"
     9  	"runtime/debug"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  	"github.com/pkg/errors"
    14  
    15  	amino "github.com/tendermint/go-amino"
    16  
    17  	"github.com/okex/exchain/libs/tendermint/libs/log"
    18  	"github.com/okex/exchain/libs/tendermint/libs/service"
    19  	types "github.com/okex/exchain/libs/tendermint/rpc/jsonrpc/types"
    20  )
    21  
    22  ///////////////////////////////////////////////////////////////////////////////
    23  // WebSocket handler
    24  ///////////////////////////////////////////////////////////////////////////////
    25  
    26  const (
    27  	defaultWSWriteChanCapacity = 1000
    28  	defaultWSWriteWait         = 10 * time.Second
    29  	defaultWSReadWait          = 30 * time.Second
    30  	defaultWSPingPeriod        = (defaultWSReadWait * 9) / 10
    31  )
    32  
    33  // WebsocketManager provides a WS handler for incoming connections and passes a
    34  // map of functions along with any additional params to new connections.
    35  // NOTE: The websocket path is defined externally, e.g. in node/node.go
    36  type WebsocketManager struct {
    37  	websocket.Upgrader
    38  
    39  	funcMap       map[string]*RPCFunc
    40  	cdc           *amino.Codec
    41  	logger        log.Logger
    42  	wsConnOptions []func(*wsConnection)
    43  }
    44  
    45  // NewWebsocketManager returns a new WebsocketManager that passes a map of
    46  // functions, connection options and logger to new WS connections.
    47  func NewWebsocketManager(
    48  	funcMap map[string]*RPCFunc,
    49  	cdc *amino.Codec,
    50  	wsConnOptions ...func(*wsConnection),
    51  ) *WebsocketManager {
    52  	return &WebsocketManager{
    53  		funcMap: funcMap,
    54  		cdc:     cdc,
    55  		Upgrader: websocket.Upgrader{
    56  			CheckOrigin: func(r *http.Request) bool {
    57  				// TODO ???
    58  				//
    59  				// The default behaviour would be relevant to browser-based clients,
    60  				// afaik. I suppose having a pass-through is a workaround for allowing
    61  				// for more complex security schemes, shifting the burden of
    62  				// AuthN/AuthZ outside the Tendermint RPC.
    63  				// I can't think of other uses right now that would warrant a TODO
    64  				// though. The real backstory of this TODO shall remain shrouded in
    65  				// mystery
    66  				return true
    67  			},
    68  		},
    69  		logger:        log.NewNopLogger(),
    70  		wsConnOptions: wsConnOptions,
    71  	}
    72  }
    73  
    74  // SetLogger sets the logger.
    75  func (wm *WebsocketManager) SetLogger(l log.Logger) {
    76  	wm.logger = l
    77  }
    78  
    79  // WebsocketHandler upgrades the request/response (via http.Hijack) and starts
    80  // the wsConnection.
    81  func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Request) {
    82  	wsConn, err := wm.Upgrade(w, r, nil)
    83  	if err != nil {
    84  		// TODO - return http error
    85  		wm.logger.Error("Failed to upgrade connection", "err", err)
    86  		return
    87  	}
    88  	defer func() {
    89  		if err := wsConn.Close(); err != nil {
    90  			wm.logger.Error("Failed to close connection", "err", err)
    91  		}
    92  	}()
    93  
    94  	// register connection
    95  	con := newWSConnection(wsConn, wm.funcMap, wm.cdc, wm.wsConnOptions...)
    96  	con.SetLogger(wm.logger.With("remote", wsConn.RemoteAddr()))
    97  	wm.logger.Info("New websocket connection", "remote", con.remoteAddr)
    98  	err = con.Start() // BLOCKING
    99  	if err != nil {
   100  		wm.logger.Error("Failed to start connection", "err", err)
   101  		return
   102  	}
   103  	con.Stop()
   104  }
   105  
   106  ///////////////////////////////////////////////////////////////////////////////
   107  // WebSocket connection
   108  ///////////////////////////////////////////////////////////////////////////////
   109  
   110  // A single websocket connection contains listener id, underlying ws
   111  // connection, and the event switch for subscribing to events.
   112  //
   113  // In case of an error, the connection is stopped.
   114  type wsConnection struct {
   115  	service.BaseService
   116  
   117  	remoteAddr string
   118  	baseConn   *websocket.Conn
   119  	// writeChan is never closed, to allow WriteRPCResponse() to fail.
   120  	writeChan chan types.RPCResponse
   121  
   122  	// chan, which is closed when/if readRoutine errors
   123  	// used to abort writeRoutine
   124  	readRoutineQuit chan struct{}
   125  
   126  	funcMap map[string]*RPCFunc
   127  	cdc     *amino.Codec
   128  
   129  	// write channel capacity
   130  	writeChanCapacity int
   131  
   132  	// each write times out after this.
   133  	writeWait time.Duration
   134  
   135  	// Connection times out if we haven't received *anything* in this long, not even pings.
   136  	readWait time.Duration
   137  
   138  	// Send pings to server with this period. Must be less than readWait, but greater than zero.
   139  	pingPeriod time.Duration
   140  
   141  	// Maximum message size.
   142  	readLimit int64
   143  
   144  	// callback which is called upon disconnect
   145  	onDisconnect func(remoteAddr string)
   146  
   147  	ctx    context.Context
   148  	cancel context.CancelFunc
   149  }
   150  
   151  // NewWSConnection wraps websocket.Conn.
   152  //
   153  // See the commentary on the func(*wsConnection) functions for a detailed
   154  // description of how to configure ping period and pong wait time. NOTE: if the
   155  // write buffer is full, pongs may be dropped, which may cause clients to
   156  // disconnect. see https://github.com/gorilla/websocket/issues/97
   157  func newWSConnection(
   158  	baseConn *websocket.Conn,
   159  	funcMap map[string]*RPCFunc,
   160  	cdc *amino.Codec,
   161  	options ...func(*wsConnection),
   162  ) *wsConnection {
   163  	wsc := &wsConnection{
   164  		remoteAddr:        baseConn.RemoteAddr().String(),
   165  		baseConn:          baseConn,
   166  		funcMap:           funcMap,
   167  		cdc:               cdc,
   168  		writeWait:         defaultWSWriteWait,
   169  		writeChanCapacity: defaultWSWriteChanCapacity,
   170  		readWait:          defaultWSReadWait,
   171  		pingPeriod:        defaultWSPingPeriod,
   172  		readRoutineQuit:   make(chan struct{}),
   173  	}
   174  	for _, option := range options {
   175  		option(wsc)
   176  	}
   177  	wsc.baseConn.SetReadLimit(wsc.readLimit)
   178  	wsc.BaseService = *service.NewBaseService(nil, "wsConnection", wsc)
   179  	return wsc
   180  }
   181  
   182  // OnDisconnect sets a callback which is used upon disconnect - not
   183  // Goroutine-safe. Nop by default.
   184  func OnDisconnect(onDisconnect func(remoteAddr string)) func(*wsConnection) {
   185  	return func(wsc *wsConnection) {
   186  		wsc.onDisconnect = onDisconnect
   187  	}
   188  }
   189  
   190  // WriteWait sets the amount of time to wait before a websocket write times out.
   191  // It should only be used in the constructor - not Goroutine-safe.
   192  func WriteWait(writeWait time.Duration) func(*wsConnection) {
   193  	return func(wsc *wsConnection) {
   194  		wsc.writeWait = writeWait
   195  	}
   196  }
   197  
   198  // WriteChanCapacity sets the capacity of the websocket write channel.
   199  // It should only be used in the constructor - not Goroutine-safe.
   200  func WriteChanCapacity(cap int) func(*wsConnection) {
   201  	return func(wsc *wsConnection) {
   202  		wsc.writeChanCapacity = cap
   203  	}
   204  }
   205  
   206  // ReadWait sets the amount of time to wait before a websocket read times out.
   207  // It should only be used in the constructor - not Goroutine-safe.
   208  func ReadWait(readWait time.Duration) func(*wsConnection) {
   209  	return func(wsc *wsConnection) {
   210  		wsc.readWait = readWait
   211  	}
   212  }
   213  
   214  // PingPeriod sets the duration for sending websocket pings.
   215  // It should only be used in the constructor - not Goroutine-safe.
   216  func PingPeriod(pingPeriod time.Duration) func(*wsConnection) {
   217  	return func(wsc *wsConnection) {
   218  		wsc.pingPeriod = pingPeriod
   219  	}
   220  }
   221  
   222  // ReadLimit sets the maximum size for reading message.
   223  // It should only be used in the constructor - not Goroutine-safe.
   224  func ReadLimit(readLimit int64) func(*wsConnection) {
   225  	return func(wsc *wsConnection) {
   226  		wsc.readLimit = readLimit
   227  	}
   228  }
   229  
   230  // OnStart implements service.Service by starting the read and write routines. It
   231  // blocks until there's some error.
   232  func (wsc *wsConnection) OnStart() error {
   233  	wsc.writeChan = make(chan types.RPCResponse, wsc.writeChanCapacity)
   234  
   235  	// Read subscriptions/unsubscriptions to events
   236  	go wsc.readRoutine()
   237  	// Write responses, BLOCKING.
   238  	wsc.writeRoutine()
   239  
   240  	return nil
   241  }
   242  
   243  // OnStop implements service.Service by unsubscribing remoteAddr from all
   244  // subscriptions.
   245  func (wsc *wsConnection) OnStop() {
   246  	if wsc.onDisconnect != nil {
   247  		wsc.onDisconnect(wsc.remoteAddr)
   248  	}
   249  
   250  	if wsc.ctx != nil {
   251  		wsc.cancel()
   252  	}
   253  }
   254  
   255  // GetRemoteAddr returns the remote address of the underlying connection.
   256  // It implements WSRPCConnection
   257  func (wsc *wsConnection) GetRemoteAddr() string {
   258  	return wsc.remoteAddr
   259  }
   260  
   261  // WriteRPCResponse pushes a response to the writeChan, and blocks until it is accepted.
   262  // It implements WSRPCConnection. It is Goroutine-safe.
   263  func (wsc *wsConnection) WriteRPCResponse(resp types.RPCResponse) {
   264  	select {
   265  	case <-wsc.Quit():
   266  		return
   267  	case wsc.writeChan <- resp:
   268  	}
   269  }
   270  
   271  // TryWriteRPCResponse attempts to push a response to the writeChan, but does not block.
   272  // It implements WSRPCConnection. It is Goroutine-safe
   273  func (wsc *wsConnection) TryWriteRPCResponse(resp types.RPCResponse) bool {
   274  	select {
   275  	case <-wsc.Quit():
   276  		return false
   277  	case wsc.writeChan <- resp:
   278  		return true
   279  	default:
   280  		return false
   281  	}
   282  }
   283  
   284  // Codec returns an amino codec used to decode parameters and encode results.
   285  // It implements WSRPCConnection.
   286  func (wsc *wsConnection) Codec() *amino.Codec {
   287  	return wsc.cdc
   288  }
   289  
   290  // Context returns the connection's context.
   291  // The context is canceled when the client's connection closes.
   292  func (wsc *wsConnection) Context() context.Context {
   293  	if wsc.ctx != nil {
   294  		return wsc.ctx
   295  	}
   296  	wsc.ctx, wsc.cancel = context.WithCancel(context.Background())
   297  	return wsc.ctx
   298  }
   299  
   300  // Read from the socket and subscribe to or unsubscribe from events
   301  func (wsc *wsConnection) readRoutine() {
   302  	defer func() {
   303  		if r := recover(); r != nil {
   304  			err, ok := r.(error)
   305  			if !ok {
   306  				err = fmt.Errorf("WSJSONRPC: %v", r)
   307  			}
   308  			wsc.Logger.Error("Panic in WSJSONRPC handler", "err", err, "stack", string(debug.Stack()))
   309  			wsc.WriteRPCResponse(types.RPCInternalError(types.JSONRPCIntID(-1), err))
   310  			go wsc.readRoutine()
   311  		}
   312  	}()
   313  
   314  	wsc.baseConn.SetPongHandler(func(m string) error {
   315  		return wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.readWait))
   316  	})
   317  
   318  	for {
   319  		select {
   320  		case <-wsc.Quit():
   321  			return
   322  		default:
   323  			// reset deadline for every type of message (control or data)
   324  			if err := wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.readWait)); err != nil {
   325  				wsc.Logger.Error("failed to set read deadline", "err", err)
   326  			}
   327  			var in []byte
   328  			_, in, err := wsc.baseConn.ReadMessage()
   329  			if err != nil {
   330  				if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   331  					wsc.Logger.Info("Client closed the connection")
   332  				} else {
   333  					wsc.Logger.Error("Failed to read request", "err", err)
   334  				}
   335  				wsc.Stop()
   336  				close(wsc.readRoutineQuit)
   337  				return
   338  			}
   339  
   340  			var request types.RPCRequest
   341  			err = json.Unmarshal(in, &request)
   342  			if err != nil {
   343  				wsc.WriteRPCResponse(types.RPCParseError(errors.Wrap(err, "error unmarshaling request")))
   344  				continue
   345  			}
   346  
   347  			// A Notification is a Request object without an "id" member.
   348  			// The Server MUST NOT reply to a Notification, including those that are within a batch request.
   349  			if request.ID == nil {
   350  				wsc.Logger.Debug(
   351  					"WSJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)",
   352  					"req", request,
   353  				)
   354  				continue
   355  			}
   356  
   357  			// Now, fetch the RPCFunc and execute it.
   358  			rpcFunc := wsc.funcMap[request.Method]
   359  			if rpcFunc == nil {
   360  				wsc.WriteRPCResponse(types.RPCMethodNotFoundError(request.ID))
   361  				continue
   362  			}
   363  
   364  			ctx := &types.Context{JSONReq: &request, WSConn: wsc}
   365  			args := []reflect.Value{reflect.ValueOf(ctx)}
   366  			if len(request.Params) > 0 {
   367  				fnArgs, err := jsonParamsToArgs(rpcFunc, wsc.cdc, request.Params)
   368  				if err != nil {
   369  					wsc.WriteRPCResponse(
   370  						types.RPCInternalError(request.ID, errors.Wrap(err, "error converting json params to arguments")),
   371  					)
   372  					continue
   373  				}
   374  				args = append(args, fnArgs...)
   375  			}
   376  
   377  			returns := rpcFunc.f.Call(args)
   378  
   379  			// TODO: Need to encode args/returns to string if we want to log them
   380  			wsc.Logger.Info("WSJSONRPC", "method", request.Method)
   381  
   382  			result, err := unreflectResult(returns)
   383  			if err != nil {
   384  				wsc.WriteRPCResponse(types.RPCInternalError(request.ID, err))
   385  				continue
   386  			}
   387  
   388  			wsc.WriteRPCResponse(types.NewRPCSuccessResponse(wsc.cdc, request.ID, result))
   389  		}
   390  	}
   391  }
   392  
   393  // receives on a write channel and writes out on the socket
   394  func (wsc *wsConnection) writeRoutine() {
   395  	pingTicker := time.NewTicker(wsc.pingPeriod)
   396  	defer func() {
   397  		pingTicker.Stop()
   398  	}()
   399  
   400  	// https://github.com/gorilla/websocket/issues/97
   401  	pongs := make(chan string, 1)
   402  	wsc.baseConn.SetPingHandler(func(m string) error {
   403  		select {
   404  		case pongs <- m:
   405  		default:
   406  		}
   407  		return nil
   408  	})
   409  
   410  	for {
   411  		select {
   412  		case <-wsc.Quit():
   413  			return
   414  		case <-wsc.readRoutineQuit: // error in readRoutine
   415  			return
   416  		case m := <-pongs:
   417  			err := wsc.writeMessageWithDeadline(websocket.PongMessage, []byte(m))
   418  			if err != nil {
   419  				wsc.Logger.Info("Failed to write pong (client may disconnect)", "err", err)
   420  			}
   421  		case <-pingTicker.C:
   422  			err := wsc.writeMessageWithDeadline(websocket.PingMessage, []byte{})
   423  			if err != nil {
   424  				wsc.Logger.Error("Failed to write ping", "err", err)
   425  				return
   426  			}
   427  		case msg := <-wsc.writeChan:
   428  			jsonBytes, err := json.MarshalIndent(msg, "", "  ")
   429  			if err != nil {
   430  				wsc.Logger.Error("Failed to marshal RPCResponse to JSON", "err", err)
   431  			} else if err = wsc.writeMessageWithDeadline(websocket.TextMessage, jsonBytes); err != nil {
   432  				wsc.Logger.Error("Failed to write response", "msg", msg, "err", err)
   433  				return
   434  			}
   435  		}
   436  	}
   437  }
   438  
   439  // All writes to the websocket must (re)set the write deadline.
   440  // If some writes don't set it while others do, they may timeout incorrectly
   441  // (https://github.com/tendermint/tendermint/issues/553)
   442  func (wsc *wsConnection) writeMessageWithDeadline(msgType int, msg []byte) error {
   443  	if err := wsc.baseConn.SetWriteDeadline(time.Now().Add(wsc.writeWait)); err != nil {
   444  		return err
   445  	}
   446  	return wsc.baseConn.WriteMessage(msgType, msg)
   447  }