github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/http.go (about)

     1  package ws
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"io"
     7  	"net/http"
     8  	"net/textproto"
     9  	"net/url"
    10  	"strconv"
    11  
    12  	"github.com/ezoic/httphead"
    13  )
    14  
    15  const (
    16  	crlf          = "\r\n"
    17  	colonAndSpace = ": "
    18  	commaAndSpace = ", "
    19  )
    20  
    21  const (
    22  	textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
    23  )
    24  
    25  var (
    26  	textHeadBadRequest          = statusText(http.StatusBadRequest)
    27  	textHeadInternalServerError = statusText(http.StatusInternalServerError)
    28  	textHeadUpgradeRequired     = statusText(http.StatusUpgradeRequired)
    29  
    30  	textTailErrHandshakeBadProtocol   = errorText(ErrHandshakeBadProtocol)
    31  	textTailErrHandshakeBadMethod     = errorText(ErrHandshakeBadMethod)
    32  	textTailErrHandshakeBadHost       = errorText(ErrHandshakeBadHost)
    33  	textTailErrHandshakeBadUpgrade    = errorText(ErrHandshakeBadUpgrade)
    34  	textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
    35  	textTailErrHandshakeBadSecAccept  = errorText(ErrHandshakeBadSecAccept)
    36  	textTailErrHandshakeBadSecKey     = errorText(ErrHandshakeBadSecKey)
    37  	textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
    38  	textTailErrUpgradeRequired        = errorText(ErrHandshakeUpgradeRequired)
    39  )
    40  
    41  var (
    42  	headerHost          = "Host"
    43  	headerUpgrade       = "Upgrade"
    44  	headerConnection    = "Connection"
    45  	headerSecVersion    = "Sec-WebSocket-Version"
    46  	headerSecProtocol   = "Sec-WebSocket-Protocol"
    47  	headerSecExtensions = "Sec-WebSocket-Extensions"
    48  	headerSecKey        = "Sec-WebSocket-Key"
    49  	headerSecAccept     = "Sec-WebSocket-Accept"
    50  
    51  	headerHostCanonical          = textproto.CanonicalMIMEHeaderKey(headerHost)
    52  	headerUpgradeCanonical       = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
    53  	headerConnectionCanonical    = textproto.CanonicalMIMEHeaderKey(headerConnection)
    54  	headerSecVersionCanonical    = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
    55  	headerSecProtocolCanonical   = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
    56  	headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
    57  	headerSecKeyCanonical        = textproto.CanonicalMIMEHeaderKey(headerSecKey)
    58  	headerSecAcceptCanonical     = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
    59  )
    60  
    61  var (
    62  	specHeaderValueUpgrade         = []byte("websocket")
    63  	specHeaderValueConnection      = []byte("Upgrade")
    64  	specHeaderValueConnectionLower = []byte("upgrade")
    65  	specHeaderValueSecVersion      = []byte("13")
    66  )
    67  
    68  var (
    69  	httpVersion1_0    = []byte("HTTP/1.0")
    70  	httpVersion1_1    = []byte("HTTP/1.1")
    71  	httpVersionPrefix = []byte("HTTP/")
    72  )
    73  
    74  type httpRequestLine struct {
    75  	method, uri  []byte
    76  	major, minor int
    77  }
    78  
    79  type httpResponseLine struct {
    80  	major, minor int
    81  	status       int
    82  	reason       []byte
    83  }
    84  
    85  // httpParseRequestLine parses http request line like "GET / HTTP/1.0".
    86  func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
    87  	var proto []byte
    88  	req.method, req.uri, proto = bsplit3(line, ' ')
    89  
    90  	var ok bool
    91  	req.major, req.minor, ok = httpParseVersion(proto)
    92  	if !ok {
    93  		err = ErrMalformedRequest
    94  		return
    95  	}
    96  
    97  	return
    98  }
    99  
   100  func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
   101  	var (
   102  		proto  []byte
   103  		status []byte
   104  	)
   105  	proto, status, resp.reason = bsplit3(line, ' ')
   106  
   107  	var ok bool
   108  	resp.major, resp.minor, ok = httpParseVersion(proto)
   109  	if !ok {
   110  		return resp, ErrMalformedResponse
   111  	}
   112  
   113  	var convErr error
   114  	resp.status, convErr = asciiToInt(status)
   115  	if convErr != nil {
   116  		return resp, ErrMalformedResponse
   117  	}
   118  
   119  	return resp, nil
   120  }
   121  
   122  // httpParseVersion parses major and minor version of HTTP protocol. It returns
   123  // parsed values and true if parse is ok.
   124  func httpParseVersion(bts []byte) (major, minor int, ok bool) {
   125  	switch {
   126  	case bytes.Equal(bts, httpVersion1_0):
   127  		return 1, 0, true
   128  	case bytes.Equal(bts, httpVersion1_1):
   129  		return 1, 1, true
   130  	case len(bts) < 8:
   131  		return
   132  	case !bytes.Equal(bts[:5], httpVersionPrefix):
   133  		return
   134  	}
   135  
   136  	bts = bts[5:]
   137  
   138  	dot := bytes.IndexByte(bts, '.')
   139  	if dot == -1 {
   140  		return
   141  	}
   142  	var err error
   143  	major, err = asciiToInt(bts[:dot])
   144  	if err != nil {
   145  		return
   146  	}
   147  	minor, err = asciiToInt(bts[dot+1:])
   148  	if err != nil {
   149  		return
   150  	}
   151  
   152  	return major, minor, true
   153  }
   154  
   155  // httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
   156  // values and true if parse is ok.
   157  func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
   158  	colon := bytes.IndexByte(line, ':')
   159  	if colon == -1 {
   160  		return
   161  	}
   162  
   163  	k = btrim(line[:colon])
   164  	// TODO(ezoic): maybe use just lower here?
   165  	canonicalizeHeaderKey(k)
   166  
   167  	v = btrim(line[colon+1:])
   168  
   169  	return k, v, true
   170  }
   171  
   172  // httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
   173  // that key is already canonical. This helps to increase performance.
   174  func httpGetHeader(h http.Header, key string) string {
   175  	if h == nil {
   176  		return ""
   177  	}
   178  	v := h[key]
   179  	if len(v) == 0 {
   180  		return ""
   181  	}
   182  	return v[0]
   183  }
   184  
   185  // The request MAY include a header field with the name
   186  // |Sec-WebSocket-Protocol|.  If present, this value indicates one or more
   187  // comma-separated subprotocol the client wishes to speak, ordered by
   188  // preference.  The elements that comprise this value MUST be non-empty strings
   189  // with characters in the range U+0021 to U+007E not including separator
   190  // characters as defined in [RFC2616] and MUST all be unique strings.  The ABNF
   191  // for the value of this header field is 1#token, where the definitions of
   192  // constructs and rules are as given in [RFC2616].
   193  func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
   194  	ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
   195  		if check(btsToString(v)) {
   196  			ret = string(v)
   197  			return false
   198  		}
   199  		return true
   200  	})
   201  	return
   202  }
   203  func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
   204  	var selected []byte
   205  	ok = httphead.ScanTokens(h, func(v []byte) bool {
   206  		if check(v) {
   207  			selected = v
   208  			return false
   209  		}
   210  		return true
   211  	})
   212  	if ok && selected != nil {
   213  		return string(selected), true
   214  	}
   215  	return
   216  }
   217  
   218  func strSelectExtensions(h string, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
   219  	return btsSelectExtensions(strToBytes(h), selected, check)
   220  }
   221  
   222  func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
   223  	s := httphead.OptionSelector{
   224  		Flags: httphead.SelectUnique | httphead.SelectCopy,
   225  		Check: check,
   226  	}
   227  	return s.Select(h, selected)
   228  }
   229  
   230  func httpWriteHeader(bw *bufio.Writer, key, value string) {
   231  	httpWriteHeaderKey(bw, key)
   232  	bw.WriteString(value)
   233  	bw.WriteString(crlf)
   234  }
   235  
   236  func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
   237  	httpWriteHeaderKey(bw, key)
   238  	bw.Write(value)
   239  	bw.WriteString(crlf)
   240  }
   241  
   242  func httpWriteHeaderKey(bw *bufio.Writer, key string) {
   243  	bw.WriteString(key)
   244  	bw.WriteString(colonAndSpace)
   245  }
   246  
   247  func httpWriteUpgradeRequest(
   248  	bw *bufio.Writer,
   249  	u *url.URL,
   250  	nonce []byte,
   251  	protocols []string,
   252  	extensions []httphead.Option,
   253  	header HandshakeHeader,
   254  ) {
   255  	bw.WriteString("GET ")
   256  	bw.WriteString(u.RequestURI())
   257  	bw.WriteString(" HTTP/1.1\r\n")
   258  
   259  	httpWriteHeader(bw, headerHost, u.Host)
   260  
   261  	httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
   262  	httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
   263  	httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
   264  
   265  	// NOTE: write nonce bytes as a string to prevent heap allocation –
   266  	// WriteString() copy given string into its inner buffer, unlike Write()
   267  	// which may write p directly to the underlying io.Writer – which in turn
   268  	// will lead to p escape.
   269  	httpWriteHeader(bw, headerSecKey, btsToString(nonce))
   270  
   271  	if len(protocols) > 0 {
   272  		httpWriteHeaderKey(bw, headerSecProtocol)
   273  		for i, p := range protocols {
   274  			if i > 0 {
   275  				bw.WriteString(commaAndSpace)
   276  			}
   277  			bw.WriteString(p)
   278  		}
   279  		bw.WriteString(crlf)
   280  	}
   281  
   282  	if len(extensions) > 0 {
   283  		httpWriteHeaderKey(bw, headerSecExtensions)
   284  		httphead.WriteOptions(bw, extensions)
   285  		bw.WriteString(crlf)
   286  	}
   287  
   288  	if header != nil {
   289  		header.WriteTo(bw)
   290  	}
   291  
   292  	bw.WriteString(crlf)
   293  }
   294  
   295  func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
   296  	bw.WriteString(textHeadUpgrade)
   297  
   298  	httpWriteHeaderKey(bw, headerSecAccept)
   299  	writeAccept(bw, nonce)
   300  	bw.WriteString(crlf)
   301  
   302  	if hs.Protocol != "" {
   303  		httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
   304  	}
   305  	if len(hs.Extensions) > 0 {
   306  		httpWriteHeaderKey(bw, headerSecExtensions)
   307  		httphead.WriteOptions(bw, hs.Extensions)
   308  		bw.WriteString(crlf)
   309  	}
   310  	if header != nil {
   311  		header(bw)
   312  	}
   313  
   314  	bw.WriteString(crlf)
   315  }
   316  
   317  func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
   318  	switch code {
   319  	case http.StatusBadRequest:
   320  		bw.WriteString(textHeadBadRequest)
   321  	case http.StatusInternalServerError:
   322  		bw.WriteString(textHeadInternalServerError)
   323  	case http.StatusUpgradeRequired:
   324  		bw.WriteString(textHeadUpgradeRequired)
   325  	default:
   326  		writeStatusText(bw, code)
   327  	}
   328  
   329  	// Write custom headers.
   330  	if header != nil {
   331  		header(bw)
   332  	}
   333  
   334  	switch err {
   335  	case ErrHandshakeBadProtocol:
   336  		bw.WriteString(textTailErrHandshakeBadProtocol)
   337  	case ErrHandshakeBadMethod:
   338  		bw.WriteString(textTailErrHandshakeBadMethod)
   339  	case ErrHandshakeBadHost:
   340  		bw.WriteString(textTailErrHandshakeBadHost)
   341  	case ErrHandshakeBadUpgrade:
   342  		bw.WriteString(textTailErrHandshakeBadUpgrade)
   343  	case ErrHandshakeBadConnection:
   344  		bw.WriteString(textTailErrHandshakeBadConnection)
   345  	case ErrHandshakeBadSecAccept:
   346  		bw.WriteString(textTailErrHandshakeBadSecAccept)
   347  	case ErrHandshakeBadSecKey:
   348  		bw.WriteString(textTailErrHandshakeBadSecKey)
   349  	case ErrHandshakeBadSecVersion:
   350  		bw.WriteString(textTailErrHandshakeBadSecVersion)
   351  	case ErrHandshakeUpgradeRequired:
   352  		bw.WriteString(textTailErrUpgradeRequired)
   353  	case nil:
   354  		bw.WriteString(crlf)
   355  	default:
   356  		writeErrorText(bw, err)
   357  	}
   358  }
   359  
   360  func writeStatusText(bw *bufio.Writer, code int) {
   361  	bw.WriteString("HTTP/1.1 ")
   362  	bw.WriteString(strconv.Itoa(code))
   363  	bw.WriteByte(' ')
   364  	bw.WriteString(http.StatusText(code))
   365  	bw.WriteString(crlf)
   366  	bw.WriteString("Content-Type: text/plain; charset=utf-8")
   367  	bw.WriteString(crlf)
   368  }
   369  
   370  func writeErrorText(bw *bufio.Writer, err error) {
   371  	body := err.Error()
   372  	bw.WriteString("Content-Length: ")
   373  	bw.WriteString(strconv.Itoa(len(body)))
   374  	bw.WriteString(crlf)
   375  	bw.WriteString(crlf)
   376  	bw.WriteString(body)
   377  }
   378  
   379  // httpError is like the http.Error with WebSocket context exception.
   380  func httpError(w http.ResponseWriter, body string, code int) {
   381  	w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   382  	w.Header().Set("Content-Length", strconv.Itoa(len(body)))
   383  	w.WriteHeader(code)
   384  	w.Write([]byte(body))
   385  }
   386  
   387  // statusText is a non-performant status text generator.
   388  // NOTE: Used only to generate constants.
   389  func statusText(code int) string {
   390  	var buf bytes.Buffer
   391  	bw := bufio.NewWriter(&buf)
   392  	writeStatusText(bw, code)
   393  	bw.Flush()
   394  	return buf.String()
   395  }
   396  
   397  // errorText is a non-performant error text generator.
   398  // NOTE: Used only to generate constants.
   399  func errorText(err error) string {
   400  	var buf bytes.Buffer
   401  	bw := bufio.NewWriter(&buf)
   402  	writeErrorText(bw, err)
   403  	bw.Flush()
   404  	return buf.String()
   405  }
   406  
   407  // HandshakeHeader is the interface that writes both upgrade request or
   408  // response headers into a given io.Writer.
   409  type HandshakeHeader interface {
   410  	io.WriterTo
   411  }
   412  
   413  // HandshakeHeaderString is an adapter to allow the use of headers represented
   414  // by ordinary string as HandshakeHeader.
   415  type HandshakeHeaderString string
   416  
   417  // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
   418  func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
   419  	n, err := io.WriteString(w, string(s))
   420  	return int64(n), err
   421  }
   422  
   423  // HandshakeHeaderBytes is an adapter to allow the use of headers represented
   424  // by ordinary slice of bytes as HandshakeHeader.
   425  type HandshakeHeaderBytes []byte
   426  
   427  // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
   428  func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
   429  	n, err := w.Write(b)
   430  	return int64(n), err
   431  }
   432  
   433  // HandshakeHeaderFunc is an adapter to allow the use of headers represented by
   434  // ordinary function as HandshakeHeader.
   435  type HandshakeHeaderFunc func(io.Writer) (int64, error)
   436  
   437  // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
   438  func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
   439  	return f(w)
   440  }
   441  
   442  // HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
   443  // HandshakeHeader.
   444  type HandshakeHeaderHTTP http.Header
   445  
   446  // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
   447  func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
   448  	wr := writer{w: w}
   449  	err := http.Header(h).Write(&wr)
   450  	return wr.n, err
   451  }
   452  
   453  type writer struct {
   454  	n int64
   455  	w io.Writer
   456  }
   457  
   458  func (w *writer) WriteString(s string) (int, error) {
   459  	n, err := io.WriteString(w.w, s)
   460  	w.n += int64(n)
   461  	return n, err
   462  }
   463  
   464  func (w *writer) Write(p []byte) (int, error) {
   465  	n, err := w.w.Write(p)
   466  	w.n += int64(n)
   467  	return n, err
   468  }