github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/http.go (about) 1 package ws 2 3 import ( 4 "bufio" 5 "bytes" 6 "io" 7 "net/http" 8 "net/textproto" 9 "net/url" 10 "strconv" 11 12 "github.com/ezoic/httphead" 13 ) 14 15 const ( 16 crlf = "\r\n" 17 colonAndSpace = ": " 18 commaAndSpace = ", " 19 ) 20 21 const ( 22 textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n" 23 ) 24 25 var ( 26 textHeadBadRequest = statusText(http.StatusBadRequest) 27 textHeadInternalServerError = statusText(http.StatusInternalServerError) 28 textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired) 29 30 textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol) 31 textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod) 32 textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost) 33 textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade) 34 textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection) 35 textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept) 36 textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey) 37 textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion) 38 textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired) 39 ) 40 41 var ( 42 headerHost = "Host" 43 headerUpgrade = "Upgrade" 44 headerConnection = "Connection" 45 headerSecVersion = "Sec-WebSocket-Version" 46 headerSecProtocol = "Sec-WebSocket-Protocol" 47 headerSecExtensions = "Sec-WebSocket-Extensions" 48 headerSecKey = "Sec-WebSocket-Key" 49 headerSecAccept = "Sec-WebSocket-Accept" 50 51 headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost) 52 headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade) 53 headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection) 54 headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion) 55 headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol) 56 headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions) 57 headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey) 58 headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept) 59 ) 60 61 var ( 62 specHeaderValueUpgrade = []byte("websocket") 63 specHeaderValueConnection = []byte("Upgrade") 64 specHeaderValueConnectionLower = []byte("upgrade") 65 specHeaderValueSecVersion = []byte("13") 66 ) 67 68 var ( 69 httpVersion1_0 = []byte("HTTP/1.0") 70 httpVersion1_1 = []byte("HTTP/1.1") 71 httpVersionPrefix = []byte("HTTP/") 72 ) 73 74 type httpRequestLine struct { 75 method, uri []byte 76 major, minor int 77 } 78 79 type httpResponseLine struct { 80 major, minor int 81 status int 82 reason []byte 83 } 84 85 // httpParseRequestLine parses http request line like "GET / HTTP/1.0". 86 func httpParseRequestLine(line []byte) (req httpRequestLine, err error) { 87 var proto []byte 88 req.method, req.uri, proto = bsplit3(line, ' ') 89 90 var ok bool 91 req.major, req.minor, ok = httpParseVersion(proto) 92 if !ok { 93 err = ErrMalformedRequest 94 return 95 } 96 97 return 98 } 99 100 func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) { 101 var ( 102 proto []byte 103 status []byte 104 ) 105 proto, status, resp.reason = bsplit3(line, ' ') 106 107 var ok bool 108 resp.major, resp.minor, ok = httpParseVersion(proto) 109 if !ok { 110 return resp, ErrMalformedResponse 111 } 112 113 var convErr error 114 resp.status, convErr = asciiToInt(status) 115 if convErr != nil { 116 return resp, ErrMalformedResponse 117 } 118 119 return resp, nil 120 } 121 122 // httpParseVersion parses major and minor version of HTTP protocol. It returns 123 // parsed values and true if parse is ok. 124 func httpParseVersion(bts []byte) (major, minor int, ok bool) { 125 switch { 126 case bytes.Equal(bts, httpVersion1_0): 127 return 1, 0, true 128 case bytes.Equal(bts, httpVersion1_1): 129 return 1, 1, true 130 case len(bts) < 8: 131 return 132 case !bytes.Equal(bts[:5], httpVersionPrefix): 133 return 134 } 135 136 bts = bts[5:] 137 138 dot := bytes.IndexByte(bts, '.') 139 if dot == -1 { 140 return 141 } 142 var err error 143 major, err = asciiToInt(bts[:dot]) 144 if err != nil { 145 return 146 } 147 minor, err = asciiToInt(bts[dot+1:]) 148 if err != nil { 149 return 150 } 151 152 return major, minor, true 153 } 154 155 // httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed 156 // values and true if parse is ok. 157 func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) { 158 colon := bytes.IndexByte(line, ':') 159 if colon == -1 { 160 return 161 } 162 163 k = btrim(line[:colon]) 164 // TODO(ezoic): maybe use just lower here? 165 canonicalizeHeaderKey(k) 166 167 v = btrim(line[colon+1:]) 168 169 return k, v, true 170 } 171 172 // httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing, 173 // that key is already canonical. This helps to increase performance. 174 func httpGetHeader(h http.Header, key string) string { 175 if h == nil { 176 return "" 177 } 178 v := h[key] 179 if len(v) == 0 { 180 return "" 181 } 182 return v[0] 183 } 184 185 // The request MAY include a header field with the name 186 // |Sec-WebSocket-Protocol|. If present, this value indicates one or more 187 // comma-separated subprotocol the client wishes to speak, ordered by 188 // preference. The elements that comprise this value MUST be non-empty strings 189 // with characters in the range U+0021 to U+007E not including separator 190 // characters as defined in [RFC2616] and MUST all be unique strings. The ABNF 191 // for the value of this header field is 1#token, where the definitions of 192 // constructs and rules are as given in [RFC2616]. 193 func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) { 194 ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool { 195 if check(btsToString(v)) { 196 ret = string(v) 197 return false 198 } 199 return true 200 }) 201 return 202 } 203 func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) { 204 var selected []byte 205 ok = httphead.ScanTokens(h, func(v []byte) bool { 206 if check(v) { 207 selected = v 208 return false 209 } 210 return true 211 }) 212 if ok && selected != nil { 213 return string(selected), true 214 } 215 return 216 } 217 218 func strSelectExtensions(h string, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) { 219 return btsSelectExtensions(strToBytes(h), selected, check) 220 } 221 222 func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) { 223 s := httphead.OptionSelector{ 224 Flags: httphead.SelectUnique | httphead.SelectCopy, 225 Check: check, 226 } 227 return s.Select(h, selected) 228 } 229 230 func httpWriteHeader(bw *bufio.Writer, key, value string) { 231 httpWriteHeaderKey(bw, key) 232 bw.WriteString(value) 233 bw.WriteString(crlf) 234 } 235 236 func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) { 237 httpWriteHeaderKey(bw, key) 238 bw.Write(value) 239 bw.WriteString(crlf) 240 } 241 242 func httpWriteHeaderKey(bw *bufio.Writer, key string) { 243 bw.WriteString(key) 244 bw.WriteString(colonAndSpace) 245 } 246 247 func httpWriteUpgradeRequest( 248 bw *bufio.Writer, 249 u *url.URL, 250 nonce []byte, 251 protocols []string, 252 extensions []httphead.Option, 253 header HandshakeHeader, 254 ) { 255 bw.WriteString("GET ") 256 bw.WriteString(u.RequestURI()) 257 bw.WriteString(" HTTP/1.1\r\n") 258 259 httpWriteHeader(bw, headerHost, u.Host) 260 261 httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade) 262 httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection) 263 httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion) 264 265 // NOTE: write nonce bytes as a string to prevent heap allocation – 266 // WriteString() copy given string into its inner buffer, unlike Write() 267 // which may write p directly to the underlying io.Writer – which in turn 268 // will lead to p escape. 269 httpWriteHeader(bw, headerSecKey, btsToString(nonce)) 270 271 if len(protocols) > 0 { 272 httpWriteHeaderKey(bw, headerSecProtocol) 273 for i, p := range protocols { 274 if i > 0 { 275 bw.WriteString(commaAndSpace) 276 } 277 bw.WriteString(p) 278 } 279 bw.WriteString(crlf) 280 } 281 282 if len(extensions) > 0 { 283 httpWriteHeaderKey(bw, headerSecExtensions) 284 httphead.WriteOptions(bw, extensions) 285 bw.WriteString(crlf) 286 } 287 288 if header != nil { 289 header.WriteTo(bw) 290 } 291 292 bw.WriteString(crlf) 293 } 294 295 func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) { 296 bw.WriteString(textHeadUpgrade) 297 298 httpWriteHeaderKey(bw, headerSecAccept) 299 writeAccept(bw, nonce) 300 bw.WriteString(crlf) 301 302 if hs.Protocol != "" { 303 httpWriteHeader(bw, headerSecProtocol, hs.Protocol) 304 } 305 if len(hs.Extensions) > 0 { 306 httpWriteHeaderKey(bw, headerSecExtensions) 307 httphead.WriteOptions(bw, hs.Extensions) 308 bw.WriteString(crlf) 309 } 310 if header != nil { 311 header(bw) 312 } 313 314 bw.WriteString(crlf) 315 } 316 317 func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) { 318 switch code { 319 case http.StatusBadRequest: 320 bw.WriteString(textHeadBadRequest) 321 case http.StatusInternalServerError: 322 bw.WriteString(textHeadInternalServerError) 323 case http.StatusUpgradeRequired: 324 bw.WriteString(textHeadUpgradeRequired) 325 default: 326 writeStatusText(bw, code) 327 } 328 329 // Write custom headers. 330 if header != nil { 331 header(bw) 332 } 333 334 switch err { 335 case ErrHandshakeBadProtocol: 336 bw.WriteString(textTailErrHandshakeBadProtocol) 337 case ErrHandshakeBadMethod: 338 bw.WriteString(textTailErrHandshakeBadMethod) 339 case ErrHandshakeBadHost: 340 bw.WriteString(textTailErrHandshakeBadHost) 341 case ErrHandshakeBadUpgrade: 342 bw.WriteString(textTailErrHandshakeBadUpgrade) 343 case ErrHandshakeBadConnection: 344 bw.WriteString(textTailErrHandshakeBadConnection) 345 case ErrHandshakeBadSecAccept: 346 bw.WriteString(textTailErrHandshakeBadSecAccept) 347 case ErrHandshakeBadSecKey: 348 bw.WriteString(textTailErrHandshakeBadSecKey) 349 case ErrHandshakeBadSecVersion: 350 bw.WriteString(textTailErrHandshakeBadSecVersion) 351 case ErrHandshakeUpgradeRequired: 352 bw.WriteString(textTailErrUpgradeRequired) 353 case nil: 354 bw.WriteString(crlf) 355 default: 356 writeErrorText(bw, err) 357 } 358 } 359 360 func writeStatusText(bw *bufio.Writer, code int) { 361 bw.WriteString("HTTP/1.1 ") 362 bw.WriteString(strconv.Itoa(code)) 363 bw.WriteByte(' ') 364 bw.WriteString(http.StatusText(code)) 365 bw.WriteString(crlf) 366 bw.WriteString("Content-Type: text/plain; charset=utf-8") 367 bw.WriteString(crlf) 368 } 369 370 func writeErrorText(bw *bufio.Writer, err error) { 371 body := err.Error() 372 bw.WriteString("Content-Length: ") 373 bw.WriteString(strconv.Itoa(len(body))) 374 bw.WriteString(crlf) 375 bw.WriteString(crlf) 376 bw.WriteString(body) 377 } 378 379 // httpError is like the http.Error with WebSocket context exception. 380 func httpError(w http.ResponseWriter, body string, code int) { 381 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 382 w.Header().Set("Content-Length", strconv.Itoa(len(body))) 383 w.WriteHeader(code) 384 w.Write([]byte(body)) 385 } 386 387 // statusText is a non-performant status text generator. 388 // NOTE: Used only to generate constants. 389 func statusText(code int) string { 390 var buf bytes.Buffer 391 bw := bufio.NewWriter(&buf) 392 writeStatusText(bw, code) 393 bw.Flush() 394 return buf.String() 395 } 396 397 // errorText is a non-performant error text generator. 398 // NOTE: Used only to generate constants. 399 func errorText(err error) string { 400 var buf bytes.Buffer 401 bw := bufio.NewWriter(&buf) 402 writeErrorText(bw, err) 403 bw.Flush() 404 return buf.String() 405 } 406 407 // HandshakeHeader is the interface that writes both upgrade request or 408 // response headers into a given io.Writer. 409 type HandshakeHeader interface { 410 io.WriterTo 411 } 412 413 // HandshakeHeaderString is an adapter to allow the use of headers represented 414 // by ordinary string as HandshakeHeader. 415 type HandshakeHeaderString string 416 417 // WriteTo implements HandshakeHeader (and io.WriterTo) interface. 418 func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) { 419 n, err := io.WriteString(w, string(s)) 420 return int64(n), err 421 } 422 423 // HandshakeHeaderBytes is an adapter to allow the use of headers represented 424 // by ordinary slice of bytes as HandshakeHeader. 425 type HandshakeHeaderBytes []byte 426 427 // WriteTo implements HandshakeHeader (and io.WriterTo) interface. 428 func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) { 429 n, err := w.Write(b) 430 return int64(n), err 431 } 432 433 // HandshakeHeaderFunc is an adapter to allow the use of headers represented by 434 // ordinary function as HandshakeHeader. 435 type HandshakeHeaderFunc func(io.Writer) (int64, error) 436 437 // WriteTo implements HandshakeHeader (and io.WriterTo) interface. 438 func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) { 439 return f(w) 440 } 441 442 // HandshakeHeaderHTTP is an adapter to allow the use of http.Header as 443 // HandshakeHeader. 444 type HandshakeHeaderHTTP http.Header 445 446 // WriteTo implements HandshakeHeader (and io.WriterTo) interface. 447 func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) { 448 wr := writer{w: w} 449 err := http.Header(h).Write(&wr) 450 return wr.n, err 451 } 452 453 type writer struct { 454 n int64 455 w io.Writer 456 } 457 458 func (w *writer) WriteString(s string) (int, error) { 459 n, err := io.WriteString(w.w, s) 460 w.n += int64(n) 461 return n, err 462 } 463 464 func (w *writer) Write(p []byte) (int, error) { 465 n, err := w.w.Write(p) 466 w.n += int64(n) 467 return n, err 468 }