github.com/core-coin/go-core/v2@v2.1.9/rpc/websocket.go (about)

     1  // Copyright 2015 by the Authors
     2  // This file is part of the go-core library.
     3  //
     4  // The go-core library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-core library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-core library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"net/http"
    24  	"net/url"
    25  	"os"
    26  	"strings"
    27  	"sync"
    28  	"time"
    29  
    30  	mapset "github.com/deckarep/golang-set"
    31  	"github.com/gorilla/websocket"
    32  
    33  	"github.com/core-coin/go-core/v2/log"
    34  )
    35  
    36  const (
    37  	wsReadBuffer       = 1024
    38  	wsWriteBuffer      = 1024
    39  	wsPingInterval     = 60 * time.Second
    40  	wsPingWriteTimeout = 5 * time.Second
    41  )
    42  
    43  var wsBufferPool = new(sync.Pool)
    44  
    45  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    46  //
    47  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    48  // To allow connections with any origin, pass "*".
    49  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    50  	var upgrader = websocket.Upgrader{
    51  		ReadBufferSize:  wsReadBuffer,
    52  		WriteBufferSize: wsWriteBuffer,
    53  		WriteBufferPool: wsBufferPool,
    54  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    55  	}
    56  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    57  		conn, err := upgrader.Upgrade(w, r, nil)
    58  		if err != nil {
    59  			log.Debug("WebSocket upgrade failed", "err", err)
    60  			return
    61  		}
    62  		codec := newWebsocketCodec(conn, r.Host, r.Header)
    63  		s.ServeCodec(codec, 0)
    64  	})
    65  }
    66  
    67  // wsHandshakeValidator returns a handler that verifies the origin during the
    68  // websocket upgrade process. When a '*' is specified as an allowed origins all
    69  // connections are accepted.
    70  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    71  	origins := mapset.NewSet()
    72  	allowAllOrigins := false
    73  
    74  	for _, origin := range allowedOrigins {
    75  		if origin == "*" {
    76  			allowAllOrigins = true
    77  		}
    78  		if origin != "" {
    79  			origins.Add(origin)
    80  		}
    81  	}
    82  	// allow localhost if no allowedOrigins are specified.
    83  	if len(origins.ToSlice()) == 0 {
    84  		origins.Add("http://localhost")
    85  		if hostname, err := os.Hostname(); err == nil {
    86  			origins.Add("http://" + hostname)
    87  		}
    88  	}
    89  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
    90  
    91  	f := func(req *http.Request) bool {
    92  		// Skip origin verification if no Origin header is present. The origin check
    93  		// is supposed to protect against browser based attacks. Browsers always set
    94  		// Origin. Non-browser software can put anything in origin and checking it doesn't
    95  		// provide additional security.
    96  		if _, ok := req.Header["Origin"]; !ok {
    97  			return true
    98  		}
    99  		// Verify origin against whitelist.
   100  		origin := strings.ToLower(req.Header.Get("Origin"))
   101  		if allowAllOrigins || originIsAllowed(origins, origin) {
   102  			return true
   103  		}
   104  		log.Warn("Rejected WebSocket connection", "origin", origin)
   105  		return false
   106  	}
   107  
   108  	return f
   109  }
   110  
   111  type wsHandshakeError struct {
   112  	err    error
   113  	status string
   114  }
   115  
   116  func (e wsHandshakeError) Error() string {
   117  	s := e.err.Error()
   118  	if e.status != "" {
   119  		s += " (HTTP status " + e.status + ")"
   120  	}
   121  	return s
   122  }
   123  
   124  func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
   125  	it := allowedOrigins.Iterator()
   126  	for origin := range it.C {
   127  		if ruleAllowsOrigin(origin.(string), browserOrigin) {
   128  			return true
   129  		}
   130  	}
   131  	return false
   132  }
   133  
   134  func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
   135  	var (
   136  		allowedScheme, allowedHostname, allowedPort string
   137  		browserScheme, browserHostname, browserPort string
   138  		err                                         error
   139  	)
   140  	allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
   141  	if err != nil {
   142  		log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
   143  		return false
   144  	}
   145  	browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
   146  	if err != nil {
   147  		log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
   148  		return false
   149  	}
   150  	if allowedScheme != "" && allowedScheme != browserScheme {
   151  		return false
   152  	}
   153  	if allowedHostname != "" && allowedHostname != browserHostname {
   154  		return false
   155  	}
   156  	if allowedPort != "" && allowedPort != browserPort {
   157  		return false
   158  	}
   159  	return true
   160  }
   161  
   162  func parseOriginURL(origin string) (string, string, string, error) {
   163  	parsedURL, err := url.Parse(strings.ToLower(origin))
   164  	if err != nil {
   165  		return "", "", "", err
   166  	}
   167  	var scheme, hostname, port string
   168  	if strings.Contains(origin, "://") {
   169  		scheme = parsedURL.Scheme
   170  		hostname = parsedURL.Hostname()
   171  		port = parsedURL.Port()
   172  	} else {
   173  		scheme = ""
   174  		hostname = parsedURL.Scheme
   175  		port = parsedURL.Opaque
   176  		if hostname == "" {
   177  			hostname = origin
   178  		}
   179  	}
   180  	return scheme, hostname, port, nil
   181  }
   182  
   183  // DialWebsocketWithDialer creates a new RPC client using WebSocket.
   184  //
   185  // The context is used for the initial connection establishment. It does not
   186  // affect subsequent interactions with the client.
   187  //
   188  // Deprecated: use DialOptions and the WithWebsocketDialer option.
   189  func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
   190  	cfg := new(clientConfig)
   191  	cfg.wsDialer = &dialer
   192  	if origin != "" {
   193  		cfg.setHeader("origin", origin)
   194  	}
   195  	connect, err := newClientTransportWS(endpoint, cfg)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	return newClient(ctx, connect)
   200  }
   201  
   202  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   203  // that is listening on the given endpoint.
   204  //
   205  // The context is used for the initial connection establishment. It does not
   206  // affect subsequent interactions with the client.
   207  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   208  	cfg := new(clientConfig)
   209  	if origin != "" {
   210  		cfg.setHeader("origin", origin)
   211  	}
   212  	connect, err := newClientTransportWS(endpoint, cfg)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	return newClient(ctx, connect)
   217  }
   218  
   219  func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
   220  	dialer := cfg.wsDialer
   221  	if dialer == nil {
   222  		dialer = &websocket.Dialer{
   223  			ReadBufferSize:  wsReadBuffer,
   224  			WriteBufferSize: wsWriteBuffer,
   225  			WriteBufferPool: wsBufferPool,
   226  		}
   227  	}
   228  
   229  	dialURL, header, err := wsClientHeaders(endpoint, "")
   230  	if err != nil {
   231  		return nil, err
   232  	}
   233  	for key, values := range cfg.httpHeaders {
   234  		header[key] = values
   235  	}
   236  
   237  	connect := func(ctx context.Context) (ServerCodec, error) {
   238  		header := header.Clone()
   239  		if cfg.httpAuth != nil {
   240  			if err := cfg.httpAuth(header); err != nil {
   241  				return nil, err
   242  			}
   243  		}
   244  		conn, resp, err := dialer.DialContext(ctx, dialURL, header)
   245  		if err != nil {
   246  			hErr := wsHandshakeError{err: err}
   247  			if resp != nil {
   248  				hErr.status = resp.Status
   249  			}
   250  			return nil, hErr
   251  		}
   252  		return newWebsocketCodec(conn, dialURL, header), nil
   253  	}
   254  	return connect, nil
   255  }
   256  
   257  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   258  	endpointURL, err := url.Parse(endpoint)
   259  	if err != nil {
   260  		return endpoint, nil, err
   261  	}
   262  	header := make(http.Header)
   263  	if origin != "" {
   264  		header.Add("origin", origin)
   265  	}
   266  	if endpointURL.User != nil {
   267  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   268  		header.Add("authorization", "Basic "+b64auth)
   269  		endpointURL.User = nil
   270  	}
   271  	return endpointURL.String(), header, nil
   272  }
   273  
   274  type websocketCodec struct {
   275  	*jsonCodec
   276  	conn *websocket.Conn
   277  	info PeerInfo
   278  
   279  	wg        sync.WaitGroup
   280  	pingReset chan struct{}
   281  }
   282  
   283  func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec {
   284  	conn.SetReadLimit(maxRequestContentLength)
   285  	wc := &websocketCodec{
   286  		jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
   287  		conn:      conn,
   288  		info: PeerInfo{
   289  			Transport:  "ws",
   290  			RemoteAddr: conn.RemoteAddr().String(),
   291  		},
   292  	}
   293  	// Fill in connection details.
   294  	wc.info.HTTP.Host = host
   295  	wc.info.HTTP.Origin = req.Get("Origin")
   296  	wc.info.HTTP.UserAgent = req.Get("User-Agent")
   297  	// Start pinger.
   298  	wc.wg.Add(1)
   299  	go wc.pingLoop()
   300  	return wc
   301  }
   302  
   303  func (wc *websocketCodec) close() {
   304  	wc.jsonCodec.close()
   305  	wc.wg.Wait()
   306  }
   307  
   308  func (wc *websocketCodec) peerInfo() PeerInfo {
   309  	return wc.info
   310  }
   311  
   312  func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
   313  	err := wc.jsonCodec.writeJSON(ctx, v)
   314  	if err == nil {
   315  		// Notify pingLoop to delay the next idle ping.
   316  		select {
   317  		case wc.pingReset <- struct{}{}:
   318  		default:
   319  		}
   320  	}
   321  	return err
   322  }
   323  
   324  // pingLoop sends periodic ping frames when the connection is idle.
   325  func (wc *websocketCodec) pingLoop() {
   326  	var timer = time.NewTimer(wsPingInterval)
   327  	defer wc.wg.Done()
   328  	defer timer.Stop()
   329  
   330  	for {
   331  		select {
   332  		case <-wc.closed():
   333  			return
   334  		case <-wc.pingReset:
   335  			if !timer.Stop() {
   336  				<-timer.C
   337  			}
   338  			timer.Reset(wsPingInterval)
   339  		case <-timer.C:
   340  			wc.jsonCodec.encMu.Lock()
   341  			wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
   342  			wc.conn.WriteMessage(websocket.PingMessage, nil)
   343  			wc.jsonCodec.encMu.Unlock()
   344  			timer.Reset(wsPingInterval)
   345  		}
   346  	}
   347  }