github.com/amazechain/amc@v0.1.3/internal/node/rpcstack.go (about)

     1  // Copyright 2022 The AmazeChain Authors
     2  // This file is part of the AmazeChain library.
     3  //
     4  // The AmazeChain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The AmazeChain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the AmazeChain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package node
    18  
    19  import (
    20  	"compress/gzip"
    21  	"context"
    22  	"fmt"
    23  	"github.com/amazechain/amc/log"
    24  	"github.com/rs/cors"
    25  	"io"
    26  	"io/ioutil"
    27  	"net"
    28  	"net/http"
    29  	"sort"
    30  	"strings"
    31  	"sync"
    32  	"sync/atomic"
    33  	"time"
    34  
    35  	"github.com/amazechain/amc/conf"
    36  	"github.com/amazechain/amc/modules/rpc/jsonrpc"
    37  )
    38  
    39  type httpConfig struct {
    40  	Modules            []string
    41  	CorsAllowedOrigins []string
    42  	Vhosts             []string
    43  	prefix             string
    44  	jwtSecret          []byte // optional JWT secret
    45  }
    46  
    47  // wsConfig is the JSON-RPC/Websocket configuration
    48  type wsConfig struct {
    49  	Origins   []string
    50  	Modules   []string
    51  	prefix    string // path prefix on which to mount ws handler
    52  	jwtSecret []byte // optional JWT secret
    53  }
    54  
    55  type rpcHandler struct {
    56  	http.Handler
    57  	server *jsonrpc.Server
    58  }
    59  
    60  type httpServer struct {
    61  	mux      http.ServeMux
    62  	mu       sync.Mutex
    63  	server   *http.Server
    64  	listener net.Listener
    65  
    66  	httpConfig  httpConfig
    67  	httpHandler atomic.Value
    68  
    69  	// WebSocket handler things.
    70  	wsConfig  wsConfig
    71  	wsHandler atomic.Value // *rpcHandler
    72  
    73  	endpoint string
    74  	host     string
    75  	port     int
    76  
    77  	handlerNames map[string]string
    78  }
    79  
    80  func newHTTPServer() *httpServer {
    81  	h := &httpServer{handlerNames: make(map[string]string)}
    82  
    83  	h.httpHandler.Store((*rpcHandler)(nil))
    84  	h.wsHandler.Store((*rpcHandler)(nil))
    85  	return h
    86  }
    87  
    88  func (h *httpServer) setListenAddr(host string, port int) error {
    89  	h.mu.Lock()
    90  	defer h.mu.Unlock()
    91  
    92  	if h.listener != nil && (host != h.host || port != h.port) {
    93  		return fmt.Errorf("HTTP server already running on %s", h.endpoint)
    94  	}
    95  
    96  	h.host, h.port = host, port
    97  	h.endpoint = fmt.Sprintf("%s:%d", host, port)
    98  	return nil
    99  }
   100  
   101  func (h *httpServer) listenAddr() string {
   102  	h.mu.Lock()
   103  	defer h.mu.Unlock()
   104  
   105  	if h.listener != nil {
   106  		return h.listener.Addr().String()
   107  	}
   108  	return h.endpoint
   109  }
   110  
   111  func (h *httpServer) start() error {
   112  	h.mu.Lock()
   113  	defer h.mu.Unlock()
   114  
   115  	if h.endpoint == "" || h.listener != nil {
   116  		return nil // already running or not configured
   117  	}
   118  
   119  	h.server = &http.Server{Handler: h}
   120  
   121  	//todo
   122  	h.server.ReadTimeout = time.Duration(60 * time.Second)
   123  	h.server.WriteTimeout = time.Duration(60 * time.Second)
   124  	h.server.IdleTimeout = time.Duration(60 * time.Second)
   125  
   126  	listener, err := net.Listen("tcp", h.endpoint)
   127  	if err != nil {
   128  		h.disableRPC()
   129  		h.disableWS()
   130  		return err
   131  	}
   132  	h.listener = listener
   133  	go h.server.Serve(listener)
   134  
   135  	if h.wsAllowed() {
   136  		url := fmt.Sprintf("ws://%v", listener.Addr())
   137  		if h.wsConfig.prefix != "" {
   138  			url += h.wsConfig.prefix
   139  		}
   140  		log.Info("WebSocket enabled", "url", url)
   141  	}
   142  
   143  	if !h.rpcAllowed() {
   144  		return nil
   145  	}
   146  	log.Info("HTTP server started",
   147  		"endpoint", listener.Addr(),
   148  		"prefix", h.httpConfig.prefix,
   149  		"cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","),
   150  		"vhosts", strings.Join(h.httpConfig.Vhosts, ","),
   151  	)
   152  
   153  	var paths []string
   154  	for path := range h.handlerNames {
   155  		paths = append(paths, path)
   156  	}
   157  	sort.Strings(paths)
   158  	logged := make(map[string]bool, len(paths))
   159  	for _, path := range paths {
   160  		name := h.handlerNames[path]
   161  		if !logged[name] {
   162  			log.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path)
   163  			logged[name] = true
   164  		}
   165  	}
   166  	return nil
   167  }
   168  
   169  func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   170  	// check if ws request and serve if ws enabled
   171  	ws := h.wsHandler.Load().(*rpcHandler)
   172  	if ws != nil && isWebsocket(r) {
   173  		if checkPath(r, h.wsConfig.prefix) {
   174  			ws.ServeHTTP(w, r)
   175  		}
   176  		return
   177  	}
   178  	// if http-rpc is enabled, try to serve request
   179  
   180  	rpc := h.httpHandler.Load().(*rpcHandler)
   181  	if rpc != nil {
   182  		muxHandler, pattern := h.mux.Handler(r)
   183  		if pattern != "" {
   184  			muxHandler.ServeHTTP(w, r)
   185  			return
   186  		}
   187  
   188  		if checkPath(r, h.httpConfig.prefix) {
   189  			rpc.ServeHTTP(w, r)
   190  			return
   191  		}
   192  	}
   193  	w.WriteHeader(http.StatusNotFound)
   194  }
   195  
   196  // enableWS turns on JSON-RPC over WebSocket on the server.
   197  func (h *httpServer) enableWS(apis []jsonrpc.API, config wsConfig) error {
   198  	h.mu.Lock()
   199  	defer h.mu.Unlock()
   200  
   201  	if h.wsAllowed() {
   202  		return fmt.Errorf("JSON-RPC over WebSocket is already enabled")
   203  	}
   204  	// Create RPC server and handler.
   205  	srv := jsonrpc.NewServer()
   206  	if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil {
   207  		return err
   208  	}
   209  	h.wsConfig = config
   210  	h.wsHandler.Store(&rpcHandler{
   211  		Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret),
   212  		server:  srv,
   213  	})
   214  	return nil
   215  }
   216  
   217  // stopWS disables JSON-RPC over WebSocket and also stops the server if it only serves WebSocket.
   218  func (h *httpServer) stopWS() {
   219  	h.mu.Lock()
   220  	defer h.mu.Unlock()
   221  
   222  	if h.disableWS() {
   223  		if !h.rpcAllowed() {
   224  			h.doStop()
   225  		}
   226  	}
   227  }
   228  
   229  // disableWS disables the WebSocket handler. This is internal, the caller must hold h.mu.
   230  func (h *httpServer) disableWS() bool {
   231  	ws := h.wsHandler.Load().(*rpcHandler)
   232  	if ws != nil {
   233  		h.wsHandler.Store((*rpcHandler)(nil))
   234  		ws.server.Stop()
   235  	}
   236  	return ws != nil
   237  }
   238  
   239  // isWebsocket checks the header of an http request for a websocket upgrade request.
   240  func isWebsocket(r *http.Request) bool {
   241  	return strings.EqualFold(r.Header.Get("Upgrade"), "websocket") &&
   242  		strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade")
   243  }
   244  
   245  func checkPath(r *http.Request, path string) bool {
   246  	if path == "" {
   247  		return r.URL.Path == "/"
   248  	}
   249  	return len(r.URL.Path) >= len(path) && r.URL.Path[:len(path)] == path
   250  }
   251  
   252  func validatePrefix(what, path string) error {
   253  	if path == "" {
   254  		return nil
   255  	}
   256  	if path[0] != '/' {
   257  		return fmt.Errorf(`%s RPC path prefix %q does not contain leading "/"`, what, path)
   258  	}
   259  	if strings.ContainsAny(path, "?#") {
   260  		// This is just to avoid confusion. While these would match correctly (i.e. they'd
   261  		// match if URL-escaped into path), it's not easy to understand for users when
   262  		// setting that on the command line.
   263  		return fmt.Errorf("%s RPC path prefix %q contains URL metadata-characters", what, path)
   264  	}
   265  	return nil
   266  }
   267  
   268  func (h *httpServer) stop() {
   269  	h.mu.Lock()
   270  	defer h.mu.Unlock()
   271  	h.doStop()
   272  }
   273  
   274  func (h *httpServer) doStop() {
   275  	if h.listener == nil {
   276  		return // not running
   277  	}
   278  
   279  	httpHandler := h.httpHandler.Load().(*rpcHandler)
   280  	if httpHandler != nil {
   281  		h.httpHandler.Store((*rpcHandler)(nil))
   282  		httpHandler.server.Stop()
   283  	}
   284  	h.server.Shutdown(context.Background())
   285  	h.listener.Close()
   286  	log.Info("HTTP server stopped", "endpoint", h.listener.Addr())
   287  
   288  	h.host, h.port, h.endpoint = "", 0, ""
   289  	h.server, h.listener = nil, nil
   290  }
   291  
   292  func (h *httpServer) enableRPC(apis []jsonrpc.API, config httpConfig) error {
   293  	h.mu.Lock()
   294  	defer h.mu.Unlock()
   295  
   296  	if h.rpcAllowed() {
   297  		return fmt.Errorf("JSON-RPC over HTTP is already enabled")
   298  	}
   299  
   300  	srv := jsonrpc.NewServer()
   301  	if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil {
   302  		return err
   303  	}
   304  	h.httpConfig = config
   305  	h.httpHandler.Store(&rpcHandler{
   306  		Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts, config.jwtSecret),
   307  		server:  srv,
   308  	})
   309  	return nil
   310  }
   311  
   312  func (h *httpServer) disableRPC() bool {
   313  	handler := h.httpHandler.Load().(*rpcHandler)
   314  	if handler != nil {
   315  		h.httpHandler.Store((*rpcHandler)(nil))
   316  		handler.server.Stop()
   317  	}
   318  	return handler != nil
   319  }
   320  
   321  func (h *httpServer) rpcAllowed() bool {
   322  	return h.httpHandler.Load().(*rpcHandler) != nil
   323  }
   324  
   325  // wsAllowed returns true when JSON-RPC over WebSocket is enabled.
   326  func (h *httpServer) wsAllowed() bool {
   327  	return h.wsHandler.Load().(*rpcHandler) != nil
   328  }
   329  
   330  func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSecret []byte) http.Handler {
   331  	// Wrap the CORS-handler within a host-handler
   332  	handler := newCorsHandler(srv, cors)
   333  	handler = newVHostHandler(vhosts, handler)
   334  	if len(jwtSecret) != 0 {
   335  		handler = newJWTHandler(jwtSecret, handler)
   336  	}
   337  	return newGzipHandler(handler)
   338  }
   339  
   340  // NewWSHandlerStack returns a wrapped ws-related handler.
   341  func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler {
   342  	if len(jwtSecret) != 0 {
   343  		return newJWTHandler(jwtSecret, srv)
   344  	}
   345  	return srv
   346  }
   347  
   348  func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
   349  	// disable CORS support if user has not specified a custom CORS configuration
   350  	if len(allowedOrigins) == 0 {
   351  		return srv
   352  	}
   353  	c := cors.New(cors.Options{
   354  		AllowedOrigins: allowedOrigins,
   355  		AllowedMethods: []string{http.MethodPost, http.MethodGet},
   356  		AllowedHeaders: []string{"*"},
   357  		MaxAge:         600,
   358  	})
   359  	return c.Handler(srv)
   360  }
   361  
   362  type virtualHostHandler struct {
   363  	vhosts map[string]struct{}
   364  	next   http.Handler
   365  }
   366  
   367  func newVHostHandler(vhosts []string, next http.Handler) http.Handler {
   368  	vhostMap := make(map[string]struct{})
   369  	for _, allowedHost := range vhosts {
   370  		vhostMap[strings.ToLower(allowedHost)] = struct{}{}
   371  	}
   372  	return &virtualHostHandler{vhostMap, next}
   373  }
   374  
   375  func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   376  	// if r.Host is not set, we can continue serving since a browser would set the Host header
   377  	if r.Host == "" {
   378  		h.next.ServeHTTP(w, r)
   379  		return
   380  	}
   381  	host, _, err := net.SplitHostPort(r.Host)
   382  	if err != nil {
   383  		// Either invalid (too many colons) or no port specified
   384  		host = r.Host
   385  	}
   386  	if ipAddr := net.ParseIP(host); ipAddr != nil {
   387  		// It's an IP address, we can serve that
   388  		h.next.ServeHTTP(w, r)
   389  		return
   390  	}
   391  	// Not an IP address, but a hostname. Need to validate
   392  	if _, exist := h.vhosts["*"]; exist {
   393  		h.next.ServeHTTP(w, r)
   394  		return
   395  	}
   396  	if _, exist := h.vhosts[host]; exist {
   397  		h.next.ServeHTTP(w, r)
   398  		return
   399  	}
   400  	http.Error(w, "invalid host specified", http.StatusForbidden)
   401  }
   402  
   403  var gzPool = sync.Pool{
   404  	New: func() interface{} {
   405  		w := gzip.NewWriter(ioutil.Discard)
   406  		return w
   407  	},
   408  }
   409  
   410  type gzipResponseWriter struct {
   411  	io.Writer
   412  	http.ResponseWriter
   413  }
   414  
   415  func (w *gzipResponseWriter) WriteHeader(status int) {
   416  	w.Header().Del("Content-Length")
   417  	w.ResponseWriter.WriteHeader(status)
   418  }
   419  
   420  func (w *gzipResponseWriter) Write(b []byte) (int, error) {
   421  	return w.Writer.Write(b)
   422  }
   423  
   424  func newGzipHandler(next http.Handler) http.Handler {
   425  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   426  		if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
   427  			next.ServeHTTP(w, r)
   428  			return
   429  		}
   430  		w.Header().Set("Content-Encoding", "gzip")
   431  		gz := gzPool.Get().(*gzip.Writer)
   432  		defer gzPool.Put(gz)
   433  
   434  		gz.Reset(w)
   435  		defer gz.Close()
   436  
   437  		next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
   438  	})
   439  }
   440  
   441  type ipcServer struct {
   442  	endpoint string
   443  
   444  	mu       sync.Mutex
   445  	listener net.Listener
   446  	srv      *jsonrpc.Server
   447  }
   448  
   449  func newIPCServer(config *conf.NodeConfig) *ipcServer {
   450  	return &ipcServer{endpoint: fmt.Sprintf("%s/%s", config.DataDir, config.IPCPath)}
   451  }
   452  
   453  func (is *ipcServer) start(apis []jsonrpc.API) error {
   454  	is.mu.Lock()
   455  	defer is.mu.Unlock()
   456  
   457  	if is.listener != nil {
   458  		return nil // already running
   459  	}
   460  	listener, srv, err := jsonrpc.StartIPCEndpoint(is.endpoint, apis)
   461  	if err != nil {
   462  		log.Warn("IPC opening failed", "url", is.endpoint, "error", err)
   463  		return err
   464  	}
   465  	log.Info("IPC endpoint opened", "url", is.endpoint)
   466  	is.listener, is.srv = listener, srv
   467  	return nil
   468  }
   469  
   470  func (is *ipcServer) stop() error {
   471  	is.mu.Lock()
   472  	defer is.mu.Unlock()
   473  
   474  	if is.listener == nil {
   475  		return nil // not running
   476  	}
   477  	err := is.listener.Close()
   478  	is.srv.Stop()
   479  	is.listener, is.srv = nil, nil
   480  	log.Info("IPC endpoint closed", "url", is.endpoint)
   481  	return err
   482  }
   483  
   484  func RegisterApisFromWhitelist(apis []jsonrpc.API, modules []string, srv *jsonrpc.Server, exposeAll bool) error {
   485  	if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
   486  		log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available)
   487  	}
   488  	whitelist := make(map[string]bool)
   489  	for _, module := range modules {
   490  		whitelist[module] = true
   491  	}
   492  	for _, api := range apis {
   493  		if exposeAll || whitelist[api.Namespace] {
   494  			if err := srv.RegisterName(api.Namespace, api.Service); err != nil {
   495  				return err
   496  			}
   497  		}
   498  	}
   499  	return nil
   500  }