github.com/fibonacci-chain/fbc@v0.0.0-20231124064014-c7636198c1e9/app/rpc/websockets/server.go (about)

     1  package websockets
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"math/big"
     9  	"net/http"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/ethereum/go-ethereum/rpc"
    14  	"github.com/fibonacci-chain/fbc/libs/cosmos-sdk/client/context"
    15  	"github.com/fibonacci-chain/fbc/libs/cosmos-sdk/server"
    16  	"github.com/fibonacci-chain/fbc/libs/tendermint/libs/log"
    17  	"github.com/fibonacci-chain/fbc/x/common/monitor"
    18  	"github.com/go-kit/kit/metrics"
    19  	"github.com/go-kit/kit/metrics/prometheus"
    20  	"github.com/gorilla/mux"
    21  	"github.com/gorilla/websocket"
    22  	stdprometheus "github.com/prometheus/client_golang/prometheus"
    23  	"github.com/spf13/viper"
    24  )
    25  
    26  const FlagSubscribeLimit = "ws.max-subscriptions"
    27  
    28  // Server defines a server that handles Ethereum websockets.
    29  type Server struct {
    30  	rpcAddr string // listen address of rest-server
    31  	wsAddr  string // listen address of ws server
    32  	api     *PubSubAPI
    33  	logger  log.Logger
    34  
    35  	connPool       chan struct{}
    36  	connPoolLock   *sync.Mutex
    37  	currentConnNum metrics.Gauge
    38  	maxConnNum     metrics.Gauge
    39  	maxSubLimit    int
    40  }
    41  
    42  // NewServer creates a new websocket server instance.
    43  func NewServer(clientCtx context.CLIContext, log log.Logger, wsAddr string) *Server {
    44  	restServerAddr := viper.GetString(server.FlagListenAddr)
    45  	parts := strings.SplitN(restServerAddr, "://", 2)
    46  	if len(parts) != 2 {
    47  		panic(fmt.Errorf("invalid listening address %s (use fully formed addresses, including the tcp:// or unix:// prefix)", restServerAddr))
    48  	}
    49  	url := parts[1]
    50  	urlParts := strings.SplitN(url, ":", 2)
    51  	if len(urlParts) != 2 {
    52  		panic(fmt.Errorf("invalid listening address %s (use ip:port as an url)", url))
    53  	}
    54  	port := urlParts[1]
    55  
    56  	return &Server{
    57  		rpcAddr:      "http://localhost:" + port,
    58  		wsAddr:       wsAddr,
    59  		api:          NewAPI(clientCtx, log),
    60  		logger:       log.With("module", "websocket-server"),
    61  		connPool:     make(chan struct{}, viper.GetInt(server.FlagWsMaxConnections)),
    62  		connPoolLock: new(sync.Mutex),
    63  		currentConnNum: prometheus.NewGaugeFrom(stdprometheus.GaugeOpts{
    64  			Namespace: monitor.XNameSpace,
    65  			Subsystem: "websocket",
    66  			Name:      "connection_number",
    67  			Help:      "the number of current websocket client connections",
    68  		}, nil),
    69  		maxConnNum: prometheus.NewGaugeFrom(stdprometheus.GaugeOpts{
    70  			Namespace: monitor.XNameSpace,
    71  			Subsystem: "websocket",
    72  			Name:      "connection_capacity",
    73  			Help:      "the capacity number of websocket client connections",
    74  		}, nil),
    75  		maxSubLimit: viper.GetInt(FlagSubscribeLimit),
    76  	}
    77  }
    78  
    79  // Start runs the websocket server
    80  func (s *Server) Start() {
    81  	ws := mux.NewRouter()
    82  	ws.Handle("/", s)
    83  	s.maxConnNum.Set(float64(viper.GetInt(server.FlagWsMaxConnections)))
    84  	s.currentConnNum.Set(0)
    85  
    86  	go func() {
    87  		err := http.ListenAndServe(fmt.Sprintf(":%s", s.wsAddr), ws)
    88  		if err != nil {
    89  			s.logger.Error("http error:", err)
    90  		}
    91  	}()
    92  }
    93  
    94  func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    95  	s.connPoolLock.Lock()
    96  	defer s.connPoolLock.Unlock()
    97  	if len(s.connPool) >= cap(s.connPool) {
    98  		w.WriteHeader(http.StatusServiceUnavailable)
    99  		return
   100  	}
   101  
   102  	var upgrader = websocket.Upgrader{
   103  		CheckOrigin: func(r *http.Request) bool {
   104  			return true
   105  		},
   106  	}
   107  
   108  	conn, err := upgrader.Upgrade(w, r, nil)
   109  	if err != nil {
   110  		s.logger.Error("websocket upgrade failed", " error", err)
   111  		return
   112  	}
   113  
   114  	s.connPool <- struct{}{}
   115  	s.currentConnNum.Set(float64(len(s.connPool)))
   116  	go s.readLoop(&wsConn{
   117  		mux:  new(sync.Mutex),
   118  		conn: conn,
   119  	})
   120  }
   121  
   122  func (s *Server) sendErrResponse(conn *wsConn, msg string) {
   123  	res := makeErrResponse(msg)
   124  	err := conn.WriteJSON(res)
   125  	if err != nil {
   126  		s.logger.Error("websocket failed write message", "error", err)
   127  	}
   128  }
   129  
   130  func makeErrResponse(errMsg string) *ErrorResponseJSON {
   131  	return &ErrorResponseJSON{
   132  		Jsonrpc: "2.0",
   133  		Error: &ErrorMessageJSON{
   134  			Code:    big.NewInt(-32600),
   135  			Message: errMsg,
   136  		},
   137  		ID: big.NewInt(1),
   138  	}
   139  }
   140  
   141  type wsConn struct {
   142  	conn     *websocket.Conn
   143  	mux      *sync.Mutex
   144  	subCount int
   145  }
   146  
   147  func (w *wsConn) GetSubCount() int {
   148  	return w.subCount
   149  }
   150  
   151  func (w *wsConn) AddSubCount(delta int) {
   152  	w.subCount += delta
   153  }
   154  
   155  func (w *wsConn) WriteJSON(v interface{}) error {
   156  	w.mux.Lock()
   157  	defer w.mux.Unlock()
   158  
   159  	return w.conn.WriteJSON(v)
   160  }
   161  
   162  func (w *wsConn) Close() error {
   163  	w.mux.Lock()
   164  	defer w.mux.Unlock()
   165  
   166  	return w.conn.Close()
   167  }
   168  
   169  func (w *wsConn) ReadMessage() (messageType int, p []byte, err error) {
   170  	// not protected by write mutex
   171  
   172  	return w.conn.ReadMessage()
   173  }
   174  
   175  func (s *Server) readLoop(wsConn *wsConn) {
   176  	subIds := make(map[rpc.ID]struct{})
   177  	for {
   178  		_, mb, err := wsConn.ReadMessage()
   179  		if err != nil {
   180  			_ = wsConn.Close()
   181  			s.logger.Error("failed to read message, close the websocket connection.", "error", err)
   182  			s.closeWsConnection(subIds)
   183  			return
   184  		}
   185  
   186  		var msg map[string]interface{}
   187  		if err = json.Unmarshal(mb, &msg); err != nil {
   188  			if err = s.batchCall(mb, wsConn); err != nil {
   189  				s.sendErrResponse(wsConn, "invalid request")
   190  			}
   191  			continue
   192  		}
   193  
   194  		// check if method == eth_subscribe or eth_unsubscribe
   195  		method := msg["method"]
   196  		methodStr, ok := method.(string)
   197  		if !ok {
   198  			s.sendErrResponse(wsConn, "invalid request")
   199  		}
   200  		if methodStr == "eth_subscribe" {
   201  			if wsConn.GetSubCount() >= s.maxSubLimit {
   202  				s.sendErrResponse(wsConn,
   203  					fmt.Sprintf("subscription has reached the upper limit(%d)", s.maxSubLimit))
   204  				continue
   205  			}
   206  			params, ok := msg["params"].([]interface{})
   207  			if !ok || len(params) == 0 {
   208  				s.sendErrResponse(wsConn, "invalid parameters")
   209  				continue
   210  			}
   211  
   212  			reqId, ok := msg["id"].(float64)
   213  			if !ok {
   214  				s.sendErrResponse(wsConn, "invaild id in request message")
   215  				continue
   216  			}
   217  
   218  			id, err := s.api.subscribe(wsConn, params)
   219  			if err != nil {
   220  				s.sendErrResponse(wsConn, err.Error())
   221  				continue
   222  			}
   223  
   224  			res := &SubscriptionResponseJSON{
   225  				Jsonrpc: "2.0",
   226  				ID:      reqId,
   227  				Result:  id,
   228  			}
   229  
   230  			err = wsConn.WriteJSON(res)
   231  			if err != nil {
   232  				s.logger.Error("failed to write json response", "ID", id, "error", err)
   233  				continue
   234  			}
   235  			s.logger.Debug("successfully subscribe", "ID", id)
   236  			subIds[id] = struct{}{}
   237  			wsConn.AddSubCount(1)
   238  			continue
   239  		} else if methodStr == "eth_unsubscribe" {
   240  			ids, ok := msg["params"].([]interface{})
   241  			if len(ids) == 0 {
   242  				s.sendErrResponse(wsConn, "invalid parameters")
   243  				continue
   244  			}
   245  			id, idok := ids[0].(string)
   246  			if !ok || !idok {
   247  				s.sendErrResponse(wsConn, "invalid parameters")
   248  				continue
   249  			}
   250  
   251  			reqId, ok := msg["id"].(float64)
   252  			if !ok {
   253  				s.sendErrResponse(wsConn, "invaild id in request message")
   254  				continue
   255  			}
   256  
   257  			ok = s.api.unsubscribe(rpc.ID(id))
   258  			res := &SubscriptionResponseJSON{
   259  				Jsonrpc: "2.0",
   260  				ID:      reqId,
   261  				Result:  ok,
   262  			}
   263  
   264  			err = wsConn.WriteJSON(res)
   265  			if err != nil {
   266  				s.logger.Error("failed to write json response", "ID", id, "error", err)
   267  				continue
   268  			}
   269  			s.logger.Debug("successfully unsubscribe", "ID", id)
   270  			delete(subIds, rpc.ID(id))
   271  			wsConn.AddSubCount(-1)
   272  			continue
   273  		}
   274  
   275  		// otherwise, call the usual rpc server to respond
   276  		data, err := s.getRpcResponse(mb)
   277  		if err != nil {
   278  			s.sendErrResponse(wsConn, err.Error())
   279  		} else {
   280  			wsConn.WriteJSON(data)
   281  		}
   282  	}
   283  }
   284  
   285  // getRpcResponse connects to the rest-server over tcp, posts a JSON-RPC request, and return response
   286  func (s *Server) getRpcResponse(mb []byte) (interface{}, error) {
   287  	req, err := http.NewRequest(http.MethodPost, s.rpcAddr, bytes.NewReader(mb))
   288  	if err != nil {
   289  		return nil, fmt.Errorf("failed to request; %s", err)
   290  	}
   291  	req.Header.Set("Content-Type", "application/json")
   292  	resp, err := http.DefaultClient.Do(req)
   293  	if err != nil {
   294  		return nil, fmt.Errorf("failed to write to rest-server; %s", err)
   295  	}
   296  
   297  	defer resp.Body.Close()
   298  	body, err := ioutil.ReadAll(resp.Body)
   299  	if err != nil {
   300  		return nil, fmt.Errorf("could not read body from response; %s", err)
   301  	}
   302  
   303  	var wsSend interface{}
   304  	err = json.Unmarshal(body, &wsSend)
   305  	if err != nil {
   306  		return nil, fmt.Errorf("failed to unmarshal rest-server response; %s", err)
   307  	}
   308  	return wsSend, nil
   309  }
   310  
   311  func (s *Server) closeWsConnection(subIds map[rpc.ID]struct{}) {
   312  	for id := range subIds {
   313  		s.api.unsubscribe(id)
   314  		delete(subIds, id)
   315  	}
   316  	s.connPoolLock.Lock()
   317  	defer s.connPoolLock.Unlock()
   318  	<-s.connPool
   319  	s.currentConnNum.Set(float64(len(s.connPool)))
   320  }
   321  
   322  func (s *Server) batchCall(mb []byte, wsConn *wsConn) error {
   323  	var msgs []interface{}
   324  	if err := json.Unmarshal(mb, &msgs); err != nil {
   325  		return err
   326  	}
   327  
   328  	for i := 0; i < len(msgs); i++ {
   329  		b, err := json.Marshal(msgs[i])
   330  		if err != nil {
   331  			s.sendErrResponse(wsConn, "invalid request")
   332  			s.logger.Error("web socket batchCall  failed", "error", err)
   333  			break
   334  		}
   335  
   336  		data, err := s.getRpcResponse(b)
   337  		if err != nil {
   338  			data = makeErrResponse(err.Error())
   339  		}
   340  		if err := wsConn.WriteJSON(data); err != nil {
   341  			break // connection broken
   342  		}
   343  	}
   344  	return nil
   345  }