gitlab.com/aquachain/aquachain@v1.17.16-rc3.0.20221018032414-e3ddf1e1c055/rpc/rpcclient/websocket.go (about)

     1  // Copyright 2018 The aquachain Authors
     2  // This file is part of the aquachain library.
     3  //
     4  // The aquachain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The aquachain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"net"
    23  	"net/url"
    24  	"os"
    25  	"strings"
    26  	"time"
    27  
    28  	"golang.org/x/net/websocket"
    29  )
    30  
    31  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
    32  // that is listening on the given endpoint.
    33  //
    34  // The context is used for the initial connection establishment. It does not
    35  // affect subsequent interactions with the client.
    36  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
    37  	if origin == "" {
    38  		var err error
    39  		if origin, err = os.Hostname(); err != nil {
    40  			return nil, err
    41  		}
    42  		if strings.HasPrefix(endpoint, "wss") {
    43  			origin = "https://" + strings.ToLower(origin)
    44  		} else {
    45  			origin = "http://" + strings.ToLower(origin)
    46  		}
    47  	}
    48  	config, err := websocket.NewConfig(endpoint, origin)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	return newClient(ctx, func(ctx context.Context) (net.Conn, error) {
    54  		return wsDialContext(ctx, config)
    55  	})
    56  }
    57  
    58  func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) {
    59  	var conn net.Conn
    60  	var err error
    61  	switch config.Location.Scheme {
    62  	case "ws":
    63  		conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location))
    64  	case "wss":
    65  		dialer := contextDialer(ctx)
    66  		conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig)
    67  	default:
    68  		err = websocket.ErrBadScheme
    69  	}
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	ws, err := websocket.NewClient(config, conn)
    74  	if err != nil {
    75  		conn.Close()
    76  		return nil, err
    77  	}
    78  	return ws, err
    79  }
    80  
    81  var wsPortMap = map[string]string{"ws": "80", "wss": "443"}
    82  
    83  func wsDialAddress(location *url.URL) string {
    84  	if _, ok := wsPortMap[location.Scheme]; ok {
    85  		if _, _, err := net.SplitHostPort(location.Host); err != nil {
    86  			return net.JoinHostPort(location.Host, wsPortMap[location.Scheme])
    87  		}
    88  	}
    89  	return location.Host
    90  }
    91  
    92  func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
    93  	d := &net.Dialer{KeepAlive: tcpKeepAliveInterval}
    94  	return d.DialContext(ctx, network, addr)
    95  }
    96  
    97  func contextDialer(ctx context.Context) *net.Dialer {
    98  	dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval}
    99  	if deadline, ok := ctx.Deadline(); ok {
   100  		dialer.Deadline = deadline
   101  	} else {
   102  		dialer.Deadline = time.Now().Add(defaultDialTimeout)
   103  	}
   104  	return dialer
   105  }