github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/http.go (about) 1 package ws 2 3 import ( 4 "bufio" 5 "bytes" 6 "io" 7 "net/http" 8 "net/textproto" 9 "net/url" 10 "strconv" 11 12 "github.com/gobwas/httphead" 13 ) 14 15 const ( 16 crlf = "\r\n" 17 colonAndSpace = ": " 18 commaAndSpace = ", " 19 ) 20 21 const ( 22 textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n" 23 ) 24 25 var ( 26 textHeadBadRequest = statusText(http.StatusBadRequest) 27 textHeadInternalServerError = statusText(http.StatusInternalServerError) 28 textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired) 29 30 textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol) 31 textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod) 32 textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost) 33 textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade) 34 textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection) 35 textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept) 36 textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey) 37 textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion) 38 textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired) 39 ) 40 41 var ( 42 headerHost = "Host" 43 headerUpgrade = "Upgrade" 44 headerConnection = "Connection" 45 headerSecVersion = "Sec-WebSocket-Version" 46 headerSecProtocol = "Sec-WebSocket-Protocol" 47 headerSecExtensions = "Sec-WebSocket-Extensions" 48 headerSecKey = "Sec-WebSocket-Key" 49 headerSecAccept = "Sec-WebSocket-Accept" 50 51 headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost) 52 headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade) 53 headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection) 54 headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion) 55 headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol) 56 headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions) 57 headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey) 58 headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept) 59 ) 60 61 var ( 62 specHeaderValueUpgrade = []byte("websocket") 63 specHeaderValueConnection = []byte("Upgrade") 64 specHeaderValueConnectionLower = []byte("upgrade") 65 specHeaderValueSecVersion = []byte("13") 66 ) 67 68 var ( 69 httpVersion1_0 = []byte("HTTP/1.0") 70 httpVersion1_1 = []byte("HTTP/1.1") 71 httpVersionPrefix = []byte("HTTP/") 72 ) 73 74 type httpRequestLine struct { 75 method, uri []byte 76 major, minor int 77 } 78 79 type httpResponseLine struct { 80 major, minor int 81 status int 82 reason []byte 83 } 84 85 // httpParseRequestLine parses http request line like "GET / HTTP/1.0". 86 func httpParseRequestLine(line []byte) (req httpRequestLine, err error) { 87 var proto []byte 88 req.method, req.uri, proto = bsplit3(line, ' ') 89 90 var ok bool 91 req.major, req.minor, ok = httpParseVersion(proto) 92 if !ok { 93 err = ErrMalformedRequest 94 return 95 } 96 97 return 98 } 99 100 func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) { 101 var ( 102 proto []byte 103 status []byte 104 ) 105 proto, status, resp.reason = bsplit3(line, ' ') 106 107 var ok bool 108 resp.major, resp.minor, ok = httpParseVersion(proto) 109 if !ok { 110 return resp, ErrMalformedResponse 111 } 112 113 var convErr error 114 resp.status, convErr = asciiToInt(status) 115 if convErr != nil { 116 return resp, ErrMalformedResponse 117 } 118 119 return resp, nil 120 } 121 122 // httpParseVersion parses major and minor version of HTTP protocol. It returns 123 // parsed values and true if parse is ok. 124 func httpParseVersion(bts []byte) (major, minor int, ok bool) { 125 switch { 126 case bytes.Equal(bts, httpVersion1_0): 127 return 1, 0, true 128 case bytes.Equal(bts, httpVersion1_1): 129 return 1, 1, true 130 case len(bts) < 8: 131 return 132 case !bytes.Equal(bts[:5], httpVersionPrefix): 133 return 134 } 135 136 bts = bts[5:] 137 138 dot := bytes.IndexByte(bts, '.') 139 if dot == -1 { 140 return 141 } 142 var err error 143 major, err = asciiToInt(bts[:dot]) 144 if err != nil { 145 return 146 } 147 minor, err = asciiToInt(bts[dot+1:]) 148 if err != nil { 149 return 150 } 151 152 return major, minor, true 153 } 154 155 // httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed 156 // values and true if parse is ok. 157 func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) { 158 colon := bytes.IndexByte(line, ':') 159 if colon == -1 { 160 return 161 } 162 163 k = btrim(line[:colon]) 164 // TODO(gobwas): maybe use just lower here? 165 canonicalizeHeaderKey(k) 166 167 v = btrim(line[colon+1:]) 168 169 return k, v, true 170 } 171 172 // httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing, 173 // that key is already canonical. This helps to increase performance. 174 func httpGetHeader(h http.Header, key string) string { 175 if h == nil { 176 return "" 177 } 178 v := h[key] 179 if len(v) == 0 { 180 return "" 181 } 182 return v[0] 183 } 184 185 // The request MAY include a header field with the name 186 // |Sec-WebSocket-Protocol|. If present, this value indicates one or more 187 // comma-separated subprotocol the client wishes to speak, ordered by 188 // preference. The elements that comprise this value MUST be non-empty strings 189 // with characters in the range U+0021 to U+007E not including separator 190 // characters as defined in [RFC2616] and MUST all be unique strings. The ABNF 191 // for the value of this header field is 1#token, where the definitions of 192 // constructs and rules are as given in [RFC2616]. 193 func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) { 194 ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool { 195 if check(btsToString(v)) { 196 ret = string(v) 197 return false 198 } 199 return true 200 }) 201 return 202 } 203 func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) { 204 var selected []byte 205 ok = httphead.ScanTokens(h, func(v []byte) bool { 206 if check(v) { 207 selected = v 208 return false 209 } 210 return true 211 }) 212 if ok && selected != nil { 213 return string(selected), true 214 } 215 return 216 } 217 218 func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) { 219 s := httphead.OptionSelector{ 220 Flags: httphead.SelectCopy, 221 Check: check, 222 } 223 return s.Select(h, selected) 224 } 225 226 func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) { 227 if in.Size() == 0 { 228 return dest, nil 229 } 230 opt, err := f(in) 231 if err != nil { 232 return nil, err 233 } 234 if opt.Size() > 0 { 235 dest = append(dest, opt) 236 } 237 return dest, nil 238 } 239 240 func negotiateExtensions( 241 h []byte, dest []httphead.Option, 242 f func(httphead.Option) (httphead.Option, error), 243 ) (_ []httphead.Option, err error) { 244 index := -1 245 var current httphead.Option 246 ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control { 247 if i != index { 248 dest, err = negotiateMaybe(current, dest, f) 249 if err != nil { 250 return httphead.ControlBreak 251 } 252 index = i 253 current = httphead.Option{Name: name} 254 } 255 if attr != nil { 256 current.Parameters.Set(attr, val) 257 } 258 return httphead.ControlContinue 259 }) 260 if !ok { 261 return nil, ErrMalformedRequest 262 } 263 return negotiateMaybe(current, dest, f) 264 } 265 266 func httpWriteHeader(bw *bufio.Writer, key, value string) { 267 httpWriteHeaderKey(bw, key) 268 bw.WriteString(value) 269 bw.WriteString(crlf) 270 } 271 272 func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) { 273 httpWriteHeaderKey(bw, key) 274 bw.Write(value) 275 bw.WriteString(crlf) 276 } 277 278 func httpWriteHeaderKey(bw *bufio.Writer, key string) { 279 bw.WriteString(key) 280 bw.WriteString(colonAndSpace) 281 } 282 283 func httpWriteUpgradeRequest( 284 bw *bufio.Writer, 285 u *url.URL, 286 nonce []byte, 287 protocols []string, 288 extensions []httphead.Option, 289 header HandshakeHeader, 290 ) { 291 bw.WriteString("GET ") 292 bw.WriteString(u.RequestURI()) 293 bw.WriteString(" HTTP/1.1\r\n") 294 295 httpWriteHeader(bw, headerHost, u.Host) 296 297 httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade) 298 httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection) 299 httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion) 300 301 // NOTE: write nonce bytes as a string to prevent heap allocation – 302 // WriteString() copy given string into its inner buffer, unlike Write() 303 // which may write p directly to the underlying io.Writer – which in turn 304 // will lead to p escape. 305 httpWriteHeader(bw, headerSecKey, btsToString(nonce)) 306 307 if len(protocols) > 0 { 308 httpWriteHeaderKey(bw, headerSecProtocol) 309 for i, p := range protocols { 310 if i > 0 { 311 bw.WriteString(commaAndSpace) 312 } 313 bw.WriteString(p) 314 } 315 bw.WriteString(crlf) 316 } 317 318 if len(extensions) > 0 { 319 httpWriteHeaderKey(bw, headerSecExtensions) 320 httphead.WriteOptions(bw, extensions) 321 bw.WriteString(crlf) 322 } 323 324 if header != nil { 325 header.WriteTo(bw) 326 } 327 328 bw.WriteString(crlf) 329 } 330 331 func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) { 332 bw.WriteString(textHeadUpgrade) 333 334 httpWriteHeaderKey(bw, headerSecAccept) 335 writeAccept(bw, nonce) 336 bw.WriteString(crlf) 337 338 if hs.Protocol != "" { 339 httpWriteHeader(bw, headerSecProtocol, hs.Protocol) 340 } 341 if len(hs.Extensions) > 0 { 342 httpWriteHeaderKey(bw, headerSecExtensions) 343 httphead.WriteOptions(bw, hs.Extensions) 344 bw.WriteString(crlf) 345 } 346 if header != nil { 347 header(bw) 348 } 349 350 bw.WriteString(crlf) 351 } 352 353 func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) { 354 switch code { 355 case http.StatusBadRequest: 356 bw.WriteString(textHeadBadRequest) 357 case http.StatusInternalServerError: 358 bw.WriteString(textHeadInternalServerError) 359 case http.StatusUpgradeRequired: 360 bw.WriteString(textHeadUpgradeRequired) 361 default: 362 writeStatusText(bw, code) 363 } 364 365 // Write custom headers. 366 if header != nil { 367 header(bw) 368 } 369 370 switch err { 371 case ErrHandshakeBadProtocol: 372 bw.WriteString(textTailErrHandshakeBadProtocol) 373 case ErrHandshakeBadMethod: 374 bw.WriteString(textTailErrHandshakeBadMethod) 375 case ErrHandshakeBadHost: 376 bw.WriteString(textTailErrHandshakeBadHost) 377 case ErrHandshakeBadUpgrade: 378 bw.WriteString(textTailErrHandshakeBadUpgrade) 379 case ErrHandshakeBadConnection: 380 bw.WriteString(textTailErrHandshakeBadConnection) 381 case ErrHandshakeBadSecAccept: 382 bw.WriteString(textTailErrHandshakeBadSecAccept) 383 case ErrHandshakeBadSecKey: 384 bw.WriteString(textTailErrHandshakeBadSecKey) 385 case ErrHandshakeBadSecVersion: 386 bw.WriteString(textTailErrHandshakeBadSecVersion) 387 case ErrHandshakeUpgradeRequired: 388 bw.WriteString(textTailErrUpgradeRequired) 389 case nil: 390 bw.WriteString(crlf) 391 default: 392 writeErrorText(bw, err) 393 } 394 } 395 396 func writeStatusText(bw *bufio.Writer, code int) { 397 bw.WriteString("HTTP/1.1 ") 398 bw.WriteString(strconv.Itoa(code)) 399 bw.WriteByte(' ') 400 bw.WriteString(http.StatusText(code)) 401 bw.WriteString(crlf) 402 bw.WriteString("Content-Type: text/plain; charset=utf-8") 403 bw.WriteString(crlf) 404 } 405 406 func writeErrorText(bw *bufio.Writer, err error) { 407 body := err.Error() 408 bw.WriteString("Content-Length: ") 409 bw.WriteString(strconv.Itoa(len(body))) 410 bw.WriteString(crlf) 411 bw.WriteString(crlf) 412 bw.WriteString(body) 413 } 414 415 // httpError is like the http.Error with WebSocket context exception. 416 func httpError(w http.ResponseWriter, body string, code int) { 417 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 418 w.Header().Set("Content-Length", strconv.Itoa(len(body))) 419 w.WriteHeader(code) 420 w.Write([]byte(body)) 421 } 422 423 // statusText is a non-performant status text generator. 424 // NOTE: Used only to generate constants. 425 func statusText(code int) string { 426 var buf bytes.Buffer 427 bw := bufio.NewWriter(&buf) 428 writeStatusText(bw, code) 429 bw.Flush() 430 return buf.String() 431 } 432 433 // errorText is a non-performant error text generator. 434 // NOTE: Used only to generate constants. 435 func errorText(err error) string { 436 var buf bytes.Buffer 437 bw := bufio.NewWriter(&buf) 438 writeErrorText(bw, err) 439 bw.Flush() 440 return buf.String() 441 } 442 443 // HandshakeHeader is the interface that writes both upgrade request or 444 // response headers into a given io.Writer. 445 type HandshakeHeader interface { 446 io.WriterTo 447 } 448 449 // HandshakeHeaderString is an adapter to allow the use of headers represented 450 // by ordinary string as HandshakeHeader. 451 type HandshakeHeaderString string 452 453 // WriteTo implements HandshakeHeader (and io.WriterTo) interface. 454 func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) { 455 n, err := io.WriteString(w, string(s)) 456 return int64(n), err 457 } 458 459 // HandshakeHeaderBytes is an adapter to allow the use of headers represented 460 // by ordinary slice of bytes as HandshakeHeader. 461 type HandshakeHeaderBytes []byte 462 463 // WriteTo implements HandshakeHeader (and io.WriterTo) interface. 464 func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) { 465 n, err := w.Write(b) 466 return int64(n), err 467 } 468 469 // HandshakeHeaderFunc is an adapter to allow the use of headers represented by 470 // ordinary function as HandshakeHeader. 471 type HandshakeHeaderFunc func(io.Writer) (int64, error) 472 473 // WriteTo implements HandshakeHeader (and io.WriterTo) interface. 474 func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) { 475 return f(w) 476 } 477 478 // HandshakeHeaderHTTP is an adapter to allow the use of http.Header as 479 // HandshakeHeader. 480 type HandshakeHeaderHTTP http.Header 481 482 // WriteTo implements HandshakeHeader (and io.WriterTo) interface. 483 func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) { 484 wr := writer{w: w} 485 err := http.Header(h).Write(&wr) 486 return wr.n, err 487 } 488 489 type writer struct { 490 n int64 491 w io.Writer 492 } 493 494 func (w *writer) WriteString(s string) (int, error) { 495 n, err := io.WriteString(w.w, s) 496 w.n += int64(n) 497 return n, err 498 } 499 500 func (w *writer) Write(p []byte) (int, error) { 501 n, err := w.w.Write(p) 502 w.n += int64(n) 503 return n, err 504 }