github.com/Unheilbar/quorum@v1.0.0/rpc/websocket.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"encoding/base64"
    23  	"fmt"
    24  	"net/http"
    25  	"net/url"
    26  	"os"
    27  	"strings"
    28  	"sync"
    29  	"time"
    30  
    31  	mapset "github.com/deckarep/golang-set"
    32  	"github.com/ethereum/go-ethereum/log"
    33  	"github.com/gorilla/websocket"
    34  )
    35  
    36  const (
    37  	wsReadBuffer       = 1024
    38  	wsWriteBuffer      = 1024
    39  	wsPingInterval     = 60 * time.Second
    40  	wsPingWriteTimeout = 5 * time.Second
    41  	wsMessageSizeLimit = 15 * 1024 * 1024
    42  )
    43  
    44  var wsBufferPool = new(sync.Pool)
    45  
    46  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    47  //
    48  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    49  // To allow connections with any origin, pass "*".
    50  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    51  	var upgrader = websocket.Upgrader{
    52  		ReadBufferSize:  wsReadBuffer,
    53  		WriteBufferSize: wsWriteBuffer,
    54  		WriteBufferPool: wsBufferPool,
    55  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    56  	}
    57  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    58  		conn, err := upgrader.Upgrade(w, r, nil)
    59  		if err != nil {
    60  			log.Debug("WebSocket upgrade failed", "err", err)
    61  			return
    62  		}
    63  		codec := newWebsocketCodec(conn)
    64  		s.authenticateHttpRequest(r, codec)
    65  		s.ServeCodec(codec, 0)
    66  	})
    67  }
    68  
    69  // wsHandshakeValidator returns a handler that verifies the origin during the
    70  // websocket upgrade process. When a '*' is specified as an allowed origins all
    71  // connections are accepted.
    72  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    73  	origins := mapset.NewSet()
    74  	allowAllOrigins := false
    75  
    76  	for _, origin := range allowedOrigins {
    77  		if origin == "*" {
    78  			allowAllOrigins = true
    79  		}
    80  		if origin != "" {
    81  			origins.Add(origin)
    82  		}
    83  	}
    84  	// allow localhost if no allowedOrigins are specified.
    85  	if len(origins.ToSlice()) == 0 {
    86  		origins.Add("http://localhost")
    87  		origins.Add("https://localhost")
    88  		if hostname, err := os.Hostname(); err == nil {
    89  			origins.Add("http://" + hostname)
    90  			origins.Add("https://" + hostname)
    91  		}
    92  	}
    93  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
    94  
    95  	f := func(req *http.Request) bool {
    96  		// Skip origin verification if no Origin header is present. The origin check
    97  		// is supposed to protect against browser based attacks. Browsers always set
    98  		// Origin. Non-browser software can put anything in origin and checking it doesn't
    99  		// provide additional security.
   100  		if _, ok := req.Header["Origin"]; !ok {
   101  			return true
   102  		}
   103  		// Verify origin against whitelist.
   104  		origin := strings.ToLower(req.Header.Get("Origin"))
   105  		if allowAllOrigins || originIsAllowed(origins, origin) {
   106  			return true
   107  		}
   108  		log.Warn("Rejected WebSocket connection", "origin", origin)
   109  		return false
   110  	}
   111  
   112  	return f
   113  }
   114  
   115  type wsHandshakeError struct {
   116  	err    error
   117  	status string
   118  }
   119  
   120  func (e wsHandshakeError) Error() string {
   121  	s := e.err.Error()
   122  	if e.status != "" {
   123  		s += " (HTTP status " + e.status + ")"
   124  	}
   125  	return s
   126  }
   127  
   128  func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
   129  	it := allowedOrigins.Iterator()
   130  	for origin := range it.C {
   131  		if ruleAllowsOrigin(origin.(string), browserOrigin) {
   132  			return true
   133  		}
   134  	}
   135  	return false
   136  }
   137  
   138  func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
   139  	var (
   140  		allowedScheme, allowedHostname, allowedPort string
   141  		browserScheme, browserHostname, browserPort string
   142  		err                                         error
   143  	)
   144  	allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
   145  	if err != nil {
   146  		log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
   147  		return false
   148  	}
   149  	browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
   150  	if err != nil {
   151  		log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
   152  		return false
   153  	}
   154  	if allowedScheme != "" && allowedScheme != browserScheme {
   155  		return false
   156  	}
   157  	if allowedHostname != "" && allowedHostname != browserHostname {
   158  		return false
   159  	}
   160  	if allowedPort != "" && allowedPort != browserPort {
   161  		return false
   162  	}
   163  	return true
   164  }
   165  
   166  func parseOriginURL(origin string) (string, string, string, error) {
   167  	parsedURL, err := url.Parse(strings.ToLower(origin))
   168  	if err != nil {
   169  		return "", "", "", err
   170  	}
   171  	var scheme, hostname, port string
   172  	if strings.Contains(origin, "://") {
   173  		scheme = parsedURL.Scheme
   174  		hostname = parsedURL.Hostname()
   175  		port = parsedURL.Port()
   176  	} else {
   177  		scheme = ""
   178  		hostname = parsedURL.Scheme
   179  		port = parsedURL.Opaque
   180  		if hostname == "" {
   181  			hostname = origin
   182  		}
   183  	}
   184  	return scheme, hostname, port, nil
   185  }
   186  
   187  // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
   188  // that is listening on the given endpoint using the provided dialer.
   189  //
   190  // The context is used for the initial connection establishment. It does not
   191  // affect subsequent interactions with the client.
   192  func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
   193  	return DialWebsocketWithCustomTLS(ctx, endpoint, origin, nil)
   194  }
   195  
   196  // Quorum
   197  //
   198  // DialWebsocketWithCustomTLS creates a new RPC client that communicates with a JSON-RPC server
   199  // that is listening on the given endpoint.
   200  // At the same time, allowing to customize TLSClientConfig of the dialer
   201  //
   202  // The context is used for the initial connection establishment. It does not
   203  // affect subsequent interactions with the client.
   204  func DialWebsocketWithCustomTLS(ctx context.Context, endpoint, origin string, tlsConfig *tls.Config) (*Client, error) {
   205  	dialer := websocket.Dialer{
   206  		ReadBufferSize:  wsReadBuffer,
   207  		WriteBufferSize: wsWriteBuffer,
   208  		WriteBufferPool: wsBufferPool,
   209  	}
   210  
   211  	endpoint, header, err := wsClientHeaders(endpoint, origin)
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  	if tlsConfig != nil {
   216  		dialer.TLSClientConfig = tlsConfig
   217  	}
   218  	ctx = resolvePSIProvider(ctx, endpoint)
   219  
   220  	credProviderFunc := CredentialsProviderFromContext(ctx)
   221  	psiProviderFunc := PSIProviderFromContext(ctx)
   222  	return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
   223  		if credProviderFunc != nil {
   224  			token, err := credProviderFunc(ctx)
   225  			if err != nil {
   226  				log.Warn("unable to obtain credentials from provider", "err", err)
   227  			} else {
   228  				header.Set(HttpAuthorizationHeader, token)
   229  			}
   230  		}
   231  		if psiProviderFunc != nil {
   232  			psi, err := psiProviderFunc(ctx)
   233  			if err != nil {
   234  				log.Warn("unable to obtain PSI from provider", "err", err)
   235  			} else {
   236  				header.Set(HttpPrivateStateIdentifierHeader, psi.String())
   237  			}
   238  		}
   239  		conn, resp, err := dialer.DialContext(ctx, endpoint, header)
   240  		if err != nil {
   241  			hErr := wsHandshakeError{err: err}
   242  			if resp != nil {
   243  				hErr.status = resp.Status
   244  			}
   245  			return nil, hErr
   246  		}
   247  		return newWebsocketCodec(conn), nil
   248  	})
   249  }
   250  
   251  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   252  // that is listening on the given endpoint.
   253  //
   254  // The context is used for the initial connection establishment. It does not
   255  // affect subsequent interactions with the client.
   256  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   257  	dialer := websocket.Dialer{
   258  		ReadBufferSize:  wsReadBuffer,
   259  		WriteBufferSize: wsWriteBuffer,
   260  		WriteBufferPool: wsBufferPool,
   261  	}
   262  	return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
   263  }
   264  
   265  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   266  	endpointURL, err := url.Parse(endpoint)
   267  	if err != nil {
   268  		return endpoint, nil, err
   269  	}
   270  	header := make(http.Header)
   271  	if origin != "" {
   272  		header.Add("origin", origin)
   273  	}
   274  	if endpointURL.User != nil {
   275  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   276  		header.Add(HttpAuthorizationHeader, "Basic "+b64auth)
   277  		endpointURL.User = nil
   278  	}
   279  	return endpointURL.String(), header, nil
   280  }
   281  
   282  type websocketCodec struct {
   283  	*jsonCodec
   284  	conn *websocket.Conn
   285  
   286  	wg        sync.WaitGroup
   287  	pingReset chan struct{}
   288  }
   289  
   290  func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
   291  	conn.SetReadLimit(wsMessageSizeLimit)
   292  	wc := &websocketCodec{
   293  		jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
   294  		conn:      conn,
   295  		pingReset: make(chan struct{}, 1),
   296  	}
   297  	wc.wg.Add(1)
   298  	go wc.pingLoop()
   299  	return wc
   300  }
   301  
   302  func (wc *websocketCodec) close() {
   303  	wc.jsonCodec.close()
   304  	wc.wg.Wait()
   305  }
   306  
   307  func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
   308  	err := wc.jsonCodec.writeJSON(ctx, v)
   309  	if err == nil {
   310  		// Notify pingLoop to delay the next idle ping.
   311  		select {
   312  		case wc.pingReset <- struct{}{}:
   313  		default:
   314  		}
   315  	}
   316  	return err
   317  }
   318  
   319  // pingLoop sends periodic ping frames when the connection is idle.
   320  func (wc *websocketCodec) pingLoop() {
   321  	var timer = time.NewTimer(wsPingInterval)
   322  	defer wc.wg.Done()
   323  	defer timer.Stop()
   324  
   325  	for {
   326  		select {
   327  		case <-wc.closed():
   328  			return
   329  		case <-wc.pingReset:
   330  			if !timer.Stop() {
   331  				<-timer.C
   332  			}
   333  			timer.Reset(wsPingInterval)
   334  		case <-timer.C:
   335  			wc.jsonCodec.encMu.Lock()
   336  			wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
   337  			wc.conn.WriteMessage(websocket.PingMessage, nil)
   338  			wc.jsonCodec.encMu.Unlock()
   339  			timer.Reset(wsPingInterval)
   340  		}
   341  	}
   342  }