github.com/juliankolbe/go-ethereum@v1.9.992/rpc/websocket.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"net/http"
    24  	"net/url"
    25  	"os"
    26  	"strings"
    27  	"sync"
    28  	"time"
    29  
    30  	mapset "github.com/deckarep/golang-set"
    31  	"github.com/juliankolbe/go-ethereum/log"
    32  	"github.com/gorilla/websocket"
    33  )
    34  
    35  const (
    36  	wsReadBuffer       = 1024
    37  	wsWriteBuffer      = 1024
    38  	wsPingInterval     = 60 * time.Second
    39  	wsPingWriteTimeout = 5 * time.Second
    40  )
    41  
    42  var wsBufferPool = new(sync.Pool)
    43  
    44  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    45  //
    46  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    47  // To allow connections with any origin, pass "*".
    48  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    49  	var upgrader = websocket.Upgrader{
    50  		ReadBufferSize:  wsReadBuffer,
    51  		WriteBufferSize: wsWriteBuffer,
    52  		WriteBufferPool: wsBufferPool,
    53  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    54  	}
    55  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    56  		conn, err := upgrader.Upgrade(w, r, nil)
    57  		if err != nil {
    58  			log.Debug("WebSocket upgrade failed", "err", err)
    59  			return
    60  		}
    61  		codec := newWebsocketCodec(conn)
    62  		s.ServeCodec(codec, 0)
    63  	})
    64  }
    65  
    66  // wsHandshakeValidator returns a handler that verifies the origin during the
    67  // websocket upgrade process. When a '*' is specified as an allowed origins all
    68  // connections are accepted.
    69  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    70  	origins := mapset.NewSet()
    71  	allowAllOrigins := false
    72  
    73  	for _, origin := range allowedOrigins {
    74  		if origin == "*" {
    75  			allowAllOrigins = true
    76  		}
    77  		if origin != "" {
    78  			origins.Add(origin)
    79  		}
    80  	}
    81  	// allow localhost if no allowedOrigins are specified.
    82  	if len(origins.ToSlice()) == 0 {
    83  		origins.Add("http://localhost")
    84  		if hostname, err := os.Hostname(); err == nil {
    85  			origins.Add("http://" + hostname)
    86  		}
    87  	}
    88  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
    89  
    90  	f := func(req *http.Request) bool {
    91  		// Skip origin verification if no Origin header is present. The origin check
    92  		// is supposed to protect against browser based attacks. Browsers always set
    93  		// Origin. Non-browser software can put anything in origin and checking it doesn't
    94  		// provide additional security.
    95  		if _, ok := req.Header["Origin"]; !ok {
    96  			return true
    97  		}
    98  		// Verify origin against whitelist.
    99  		origin := strings.ToLower(req.Header.Get("Origin"))
   100  		if allowAllOrigins || originIsAllowed(origins, origin) {
   101  			return true
   102  		}
   103  		log.Warn("Rejected WebSocket connection", "origin", origin)
   104  		return false
   105  	}
   106  
   107  	return f
   108  }
   109  
   110  type wsHandshakeError struct {
   111  	err    error
   112  	status string
   113  }
   114  
   115  func (e wsHandshakeError) Error() string {
   116  	s := e.err.Error()
   117  	if e.status != "" {
   118  		s += " (HTTP status " + e.status + ")"
   119  	}
   120  	return s
   121  }
   122  
   123  func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
   124  	it := allowedOrigins.Iterator()
   125  	for origin := range it.C {
   126  		if ruleAllowsOrigin(origin.(string), browserOrigin) {
   127  			return true
   128  		}
   129  	}
   130  	return false
   131  }
   132  
   133  func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
   134  	var (
   135  		allowedScheme, allowedHostname, allowedPort string
   136  		browserScheme, browserHostname, browserPort string
   137  		err                                         error
   138  	)
   139  	allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
   140  	if err != nil {
   141  		log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
   142  		return false
   143  	}
   144  	browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
   145  	if err != nil {
   146  		log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
   147  		return false
   148  	}
   149  	if allowedScheme != "" && allowedScheme != browserScheme {
   150  		return false
   151  	}
   152  	if allowedHostname != "" && allowedHostname != browserHostname {
   153  		return false
   154  	}
   155  	if allowedPort != "" && allowedPort != browserPort {
   156  		return false
   157  	}
   158  	return true
   159  }
   160  
   161  func parseOriginURL(origin string) (string, string, string, error) {
   162  	parsedURL, err := url.Parse(strings.ToLower(origin))
   163  	if err != nil {
   164  		return "", "", "", err
   165  	}
   166  	var scheme, hostname, port string
   167  	if strings.Contains(origin, "://") {
   168  		scheme = parsedURL.Scheme
   169  		hostname = parsedURL.Hostname()
   170  		port = parsedURL.Port()
   171  	} else {
   172  		scheme = ""
   173  		hostname = parsedURL.Scheme
   174  		port = parsedURL.Opaque
   175  		if hostname == "" {
   176  			hostname = origin
   177  		}
   178  	}
   179  	return scheme, hostname, port, nil
   180  }
   181  
   182  // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
   183  // that is listening on the given endpoint using the provided dialer.
   184  func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
   185  	endpoint, header, err := wsClientHeaders(endpoint, origin)
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  	return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
   190  		conn, resp, err := dialer.DialContext(ctx, endpoint, header)
   191  		if err != nil {
   192  			hErr := wsHandshakeError{err: err}
   193  			if resp != nil {
   194  				hErr.status = resp.Status
   195  			}
   196  			return nil, hErr
   197  		}
   198  		return newWebsocketCodec(conn), nil
   199  	})
   200  }
   201  
   202  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   203  // that is listening on the given endpoint.
   204  //
   205  // The context is used for the initial connection establishment. It does not
   206  // affect subsequent interactions with the client.
   207  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   208  	dialer := websocket.Dialer{
   209  		ReadBufferSize:  wsReadBuffer,
   210  		WriteBufferSize: wsWriteBuffer,
   211  		WriteBufferPool: wsBufferPool,
   212  	}
   213  	return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
   214  }
   215  
   216  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   217  	endpointURL, err := url.Parse(endpoint)
   218  	if err != nil {
   219  		return endpoint, nil, err
   220  	}
   221  	header := make(http.Header)
   222  	if origin != "" {
   223  		header.Add("origin", origin)
   224  	}
   225  	if endpointURL.User != nil {
   226  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   227  		header.Add("authorization", "Basic "+b64auth)
   228  		endpointURL.User = nil
   229  	}
   230  	return endpointURL.String(), header, nil
   231  }
   232  
   233  type websocketCodec struct {
   234  	*jsonCodec
   235  	conn *websocket.Conn
   236  
   237  	wg        sync.WaitGroup
   238  	pingReset chan struct{}
   239  }
   240  
   241  func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
   242  	conn.SetReadLimit(maxRequestContentLength)
   243  	wc := &websocketCodec{
   244  		jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
   245  		conn:      conn,
   246  		pingReset: make(chan struct{}, 1),
   247  	}
   248  	wc.wg.Add(1)
   249  	go wc.pingLoop()
   250  	return wc
   251  }
   252  
   253  func (wc *websocketCodec) close() {
   254  	wc.jsonCodec.close()
   255  	wc.wg.Wait()
   256  }
   257  
   258  func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
   259  	err := wc.jsonCodec.writeJSON(ctx, v)
   260  	if err == nil {
   261  		// Notify pingLoop to delay the next idle ping.
   262  		select {
   263  		case wc.pingReset <- struct{}{}:
   264  		default:
   265  		}
   266  	}
   267  	return err
   268  }
   269  
   270  // pingLoop sends periodic ping frames when the connection is idle.
   271  func (wc *websocketCodec) pingLoop() {
   272  	var timer = time.NewTimer(wsPingInterval)
   273  	defer wc.wg.Done()
   274  	defer timer.Stop()
   275  
   276  	for {
   277  		select {
   278  		case <-wc.closed():
   279  			return
   280  		case <-wc.pingReset:
   281  			if !timer.Stop() {
   282  				<-timer.C
   283  			}
   284  			timer.Reset(wsPingInterval)
   285  		case <-timer.C:
   286  			wc.jsonCodec.encMu.Lock()
   287  			wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
   288  			wc.conn.WriteMessage(websocket.PingMessage, nil)
   289  			wc.jsonCodec.encMu.Unlock()
   290  			timer.Reset(wsPingInterval)
   291  		}
   292  	}
   293  }