github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/dial.go (about)

     1  // +build !js
     2  
     3  package websocket
     4  
     5  import (
     6  	"bufio"
     7  	"bytes"
     8  	"context"
     9  	"crypto/rand"
    10  	"encoding/base64"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"net/http"
    15  	"net/url"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"nhooyr.io/websocket/internal/errd"
    21  )
    22  
    23  // DialOptions represents Dial's options.
    24  type DialOptions struct {
    25  	// HTTPClient is used for the connection.
    26  	// Its Transport must return writable bodies for WebSocket handshakes.
    27  	// http.Transport does beginning with Go 1.12.
    28  	HTTPClient *http.Client
    29  
    30  	// HTTPHeader specifies the HTTP headers included in the handshake request.
    31  	HTTPHeader http.Header
    32  
    33  	// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
    34  	Subprotocols []string
    35  
    36  	// CompressionMode controls the compression mode.
    37  	// Defaults to CompressionNoContextTakeover.
    38  	//
    39  	// See docs on CompressionMode for details.
    40  	CompressionMode CompressionMode
    41  
    42  	// CompressionThreshold controls the minimum size of a message before compression is applied.
    43  	//
    44  	// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
    45  	// for CompressionContextTakeover.
    46  	CompressionThreshold int
    47  }
    48  
    49  // Dial performs a WebSocket handshake on url.
    50  //
    51  // The response is the WebSocket handshake response from the server.
    52  // You never need to close resp.Body yourself.
    53  //
    54  // If an error occurs, the returned response may be non nil.
    55  // However, you can only read the first 1024 bytes of the body.
    56  //
    57  // This function requires at least Go 1.12 as it uses a new feature
    58  // in net/http to perform WebSocket handshakes.
    59  // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
    60  //
    61  // URLs with http/https schemes will work and are interpreted as ws/wss.
    62  func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
    63  	return dial(ctx, u, opts, nil)
    64  }
    65  
    66  func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
    67  	defer errd.Wrap(&err, "failed to WebSocket dial")
    68  
    69  	if opts == nil {
    70  		opts = &DialOptions{}
    71  	}
    72  
    73  	opts = &*opts
    74  	if opts.HTTPClient == nil {
    75  		opts.HTTPClient = http.DefaultClient
    76  	} else if opts.HTTPClient.Timeout > 0 {
    77  		var cancel context.CancelFunc
    78  
    79  		ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
    80  		defer cancel()
    81  
    82  		newClient := *opts.HTTPClient
    83  		newClient.Timeout = 0
    84  		opts.HTTPClient = &newClient
    85  	}
    86  
    87  	if opts.HTTPHeader == nil {
    88  		opts.HTTPHeader = http.Header{}
    89  	}
    90  
    91  	secWebSocketKey, err := secWebSocketKey(rand)
    92  	if err != nil {
    93  		return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
    94  	}
    95  
    96  	var copts *compressionOptions
    97  	if opts.CompressionMode != CompressionDisabled {
    98  		copts = opts.CompressionMode.opts()
    99  	}
   100  
   101  	resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
   102  	if err != nil {
   103  		return nil, resp, err
   104  	}
   105  	respBody := resp.Body
   106  	resp.Body = nil
   107  	defer func() {
   108  		if err != nil {
   109  			// We read a bit of the body for easier debugging.
   110  			r := io.LimitReader(respBody, 1024)
   111  
   112  			timer := time.AfterFunc(time.Second*3, func() {
   113  				respBody.Close()
   114  			})
   115  			defer timer.Stop()
   116  
   117  			b, _ := ioutil.ReadAll(r)
   118  			respBody.Close()
   119  			resp.Body = ioutil.NopCloser(bytes.NewReader(b))
   120  		}
   121  	}()
   122  
   123  	copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
   124  	if err != nil {
   125  		return nil, resp, err
   126  	}
   127  
   128  	rwc, ok := respBody.(io.ReadWriteCloser)
   129  	if !ok {
   130  		return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
   131  	}
   132  
   133  	return newConn(connConfig{
   134  		subprotocol:    resp.Header.Get("Sec-WebSocket-Protocol"),
   135  		rwc:            rwc,
   136  		client:         true,
   137  		copts:          copts,
   138  		flateThreshold: opts.CompressionThreshold,
   139  		br:             getBufioReader(rwc),
   140  		bw:             getBufioWriter(rwc),
   141  	}), resp, nil
   142  }
   143  
   144  func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
   145  	u, err := url.Parse(urls)
   146  	if err != nil {
   147  		return nil, fmt.Errorf("failed to parse url: %w", err)
   148  	}
   149  
   150  	switch u.Scheme {
   151  	case "ws":
   152  		u.Scheme = "http"
   153  	case "wss":
   154  		u.Scheme = "https"
   155  	case "http", "https":
   156  	default:
   157  		return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
   158  	}
   159  
   160  	req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
   161  	req.Header = opts.HTTPHeader.Clone()
   162  	req.Header.Set("Connection", "Upgrade")
   163  	req.Header.Set("Upgrade", "websocket")
   164  	req.Header.Set("Sec-WebSocket-Version", "13")
   165  	req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
   166  	if len(opts.Subprotocols) > 0 {
   167  		req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
   168  	}
   169  	if copts != nil {
   170  		copts.setHeader(req.Header)
   171  	}
   172  
   173  	resp, err := opts.HTTPClient.Do(req)
   174  	if err != nil {
   175  		return nil, fmt.Errorf("failed to send handshake request: %w", err)
   176  	}
   177  	return resp, nil
   178  }
   179  
   180  func secWebSocketKey(rr io.Reader) (string, error) {
   181  	if rr == nil {
   182  		rr = rand.Reader
   183  	}
   184  	b := make([]byte, 16)
   185  	_, err := io.ReadFull(rr, b)
   186  	if err != nil {
   187  		return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
   188  	}
   189  	return base64.StdEncoding.EncodeToString(b), nil
   190  }
   191  
   192  func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
   193  	if resp.StatusCode != http.StatusSwitchingProtocols {
   194  		return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
   195  	}
   196  
   197  	if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
   198  		return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
   199  	}
   200  
   201  	if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
   202  		return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
   203  	}
   204  
   205  	if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
   206  		return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
   207  			resp.Header.Get("Sec-WebSocket-Accept"),
   208  			secWebSocketKey,
   209  		)
   210  	}
   211  
   212  	err := verifySubprotocol(opts.Subprotocols, resp)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  
   217  	return verifyServerExtensions(copts, resp.Header)
   218  }
   219  
   220  func verifySubprotocol(subprotos []string, resp *http.Response) error {
   221  	proto := resp.Header.Get("Sec-WebSocket-Protocol")
   222  	if proto == "" {
   223  		return nil
   224  	}
   225  
   226  	for _, sp2 := range subprotos {
   227  		if strings.EqualFold(sp2, proto) {
   228  			return nil
   229  		}
   230  	}
   231  
   232  	return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
   233  }
   234  
   235  func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
   236  	exts := websocketExtensions(h)
   237  	if len(exts) == 0 {
   238  		return nil, nil
   239  	}
   240  
   241  	ext := exts[0]
   242  	if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
   243  		return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
   244  	}
   245  
   246  	copts = &*copts
   247  
   248  	for _, p := range ext.params {
   249  		switch p {
   250  		case "client_no_context_takeover":
   251  			copts.clientNoContextTakeover = true
   252  			continue
   253  		case "server_no_context_takeover":
   254  			copts.serverNoContextTakeover = true
   255  			continue
   256  		}
   257  
   258  		return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
   259  	}
   260  
   261  	return copts, nil
   262  }
   263  
   264  var bufioReaderPool sync.Pool
   265  
   266  func getBufioReader(r io.Reader) *bufio.Reader {
   267  	br, ok := bufioReaderPool.Get().(*bufio.Reader)
   268  	if !ok {
   269  		return bufio.NewReader(r)
   270  	}
   271  	br.Reset(r)
   272  	return br
   273  }
   274  
   275  func putBufioReader(br *bufio.Reader) {
   276  	bufioReaderPool.Put(br)
   277  }
   278  
   279  var bufioWriterPool sync.Pool
   280  
   281  func getBufioWriter(w io.Writer) *bufio.Writer {
   282  	bw, ok := bufioWriterPool.Get().(*bufio.Writer)
   283  	if !ok {
   284  		return bufio.NewWriter(w)
   285  	}
   286  	bw.Reset(w)
   287  	return bw
   288  }
   289  
   290  func putBufioWriter(bw *bufio.Writer) {
   291  	bufioWriterPool.Put(bw)
   292  }