github.com/ConsenSys/Quorum@v20.10.0+incompatible/rpc/websocket.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"encoding/base64"
    23  	"fmt"
    24  	"net/http"
    25  	"net/url"
    26  	"os"
    27  	"strings"
    28  	"sync"
    29  
    30  	mapset "github.com/deckarep/golang-set"
    31  	"github.com/ethereum/go-ethereum/log"
    32  	"github.com/gorilla/websocket"
    33  )
    34  
    35  const (
    36  	wsReadBuffer  = 1024
    37  	wsWriteBuffer = 1024
    38  )
    39  
    40  var wsBufferPool = new(sync.Pool)
    41  
    42  // NewWSServer creates a new websocket RPC server around an API provider.
    43  //
    44  // Deprecated: use Server.WebsocketHandler
    45  func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
    46  	return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
    47  }
    48  
    49  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    50  //
    51  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    52  // To allow connections with any origin, pass "*".
    53  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    54  	var upgrader = websocket.Upgrader{
    55  		ReadBufferSize:  wsReadBuffer,
    56  		WriteBufferSize: wsWriteBuffer,
    57  		WriteBufferPool: wsBufferPool,
    58  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    59  	}
    60  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    61  		conn, err := upgrader.Upgrade(w, r, nil)
    62  		if err != nil {
    63  			log.Debug("WebSocket upgrade failed", "err", err)
    64  			return
    65  		}
    66  		codec := newWebsocketCodec(conn)
    67  		s.authenticateHttpRequest(r, codec)
    68  		s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions)
    69  	})
    70  }
    71  
    72  // wsHandshakeValidator returns a handler that verifies the origin during the
    73  // websocket upgrade process. When a '*' is specified as an allowed origins all
    74  // connections are accepted.
    75  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    76  	origins := mapset.NewSet()
    77  	allowAllOrigins := false
    78  
    79  	for _, origin := range allowedOrigins {
    80  		if origin == "*" {
    81  			allowAllOrigins = true
    82  		}
    83  		if origin != "" {
    84  			origins.Add(strings.ToLower(origin))
    85  		}
    86  	}
    87  	// allow localhost if no allowedOrigins are specified.
    88  	if len(origins.ToSlice()) == 0 {
    89  		origins.Add("http://localhost")
    90  		origins.Add("https://localhost")
    91  		if hostname, err := os.Hostname(); err == nil {
    92  			origins.Add("http://" + strings.ToLower(hostname))
    93  			origins.Add("https://" + strings.ToLower(hostname))
    94  		}
    95  	}
    96  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
    97  
    98  	f := func(req *http.Request) bool {
    99  		// Skip origin verification if no Origin header is present. The origin check
   100  		// is supposed to protect against browser based attacks. Browsers always set
   101  		// Origin. Non-browser software can put anything in origin and checking it doesn't
   102  		// provide additional security.
   103  		if _, ok := req.Header["Origin"]; !ok {
   104  			return true
   105  		}
   106  		// Verify origin against whitelist.
   107  		origin := strings.ToLower(req.Header.Get("Origin"))
   108  		if allowAllOrigins || origins.Contains(origin) {
   109  			return true
   110  		}
   111  		log.Warn("Rejected WebSocket connection", "origin", origin)
   112  		return false
   113  	}
   114  
   115  	return f
   116  }
   117  
   118  type wsHandshakeError struct {
   119  	err    error
   120  	status string
   121  }
   122  
   123  func (e wsHandshakeError) Error() string {
   124  	s := e.err.Error()
   125  	if e.status != "" {
   126  		s += " (HTTP status " + e.status + ")"
   127  	}
   128  	return s
   129  }
   130  
   131  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   132  // that is listening on the given endpoint.
   133  //
   134  // The context is used for the initial connection establishment. It does not
   135  // affect subsequent interactions with the client.
   136  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   137  	return DialWebsocketWithCustomTLS(ctx, endpoint, origin, nil)
   138  }
   139  
   140  // Quorum
   141  //
   142  // DialWebsocketWithCustomTLS creates a new RPC client that communicates with a JSON-RPC server
   143  // that is listening on the given endpoint.
   144  // At the same time, allowing to customize TLSClientConfig of the dialer
   145  //
   146  // The context is used for the initial connection establishment. It does not
   147  // affect subsequent interactions with the client.
   148  func DialWebsocketWithCustomTLS(ctx context.Context, endpoint, origin string, tlsConfig *tls.Config) (*Client, error) {
   149  	endpoint, header, err := wsClientHeaders(endpoint, origin)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	dialer := websocket.Dialer{
   154  		ReadBufferSize:  wsReadBuffer,
   155  		WriteBufferSize: wsWriteBuffer,
   156  		WriteBufferPool: wsBufferPool,
   157  	}
   158  	if tlsConfig != nil {
   159  		dialer.TLSClientConfig = tlsConfig
   160  	}
   161  	credProviderFunc, hasCredProviderFunc := ctx.Value(CtxCredentialsProvider).(HttpCredentialsProviderFunc)
   162  	return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
   163  		if hasCredProviderFunc {
   164  			token, err := credProviderFunc(ctx)
   165  			if err != nil {
   166  				log.Warn("unable to obtain credentials from provider", "err", err)
   167  			} else {
   168  				header.Set(HttpAuthorizationHeader, token)
   169  			}
   170  		}
   171  		conn, resp, err := dialer.DialContext(ctx, endpoint, header)
   172  		if err != nil {
   173  			hErr := wsHandshakeError{err: err}
   174  			if resp != nil {
   175  				hErr.status = resp.Status
   176  			}
   177  			return nil, hErr
   178  		}
   179  		return newWebsocketCodec(conn), nil
   180  	})
   181  }
   182  
   183  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   184  	endpointURL, err := url.Parse(endpoint)
   185  	if err != nil {
   186  		return endpoint, nil, err
   187  	}
   188  	header := make(http.Header)
   189  	if origin != "" {
   190  		header.Add("origin", origin)
   191  	}
   192  	if endpointURL.User != nil {
   193  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   194  		header.Add(HttpAuthorizationHeader, "Basic "+b64auth)
   195  		endpointURL.User = nil
   196  	}
   197  	return endpointURL.String(), header, nil
   198  }
   199  
   200  func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
   201  	conn.SetReadLimit(maxRequestContentLength)
   202  	return newCodec(conn, conn.WriteJSON, conn.ReadJSON)
   203  }