github.com/0xPolygon/supernets2-node@v0.0.0-20230711153321-2fe574524eaa/jsonrpc/server.go (about)

     1  package jsonrpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"time"
    13  
    14  	"github.com/0xPolygon/supernets2-node/jsonrpc/metrics"
    15  	"github.com/0xPolygon/supernets2-node/jsonrpc/types"
    16  	"github.com/0xPolygon/supernets2-node/log"
    17  	"github.com/didip/tollbooth/v6"
    18  	"github.com/gorilla/websocket"
    19  )
    20  
    21  const (
    22  	// APIEth represents the eth API prefix.
    23  	APIEth = "eth"
    24  	// APINet represents the net API prefix.
    25  	APINet = "net"
    26  	// APIDebug represents the debug API prefix.
    27  	APIDebug = "debug"
    28  	// APIZKEVM represents the zkevm API prefix.
    29  	APIZKEVM = "zkevm"
    30  	// APITxPool represents the txpool API prefix.
    31  	APITxPool = "txpool"
    32  	// APIWeb3 represents the web3 API prefix.
    33  	APIWeb3 = "web3"
    34  
    35  	wsBufferSizeLimitInBytes = 1024
    36  )
    37  
    38  // Server is an API backend to handle RPC requests
    39  type Server struct {
    40  	config     Config
    41  	chainID    uint64
    42  	handler    *Handler
    43  	srv        *http.Server
    44  	wsSrv      *http.Server
    45  	wsUpgrader websocket.Upgrader
    46  }
    47  
    48  // Service implementation of a service an it's name
    49  type Service struct {
    50  	Name    string
    51  	Service interface{}
    52  }
    53  
    54  // NewServer returns the JsonRPC server
    55  func NewServer(
    56  	cfg Config,
    57  	chainID uint64,
    58  	p types.PoolInterface,
    59  	s types.StateInterface,
    60  	storage storageInterface,
    61  	services []Service,
    62  ) *Server {
    63  	s.PrepareWebSocket()
    64  	handler := newJSONRpcHandler()
    65  
    66  	for _, service := range services {
    67  		handler.registerService(service)
    68  	}
    69  
    70  	srv := &Server{
    71  		config:  cfg,
    72  		handler: handler,
    73  		chainID: chainID,
    74  	}
    75  	return srv
    76  }
    77  
    78  // Start initializes the JSON RPC server to listen for request
    79  func (s *Server) Start() error {
    80  	metrics.Register()
    81  
    82  	if s.config.WebSockets.Enabled {
    83  		go s.startWS()
    84  	}
    85  
    86  	return s.startHTTP()
    87  }
    88  
    89  // startHTTP starts a server to respond http requests
    90  func (s *Server) startHTTP() error {
    91  	if s.srv != nil {
    92  		return fmt.Errorf("server already started")
    93  	}
    94  
    95  	address := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
    96  
    97  	lis, err := net.Listen("tcp", address)
    98  	if err != nil {
    99  		log.Errorf("failed to create tcp listener: %v", err)
   100  		return err
   101  	}
   102  
   103  	mux := http.NewServeMux()
   104  
   105  	lmt := tollbooth.NewLimiter(s.config.MaxRequestsPerIPAndSecond, nil)
   106  	mux.Handle("/", tollbooth.LimitFuncHandler(lmt, s.handle))
   107  
   108  	s.srv = &http.Server{
   109  		Handler:           mux,
   110  		ReadHeaderTimeout: s.config.ReadTimeout.Duration,
   111  		ReadTimeout:       s.config.ReadTimeout.Duration,
   112  		WriteTimeout:      s.config.WriteTimeout.Duration,
   113  	}
   114  	log.Infof("http server started: %s", address)
   115  	if err := s.srv.Serve(lis); err != nil {
   116  		if err == http.ErrServerClosed {
   117  			log.Infof("http server stopped")
   118  			return nil
   119  		}
   120  		log.Errorf("closed http connection: %v", err)
   121  		return err
   122  	}
   123  	return nil
   124  }
   125  
   126  // startWS starts a server to respond WebSockets connections
   127  func (s *Server) startWS() {
   128  	log.Infof("starting websocket server")
   129  
   130  	if s.wsSrv != nil {
   131  		log.Errorf("websocket server already started")
   132  		return
   133  	}
   134  
   135  	address := fmt.Sprintf("%s:%d", s.config.WebSockets.Host, s.config.WebSockets.Port)
   136  
   137  	lis, err := net.Listen("tcp", address)
   138  	if err != nil {
   139  		log.Errorf("failed to create tcp listener: %v", err)
   140  		return
   141  	}
   142  
   143  	mux := http.NewServeMux()
   144  	mux.HandleFunc("/", s.handleWs)
   145  
   146  	s.wsSrv = &http.Server{
   147  		Handler:           mux,
   148  		ReadHeaderTimeout: s.config.ReadTimeout.Duration,
   149  		ReadTimeout:       s.config.ReadTimeout.Duration,
   150  		WriteTimeout:      s.config.WriteTimeout.Duration,
   151  	}
   152  	s.wsUpgrader = websocket.Upgrader{
   153  		ReadBufferSize:  wsBufferSizeLimitInBytes,
   154  		WriteBufferSize: wsBufferSizeLimitInBytes,
   155  	}
   156  	log.Infof("websocket server started: %s", address)
   157  	if err := s.wsSrv.Serve(lis); err != nil {
   158  		if err == http.ErrServerClosed {
   159  			log.Infof("websocket server stopped")
   160  			return
   161  		}
   162  		log.Errorf("closed websocket connection: %v", err)
   163  		return
   164  	}
   165  }
   166  
   167  // Stop shutdown the rpc server
   168  func (s *Server) Stop() error {
   169  	if s.srv != nil {
   170  		if err := s.srv.Shutdown(context.Background()); err != nil {
   171  			return err
   172  		}
   173  
   174  		if err := s.srv.Close(); err != nil {
   175  			return err
   176  		}
   177  		s.srv = nil
   178  	}
   179  
   180  	if s.wsSrv != nil {
   181  		if err := s.wsSrv.Shutdown(context.Background()); err != nil {
   182  			return err
   183  		}
   184  
   185  		if err := s.wsSrv.Close(); err != nil {
   186  			return err
   187  		}
   188  		s.wsSrv = nil
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  func (s *Server) handle(w http.ResponseWriter, req *http.Request) {
   195  	w.Header().Set("Content-Type", "application/json")
   196  	w.Header().Set("Access-Control-Allow-Origin", "*")
   197  	w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
   198  	w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
   199  
   200  	if (*req).Method == "OPTIONS" {
   201  		// TODO(pg): need to count it in the metrics?
   202  		return
   203  	}
   204  
   205  	if req.Method == "GET" {
   206  		// TODO(pg): need to count it in the metrics?
   207  		_, err := w.Write([]byte("zkEVM JSON RPC Server"))
   208  		if err != nil {
   209  			log.Error(err)
   210  		}
   211  		return
   212  	}
   213  
   214  	if req.Method != "POST" {
   215  		err := errors.New("method " + req.Method + " not allowed")
   216  		s.handleInvalidRequest(w, err)
   217  		return
   218  	}
   219  
   220  	data, err := io.ReadAll(req.Body)
   221  	if err != nil {
   222  		s.handleInvalidRequest(w, err)
   223  		return
   224  	}
   225  
   226  	single, err := s.isSingleRequest(data)
   227  	if err != nil {
   228  		s.handleInvalidRequest(w, err)
   229  		return
   230  	}
   231  
   232  	start := time.Now()
   233  	var respLen int
   234  	if single {
   235  		respLen = s.handleSingleRequest(req, w, data)
   236  	} else {
   237  		respLen = s.handleBatchRequest(req, w, data)
   238  	}
   239  	metrics.RequestDuration(start)
   240  	combinedLog(req, start, http.StatusOK, respLen)
   241  }
   242  
   243  func (s *Server) isSingleRequest(data []byte) (bool, types.Error) {
   244  	x := bytes.TrimLeft(data, " \t\r\n")
   245  
   246  	if len(x) == 0 {
   247  		return false, types.NewRPCError(types.InvalidRequestErrorCode, "Invalid json request")
   248  	}
   249  
   250  	return x[0] == '{', nil
   251  }
   252  
   253  func (s *Server) handleSingleRequest(httpRequest *http.Request, w http.ResponseWriter, data []byte) int {
   254  	defer metrics.RequestHandled(metrics.RequestHandledLabelSingle)
   255  	request, err := s.parseRequest(data)
   256  	if err != nil {
   257  		handleError(w, err)
   258  		return 0
   259  	}
   260  	req := handleRequest{Request: request, HttpRequest: httpRequest}
   261  	response := s.handler.Handle(req)
   262  
   263  	respBytes, err := json.Marshal(response)
   264  	if err != nil {
   265  		handleError(w, err)
   266  		return 0
   267  	}
   268  
   269  	_, err = w.Write(respBytes)
   270  	if err != nil {
   271  		handleError(w, err)
   272  		return 0
   273  	}
   274  	return len(respBytes)
   275  }
   276  
   277  func (s *Server) handleBatchRequest(httpRequest *http.Request, w http.ResponseWriter, data []byte) int {
   278  	defer metrics.RequestHandled(metrics.RequestHandledLabelBatch)
   279  	requests, err := s.parseRequests(data)
   280  	if err != nil {
   281  		handleError(w, err)
   282  		return 0
   283  	}
   284  
   285  	responses := make([]types.Response, 0, len(requests))
   286  
   287  	for _, request := range requests {
   288  		req := handleRequest{Request: request, HttpRequest: httpRequest}
   289  		response := s.handler.Handle(req)
   290  		responses = append(responses, response)
   291  	}
   292  
   293  	respBytes, _ := json.Marshal(responses)
   294  	_, err = w.Write(respBytes)
   295  	if err != nil {
   296  		log.Error(err)
   297  		return 0
   298  	}
   299  	return len(respBytes)
   300  }
   301  
   302  func (s *Server) parseRequest(data []byte) (types.Request, error) {
   303  	var req types.Request
   304  
   305  	if err := json.Unmarshal(data, &req); err != nil {
   306  		return types.Request{}, types.NewRPCError(types.InvalidRequestErrorCode, "Invalid json request")
   307  	}
   308  
   309  	return req, nil
   310  }
   311  
   312  func (s *Server) parseRequests(data []byte) ([]types.Request, error) {
   313  	var requests []types.Request
   314  
   315  	if err := json.Unmarshal(data, &requests); err != nil {
   316  		return nil, types.NewRPCError(types.InvalidRequestErrorCode, "Invalid json request")
   317  	}
   318  
   319  	return requests, nil
   320  }
   321  
   322  func (s *Server) handleInvalidRequest(w http.ResponseWriter, err error) {
   323  	defer metrics.RequestHandled(metrics.RequestHandledLabelInvalid)
   324  	handleError(w, err)
   325  }
   326  
   327  func (s *Server) handleWs(w http.ResponseWriter, req *http.Request) {
   328  	// CORS rule - Allow requests from anywhere
   329  	s.wsUpgrader.CheckOrigin = func(r *http.Request) bool { return true }
   330  
   331  	// Upgrade the connection to a WS one
   332  	wsConn, err := s.wsUpgrader.Upgrade(w, req, nil)
   333  	if err != nil {
   334  		log.Error(fmt.Sprintf("Unable to upgrade to a WS connection, %s", err.Error()))
   335  
   336  		return
   337  	}
   338  
   339  	// Defer WS closure
   340  	defer func(ws *websocket.Conn) {
   341  		err = ws.Close()
   342  		if err != nil {
   343  			log.Error(fmt.Sprintf("Unable to gracefully close WS connection, %s", err.Error()))
   344  		}
   345  	}(wsConn)
   346  
   347  	log.Info("Websocket connection established")
   348  	for {
   349  		msgType, message, err := wsConn.ReadMessage()
   350  		if err != nil {
   351  			if websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) {
   352  				log.Info("Closing WS connection gracefully")
   353  			} else {
   354  				log.Error(fmt.Sprintf("Unable to read WS message, %s", err.Error()))
   355  				log.Info("Closing WS connection with error")
   356  			}
   357  
   358  			s.handler.RemoveFilterByWsConn(wsConn)
   359  
   360  			break
   361  		}
   362  
   363  		if msgType == websocket.TextMessage || msgType == websocket.BinaryMessage {
   364  			go func() {
   365  				resp, err := s.handler.HandleWs(message, wsConn)
   366  				if err != nil {
   367  					log.Error(fmt.Sprintf("Unable to handle WS request, %s", err.Error()))
   368  					_ = wsConn.WriteMessage(msgType, []byte(fmt.Sprintf("WS Handle error: %s", err.Error())))
   369  				} else {
   370  					_ = wsConn.WriteMessage(msgType, resp)
   371  				}
   372  			}()
   373  		}
   374  	}
   375  }
   376  
   377  func handleError(w http.ResponseWriter, err error) {
   378  	log.Error(err)
   379  	_, err = w.Write([]byte(err.Error()))
   380  	if err != nil {
   381  		log.Error(err)
   382  	}
   383  }
   384  
   385  // RPCErrorResponse formats error to be returned through RPC
   386  func RPCErrorResponse(code int, message string, err error) (interface{}, types.Error) {
   387  	return RPCErrorResponseWithData(code, message, nil, err)
   388  }
   389  
   390  // RPCErrorResponseWithData formats error to be returned through RPC
   391  func RPCErrorResponseWithData(code int, message string, data *[]byte, err error) (interface{}, types.Error) {
   392  	if err != nil {
   393  		log.Errorf("%v:%v", message, err.Error())
   394  	} else {
   395  		log.Error(message)
   396  	}
   397  	return nil, types.NewRPCErrorWithData(code, message, data)
   398  }
   399  
   400  func combinedLog(r *http.Request, start time.Time, httpStatus, dataLen int) {
   401  	log.Infof("%s - - %s \"%s %s %s\" %d %d \"%s\" \"%s\"",
   402  		r.RemoteAddr,
   403  		start.Format("[02/Jan/2006:15:04:05 -0700]"),
   404  		r.Method,
   405  		r.URL.Path,
   406  		r.Proto,
   407  		httpStatus,
   408  		dataLen,
   409  		r.Host,
   410  		r.UserAgent(),
   411  	)
   412  }