github.com/ethereum/go-ethereum@v1.16.1/rpc/websocket.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"net/http"
    24  	"net/url"
    25  	"os"
    26  	"strings"
    27  	"sync"
    28  	"time"
    29  
    30  	mapset "github.com/deckarep/golang-set/v2"
    31  	"github.com/ethereum/go-ethereum/log"
    32  	"github.com/gorilla/websocket"
    33  )
    34  
    35  const (
    36  	wsReadBuffer       = 1024
    37  	wsWriteBuffer      = 1024
    38  	wsPingInterval     = 30 * time.Second
    39  	wsPingWriteTimeout = 5 * time.Second
    40  	wsPongTimeout      = 30 * time.Second
    41  	wsDefaultReadLimit = 32 * 1024 * 1024
    42  )
    43  
    44  var wsBufferPool = new(sync.Pool)
    45  
    46  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    47  //
    48  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    49  // To allow connections with any origin, pass "*".
    50  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    51  	var upgrader = websocket.Upgrader{
    52  		ReadBufferSize:  wsReadBuffer,
    53  		WriteBufferSize: wsWriteBuffer,
    54  		WriteBufferPool: wsBufferPool,
    55  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    56  	}
    57  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    58  		conn, err := upgrader.Upgrade(w, r, nil)
    59  		if err != nil {
    60  			log.Debug("WebSocket upgrade failed", "err", err)
    61  			return
    62  		}
    63  		codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
    64  		s.ServeCodec(codec, 0)
    65  	})
    66  }
    67  
    68  // wsHandshakeValidator returns a handler that verifies the origin during the
    69  // websocket upgrade process. When a '*' is specified as an allowed origins all
    70  // connections are accepted.
    71  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    72  	origins := mapset.NewSet[string]()
    73  	allowAllOrigins := false
    74  
    75  	for _, origin := range allowedOrigins {
    76  		if origin == "*" {
    77  			allowAllOrigins = true
    78  		}
    79  		if origin != "" {
    80  			origins.Add(origin)
    81  		}
    82  	}
    83  	// allow localhost if no allowedOrigins are specified.
    84  	if len(origins.ToSlice()) == 0 {
    85  		origins.Add("http://localhost")
    86  		if hostname, err := os.Hostname(); err == nil {
    87  			origins.Add("http://" + hostname)
    88  		}
    89  	}
    90  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
    91  
    92  	f := func(req *http.Request) bool {
    93  		// Skip origin verification if no Origin header is present. The origin check
    94  		// is supposed to protect against browser based attacks. Browsers always set
    95  		// Origin. Non-browser software can put anything in origin and checking it doesn't
    96  		// provide additional security.
    97  		if _, ok := req.Header["Origin"]; !ok {
    98  			return true
    99  		}
   100  		// Verify origin against allow list.
   101  		origin := strings.ToLower(req.Header.Get("Origin"))
   102  		if allowAllOrigins || originIsAllowed(origins, origin) {
   103  			return true
   104  		}
   105  		log.Warn("Rejected WebSocket connection", "origin", origin)
   106  		return false
   107  	}
   108  
   109  	return f
   110  }
   111  
   112  type wsHandshakeError struct {
   113  	err    error
   114  	status string
   115  }
   116  
   117  func (e wsHandshakeError) Error() string {
   118  	s := e.err.Error()
   119  	if e.status != "" {
   120  		s += " (HTTP status " + e.status + ")"
   121  	}
   122  	return s
   123  }
   124  
   125  func (e wsHandshakeError) Unwrap() error {
   126  	return e.err
   127  }
   128  
   129  func originIsAllowed(allowedOrigins mapset.Set[string], browserOrigin string) bool {
   130  	it := allowedOrigins.Iterator()
   131  	for origin := range it.C {
   132  		if ruleAllowsOrigin(origin, browserOrigin) {
   133  			return true
   134  		}
   135  	}
   136  	return false
   137  }
   138  
   139  func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
   140  	var (
   141  		allowedScheme, allowedHostname, allowedPort string
   142  		browserScheme, browserHostname, browserPort string
   143  		err                                         error
   144  	)
   145  	allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
   146  	if err != nil {
   147  		log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
   148  		return false
   149  	}
   150  	browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
   151  	if err != nil {
   152  		log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
   153  		return false
   154  	}
   155  	if allowedScheme != "" && allowedScheme != browserScheme {
   156  		return false
   157  	}
   158  	if allowedHostname != "" && allowedHostname != browserHostname {
   159  		return false
   160  	}
   161  	if allowedPort != "" && allowedPort != browserPort {
   162  		return false
   163  	}
   164  	return true
   165  }
   166  
   167  func parseOriginURL(origin string) (string, string, string, error) {
   168  	parsedURL, err := url.Parse(strings.ToLower(origin))
   169  	if err != nil {
   170  		return "", "", "", err
   171  	}
   172  	var scheme, hostname, port string
   173  	if strings.Contains(origin, "://") {
   174  		scheme = parsedURL.Scheme
   175  		hostname = parsedURL.Hostname()
   176  		port = parsedURL.Port()
   177  	} else {
   178  		scheme = ""
   179  		hostname = parsedURL.Scheme
   180  		port = parsedURL.Opaque
   181  		if hostname == "" {
   182  			hostname = origin
   183  		}
   184  	}
   185  	return scheme, hostname, port, nil
   186  }
   187  
   188  // DialWebsocketWithDialer creates a new RPC client using WebSocket.
   189  //
   190  // The context is used for the initial connection establishment. It does not
   191  // affect subsequent interactions with the client.
   192  //
   193  // Deprecated: use DialOptions and the WithWebsocketDialer option.
   194  func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
   195  	cfg := new(clientConfig)
   196  	cfg.wsDialer = &dialer
   197  	if origin != "" {
   198  		cfg.setHeader("origin", origin)
   199  	}
   200  	connect, err := newClientTransportWS(endpoint, cfg)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	return newClient(ctx, cfg, connect)
   205  }
   206  
   207  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   208  // that is listening on the given endpoint.
   209  //
   210  // The context is used for the initial connection establishment. It does not
   211  // affect subsequent interactions with the client.
   212  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   213  	cfg := new(clientConfig)
   214  	if origin != "" {
   215  		cfg.setHeader("origin", origin)
   216  	}
   217  	connect, err := newClientTransportWS(endpoint, cfg)
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  	return newClient(ctx, cfg, connect)
   222  }
   223  
   224  func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
   225  	dialer := cfg.wsDialer
   226  	if dialer == nil {
   227  		dialer = &websocket.Dialer{
   228  			ReadBufferSize:  wsReadBuffer,
   229  			WriteBufferSize: wsWriteBuffer,
   230  			WriteBufferPool: wsBufferPool,
   231  			Proxy:           http.ProxyFromEnvironment,
   232  		}
   233  	}
   234  
   235  	dialURL, header, err := wsClientHeaders(endpoint, "")
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  	for key, values := range cfg.httpHeaders {
   240  		header[key] = values
   241  	}
   242  
   243  	connect := func(ctx context.Context) (ServerCodec, error) {
   244  		header := header.Clone()
   245  		if cfg.httpAuth != nil {
   246  			if err := cfg.httpAuth(header); err != nil {
   247  				return nil, err
   248  			}
   249  		}
   250  		conn, resp, err := dialer.DialContext(ctx, dialURL, header)
   251  		if err != nil {
   252  			hErr := wsHandshakeError{err: err}
   253  			if resp != nil {
   254  				hErr.status = resp.Status
   255  			}
   256  			return nil, hErr
   257  		}
   258  		messageSizeLimit := int64(wsDefaultReadLimit)
   259  		if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 {
   260  			messageSizeLimit = *cfg.wsMessageSizeLimit
   261  		}
   262  		return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil
   263  	}
   264  	return connect, nil
   265  }
   266  
   267  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   268  	endpointURL, err := url.Parse(endpoint)
   269  	if err != nil {
   270  		return endpoint, nil, err
   271  	}
   272  	header := make(http.Header)
   273  	if origin != "" {
   274  		header.Add("origin", origin)
   275  	}
   276  	if endpointURL.User != nil {
   277  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   278  		header.Add("authorization", "Basic "+b64auth)
   279  		endpointURL.User = nil
   280  	}
   281  	return endpointURL.String(), header, nil
   282  }
   283  
   284  type websocketCodec struct {
   285  	*jsonCodec
   286  	conn *websocket.Conn
   287  	info PeerInfo
   288  
   289  	wg           sync.WaitGroup
   290  	pingReset    chan struct{}
   291  	pongReceived chan struct{}
   292  }
   293  
   294  func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec {
   295  	conn.SetReadLimit(readLimit)
   296  	encode := func(v interface{}, isErrorResponse bool) error {
   297  		return conn.WriteJSON(v)
   298  	}
   299  	wc := &websocketCodec{
   300  		jsonCodec:    NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec),
   301  		conn:         conn,
   302  		pingReset:    make(chan struct{}, 1),
   303  		pongReceived: make(chan struct{}),
   304  		info: PeerInfo{
   305  			Transport:  "ws",
   306  			RemoteAddr: conn.RemoteAddr().String(),
   307  		},
   308  	}
   309  	// Fill in connection details.
   310  	wc.info.HTTP.Host = host
   311  	wc.info.HTTP.Origin = req.Get("Origin")
   312  	wc.info.HTTP.UserAgent = req.Get("User-Agent")
   313  	// Start pinger.
   314  	conn.SetPongHandler(func(appData string) error {
   315  		select {
   316  		case wc.pongReceived <- struct{}{}:
   317  		case <-wc.closed():
   318  		}
   319  		return nil
   320  	})
   321  	wc.wg.Add(1)
   322  	go wc.pingLoop()
   323  	return wc
   324  }
   325  
   326  func (wc *websocketCodec) close() {
   327  	wc.jsonCodec.close()
   328  	wc.wg.Wait()
   329  }
   330  
   331  func (wc *websocketCodec) peerInfo() PeerInfo {
   332  	return wc.info
   333  }
   334  
   335  func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error {
   336  	err := wc.jsonCodec.writeJSON(ctx, v, isError)
   337  	if err == nil {
   338  		// Notify pingLoop to delay the next idle ping.
   339  		select {
   340  		case wc.pingReset <- struct{}{}:
   341  		default:
   342  		}
   343  	}
   344  	return err
   345  }
   346  
   347  // pingLoop sends periodic ping frames when the connection is idle.
   348  func (wc *websocketCodec) pingLoop() {
   349  	var pingTimer = time.NewTimer(wsPingInterval)
   350  	defer wc.wg.Done()
   351  	defer pingTimer.Stop()
   352  
   353  	for {
   354  		select {
   355  		case <-wc.closed():
   356  			return
   357  
   358  		case <-wc.pingReset:
   359  			if !pingTimer.Stop() {
   360  				<-pingTimer.C
   361  			}
   362  			pingTimer.Reset(wsPingInterval)
   363  
   364  		case <-pingTimer.C:
   365  			wc.jsonCodec.encMu.Lock()
   366  			wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
   367  			wc.conn.WriteMessage(websocket.PingMessage, nil)
   368  			wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout))
   369  			wc.jsonCodec.encMu.Unlock()
   370  			pingTimer.Reset(wsPingInterval)
   371  
   372  		case <-wc.pongReceived:
   373  			wc.conn.SetReadDeadline(time.Time{})
   374  		}
   375  	}
   376  }