github.com/SmartMeshFoundation/Spectrum@v0.0.0-20220621030607-452a266fee1e/rpc/websocket.go (about)

     1  // Copyright 2015 The Spectrum Authors
     2  // This file is part of the Spectrum library.
     3  //
     4  // The Spectrum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The Spectrum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the Spectrum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"fmt"
    23  	"net"
    24  	"net/http"
    25  	"net/url"
    26  	"os"
    27  	"strings"
    28  	"time"
    29  
    30  	"github.com/SmartMeshFoundation/Spectrum/log"
    31  	mapset "github.com/deckarep/golang-set"
    32  	"golang.org/x/net/websocket"
    33  )
    34  
    35  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    36  //
    37  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    38  // To allow connections with any origin, pass "*".
    39  func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    40  	return websocket.Server{
    41  		Handshake: wsHandshakeValidator(allowedOrigins),
    42  		Handler: func(conn *websocket.Conn) {
    43  			srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
    44  		},
    45  	}
    46  }
    47  
    48  // NewWSServer creates a new websocket RPC server around an API provider.
    49  //
    50  // Deprecated: use Server.WebsocketHandler
    51  func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
    52  	return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
    53  }
    54  
    55  // wsHandshakeValidator returns a handler that verifies the origin during the
    56  // websocket upgrade process. When a '*' is specified as an allowed origins all
    57  // connections are accepted.
    58  func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error {
    59  	//origins := set.New()
    60  	origins := mapset.NewSet()
    61  
    62  	allowAllOrigins := false
    63  
    64  	for _, origin := range allowedOrigins {
    65  		if origin == "*" {
    66  			allowAllOrigins = true
    67  		}
    68  		if origin != "" {
    69  			origins.Add(strings.ToLower(origin))
    70  		}
    71  	}
    72  
    73  	// allow localhost if no allowedOrigins are specified.
    74  	if len(origins.ToSlice()) == 0 {
    75  		origins.Add("http://localhost")
    76  		if hostname, err := os.Hostname(); err == nil {
    77  			origins.Add("http://" + strings.ToLower(hostname))
    78  		}
    79  	}
    80  
    81  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice()))
    82  
    83  	f := func(cfg *websocket.Config, req *http.Request) error {
    84  		origin := strings.ToLower(req.Header.Get("Origin"))
    85  		if allowAllOrigins || origins.Contains(origin) {
    86  			return nil
    87  		}
    88  		log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin))
    89  		return fmt.Errorf("origin %s not allowed", origin)
    90  	}
    91  
    92  	return f
    93  }
    94  
    95  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
    96  // that is listening on the given endpoint.
    97  //
    98  // The context is used for the initial connection establishment. It does not
    99  // affect subsequent interactions with the client.
   100  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   101  	if origin == "" {
   102  		var err error
   103  		if origin, err = os.Hostname(); err != nil {
   104  			return nil, err
   105  		}
   106  		if strings.HasPrefix(endpoint, "wss") {
   107  			origin = "https://" + strings.ToLower(origin)
   108  		} else {
   109  			origin = "http://" + strings.ToLower(origin)
   110  		}
   111  	}
   112  	config, err := websocket.NewConfig(endpoint, origin)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	return newClient(ctx, func(ctx context.Context) (net.Conn, error) {
   118  		return wsDialContext(ctx, config)
   119  	})
   120  }
   121  
   122  func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) {
   123  	var conn net.Conn
   124  	var err error
   125  	switch config.Location.Scheme {
   126  	case "ws":
   127  		conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location))
   128  	case "wss":
   129  		dialer := contextDialer(ctx)
   130  		conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig)
   131  	default:
   132  		err = websocket.ErrBadScheme
   133  	}
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	ws, err := websocket.NewClient(config, conn)
   138  	if err != nil {
   139  		conn.Close()
   140  		return nil, err
   141  	}
   142  	return ws, err
   143  }
   144  
   145  var wsPortMap = map[string]string{"ws": "80", "wss": "443"}
   146  
   147  func wsDialAddress(location *url.URL) string {
   148  	if _, ok := wsPortMap[location.Scheme]; ok {
   149  		if _, _, err := net.SplitHostPort(location.Host); err != nil {
   150  			return net.JoinHostPort(location.Host, wsPortMap[location.Scheme])
   151  		}
   152  	}
   153  	return location.Host
   154  }
   155  
   156  func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   157  	d := &net.Dialer{KeepAlive: tcpKeepAliveInterval}
   158  	return d.DialContext(ctx, network, addr)
   159  }
   160  
   161  func contextDialer(ctx context.Context) *net.Dialer {
   162  	dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval}
   163  	if deadline, ok := ctx.Deadline(); ok {
   164  		dialer.Deadline = deadline
   165  	} else {
   166  		dialer.Deadline = time.Now().Add(defaultDialTimeout)
   167  	}
   168  	return dialer
   169  }