github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/http.go (about)

     1  package ws
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"io"
     7  	"net/http"
     8  	"net/textproto"
     9  	"net/url"
    10  	"strconv"
    11  
    12  	"github.com/gobwas/httphead"
    13  )
    14  
    15  const (
    16  	crlf          = "\r\n"
    17  	colonAndSpace = ": "
    18  	commaAndSpace = ", "
    19  )
    20  
    21  const (
    22  	textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
    23  )
    24  
    25  var (
    26  	textHeadBadRequest          = statusText(http.StatusBadRequest)
    27  	textHeadInternalServerError = statusText(http.StatusInternalServerError)
    28  	textHeadUpgradeRequired     = statusText(http.StatusUpgradeRequired)
    29  
    30  	textTailErrHandshakeBadProtocol   = errorText(ErrHandshakeBadProtocol)
    31  	textTailErrHandshakeBadMethod     = errorText(ErrHandshakeBadMethod)
    32  	textTailErrHandshakeBadHost       = errorText(ErrHandshakeBadHost)
    33  	textTailErrHandshakeBadUpgrade    = errorText(ErrHandshakeBadUpgrade)
    34  	textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
    35  	textTailErrHandshakeBadSecAccept  = errorText(ErrHandshakeBadSecAccept)
    36  	textTailErrHandshakeBadSecKey     = errorText(ErrHandshakeBadSecKey)
    37  	textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
    38  	textTailErrUpgradeRequired        = errorText(ErrHandshakeUpgradeRequired)
    39  )
    40  
    41  var (
    42  	headerHost          = "Host"
    43  	headerUpgrade       = "Upgrade"
    44  	headerConnection    = "Connection"
    45  	headerSecVersion    = "Sec-WebSocket-Version"
    46  	headerSecProtocol   = "Sec-WebSocket-Protocol"
    47  	headerSecExtensions = "Sec-WebSocket-Extensions"
    48  	headerSecKey        = "Sec-WebSocket-Key"
    49  	headerSecAccept     = "Sec-WebSocket-Accept"
    50  
    51  	headerHostCanonical          = textproto.CanonicalMIMEHeaderKey(headerHost)
    52  	headerUpgradeCanonical       = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
    53  	headerConnectionCanonical    = textproto.CanonicalMIMEHeaderKey(headerConnection)
    54  	headerSecVersionCanonical    = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
    55  	headerSecProtocolCanonical   = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
    56  	headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
    57  	headerSecKeyCanonical        = textproto.CanonicalMIMEHeaderKey(headerSecKey)
    58  	headerSecAcceptCanonical     = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
    59  )
    60  
    61  var (
    62  	specHeaderValueUpgrade         = []byte("websocket")
    63  	specHeaderValueConnection      = []byte("Upgrade")
    64  	specHeaderValueConnectionLower = []byte("upgrade")
    65  	specHeaderValueSecVersion      = []byte("13")
    66  )
    67  
    68  var (
    69  	httpVersion1_0    = []byte("HTTP/1.0")
    70  	httpVersion1_1    = []byte("HTTP/1.1")
    71  	httpVersionPrefix = []byte("HTTP/")
    72  )
    73  
    74  type httpRequestLine struct {
    75  	method, uri  []byte
    76  	major, minor int
    77  }
    78  
    79  type httpResponseLine struct {
    80  	major, minor int
    81  	status       int
    82  	reason       []byte
    83  }
    84  
    85  // httpParseRequestLine parses http request line like "GET / HTTP/1.0".
    86  func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
    87  	var proto []byte
    88  	req.method, req.uri, proto = bsplit3(line, ' ')
    89  
    90  	var ok bool
    91  	req.major, req.minor, ok = httpParseVersion(proto)
    92  	if !ok {
    93  		err = ErrMalformedRequest
    94  		return
    95  	}
    96  
    97  	return
    98  }
    99  
   100  func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
   101  	var (
   102  		proto  []byte
   103  		status []byte
   104  	)
   105  	proto, status, resp.reason = bsplit3(line, ' ')
   106  
   107  	var ok bool
   108  	resp.major, resp.minor, ok = httpParseVersion(proto)
   109  	if !ok {
   110  		return resp, ErrMalformedResponse
   111  	}
   112  
   113  	var convErr error
   114  	resp.status, convErr = asciiToInt(status)
   115  	if convErr != nil {
   116  		return resp, ErrMalformedResponse
   117  	}
   118  
   119  	return resp, nil
   120  }
   121  
   122  // httpParseVersion parses major and minor version of HTTP protocol. It returns
   123  // parsed values and true if parse is ok.
   124  func httpParseVersion(bts []byte) (major, minor int, ok bool) {
   125  	switch {
   126  	case bytes.Equal(bts, httpVersion1_0):
   127  		return 1, 0, true
   128  	case bytes.Equal(bts, httpVersion1_1):
   129  		return 1, 1, true
   130  	case len(bts) < 8:
   131  		return
   132  	case !bytes.Equal(bts[:5], httpVersionPrefix):
   133  		return
   134  	}
   135  
   136  	bts = bts[5:]
   137  
   138  	dot := bytes.IndexByte(bts, '.')
   139  	if dot == -1 {
   140  		return
   141  	}
   142  	var err error
   143  	major, err = asciiToInt(bts[:dot])
   144  	if err != nil {
   145  		return
   146  	}
   147  	minor, err = asciiToInt(bts[dot+1:])
   148  	if err != nil {
   149  		return
   150  	}
   151  
   152  	return major, minor, true
   153  }
   154  
   155  // httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
   156  // values and true if parse is ok.
   157  func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
   158  	colon := bytes.IndexByte(line, ':')
   159  	if colon == -1 {
   160  		return
   161  	}
   162  
   163  	k = btrim(line[:colon])
   164  	// TODO(gobwas): maybe use just lower here?
   165  	canonicalizeHeaderKey(k)
   166  
   167  	v = btrim(line[colon+1:])
   168  
   169  	return k, v, true
   170  }
   171  
   172  // httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
   173  // that key is already canonical. This helps to increase performance.
   174  func httpGetHeader(h http.Header, key string) string {
   175  	if h == nil {
   176  		return ""
   177  	}
   178  	v := h[key]
   179  	if len(v) == 0 {
   180  		return ""
   181  	}
   182  	return v[0]
   183  }
   184  
   185  // The request MAY include a header field with the name
   186  // |Sec-WebSocket-Protocol|.  If present, this value indicates one or more
   187  // comma-separated subprotocol the client wishes to speak, ordered by
   188  // preference.  The elements that comprise this value MUST be non-empty strings
   189  // with characters in the range U+0021 to U+007E not including separator
   190  // characters as defined in [RFC2616] and MUST all be unique strings.  The ABNF
   191  // for the value of this header field is 1#token, where the definitions of
   192  // constructs and rules are as given in [RFC2616].
   193  func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
   194  	ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
   195  		if check(btsToString(v)) {
   196  			ret = string(v)
   197  			return false
   198  		}
   199  		return true
   200  	})
   201  	return
   202  }
   203  func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
   204  	var selected []byte
   205  	ok = httphead.ScanTokens(h, func(v []byte) bool {
   206  		if check(v) {
   207  			selected = v
   208  			return false
   209  		}
   210  		return true
   211  	})
   212  	if ok && selected != nil {
   213  		return string(selected), true
   214  	}
   215  	return
   216  }
   217  
   218  func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
   219  	s := httphead.OptionSelector{
   220  		Flags: httphead.SelectCopy,
   221  		Check: check,
   222  	}
   223  	return s.Select(h, selected)
   224  }
   225  
   226  func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) {
   227  	if in.Size() == 0 {
   228  		return dest, nil
   229  	}
   230  	opt, err := f(in)
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  	if opt.Size() > 0 {
   235  		dest = append(dest, opt)
   236  	}
   237  	return dest, nil
   238  }
   239  
   240  func negotiateExtensions(
   241  	h []byte, dest []httphead.Option,
   242  	f func(httphead.Option) (httphead.Option, error),
   243  ) (_ []httphead.Option, err error) {
   244  	index := -1
   245  	var current httphead.Option
   246  	ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control {
   247  		if i != index {
   248  			dest, err = negotiateMaybe(current, dest, f)
   249  			if err != nil {
   250  				return httphead.ControlBreak
   251  			}
   252  			index = i
   253  			current = httphead.Option{Name: name}
   254  		}
   255  		if attr != nil {
   256  			current.Parameters.Set(attr, val)
   257  		}
   258  		return httphead.ControlContinue
   259  	})
   260  	if !ok {
   261  		return nil, ErrMalformedRequest
   262  	}
   263  	return negotiateMaybe(current, dest, f)
   264  }
   265  
   266  func httpWriteHeader(bw *bufio.Writer, key, value string) {
   267  	httpWriteHeaderKey(bw, key)
   268  	bw.WriteString(value)
   269  	bw.WriteString(crlf)
   270  }
   271  
   272  func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
   273  	httpWriteHeaderKey(bw, key)
   274  	bw.Write(value)
   275  	bw.WriteString(crlf)
   276  }
   277  
   278  func httpWriteHeaderKey(bw *bufio.Writer, key string) {
   279  	bw.WriteString(key)
   280  	bw.WriteString(colonAndSpace)
   281  }
   282  
   283  func httpWriteUpgradeRequest(
   284  	bw *bufio.Writer,
   285  	u *url.URL,
   286  	nonce []byte,
   287  	protocols []string,
   288  	extensions []httphead.Option,
   289  	header HandshakeHeader,
   290  ) {
   291  	bw.WriteString("GET ")
   292  	bw.WriteString(u.RequestURI())
   293  	bw.WriteString(" HTTP/1.1\r\n")
   294  
   295  	httpWriteHeader(bw, headerHost, u.Host)
   296  
   297  	httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
   298  	httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
   299  	httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
   300  
   301  	// NOTE: write nonce bytes as a string to prevent heap allocation –
   302  	// WriteString() copy given string into its inner buffer, unlike Write()
   303  	// which may write p directly to the underlying io.Writer – which in turn
   304  	// will lead to p escape.
   305  	httpWriteHeader(bw, headerSecKey, btsToString(nonce))
   306  
   307  	if len(protocols) > 0 {
   308  		httpWriteHeaderKey(bw, headerSecProtocol)
   309  		for i, p := range protocols {
   310  			if i > 0 {
   311  				bw.WriteString(commaAndSpace)
   312  			}
   313  			bw.WriteString(p)
   314  		}
   315  		bw.WriteString(crlf)
   316  	}
   317  
   318  	if len(extensions) > 0 {
   319  		httpWriteHeaderKey(bw, headerSecExtensions)
   320  		httphead.WriteOptions(bw, extensions)
   321  		bw.WriteString(crlf)
   322  	}
   323  
   324  	if header != nil {
   325  		header.WriteTo(bw)
   326  	}
   327  
   328  	bw.WriteString(crlf)
   329  }
   330  
   331  func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
   332  	bw.WriteString(textHeadUpgrade)
   333  
   334  	httpWriteHeaderKey(bw, headerSecAccept)
   335  	writeAccept(bw, nonce)
   336  	bw.WriteString(crlf)
   337  
   338  	if hs.Protocol != "" {
   339  		httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
   340  	}
   341  	if len(hs.Extensions) > 0 {
   342  		httpWriteHeaderKey(bw, headerSecExtensions)
   343  		httphead.WriteOptions(bw, hs.Extensions)
   344  		bw.WriteString(crlf)
   345  	}
   346  	if header != nil {
   347  		header(bw)
   348  	}
   349  
   350  	bw.WriteString(crlf)
   351  }
   352  
   353  func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
   354  	switch code {
   355  	case http.StatusBadRequest:
   356  		bw.WriteString(textHeadBadRequest)
   357  	case http.StatusInternalServerError:
   358  		bw.WriteString(textHeadInternalServerError)
   359  	case http.StatusUpgradeRequired:
   360  		bw.WriteString(textHeadUpgradeRequired)
   361  	default:
   362  		writeStatusText(bw, code)
   363  	}
   364  
   365  	// Write custom headers.
   366  	if header != nil {
   367  		header(bw)
   368  	}
   369  
   370  	switch err {
   371  	case ErrHandshakeBadProtocol:
   372  		bw.WriteString(textTailErrHandshakeBadProtocol)
   373  	case ErrHandshakeBadMethod:
   374  		bw.WriteString(textTailErrHandshakeBadMethod)
   375  	case ErrHandshakeBadHost:
   376  		bw.WriteString(textTailErrHandshakeBadHost)
   377  	case ErrHandshakeBadUpgrade:
   378  		bw.WriteString(textTailErrHandshakeBadUpgrade)
   379  	case ErrHandshakeBadConnection:
   380  		bw.WriteString(textTailErrHandshakeBadConnection)
   381  	case ErrHandshakeBadSecAccept:
   382  		bw.WriteString(textTailErrHandshakeBadSecAccept)
   383  	case ErrHandshakeBadSecKey:
   384  		bw.WriteString(textTailErrHandshakeBadSecKey)
   385  	case ErrHandshakeBadSecVersion:
   386  		bw.WriteString(textTailErrHandshakeBadSecVersion)
   387  	case ErrHandshakeUpgradeRequired:
   388  		bw.WriteString(textTailErrUpgradeRequired)
   389  	case nil:
   390  		bw.WriteString(crlf)
   391  	default:
   392  		writeErrorText(bw, err)
   393  	}
   394  }
   395  
   396  func writeStatusText(bw *bufio.Writer, code int) {
   397  	bw.WriteString("HTTP/1.1 ")
   398  	bw.WriteString(strconv.Itoa(code))
   399  	bw.WriteByte(' ')
   400  	bw.WriteString(http.StatusText(code))
   401  	bw.WriteString(crlf)
   402  	bw.WriteString("Content-Type: text/plain; charset=utf-8")
   403  	bw.WriteString(crlf)
   404  }
   405  
   406  func writeErrorText(bw *bufio.Writer, err error) {
   407  	body := err.Error()
   408  	bw.WriteString("Content-Length: ")
   409  	bw.WriteString(strconv.Itoa(len(body)))
   410  	bw.WriteString(crlf)
   411  	bw.WriteString(crlf)
   412  	bw.WriteString(body)
   413  }
   414  
   415  // httpError is like the http.Error with WebSocket context exception.
   416  func httpError(w http.ResponseWriter, body string, code int) {
   417  	w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   418  	w.Header().Set("Content-Length", strconv.Itoa(len(body)))
   419  	w.WriteHeader(code)
   420  	w.Write([]byte(body))
   421  }
   422  
   423  // statusText is a non-performant status text generator.
   424  // NOTE: Used only to generate constants.
   425  func statusText(code int) string {
   426  	var buf bytes.Buffer
   427  	bw := bufio.NewWriter(&buf)
   428  	writeStatusText(bw, code)
   429  	bw.Flush()
   430  	return buf.String()
   431  }
   432  
   433  // errorText is a non-performant error text generator.
   434  // NOTE: Used only to generate constants.
   435  func errorText(err error) string {
   436  	var buf bytes.Buffer
   437  	bw := bufio.NewWriter(&buf)
   438  	writeErrorText(bw, err)
   439  	bw.Flush()
   440  	return buf.String()
   441  }
   442  
   443  // HandshakeHeader is the interface that writes both upgrade request or
   444  // response headers into a given io.Writer.
   445  type HandshakeHeader interface {
   446  	io.WriterTo
   447  }
   448  
   449  // HandshakeHeaderString is an adapter to allow the use of headers represented
   450  // by ordinary string as HandshakeHeader.
   451  type HandshakeHeaderString string
   452  
   453  // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
   454  func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
   455  	n, err := io.WriteString(w, string(s))
   456  	return int64(n), err
   457  }
   458  
   459  // HandshakeHeaderBytes is an adapter to allow the use of headers represented
   460  // by ordinary slice of bytes as HandshakeHeader.
   461  type HandshakeHeaderBytes []byte
   462  
   463  // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
   464  func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
   465  	n, err := w.Write(b)
   466  	return int64(n), err
   467  }
   468  
   469  // HandshakeHeaderFunc is an adapter to allow the use of headers represented by
   470  // ordinary function as HandshakeHeader.
   471  type HandshakeHeaderFunc func(io.Writer) (int64, error)
   472  
   473  // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
   474  func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
   475  	return f(w)
   476  }
   477  
   478  // HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
   479  // HandshakeHeader.
   480  type HandshakeHeaderHTTP http.Header
   481  
   482  // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
   483  func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
   484  	wr := writer{w: w}
   485  	err := http.Header(h).Write(&wr)
   486  	return wr.n, err
   487  }
   488  
   489  type writer struct {
   490  	n int64
   491  	w io.Writer
   492  }
   493  
   494  func (w *writer) WriteString(s string) (int, error) {
   495  	n, err := io.WriteString(w.w, s)
   496  	w.n += int64(n)
   497  	return n, err
   498  }
   499  
   500  func (w *writer) Write(p []byte) (int, error) {
   501  	n, err := w.w.Write(p)
   502  	w.n += int64(n)
   503  	return n, err
   504  }