github.com/sberex/go-sberex@v1.8.2-0.20181113200658-ed96ac38f7d7/rpc/websocket.go (about)

     1  // This file is part of the go-sberex library. The go-sberex library is 
     2  // free software: you can redistribute it and/or modify it under the terms 
     3  // of the GNU Lesser General Public License as published by the Free 
     4  // Software Foundation, either version 3 of the License, or (at your option)
     5  // any later version.
     6  //
     7  // The go-sberex library is distributed in the hope that it will be useful, 
     8  // but WITHOUT ANY WARRANTY; without even the implied warranty of
     9  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser 
    10  // General Public License <http://www.gnu.org/licenses/> for more details.
    11  
    12  package rpc
    13  
    14  import (
    15  	"context"
    16  	"crypto/tls"
    17  	"fmt"
    18  	"net"
    19  	"net/http"
    20  	"net/url"
    21  	"os"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/Sberex/go-sberex/log"
    26  	"golang.org/x/net/websocket"
    27  	"gopkg.in/fatih/set.v0"
    28  )
    29  
    30  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    31  //
    32  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    33  // To allow connections with any origin, pass "*".
    34  func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    35  	return websocket.Server{
    36  		Handshake: wsHandshakeValidator(allowedOrigins),
    37  		Handler: func(conn *websocket.Conn) {
    38  			srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
    39  		},
    40  	}
    41  }
    42  
    43  // NewWSServer creates a new websocket RPC server around an API provider.
    44  //
    45  // Deprecated: use Server.WebsocketHandler
    46  func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
    47  	return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
    48  }
    49  
    50  // wsHandshakeValidator returns a handler that verifies the origin during the
    51  // websocket upgrade process. When a '*' is specified as an allowed origins all
    52  // connections are accepted.
    53  func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error {
    54  	origins := set.New()
    55  	allowAllOrigins := false
    56  
    57  	for _, origin := range allowedOrigins {
    58  		if origin == "*" {
    59  			allowAllOrigins = true
    60  		}
    61  		if origin != "" {
    62  			origins.Add(strings.ToLower(origin))
    63  		}
    64  	}
    65  
    66  	// allow localhost if no allowedOrigins are specified.
    67  	if len(origins.List()) == 0 {
    68  		origins.Add("http://localhost")
    69  		if hostname, err := os.Hostname(); err == nil {
    70  			origins.Add("http://" + strings.ToLower(hostname))
    71  		}
    72  	}
    73  
    74  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.List()))
    75  
    76  	f := func(cfg *websocket.Config, req *http.Request) error {
    77  		origin := strings.ToLower(req.Header.Get("Origin"))
    78  		if allowAllOrigins || origins.Has(origin) {
    79  			return nil
    80  		}
    81  		log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin))
    82  		return fmt.Errorf("origin %s not allowed", origin)
    83  	}
    84  
    85  	return f
    86  }
    87  
    88  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
    89  // that is listening on the given endpoint.
    90  //
    91  // The context is used for the initial connection establishment. It does not
    92  // affect subsequent interactions with the client.
    93  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
    94  	if origin == "" {
    95  		var err error
    96  		if origin, err = os.Hostname(); err != nil {
    97  			return nil, err
    98  		}
    99  		if strings.HasPrefix(endpoint, "wss") {
   100  			origin = "https://" + strings.ToLower(origin)
   101  		} else {
   102  			origin = "http://" + strings.ToLower(origin)
   103  		}
   104  	}
   105  	config, err := websocket.NewConfig(endpoint, origin)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	return newClient(ctx, func(ctx context.Context) (net.Conn, error) {
   111  		return wsDialContext(ctx, config)
   112  	})
   113  }
   114  
   115  func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) {
   116  	var conn net.Conn
   117  	var err error
   118  	switch config.Location.Scheme {
   119  	case "ws":
   120  		conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location))
   121  	case "wss":
   122  		dialer := contextDialer(ctx)
   123  		conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig)
   124  	default:
   125  		err = websocket.ErrBadScheme
   126  	}
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	ws, err := websocket.NewClient(config, conn)
   131  	if err != nil {
   132  		conn.Close()
   133  		return nil, err
   134  	}
   135  	return ws, err
   136  }
   137  
   138  var wsPortMap = map[string]string{"ws": "80", "wss": "443"}
   139  
   140  func wsDialAddress(location *url.URL) string {
   141  	if _, ok := wsPortMap[location.Scheme]; ok {
   142  		if _, _, err := net.SplitHostPort(location.Host); err != nil {
   143  			return net.JoinHostPort(location.Host, wsPortMap[location.Scheme])
   144  		}
   145  	}
   146  	return location.Host
   147  }
   148  
   149  func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   150  	d := &net.Dialer{KeepAlive: tcpKeepAliveInterval}
   151  	return d.DialContext(ctx, network, addr)
   152  }
   153  
   154  func contextDialer(ctx context.Context) *net.Dialer {
   155  	dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval}
   156  	if deadline, ok := ctx.Deadline(); ok {
   157  		dialer.Deadline = deadline
   158  	} else {
   159  		dialer.Deadline = time.Now().Add(defaultDialTimeout)
   160  	}
   161  	return dialer
   162  }