github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/dialer.go (about)

     1  package ws
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"crypto/tls"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/url"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/ezoic/httphead"
    17  	"github.com/ezoic/pool/pbufio"
    18  )
    19  
    20  // Constants used by Dialer.
    21  const (
    22  	DefaultClientReadBufferSize  = 4096
    23  	DefaultClientWriteBufferSize = 4096
    24  )
    25  
    26  // Handshake represents handshake result.
    27  type Handshake struct {
    28  	// Protocol is the subprotocol selected during handshake.
    29  	Protocol string
    30  
    31  	// Extensions is the list of negotiated extensions.
    32  	Extensions []httphead.Option
    33  }
    34  
    35  // Errors used by the websocket client.
    36  var (
    37  	ErrHandshakeBadStatus      = fmt.Errorf("unexpected http status")
    38  	ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
    39  	ErrHandshakeBadExtensions  = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
    40  )
    41  
    42  // DefaultDialer is dialer that holds no options and is used by Dial function.
    43  var DefaultDialer Dialer
    44  
    45  // Dial is like Dialer{}.Dial().
    46  func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
    47  	return DefaultDialer.Dial(ctx, urlstr)
    48  }
    49  
    50  // Dialer contains options for establishing websocket connection to an url.
    51  type Dialer struct {
    52  	// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
    53  	// They used to read and write http data while upgrading to WebSocket.
    54  	// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
    55  	//
    56  	// If a size is zero then default value is used.
    57  	ReadBufferSize, WriteBufferSize int
    58  
    59  	// Timeout is the maximum amount of time a Dial() will wait for a connect
    60  	// and an handshake to complete.
    61  	//
    62  	// The default is no timeout.
    63  	Timeout time.Duration
    64  
    65  	// Protocols is the list of subprotocols that the client wants to speak,
    66  	// ordered by preference.
    67  	//
    68  	// See https://tools.ietf.org/html/rfc6455#section-4.1
    69  	Protocols []string
    70  
    71  	// Extensions is the list of extensions that client wants to speak.
    72  	//
    73  	// Note that if server decides to use some of this extensions, Dial() will
    74  	// return Handshake struct containing a slice of items, which are the
    75  	// shallow copies of the items from this list. That is, internals of
    76  	// Extensions items are shared during Dial().
    77  	//
    78  	// See https://tools.ietf.org/html/rfc6455#section-4.1
    79  	// See https://tools.ietf.org/html/rfc6455#section-9.1
    80  	Extensions []httphead.Option
    81  
    82  	// Header is an optional HandshakeHeader instance that could be used to
    83  	// write additional headers to the handshake request.
    84  	//
    85  	// It used instead of any key-value mappings to avoid allocations in user
    86  	// land.
    87  	Header HandshakeHeader
    88  
    89  	// OnStatusError is the callback that will be called after receiving non
    90  	// "101 Continue" HTTP response status. It receives an io.Reader object
    91  	// representing server response bytes. That is, it gives ability to parse
    92  	// HTTP response somehow (probably with http.ReadResponse call) and make a
    93  	// decision of further logic.
    94  	//
    95  	// The arguments are only valid until the callback returns.
    96  	OnStatusError func(status int, reason []byte, resp io.Reader)
    97  
    98  	// OnHeader is the callback that will be called after successful parsing of
    99  	// header, that is not used during WebSocket handshake procedure. That is,
   100  	// it will be called with non-websocket headers, which could be relevant
   101  	// for application-level logic.
   102  	//
   103  	// The arguments are only valid until the callback returns.
   104  	//
   105  	// Returned value could be used to prevent processing response.
   106  	OnHeader func(key, value []byte) (err error)
   107  
   108  	// NetDial is the function that is used to get plain tcp connection.
   109  	// If it is not nil, then it is used instead of net.Dialer.
   110  	NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
   111  
   112  	// TLSClient is the callback that will be called after successful dial with
   113  	// received connection and its remote host name. If it is nil, then the
   114  	// default tls.Client() will be used.
   115  	// If it is not nil, then TLSConfig field is ignored.
   116  	TLSClient func(conn net.Conn, hostname string) net.Conn
   117  
   118  	// TLSConfig is passed to tls.Client() to start TLS over established
   119  	// connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
   120  	// non-nil and its ServerName is empty, then for every Dial() it will be
   121  	// cloned and appropriate ServerName will be set.
   122  	TLSConfig *tls.Config
   123  
   124  	// WrapConn is the optional callback that will be called when connection is
   125  	// ready for an i/o. That is, it will be called after successful dial and
   126  	// TLS initialization (for "wss" schemes). It may be helpful for different
   127  	// user land purposes such as end to end encryption.
   128  	//
   129  	// Note that for debugging purposes of an http handshake (e.g. sent request
   130  	// and received response), there is an wsutil.DebugDialer struct.
   131  	WrapConn func(conn net.Conn) net.Conn
   132  }
   133  
   134  // Dial connects to the url host and upgrades connection to WebSocket.
   135  //
   136  // If server has sent frames right after successful handshake then returned
   137  // buffer will be non-nil. In other cases buffer is always nil. For better
   138  // memory efficiency received non-nil bufio.Reader should be returned to the
   139  // inner pool with PutReader() function after use.
   140  //
   141  // Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
   142  // If you want to dial non-ascii host name, take care of its name serialization
   143  // avoiding bad request issues. For more info see net/http Request.Write()
   144  // implementation, especially cleanHost() function.
   145  func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
   146  	u, err := url.ParseRequestURI(urlstr)
   147  	if err != nil {
   148  		return
   149  	}
   150  
   151  	// Prepare context to dial with. Initially it is the same as original, but
   152  	// if d.Timeout is non-zero and points to time that is before ctx.Deadline,
   153  	// we use more shorter context for dial.
   154  	dialctx := ctx
   155  
   156  	var deadline time.Time
   157  	if t := d.Timeout; t != 0 {
   158  		deadline = time.Now().Add(t)
   159  		if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
   160  			var cancel context.CancelFunc
   161  			dialctx, cancel = context.WithDeadline(ctx, deadline)
   162  			defer cancel()
   163  		}
   164  	}
   165  	if conn, err = d.dial(dialctx, u); err != nil {
   166  		return
   167  	}
   168  	defer func() {
   169  		if err != nil {
   170  			conn.Close()
   171  		}
   172  	}()
   173  	if ctx == context.Background() {
   174  		// No need to start I/O interrupter goroutine which is not zero-cost.
   175  		conn.SetDeadline(deadline)
   176  		defer conn.SetDeadline(noDeadline)
   177  	} else {
   178  		// Context could be canceled or its deadline could be exceeded.
   179  		// Start the interrupter goroutine to handle context cancelation.
   180  		done := setupContextDeadliner(ctx, conn)
   181  		defer func() {
   182  			// Map Upgrade() error to a possible context expiration error. That
   183  			// is, even if Upgrade() err is nil, context could be already
   184  			// expired and connection be "poisoned" by SetDeadline() call.
   185  			// In that case we must not return ctx.Err() error.
   186  			done(&err)
   187  		}()
   188  	}
   189  
   190  	br, hs, err = d.Upgrade(conn, u)
   191  
   192  	return
   193  }
   194  
   195  var (
   196  	// netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
   197  	// Dialer.NetDial is not provided.
   198  	netEmptyDialer net.Dialer
   199  	// tlsEmptyConfig is an empty tls.Config used as default one.
   200  	tlsEmptyConfig tls.Config
   201  )
   202  
   203  func tlsDefaultConfig() *tls.Config {
   204  	return &tlsEmptyConfig
   205  }
   206  
   207  func hostport(host string, defaultPort string) (hostname, addr string) {
   208  	var (
   209  		colon   = strings.LastIndexByte(host, ':')
   210  		bracket = strings.IndexByte(host, ']')
   211  	)
   212  	if colon > bracket {
   213  		return host[:colon], host
   214  	}
   215  	return host, host + defaultPort
   216  }
   217  
   218  func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
   219  	dial := d.NetDial
   220  	if dial == nil {
   221  		dial = netEmptyDialer.DialContext
   222  	}
   223  	switch u.Scheme {
   224  	case "ws":
   225  		_, addr := hostport(u.Host, ":80")
   226  		conn, err = dial(ctx, "tcp", addr)
   227  	case "wss":
   228  		hostname, addr := hostport(u.Host, ":443")
   229  		conn, err = dial(ctx, "tcp", addr)
   230  		if err != nil {
   231  			return
   232  		}
   233  		tlsClient := d.TLSClient
   234  		if tlsClient == nil {
   235  			tlsClient = d.tlsClient
   236  		}
   237  		conn = tlsClient(conn, hostname)
   238  	default:
   239  		return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
   240  	}
   241  	if wrap := d.WrapConn; wrap != nil {
   242  		conn = wrap(conn)
   243  	}
   244  	return
   245  }
   246  
   247  func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
   248  	config := d.TLSConfig
   249  	if config == nil {
   250  		config = tlsDefaultConfig()
   251  	}
   252  	if config.ServerName == "" {
   253  		config = tlsCloneConfig(config)
   254  		config.ServerName = hostname
   255  	}
   256  	// Do not make conn.Handshake() here because downstairs we will prepare
   257  	// i/o on this conn with proper context's timeout handling.
   258  	return tls.Client(conn, config)
   259  }
   260  
   261  var (
   262  	// This variables are set like in net/net.go.
   263  	// noDeadline is just zero value for readability.
   264  	noDeadline = time.Time{}
   265  	// aLongTimeAgo is a non-zero time, far in the past, used for immediate
   266  	// cancelation of dials.
   267  	aLongTimeAgo = time.Unix(42, 0)
   268  )
   269  
   270  // Upgrade writes an upgrade request to the given io.ReadWriter conn at given
   271  // url u and reads a response from it.
   272  //
   273  // It is a caller responsibility to manage I/O deadlines on conn.
   274  //
   275  // It returns handshake info and some bytes which could be written by the peer
   276  // right after response and be caught by us during buffered read.
   277  func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
   278  	// headerSeen constants helps to report whether or not some header was seen
   279  	// during reading request bytes.
   280  	const (
   281  		headerSeenUpgrade = 1 << iota
   282  		headerSeenConnection
   283  		headerSeenSecAccept
   284  
   285  		// headerSeenAll is the value that we expect to receive at the end of
   286  		// headers read/parse loop.
   287  		headerSeenAll = 0 |
   288  			headerSeenUpgrade |
   289  			headerSeenConnection |
   290  			headerSeenSecAccept
   291  	)
   292  
   293  	br = pbufio.GetReader(conn,
   294  		nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
   295  	)
   296  	bw := pbufio.GetWriter(conn,
   297  		nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
   298  	)
   299  	defer func() {
   300  		pbufio.PutWriter(bw)
   301  		if br.Buffered() == 0 || err != nil {
   302  			// Server does not wrote additional bytes to the connection or
   303  			// error occurred. That is, no reason to return buffer.
   304  			pbufio.PutReader(br)
   305  			br = nil
   306  		}
   307  	}()
   308  
   309  	nonce := make([]byte, nonceSize)
   310  	initNonce(nonce)
   311  
   312  	httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
   313  	if err = bw.Flush(); err != nil {
   314  		return
   315  	}
   316  
   317  	// Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
   318  	sl, err := readLine(br)
   319  	if err != nil {
   320  		return
   321  	}
   322  	// Begin validation of the response.
   323  	// See https://tools.ietf.org/html/rfc6455#section-4.2.2
   324  	// Parse request line data like HTTP version, uri and method.
   325  	resp, err := httpParseResponseLine(sl)
   326  	if err != nil {
   327  		return
   328  	}
   329  	// Even if RFC says "1.1 or higher" without mentioning the part of the
   330  	// version, we apply it only to minor part.
   331  	if resp.major != 1 || resp.minor < 1 {
   332  		err = ErrHandshakeBadProtocol
   333  		return
   334  	}
   335  	if resp.status != 101 {
   336  		err = StatusError(resp.status)
   337  		if onStatusError := d.OnStatusError; onStatusError != nil {
   338  			// Invoke callback with multireader of status-line bytes br.
   339  			onStatusError(resp.status, resp.reason,
   340  				io.MultiReader(
   341  					bytes.NewReader(sl),
   342  					strings.NewReader(crlf),
   343  					br,
   344  				),
   345  			)
   346  		}
   347  		return
   348  	}
   349  	// If response status is 101 then we expect all technical headers to be
   350  	// valid. If not, then we stop processing response without giving user
   351  	// ability to read non-technical headers. That is, we do not distinguish
   352  	// technical errors (such as parsing error) and protocol errors.
   353  	var headerSeen byte
   354  	for {
   355  		line, e := readLine(br)
   356  		if e != nil {
   357  			err = e
   358  			return
   359  		}
   360  		if len(line) == 0 {
   361  			// Blank line, no more lines to read.
   362  			break
   363  		}
   364  
   365  		k, v, ok := httpParseHeaderLine(line)
   366  		if !ok {
   367  			err = ErrMalformedResponse
   368  			return
   369  		}
   370  
   371  		switch btsToString(k) {
   372  		case headerUpgradeCanonical:
   373  			headerSeen |= headerSeenUpgrade
   374  			if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
   375  				err = ErrHandshakeBadUpgrade
   376  				return
   377  			}
   378  
   379  		case headerConnectionCanonical:
   380  			headerSeen |= headerSeenConnection
   381  			// Note that as RFC6455 says:
   382  			//   > A |Connection| header field with value "Upgrade".
   383  			// That is, in server side, "Connection" header could contain
   384  			// multiple token. But in response it must contains exactly one.
   385  			if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
   386  				err = ErrHandshakeBadConnection
   387  				return
   388  			}
   389  
   390  		case headerSecAcceptCanonical:
   391  			headerSeen |= headerSeenSecAccept
   392  			if !checkAcceptFromNonce(v, nonce) {
   393  				err = ErrHandshakeBadSecAccept
   394  				return
   395  			}
   396  
   397  		case headerSecProtocolCanonical:
   398  			// RFC6455 1.3:
   399  			//   "The server selects one or none of the acceptable protocols
   400  			//   and echoes that value in its handshake to indicate that it has
   401  			//   selected that protocol."
   402  			for _, want := range d.Protocols {
   403  				if string(v) == want {
   404  					hs.Protocol = want
   405  					break
   406  				}
   407  			}
   408  			if hs.Protocol == "" {
   409  				// Server echoed subprotocol that is not present in client
   410  				// requested protocols.
   411  				err = ErrHandshakeBadSubProtocol
   412  				return
   413  			}
   414  
   415  		case headerSecExtensionsCanonical:
   416  			hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
   417  			if err != nil {
   418  				return
   419  			}
   420  
   421  		default:
   422  			if onHeader := d.OnHeader; onHeader != nil {
   423  				if e := onHeader(k, v); e != nil {
   424  					err = e
   425  					return
   426  				}
   427  			}
   428  		}
   429  	}
   430  	if err == nil && headerSeen != headerSeenAll {
   431  		switch {
   432  		case headerSeen&headerSeenUpgrade == 0:
   433  			err = ErrHandshakeBadUpgrade
   434  		case headerSeen&headerSeenConnection == 0:
   435  			err = ErrHandshakeBadConnection
   436  		case headerSeen&headerSeenSecAccept == 0:
   437  			err = ErrHandshakeBadSecAccept
   438  		default:
   439  			panic("unknown headers state")
   440  		}
   441  	}
   442  	return
   443  }
   444  
   445  // PutReader returns bufio.Reader instance to the inner reuse pool.
   446  // It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
   447  // contains unprocessed buffered data, that was sent by the server quickly
   448  // right after handshake.
   449  func PutReader(br *bufio.Reader) {
   450  	pbufio.PutReader(br)
   451  }
   452  
   453  // StatusError contains an unexpected status-line code from the server.
   454  type StatusError int
   455  
   456  func (s StatusError) Error() string {
   457  	return "unexpected HTTP response status: " + strconv.Itoa(int(s))
   458  }
   459  
   460  func isTimeoutError(err error) bool {
   461  	t, ok := err.(net.Error)
   462  	return ok && t.Timeout()
   463  }
   464  
   465  func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
   466  	if len(selected) == 0 {
   467  		return received, nil
   468  	}
   469  	var (
   470  		index  int
   471  		option httphead.Option
   472  		err    error
   473  	)
   474  	index = -1
   475  	match := func() (ok bool) {
   476  		for _, want := range wanted {
   477  			if option.Equal(want) {
   478  				// Check parsed extension to be present in client
   479  				// requested extensions. We move matched extension
   480  				// from client list to avoid allocation.
   481  				received = append(received, want)
   482  				return true
   483  			}
   484  		}
   485  		return false
   486  	}
   487  	ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
   488  		if i != index {
   489  			// Met next option.
   490  			index = i
   491  			if i != 0 && !match() {
   492  				// Server returned non-requested extension.
   493  				err = ErrHandshakeBadExtensions
   494  				return httphead.ControlBreak
   495  			}
   496  			option = httphead.Option{Name: name}
   497  		}
   498  		if attr != nil {
   499  			option.Parameters.Set(attr, val)
   500  		}
   501  		return httphead.ControlContinue
   502  	})
   503  	if !ok {
   504  		err = ErrMalformedResponse
   505  		return received, err
   506  	}
   507  	if !match() {
   508  		return received, ErrHandshakeBadExtensions
   509  	}
   510  	return received, err
   511  }
   512  
   513  // setupContextDeadliner is a helper function that starts connection I/O
   514  // interrupter goroutine.
   515  //
   516  // Started goroutine calls SetDeadline() with long time ago value when context
   517  // become expired to make any I/O operations failed. It returns done function
   518  // that stops started goroutine and maps error received from conn I/O methods
   519  // to possible context expiration error.
   520  //
   521  // In concern with possible SetDeadline() call inside interrupter goroutine,
   522  // caller passes pointer to its I/O error (even if it is nil) to done(&err).
   523  // That is, even if I/O error is nil, context could be already expired and
   524  // connection "poisoned" by SetDeadline() call. In that case done(&err) will
   525  // store at *err ctx.Err() result. If err is caused not by timeout, it will
   526  // leaved untouched.
   527  func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
   528  	var (
   529  		quit      = make(chan struct{})
   530  		interrupt = make(chan error, 1)
   531  	)
   532  	go func() {
   533  		select {
   534  		case <-quit:
   535  			interrupt <- nil
   536  		case <-ctx.Done():
   537  			// Cancel i/o immediately.
   538  			conn.SetDeadline(aLongTimeAgo)
   539  			interrupt <- ctx.Err()
   540  		}
   541  	}()
   542  	return func(err *error) {
   543  		close(quit)
   544  		// If ctx.Err() is non-nil and the original err is net.Error with
   545  		// Timeout() == true, then it means that I/O was canceled by us by
   546  		// SetDeadline(aLongTimeAgo) call, or by somebody else previously
   547  		// by conn.SetDeadline(x).
   548  		//
   549  		// Even on race condition when both deadlines are expired
   550  		// (SetDeadline() made not by us and context's), we prefer ctx.Err() to
   551  		// be returned.
   552  		if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
   553  			*err = ctxErr
   554  		}
   555  	}
   556  }