github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/rpc/websocket.go (about)

     1  package rpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  	"net/http"
    11  	"net/url"
    12  	"os"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/quickchainproject/quickchain/log"
    17  	"golang.org/x/net/websocket"
    18  	"gopkg.in/fatih/set.v0"
    19  )
    20  
    21  // websocketJSONCodec is a custom JSON codec with payload size enforcement and
    22  // special number parsing.
    23  var websocketJSONCodec = websocket.Codec{
    24  	// Marshal is the stock JSON marshaller used by the websocket library too.
    25  	Marshal: func(v interface{}) ([]byte, byte, error) {
    26  		msg, err := json.Marshal(v)
    27  		return msg, websocket.TextFrame, err
    28  	},
    29  	// Unmarshal is a specialized unmarshaller to properly convert numbers.
    30  	Unmarshal: func(msg []byte, payloadType byte, v interface{}) error {
    31  		dec := json.NewDecoder(bytes.NewReader(msg))
    32  		dec.UseNumber()
    33  
    34  		return dec.Decode(v)
    35  	},
    36  }
    37  
    38  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    39  //
    40  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    41  // To allow connections with any origin, pass "*".
    42  func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    43  	return websocket.Server{
    44  		Handshake: wsHandshakeValidator(allowedOrigins),
    45  		Handler: func(conn *websocket.Conn) {
    46  			// Create a custom encode/decode pair to enforce payload size and number encoding
    47  			conn.MaxPayloadBytes = maxRequestContentLength
    48  
    49  			encoder := func(v interface{}) error {
    50  				return websocketJSONCodec.Send(conn, v)
    51  			}
    52  			decoder := func(v interface{}) error {
    53  				return websocketJSONCodec.Receive(conn, v)
    54  			}
    55  			srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions)
    56  		},
    57  	}
    58  }
    59  
    60  // NewWSServer creates a new websocket RPC server around an API provider.
    61  //
    62  // Deprecated: use Server.WebsocketHandler
    63  func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
    64  	return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
    65  }
    66  
    67  // wsHandshakeValidator returns a handler that verifies the origin during the
    68  // websocket upgrade process. When a '*' is specified as an allowed origins all
    69  // connections are accepted.
    70  func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error {
    71  	origins := set.New()
    72  	allowAllOrigins := false
    73  
    74  	for _, origin := range allowedOrigins {
    75  		if origin == "*" {
    76  			allowAllOrigins = true
    77  		}
    78  		if origin != "" {
    79  			origins.Add(strings.ToLower(origin))
    80  		}
    81  	}
    82  
    83  	// allow localhost if no allowedOrigins are specified.
    84  	if len(origins.List()) == 0 {
    85  		origins.Add("http://localhost")
    86  		if hostname, err := os.Hostname(); err == nil {
    87  			origins.Add("http://" + strings.ToLower(hostname))
    88  		}
    89  	}
    90  
    91  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.List()))
    92  
    93  	f := func(cfg *websocket.Config, req *http.Request) error {
    94  		origin := strings.ToLower(req.Header.Get("Origin"))
    95  		if allowAllOrigins || origins.Has(origin) {
    96  			return nil
    97  		}
    98  		log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin))
    99  		return fmt.Errorf("origin %s not allowed", origin)
   100  	}
   101  
   102  	return f
   103  }
   104  
   105  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   106  // that is listening on the given endpoint.
   107  //
   108  // The context is used for the initial connection establishment. It does not
   109  // affect subsequent interactions with the client.
   110  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   111  	if origin == "" {
   112  		var err error
   113  		if origin, err = os.Hostname(); err != nil {
   114  			return nil, err
   115  		}
   116  		if strings.HasPrefix(endpoint, "wss") {
   117  			origin = "https://" + strings.ToLower(origin)
   118  		} else {
   119  			origin = "http://" + strings.ToLower(origin)
   120  		}
   121  	}
   122  	config, err := websocket.NewConfig(endpoint, origin)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	return newClient(ctx, func(ctx context.Context) (net.Conn, error) {
   128  		return wsDialContext(ctx, config)
   129  	})
   130  }
   131  
   132  func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) {
   133  	var conn net.Conn
   134  	var err error
   135  	switch config.Location.Scheme {
   136  	case "ws":
   137  		conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location))
   138  	case "wss":
   139  		dialer := contextDialer(ctx)
   140  		conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig)
   141  	default:
   142  		err = websocket.ErrBadScheme
   143  	}
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	ws, err := websocket.NewClient(config, conn)
   148  	if err != nil {
   149  		conn.Close()
   150  		return nil, err
   151  	}
   152  	return ws, err
   153  }
   154  
   155  var wsPortMap = map[string]string{"ws": "80", "wss": "443"}
   156  
   157  func wsDialAddress(location *url.URL) string {
   158  	if _, ok := wsPortMap[location.Scheme]; ok {
   159  		if _, _, err := net.SplitHostPort(location.Host); err != nil {
   160  			return net.JoinHostPort(location.Host, wsPortMap[location.Scheme])
   161  		}
   162  	}
   163  	return location.Host
   164  }
   165  
   166  func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   167  	d := &net.Dialer{KeepAlive: tcpKeepAliveInterval}
   168  	return d.DialContext(ctx, network, addr)
   169  }
   170  
   171  func contextDialer(ctx context.Context) *net.Dialer {
   172  	dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval}
   173  	if deadline, ok := ctx.Deadline(); ok {
   174  		dialer.Deadline = deadline
   175  	} else {
   176  		dialer.Deadline = time.Now().Add(defaultDialTimeout)
   177  	}
   178  	return dialer
   179  }