github.com/ava-labs/subnet-evm@v0.6.4/rpc/websocket.go (about)

     1  // (c) 2019-2020, Ava Labs, Inc.
     2  //
     3  // This file is a derived work, based on the go-ethereum library whose original
     4  // notices appear below.
     5  //
     6  // It is distributed under a license compatible with the licensing terms of the
     7  // original code from which it is derived.
     8  //
     9  // Much love to the original authors for their work.
    10  // **********
    11  // Copyright 2015 The go-ethereum Authors
    12  // This file is part of the go-ethereum library.
    13  //
    14  // The go-ethereum library is free software: you can redistribute it and/or modify
    15  // it under the terms of the GNU Lesser General Public License as published by
    16  // the Free Software Foundation, either version 3 of the License, or
    17  // (at your option) any later version.
    18  //
    19  // The go-ethereum library is distributed in the hope that it will be useful,
    20  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    21  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    22  // GNU Lesser General Public License for more details.
    23  //
    24  // You should have received a copy of the GNU Lesser General Public License
    25  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    26  
    27  package rpc
    28  
    29  import (
    30  	"context"
    31  	"encoding/base64"
    32  	"fmt"
    33  	"net/http"
    34  	"net/url"
    35  	"os"
    36  	"strings"
    37  	"sync"
    38  	"time"
    39  
    40  	mapset "github.com/deckarep/golang-set/v2"
    41  	"github.com/ethereum/go-ethereum/log"
    42  	"github.com/gorilla/websocket"
    43  )
    44  
    45  const (
    46  	wsReadBuffer       = 1024
    47  	wsWriteBuffer      = 1024
    48  	wsPingInterval     = 30 * time.Second
    49  	wsPingWriteTimeout = 5 * time.Second
    50  	wsPongTimeout      = 30 * time.Second
    51  	wsMessageSizeLimit = 32 * 1024 * 1024
    52  )
    53  
    54  var wsBufferPool = new(sync.Pool)
    55  
    56  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    57  //
    58  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    59  // To allow connections with any origin, pass "*".
    60  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    61  	return s.WebsocketHandlerWithDuration(allowedOrigins, 0, 0, 0)
    62  }
    63  
    64  func (s *Server) WebsocketHandlerWithDuration(allowedOrigins []string, apiMaxDuration, refillRate, maxStored time.Duration) http.Handler {
    65  	var upgrader = websocket.Upgrader{
    66  		ReadBufferSize:  wsReadBuffer,
    67  		WriteBufferSize: wsWriteBuffer,
    68  		WriteBufferPool: wsBufferPool,
    69  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    70  	}
    71  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    72  		conn, err := upgrader.Upgrade(w, r, nil)
    73  		if err != nil {
    74  			log.Debug("WebSocket upgrade failed", "err", err)
    75  			return
    76  		}
    77  		codec := newWebsocketCodec(conn, r.Host, r.Header)
    78  		s.ServeCodec(codec, 0, apiMaxDuration, refillRate, maxStored)
    79  	})
    80  }
    81  
    82  // wsHandshakeValidator returns a handler that verifies the origin during the
    83  // websocket upgrade process. When a '*' is specified as an allowed origins all
    84  // connections are accepted.
    85  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    86  	origins := mapset.NewSet[string]()
    87  	allowAllOrigins := false
    88  
    89  	for _, origin := range allowedOrigins {
    90  		if origin == "*" {
    91  			allowAllOrigins = true
    92  		}
    93  		if origin != "" {
    94  			origins.Add(origin)
    95  		}
    96  	}
    97  	// allow localhost if no allowedOrigins are specified.
    98  	if len(origins.ToSlice()) == 0 {
    99  		origins.Add("http://localhost")
   100  		if hostname, err := os.Hostname(); err == nil {
   101  			origins.Add("http://" + hostname)
   102  		}
   103  	}
   104  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
   105  
   106  	f := func(req *http.Request) bool {
   107  		// Skip origin verification if no Origin header is present. The origin check
   108  		// is supposed to protect against browser based attacks. Browsers always set
   109  		// Origin. Non-browser software can put anything in origin and checking it doesn't
   110  		// provide additional security.
   111  		if _, ok := req.Header["Origin"]; !ok {
   112  			return true
   113  		}
   114  		// Verify origin against allow list.
   115  		origin := strings.ToLower(req.Header.Get("Origin"))
   116  		if allowAllOrigins || originIsAllowed(origins, origin) {
   117  			return true
   118  		}
   119  		log.Warn("Rejected WebSocket connection", "origin", origin)
   120  		return false
   121  	}
   122  
   123  	return f
   124  }
   125  
   126  type wsHandshakeError struct {
   127  	err    error
   128  	status string
   129  }
   130  
   131  func (e wsHandshakeError) Error() string {
   132  	s := e.err.Error()
   133  	if e.status != "" {
   134  		s += " (HTTP status " + e.status + ")"
   135  	}
   136  	return s
   137  }
   138  
   139  func originIsAllowed(allowedOrigins mapset.Set[string], browserOrigin string) bool {
   140  	it := allowedOrigins.Iterator()
   141  	for origin := range it.C {
   142  		if ruleAllowsOrigin(origin, browserOrigin) {
   143  			return true
   144  		}
   145  	}
   146  	return false
   147  }
   148  
   149  func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
   150  	var (
   151  		allowedScheme, allowedHostname, allowedPort string
   152  		browserScheme, browserHostname, browserPort string
   153  		err                                         error
   154  	)
   155  	allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
   156  	if err != nil {
   157  		log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
   158  		return false
   159  	}
   160  	browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
   161  	if err != nil {
   162  		log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
   163  		return false
   164  	}
   165  	if allowedScheme != "" && allowedScheme != browserScheme {
   166  		return false
   167  	}
   168  	if allowedHostname != "" && allowedHostname != browserHostname {
   169  		return false
   170  	}
   171  	if allowedPort != "" && allowedPort != browserPort {
   172  		return false
   173  	}
   174  	return true
   175  }
   176  
   177  func parseOriginURL(origin string) (string, string, string, error) {
   178  	parsedURL, err := url.Parse(strings.ToLower(origin))
   179  	if err != nil {
   180  		return "", "", "", err
   181  	}
   182  	var scheme, hostname, port string
   183  	if strings.Contains(origin, "://") {
   184  		scheme = parsedURL.Scheme
   185  		hostname = parsedURL.Hostname()
   186  		port = parsedURL.Port()
   187  	} else {
   188  		scheme = ""
   189  		hostname = parsedURL.Scheme
   190  		port = parsedURL.Opaque
   191  		if hostname == "" {
   192  			hostname = origin
   193  		}
   194  	}
   195  	return scheme, hostname, port, nil
   196  }
   197  
   198  // DialWebsocketWithDialer creates a new RPC client using WebSocket.
   199  //
   200  // The context is used for the initial connection establishment. It does not
   201  // affect subsequent interactions with the client.
   202  //
   203  // Deprecated: use DialOptions and the WithWebsocketDialer option.
   204  func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
   205  	cfg := new(clientConfig)
   206  	cfg.wsDialer = &dialer
   207  	if origin != "" {
   208  		cfg.setHeader("origin", origin)
   209  	}
   210  	connect, err := newClientTransportWS(endpoint, cfg)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	return newClient(ctx, cfg, connect)
   215  }
   216  
   217  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   218  // that is listening on the given endpoint.
   219  //
   220  // The context is used for the initial connection establishment. It does not
   221  // affect subsequent interactions with the client.
   222  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   223  	cfg := new(clientConfig)
   224  	if origin != "" {
   225  		cfg.setHeader("origin", origin)
   226  	}
   227  	connect, err := newClientTransportWS(endpoint, cfg)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  	return newClient(ctx, cfg, connect)
   232  }
   233  
   234  func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
   235  	dialer := cfg.wsDialer
   236  	if dialer == nil {
   237  		dialer = &websocket.Dialer{
   238  			ReadBufferSize:  wsReadBuffer,
   239  			WriteBufferSize: wsWriteBuffer,
   240  			WriteBufferPool: wsBufferPool,
   241  			Proxy:           http.ProxyFromEnvironment,
   242  		}
   243  	}
   244  
   245  	dialURL, header, err := wsClientHeaders(endpoint, "")
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	for key, values := range cfg.httpHeaders {
   250  		header[key] = values
   251  	}
   252  
   253  	connect := func(ctx context.Context) (ServerCodec, error) {
   254  		header := header.Clone()
   255  		if cfg.httpAuth != nil {
   256  			if err := cfg.httpAuth(header); err != nil {
   257  				return nil, err
   258  			}
   259  		}
   260  		conn, resp, err := dialer.DialContext(ctx, dialURL, header)
   261  		if err != nil {
   262  			hErr := wsHandshakeError{err: err}
   263  			if resp != nil {
   264  				hErr.status = resp.Status
   265  			}
   266  			return nil, hErr
   267  		}
   268  		return newWebsocketCodec(conn, dialURL, header), nil
   269  	}
   270  	return connect, nil
   271  }
   272  
   273  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   274  	endpointURL, err := url.Parse(endpoint)
   275  	if err != nil {
   276  		return endpoint, nil, err
   277  	}
   278  	header := make(http.Header)
   279  	if origin != "" {
   280  		header.Add("origin", origin)
   281  	}
   282  	if endpointURL.User != nil {
   283  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   284  		header.Add("authorization", "Basic "+b64auth)
   285  		endpointURL.User = nil
   286  	}
   287  	return endpointURL.String(), header, nil
   288  }
   289  
   290  type websocketCodec struct {
   291  	*jsonCodec
   292  	conn *websocket.Conn
   293  	info PeerInfo
   294  
   295  	wg           sync.WaitGroup
   296  	pingReset    chan struct{}
   297  	pongReceived chan struct{}
   298  }
   299  
   300  func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec {
   301  	conn.SetReadLimit(wsMessageSizeLimit)
   302  	encode := func(v interface{}, isErrorResponse bool) error {
   303  		return conn.WriteJSON(v)
   304  	}
   305  	wc := &websocketCodec{
   306  		jsonCodec:    NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec),
   307  		conn:         conn,
   308  		pingReset:    make(chan struct{}, 1),
   309  		pongReceived: make(chan struct{}),
   310  		info: PeerInfo{
   311  			Transport:  "ws",
   312  			RemoteAddr: conn.RemoteAddr().String(),
   313  		},
   314  	}
   315  	// Fill in connection details.
   316  	wc.info.HTTP.Host = host
   317  	wc.info.HTTP.Origin = req.Get("Origin")
   318  	wc.info.HTTP.UserAgent = req.Get("User-Agent")
   319  	// Start pinger.
   320  	conn.SetPongHandler(func(appData string) error {
   321  		select {
   322  		case wc.pongReceived <- struct{}{}:
   323  		case <-wc.closed():
   324  		}
   325  		return nil
   326  	})
   327  	wc.wg.Add(1)
   328  	go wc.pingLoop()
   329  	return wc
   330  }
   331  
   332  func (wc *websocketCodec) close() {
   333  	wc.jsonCodec.close()
   334  	wc.wg.Wait()
   335  }
   336  
   337  func (wc *websocketCodec) peerInfo() PeerInfo {
   338  	return wc.info
   339  }
   340  
   341  func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error {
   342  	return wc.writeJSONSkipDeadline(ctx, v, isError, false)
   343  }
   344  
   345  func (wc *websocketCodec) writeJSONSkipDeadline(ctx context.Context, v interface{}, isError bool, skip bool) error {
   346  	err := wc.jsonCodec.writeJSONSkipDeadline(ctx, v, isError, skip)
   347  	if err == nil {
   348  		// Notify pingLoop to delay the next idle ping.
   349  		select {
   350  		case wc.pingReset <- struct{}{}:
   351  		default:
   352  		}
   353  	}
   354  	return err
   355  }
   356  
   357  // pingLoop sends periodic ping frames when the connection is idle.
   358  func (wc *websocketCodec) pingLoop() {
   359  	var pingTimer = time.NewTimer(wsPingInterval)
   360  	defer wc.wg.Done()
   361  	defer pingTimer.Stop()
   362  
   363  	for {
   364  		select {
   365  		case <-wc.closed():
   366  			return
   367  
   368  		case <-wc.pingReset:
   369  			if !pingTimer.Stop() {
   370  				<-pingTimer.C
   371  			}
   372  			pingTimer.Reset(wsPingInterval)
   373  
   374  		case <-pingTimer.C:
   375  			wc.jsonCodec.encMu.Lock()
   376  			wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
   377  			wc.conn.WriteMessage(websocket.PingMessage, nil)
   378  			wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout))
   379  			wc.jsonCodec.encMu.Unlock()
   380  			pingTimer.Reset(wsPingInterval)
   381  
   382  		case <-wc.pongReceived:
   383  			wc.conn.SetReadDeadline(time.Time{})
   384  		}
   385  	}
   386  }