github.com/kisexp/xdchain@v0.0.0-20211206025815-490d6b732aa7/rpc/websocket.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"encoding/base64"
    23  	"fmt"
    24  	"net/http"
    25  	"net/url"
    26  	"os"
    27  	"strings"
    28  	"sync"
    29  	"time"
    30  
    31  	mapset "github.com/deckarep/golang-set"
    32  	"github.com/kisexp/xdchain/log"
    33  	"github.com/gorilla/websocket"
    34  )
    35  
    36  const (
    37  	wsReadBuffer       = 1024
    38  	wsWriteBuffer      = 1024
    39  	wsPingInterval     = 60 * time.Second
    40  	wsPingWriteTimeout = 5 * time.Second
    41  )
    42  
    43  var wsBufferPool = new(sync.Pool)
    44  
    45  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    46  //
    47  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    48  // To allow connections with any origin, pass "*".
    49  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    50  	var upgrader = websocket.Upgrader{
    51  		ReadBufferSize:  wsReadBuffer,
    52  		WriteBufferSize: wsWriteBuffer,
    53  		WriteBufferPool: wsBufferPool,
    54  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    55  	}
    56  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    57  		conn, err := upgrader.Upgrade(w, r, nil)
    58  		if err != nil {
    59  			log.Debug("WebSocket upgrade failed", "err", err)
    60  			return
    61  		}
    62  		codec := newWebsocketCodec(conn)
    63  		s.authenticateHttpRequest(r, codec)
    64  		s.ServeCodec(codec, 0)
    65  	})
    66  }
    67  
    68  // wsHandshakeValidator returns a handler that verifies the origin during the
    69  // websocket upgrade process. When a '*' is specified as an allowed origins all
    70  // connections are accepted.
    71  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    72  	origins := mapset.NewSet()
    73  	allowAllOrigins := false
    74  
    75  	for _, origin := range allowedOrigins {
    76  		if origin == "*" {
    77  			allowAllOrigins = true
    78  		}
    79  		if origin != "" {
    80  			origins.Add(origin)
    81  		}
    82  	}
    83  	// allow localhost if no allowedOrigins are specified.
    84  	if len(origins.ToSlice()) == 0 {
    85  		origins.Add("http://localhost")
    86  		origins.Add("https://localhost")
    87  		if hostname, err := os.Hostname(); err == nil {
    88  			origins.Add("http://" + hostname)
    89  			origins.Add("https://" + hostname)
    90  		}
    91  	}
    92  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
    93  
    94  	f := func(req *http.Request) bool {
    95  		// Skip origin verification if no Origin header is present. The origin check
    96  		// is supposed to protect against browser based attacks. Browsers always set
    97  		// Origin. Non-browser software can put anything in origin and checking it doesn't
    98  		// provide additional security.
    99  		if _, ok := req.Header["Origin"]; !ok {
   100  			return true
   101  		}
   102  		// Verify origin against whitelist.
   103  		origin := strings.ToLower(req.Header.Get("Origin"))
   104  		if allowAllOrigins || originIsAllowed(origins, origin) {
   105  			return true
   106  		}
   107  		log.Warn("Rejected WebSocket connection", "origin", origin)
   108  		return false
   109  	}
   110  
   111  	return f
   112  }
   113  
   114  type wsHandshakeError struct {
   115  	err    error
   116  	status string
   117  }
   118  
   119  func (e wsHandshakeError) Error() string {
   120  	s := e.err.Error()
   121  	if e.status != "" {
   122  		s += " (HTTP status " + e.status + ")"
   123  	}
   124  	return s
   125  }
   126  
   127  func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
   128  	it := allowedOrigins.Iterator()
   129  	for origin := range it.C {
   130  		if ruleAllowsOrigin(origin.(string), browserOrigin) {
   131  			return true
   132  		}
   133  	}
   134  	return false
   135  }
   136  
   137  func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
   138  	var (
   139  		allowedScheme, allowedHostname, allowedPort string
   140  		browserScheme, browserHostname, browserPort string
   141  		err                                         error
   142  	)
   143  	allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
   144  	if err != nil {
   145  		log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
   146  		return false
   147  	}
   148  	browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
   149  	if err != nil {
   150  		log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
   151  		return false
   152  	}
   153  	if allowedScheme != "" && allowedScheme != browserScheme {
   154  		return false
   155  	}
   156  	if allowedHostname != "" && allowedHostname != browserHostname {
   157  		return false
   158  	}
   159  	if allowedPort != "" && allowedPort != browserPort {
   160  		return false
   161  	}
   162  	return true
   163  }
   164  
   165  func parseOriginURL(origin string) (string, string, string, error) {
   166  	parsedURL, err := url.Parse(strings.ToLower(origin))
   167  	if err != nil {
   168  		return "", "", "", err
   169  	}
   170  	var scheme, hostname, port string
   171  	if strings.Contains(origin, "://") {
   172  		scheme = parsedURL.Scheme
   173  		hostname = parsedURL.Hostname()
   174  		port = parsedURL.Port()
   175  	} else {
   176  		scheme = ""
   177  		hostname = parsedURL.Scheme
   178  		port = parsedURL.Opaque
   179  		if hostname == "" {
   180  			hostname = origin
   181  		}
   182  	}
   183  	return scheme, hostname, port, nil
   184  }
   185  
   186  // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
   187  // that is listening on the given endpoint using the provided dialer.
   188  //
   189  // The context is used for the initial connection establishment. It does not
   190  // affect subsequent interactions with the client.
   191  func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
   192  	return DialWebsocketWithCustomTLS(ctx, endpoint, origin, nil)
   193  }
   194  
   195  // Quorum
   196  //
   197  // DialWebsocketWithCustomTLS creates a new RPC client that communicates with a JSON-RPC server
   198  // that is listening on the given endpoint.
   199  // At the same time, allowing to customize TLSClientConfig of the dialer
   200  //
   201  // The context is used for the initial connection establishment. It does not
   202  // affect subsequent interactions with the client.
   203  func DialWebsocketWithCustomTLS(ctx context.Context, endpoint, origin string, tlsConfig *tls.Config) (*Client, error) {
   204  	dialer := websocket.Dialer{
   205  		ReadBufferSize:  wsReadBuffer,
   206  		WriteBufferSize: wsWriteBuffer,
   207  		WriteBufferPool: wsBufferPool,
   208  	}
   209  
   210  	endpoint, header, err := wsClientHeaders(endpoint, origin)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	if tlsConfig != nil {
   215  		dialer.TLSClientConfig = tlsConfig
   216  	}
   217  	ctx = resolvePSIProvider(ctx, endpoint)
   218  
   219  	credProviderFunc := CredentialsProviderFromContext(ctx)
   220  	psiProviderFunc := PSIProviderFromContext(ctx)
   221  	return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
   222  		if credProviderFunc != nil {
   223  			token, err := credProviderFunc(ctx)
   224  			if err != nil {
   225  				log.Warn("unable to obtain credentials from provider", "err", err)
   226  			} else {
   227  				header.Set(HttpAuthorizationHeader, token)
   228  			}
   229  		}
   230  		if psiProviderFunc != nil {
   231  			psi, err := psiProviderFunc(ctx)
   232  			if err != nil {
   233  				log.Warn("unable to obtain PSI from provider", "err", err)
   234  			} else {
   235  				header.Set(HttpPrivateStateIdentifierHeader, psi.String())
   236  			}
   237  		}
   238  		conn, resp, err := dialer.DialContext(ctx, endpoint, header)
   239  		if err != nil {
   240  			hErr := wsHandshakeError{err: err}
   241  			if resp != nil {
   242  				hErr.status = resp.Status
   243  			}
   244  			return nil, hErr
   245  		}
   246  		return newWebsocketCodec(conn), nil
   247  	})
   248  }
   249  
   250  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   251  // that is listening on the given endpoint.
   252  //
   253  // The context is used for the initial connection establishment. It does not
   254  // affect subsequent interactions with the client.
   255  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   256  	dialer := websocket.Dialer{
   257  		ReadBufferSize:  wsReadBuffer,
   258  		WriteBufferSize: wsWriteBuffer,
   259  		WriteBufferPool: wsBufferPool,
   260  	}
   261  	return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
   262  }
   263  
   264  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   265  	endpointURL, err := url.Parse(endpoint)
   266  	if err != nil {
   267  		return endpoint, nil, err
   268  	}
   269  	header := make(http.Header)
   270  	if origin != "" {
   271  		header.Add("origin", origin)
   272  	}
   273  	if endpointURL.User != nil {
   274  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   275  		header.Add(HttpAuthorizationHeader, "Basic "+b64auth)
   276  		endpointURL.User = nil
   277  	}
   278  	return endpointURL.String(), header, nil
   279  }
   280  
   281  type websocketCodec struct {
   282  	*jsonCodec
   283  	conn *websocket.Conn
   284  
   285  	wg        sync.WaitGroup
   286  	pingReset chan struct{}
   287  }
   288  
   289  func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
   290  	conn.SetReadLimit(maxRequestContentLength)
   291  	wc := &websocketCodec{
   292  		jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
   293  		conn:      conn,
   294  		pingReset: make(chan struct{}, 1),
   295  	}
   296  	wc.wg.Add(1)
   297  	go wc.pingLoop()
   298  	return wc
   299  }
   300  
   301  func (wc *websocketCodec) close() {
   302  	wc.jsonCodec.close()
   303  	wc.wg.Wait()
   304  }
   305  
   306  func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
   307  	err := wc.jsonCodec.writeJSON(ctx, v)
   308  	if err == nil {
   309  		// Notify pingLoop to delay the next idle ping.
   310  		select {
   311  		case wc.pingReset <- struct{}{}:
   312  		default:
   313  		}
   314  	}
   315  	return err
   316  }
   317  
   318  // pingLoop sends periodic ping frames when the connection is idle.
   319  func (wc *websocketCodec) pingLoop() {
   320  	var timer = time.NewTimer(wsPingInterval)
   321  	defer wc.wg.Done()
   322  	defer timer.Stop()
   323  
   324  	for {
   325  		select {
   326  		case <-wc.closed():
   327  			return
   328  		case <-wc.pingReset:
   329  			if !timer.Stop() {
   330  				<-timer.C
   331  			}
   332  			timer.Reset(wsPingInterval)
   333  		case <-timer.C:
   334  			wc.jsonCodec.encMu.Lock()
   335  			wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
   336  			wc.conn.WriteMessage(websocket.PingMessage, nil)
   337  			wc.jsonCodec.encMu.Unlock()
   338  			timer.Reset(wsPingInterval)
   339  		}
   340  	}
   341  }