github.com/puppeth/go-ethereum@v0.8.6-0.20171014130046-e9295163aa25/rpc/websocket.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"fmt"
    23  	"net"
    24  	"net/http"
    25  	"net/url"
    26  	"os"
    27  	"strings"
    28  	"time"
    29  
    30  	"github.com/ethereum/go-ethereum/log"
    31  	"golang.org/x/net/websocket"
    32  	"gopkg.in/fatih/set.v0"
    33  )
    34  
    35  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    36  //
    37  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    38  // To allow connections with any origin, pass "*".
    39  func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    40  	return websocket.Server{
    41  		Handshake: wsHandshakeValidator(allowedOrigins),
    42  		Handler: func(conn *websocket.Conn) {
    43  			srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
    44  		},
    45  	}
    46  }
    47  
    48  // NewWSServer creates a new websocket RPC server around an API provider.
    49  //
    50  // Deprecated: use Server.WebsocketHandler
    51  func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
    52  	return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
    53  }
    54  
    55  // wsHandshakeValidator returns a handler that verifies the origin during the
    56  // websocket upgrade process. When a '*' is specified as an allowed origins all
    57  // connections are accepted.
    58  func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error {
    59  	origins := set.New()
    60  	allowAllOrigins := false
    61  
    62  	for _, origin := range allowedOrigins {
    63  		if origin == "*" {
    64  			allowAllOrigins = true
    65  		}
    66  		if origin != "" {
    67  			origins.Add(strings.ToLower(origin))
    68  		}
    69  	}
    70  
    71  	// allow localhost if no allowedOrigins are specified.
    72  	if len(origins.List()) == 0 {
    73  		origins.Add("http://localhost")
    74  		if hostname, err := os.Hostname(); err == nil {
    75  			origins.Add("http://" + strings.ToLower(hostname))
    76  		}
    77  	}
    78  
    79  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.List()))
    80  
    81  	f := func(cfg *websocket.Config, req *http.Request) error {
    82  		origin := strings.ToLower(req.Header.Get("Origin"))
    83  		if allowAllOrigins || origins.Has(origin) {
    84  			return nil
    85  		}
    86  		log.Debug(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin))
    87  		return fmt.Errorf("origin %s not allowed", origin)
    88  	}
    89  
    90  	return f
    91  }
    92  
    93  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
    94  // that is listening on the given endpoint.
    95  //
    96  // The context is used for the initial connection establishment. It does not
    97  // affect subsequent interactions with the client.
    98  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
    99  	if origin == "" {
   100  		var err error
   101  		if origin, err = os.Hostname(); err != nil {
   102  			return nil, err
   103  		}
   104  		if strings.HasPrefix(endpoint, "wss") {
   105  			origin = "https://" + strings.ToLower(origin)
   106  		} else {
   107  			origin = "http://" + strings.ToLower(origin)
   108  		}
   109  	}
   110  	config, err := websocket.NewConfig(endpoint, origin)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	return newClient(ctx, func(ctx context.Context) (net.Conn, error) {
   116  		return wsDialContext(ctx, config)
   117  	})
   118  }
   119  
   120  func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) {
   121  	var conn net.Conn
   122  	var err error
   123  	switch config.Location.Scheme {
   124  	case "ws":
   125  		conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location))
   126  	case "wss":
   127  		dialer := contextDialer(ctx)
   128  		conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig)
   129  	default:
   130  		err = websocket.ErrBadScheme
   131  	}
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	ws, err := websocket.NewClient(config, conn)
   136  	if err != nil {
   137  		conn.Close()
   138  		return nil, err
   139  	}
   140  	return ws, err
   141  }
   142  
   143  var wsPortMap = map[string]string{"ws": "80", "wss": "443"}
   144  
   145  func wsDialAddress(location *url.URL) string {
   146  	if _, ok := wsPortMap[location.Scheme]; ok {
   147  		if _, _, err := net.SplitHostPort(location.Host); err != nil {
   148  			return net.JoinHostPort(location.Host, wsPortMap[location.Scheme])
   149  		}
   150  	}
   151  	return location.Host
   152  }
   153  
   154  func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   155  	d := &net.Dialer{KeepAlive: tcpKeepAliveInterval}
   156  	return d.DialContext(ctx, network, addr)
   157  }
   158  
   159  func contextDialer(ctx context.Context) *net.Dialer {
   160  	dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval}
   161  	if deadline, ok := ctx.Deadline(); ok {
   162  		dialer.Deadline = deadline
   163  	} else {
   164  		dialer.Deadline = time.Now().Add(defaultDialTimeout)
   165  	}
   166  	return dialer
   167  }