gitlab.com/flarenetwork/coreth@v0.1.1/rpc/websocket.go (about)

     1  // (c) 2019-2020, Ava Labs, Inc.
     2  //
     3  // This file is a derived work, based on the go-ethereum library whose original
     4  // notices appear below.
     5  //
     6  // It is distributed under a license compatible with the licensing terms of the
     7  // original code from which it is derived.
     8  //
     9  // Much love to the original authors for their work.
    10  // **********
    11  // Copyright 2015 The go-ethereum Authors
    12  // This file is part of the go-ethereum library.
    13  //
    14  // The go-ethereum library is free software: you can redistribute it and/or modify
    15  // it under the terms of the GNU Lesser General Public License as published by
    16  // the Free Software Foundation, either version 3 of the License, or
    17  // (at your option) any later version.
    18  //
    19  // The go-ethereum library is distributed in the hope that it will be useful,
    20  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    21  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    22  // GNU Lesser General Public License for more details.
    23  //
    24  // You should have received a copy of the GNU Lesser General Public License
    25  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    26  
    27  package rpc
    28  
    29  import (
    30  	"context"
    31  	"encoding/base64"
    32  	"fmt"
    33  	"net/http"
    34  	"net/url"
    35  	"os"
    36  	"strings"
    37  	"sync"
    38  	"time"
    39  
    40  	mapset "github.com/deckarep/golang-set"
    41  	"github.com/ethereum/go-ethereum/log"
    42  	"github.com/gorilla/websocket"
    43  )
    44  
    45  const (
    46  	wsReadBuffer       = 1024
    47  	wsWriteBuffer      = 1024
    48  	wsPingInterval     = 30 * time.Second
    49  	wsPingWriteTimeout = 5 * time.Second
    50  	wsMessageSizeLimit = 15 * 1024 * 1024
    51  )
    52  
    53  var wsBufferPool = new(sync.Pool)
    54  
    55  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    56  //
    57  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    58  // To allow connections with any origin, pass "*".
    59  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    60  	return s.WebsocketHandlerWithDuration(allowedOrigins, 0)
    61  }
    62  
    63  func (s *Server) WebsocketHandlerWithDuration(allowedOrigins []string, apiMaxDuration time.Duration) http.Handler {
    64  	var upgrader = websocket.Upgrader{
    65  		ReadBufferSize:  wsReadBuffer,
    66  		WriteBufferSize: wsWriteBuffer,
    67  		WriteBufferPool: wsBufferPool,
    68  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    69  	}
    70  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    71  		conn, err := upgrader.Upgrade(w, r, nil)
    72  		if err != nil {
    73  			log.Debug("WebSocket upgrade failed", "err", err)
    74  			return
    75  		}
    76  		codec := newWebsocketCodec(conn)
    77  		s.ServeCodec(codec, 0, apiMaxDuration)
    78  	})
    79  }
    80  
    81  // wsHandshakeValidator returns a handler that verifies the origin during the
    82  // websocket upgrade process. When a '*' is specified as an allowed origins all
    83  // connections are accepted.
    84  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    85  	origins := mapset.NewSet()
    86  	allowAllOrigins := false
    87  
    88  	for _, origin := range allowedOrigins {
    89  		if origin == "*" {
    90  			allowAllOrigins = true
    91  		}
    92  		if origin != "" {
    93  			origins.Add(origin)
    94  		}
    95  	}
    96  	// allow localhost if no allowedOrigins are specified.
    97  	if len(origins.ToSlice()) == 0 {
    98  		origins.Add("http://localhost")
    99  		if hostname, err := os.Hostname(); err == nil {
   100  			origins.Add("http://" + hostname)
   101  		}
   102  	}
   103  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
   104  
   105  	f := func(req *http.Request) bool {
   106  		// Skip origin verification if no Origin header is present. The origin check
   107  		// is supposed to protect against browser based attacks. Browsers always set
   108  		// Origin. Non-browser software can put anything in origin and checking it doesn't
   109  		// provide additional security.
   110  		if _, ok := req.Header["Origin"]; !ok {
   111  			return true
   112  		}
   113  		// Verify origin against allow list.
   114  		origin := strings.ToLower(req.Header.Get("Origin"))
   115  		if allowAllOrigins || originIsAllowed(origins, origin) {
   116  			return true
   117  		}
   118  		log.Warn("Rejected WebSocket connection", "origin", origin)
   119  		return false
   120  	}
   121  
   122  	return f
   123  }
   124  
   125  type wsHandshakeError struct {
   126  	err    error
   127  	status string
   128  }
   129  
   130  func (e wsHandshakeError) Error() string {
   131  	s := e.err.Error()
   132  	if e.status != "" {
   133  		s += " (HTTP status " + e.status + ")"
   134  	}
   135  	return s
   136  }
   137  
   138  func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
   139  	it := allowedOrigins.Iterator()
   140  	for origin := range it.C {
   141  		if ruleAllowsOrigin(origin.(string), browserOrigin) {
   142  			return true
   143  		}
   144  	}
   145  	return false
   146  }
   147  
   148  func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
   149  	var (
   150  		allowedScheme, allowedHostname, allowedPort string
   151  		browserScheme, browserHostname, browserPort string
   152  		err                                         error
   153  	)
   154  	allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
   155  	if err != nil {
   156  		log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
   157  		return false
   158  	}
   159  	browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
   160  	if err != nil {
   161  		log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
   162  		return false
   163  	}
   164  	if allowedScheme != "" && allowedScheme != browserScheme {
   165  		return false
   166  	}
   167  	if allowedHostname != "" && allowedHostname != browserHostname {
   168  		return false
   169  	}
   170  	if allowedPort != "" && allowedPort != browserPort {
   171  		return false
   172  	}
   173  	return true
   174  }
   175  
   176  func parseOriginURL(origin string) (string, string, string, error) {
   177  	parsedURL, err := url.Parse(strings.ToLower(origin))
   178  	if err != nil {
   179  		return "", "", "", err
   180  	}
   181  	var scheme, hostname, port string
   182  	if strings.Contains(origin, "://") {
   183  		scheme = parsedURL.Scheme
   184  		hostname = parsedURL.Hostname()
   185  		port = parsedURL.Port()
   186  	} else {
   187  		scheme = ""
   188  		hostname = parsedURL.Scheme
   189  		port = parsedURL.Opaque
   190  		if hostname == "" {
   191  			hostname = origin
   192  		}
   193  	}
   194  	return scheme, hostname, port, nil
   195  }
   196  
   197  // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
   198  // that is listening on the given endpoint using the provided dialer.
   199  func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
   200  	endpoint, header, err := wsClientHeaders(endpoint, origin)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
   205  		conn, resp, err := dialer.DialContext(ctx, endpoint, header)
   206  		if err != nil {
   207  			hErr := wsHandshakeError{err: err}
   208  			if resp != nil {
   209  				hErr.status = resp.Status
   210  			}
   211  			return nil, hErr
   212  		}
   213  		return newWebsocketCodec(conn), nil
   214  	})
   215  }
   216  
   217  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   218  // that is listening on the given endpoint.
   219  //
   220  // The context is used for the initial connection establishment. It does not
   221  // affect subsequent interactions with the client.
   222  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   223  	dialer := websocket.Dialer{
   224  		ReadBufferSize:  wsReadBuffer,
   225  		WriteBufferSize: wsWriteBuffer,
   226  		WriteBufferPool: wsBufferPool,
   227  	}
   228  	return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
   229  }
   230  
   231  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   232  	endpointURL, err := url.Parse(endpoint)
   233  	if err != nil {
   234  		return endpoint, nil, err
   235  	}
   236  	header := make(http.Header)
   237  	if origin != "" {
   238  		header.Add("origin", origin)
   239  	}
   240  	if endpointURL.User != nil {
   241  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   242  		header.Add("authorization", "Basic "+b64auth)
   243  		endpointURL.User = nil
   244  	}
   245  	return endpointURL.String(), header, nil
   246  }
   247  
   248  type websocketCodec struct {
   249  	*jsonCodec
   250  	conn *websocket.Conn
   251  
   252  	wg        sync.WaitGroup
   253  	pingReset chan struct{}
   254  }
   255  
   256  func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
   257  	conn.SetReadLimit(wsMessageSizeLimit)
   258  	conn.SetPongHandler(func(appData string) error {
   259  		conn.SetReadDeadline(time.Time{})
   260  		return nil
   261  	})
   262  	wc := &websocketCodec{
   263  		jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
   264  		conn:      conn,
   265  		pingReset: make(chan struct{}, 1),
   266  	}
   267  	wc.wg.Add(1)
   268  	go wc.pingLoop()
   269  	return wc
   270  }
   271  
   272  func (wc *websocketCodec) close() {
   273  	wc.jsonCodec.close()
   274  	wc.wg.Wait()
   275  }
   276  
   277  func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
   278  	return wc.writeJSONSkipDeadline(ctx, v, false)
   279  }
   280  
   281  func (wc *websocketCodec) writeJSONSkipDeadline(ctx context.Context, v interface{}, skip bool) error {
   282  	err := wc.jsonCodec.writeJSONSkipDeadline(ctx, v, skip)
   283  	if err == nil {
   284  		// Notify pingLoop to delay the next idle ping.
   285  		select {
   286  		case wc.pingReset <- struct{}{}:
   287  		default:
   288  		}
   289  	}
   290  	return err
   291  }
   292  
   293  // pingLoop sends periodic ping frames when the connection is idle.
   294  func (wc *websocketCodec) pingLoop() {
   295  	var timer = time.NewTimer(wsPingInterval)
   296  	defer wc.wg.Done()
   297  	defer timer.Stop()
   298  
   299  	for {
   300  		select {
   301  		case <-wc.closed():
   302  			return
   303  		case <-wc.pingReset:
   304  			if !timer.Stop() {
   305  				<-timer.C
   306  			}
   307  			timer.Reset(wsPingInterval)
   308  		case <-timer.C:
   309  			wc.jsonCodec.encMu.Lock()
   310  			wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
   311  			wc.conn.WriteMessage(websocket.PingMessage, nil)
   312  			wc.jsonCodec.encMu.Unlock()
   313  			timer.Reset(wsPingInterval)
   314  		}
   315  	}
   316  }