github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/server.go (about) 1 package ws 2 3 import ( 4 "bufio" 5 "bytes" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "strings" 11 "time" 12 13 "github.com/ezoic/httphead" 14 "github.com/ezoic/pool/pbufio" 15 ) 16 17 // Constants used by ConnUpgrader. 18 const ( 19 DefaultServerReadBufferSize = 4096 20 DefaultServerWriteBufferSize = 512 21 ) 22 23 // Errors used by both client and server when preparing WebSocket handshake. 24 var ( 25 ErrHandshakeBadProtocol = RejectConnectionError( 26 RejectionStatus(http.StatusHTTPVersionNotSupported), 27 RejectionReason(fmt.Sprintf("handshake error: bad HTTP protocol version")), 28 ) 29 ErrHandshakeBadMethod = RejectConnectionError( 30 RejectionStatus(http.StatusMethodNotAllowed), 31 RejectionReason(fmt.Sprintf("handshake error: bad HTTP request method")), 32 ) 33 ErrHandshakeBadHost = RejectConnectionError( 34 RejectionStatus(http.StatusBadRequest), 35 RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)), 36 ) 37 ErrHandshakeBadUpgrade = RejectConnectionError( 38 RejectionStatus(http.StatusBadRequest), 39 RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)), 40 ) 41 ErrHandshakeBadConnection = RejectConnectionError( 42 RejectionStatus(http.StatusBadRequest), 43 RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)), 44 ) 45 ErrHandshakeBadSecAccept = RejectConnectionError( 46 RejectionStatus(http.StatusBadRequest), 47 RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)), 48 ) 49 ErrHandshakeBadSecKey = RejectConnectionError( 50 RejectionStatus(http.StatusBadRequest), 51 RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)), 52 ) 53 ErrHandshakeBadSecVersion = RejectConnectionError( 54 RejectionStatus(http.StatusBadRequest), 55 RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)), 56 ) 57 ) 58 59 // ErrMalformedResponse is returned by Dialer to indicate that server response 60 // can not be parsed. 61 var ErrMalformedResponse = fmt.Errorf("malformed HTTP response") 62 63 // ErrMalformedRequest is returned when HTTP request can not be parsed. 64 var ErrMalformedRequest = RejectConnectionError( 65 RejectionStatus(http.StatusBadRequest), 66 RejectionReason("malformed HTTP request"), 67 ) 68 69 // ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that 70 // connection is rejected because given WebSocket version is malformed. 71 // 72 // According to RFC6455: 73 // If this version does not match a version understood by the server, the 74 // server MUST abort the WebSocket handshake described in this section and 75 // instead send an appropriate HTTP error code (such as 426 Upgrade Required) 76 // and a |Sec-WebSocket-Version| header field indicating the version(s) the 77 // server is capable of understanding. 78 var ErrHandshakeUpgradeRequired = RejectConnectionError( 79 RejectionStatus(http.StatusUpgradeRequired), 80 RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")), 81 RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)), 82 ) 83 84 // ErrNotHijacker is an error returned when http.ResponseWriter does not 85 // implement http.Hijacker interface. 86 var ErrNotHijacker = RejectConnectionError( 87 RejectionStatus(http.StatusInternalServerError), 88 RejectionReason("given http.ResponseWriter is not a http.Hijacker"), 89 ) 90 91 // DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by 92 // UpgradeHTTP function. 93 var DefaultHTTPUpgrader HTTPUpgrader 94 95 // UpgradeHTTP is like HTTPUpgrader{}.Upgrade(). 96 func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) { 97 return DefaultHTTPUpgrader.Upgrade(r, w) 98 } 99 100 // DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade 101 // function. 102 var DefaultUpgrader Upgrader 103 104 // Upgrade is like Upgrader{}.Upgrade(). 105 func Upgrade(conn io.ReadWriter) (Handshake, error) { 106 return DefaultUpgrader.Upgrade(conn) 107 } 108 109 // HTTPUpgrader contains options for upgrading connection to websocket from 110 // net/http Handler arguments. 111 type HTTPUpgrader struct { 112 // Timeout is the maximum amount of time an Upgrade() will spent while 113 // writing handshake response. 114 // 115 // The default is no timeout. 116 Timeout time.Duration 117 118 // Header is an optional http.Header mapping that could be used to 119 // write additional headers to the handshake response. 120 // 121 // Note that if present, it will be written in any result of handshake. 122 Header http.Header 123 124 // Protocol is the select function that is used to select subprotocol from 125 // list requested by client. If this field is set, then the first matched 126 // protocol is sent to a client as negotiated. 127 Protocol func(string) bool 128 129 // Extension is the select function that is used to select extensions from 130 // list requested by client. If this field is set, then the all matched 131 // extensions are sent to a client as negotiated. 132 Extension func(httphead.Option) bool 133 } 134 135 // Upgrade upgrades http connection to the websocket connection. 136 // 137 // It hijacks net.Conn from w and returns received net.Conn and 138 // bufio.ReadWriter. On successful handshake it returns Handshake struct 139 // describing handshake info. 140 func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) { 141 // Hijack connection first to get the ability to write rejection errors the 142 // same way as in Upgrader. 143 hj, ok := w.(http.Hijacker) 144 if ok { 145 conn, rw, err = hj.Hijack() 146 } else { 147 err = ErrNotHijacker 148 } 149 if err != nil { 150 httpError(w, err.Error(), http.StatusInternalServerError) 151 return 152 } 153 154 // See https://tools.ietf.org/html/rfc6455#section-4.1 155 // The method of the request MUST be GET, and the HTTP version MUST be at least 1.1. 156 var nonce string 157 if r.Method != http.MethodGet { 158 err = ErrHandshakeBadMethod 159 } else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) { 160 err = ErrHandshakeBadProtocol 161 } else if r.Host == "" { 162 err = ErrHandshakeBadHost 163 } else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") { 164 err = ErrHandshakeBadUpgrade 165 } else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") { 166 err = ErrHandshakeBadConnection 167 } else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize { 168 err = ErrHandshakeBadSecKey 169 } else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" { 170 // According to RFC6455: 171 // 172 // If this version does not match a version understood by the server, 173 // the server MUST abort the WebSocket handshake described in this 174 // section and instead send an appropriate HTTP error code (such as 426 175 // Upgrade Required) and a |Sec-WebSocket-Version| header field 176 // indicating the version(s) the server is capable of understanding. 177 // 178 // So we branching here cause empty or not present version does not 179 // meet the ABNF rules of RFC6455: 180 // 181 // version = DIGIT | (NZDIGIT DIGIT) | 182 // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT) 183 // ; Limited to 0-255 range, with no leading zeros 184 // 185 // That is, if version is really invalid – we sent 426 status, if it 186 // not present or empty – it is 400. 187 if v != "" { 188 err = ErrHandshakeUpgradeRequired 189 } else { 190 err = ErrHandshakeBadSecVersion 191 } 192 } 193 if check := u.Protocol; err == nil && check != nil { 194 ps := r.Header[headerSecProtocolCanonical] 195 for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ { 196 var ok bool 197 hs.Protocol, ok = strSelectProtocol(ps[i], check) 198 if !ok { 199 err = ErrMalformedRequest 200 } 201 } 202 } 203 if check := u.Extension; err == nil && check != nil { 204 xs := r.Header[headerSecExtensionsCanonical] 205 for i := 0; i < len(xs) && err == nil; i++ { 206 var ok bool 207 hs.Extensions, ok = strSelectExtensions(xs[i], hs.Extensions, check) 208 if !ok { 209 err = ErrMalformedRequest 210 } 211 } 212 } 213 214 // Clear deadlines set by server. 215 conn.SetDeadline(noDeadline) 216 if t := u.Timeout; t != 0 { 217 conn.SetWriteDeadline(time.Now().Add(t)) 218 defer conn.SetWriteDeadline(noDeadline) 219 } 220 221 var header handshakeHeader 222 if h := u.Header; h != nil { 223 header[0] = HandshakeHeaderHTTP(h) 224 } 225 if err == nil { 226 httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo) 227 err = rw.Writer.Flush() 228 } else { 229 var code int 230 if rej, ok := err.(*rejectConnectionError); ok { 231 code = rej.code 232 header[1] = rej.header 233 } 234 if code == 0 { 235 code = http.StatusInternalServerError 236 } 237 httpWriteResponseError(rw.Writer, err, code, header.WriteTo) 238 // Do not store Flush() error to not override already existing one. 239 rw.Writer.Flush() 240 } 241 return 242 } 243 244 // Upgrader contains options for upgrading connection to websocket. 245 type Upgrader struct { 246 // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. 247 // They used to read and write http data while upgrading to WebSocket. 248 // Allocated buffers are pooled with sync.Pool to avoid extra allocations. 249 // 250 // If a size is zero then default value is used. 251 // 252 // Usually it is useful to set read buffer size bigger than write buffer 253 // size because incoming request could contain long header values, such as 254 // Cookie. Response, in other way, could be big only if user write multiple 255 // custom headers. Usually response takes less than 256 bytes. 256 ReadBufferSize, WriteBufferSize int 257 258 // Protocol is a select function that is used to select subprotocol 259 // from list requested by client. If this field is set, then the first matched 260 // protocol is sent to a client as negotiated. 261 // 262 // The argument is only valid until the callback returns. 263 Protocol func([]byte) bool 264 265 // ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually. 266 // Note that returned bytes must be valid until Upgrade returns. 267 // If ProtocolCustom is set, it used instead of Protocol function. 268 ProtocolCustom func([]byte) (string, bool) 269 270 // Extension is a select function that is used to select extensions 271 // from list requested by client. If this field is set, then the all matched 272 // extensions are sent to a client as negotiated. 273 // 274 // The argument is only valid until the callback returns. 275 // 276 // According to the RFC6455 order of extensions passed by a client is 277 // significant. That is, returning true from this function means that no 278 // other extension with the same name should be checked because server 279 // accepted the most preferable extension right now: 280 // "Note that the order of extensions is significant. Any interactions between 281 // multiple extensions MAY be defined in the documents defining the extensions. 282 // In the absence of such definitions, the interpretation is that the header 283 // fields listed by the client in its request represent a preference of the 284 // header fields it wishes to use, with the first options listed being most 285 // preferable." 286 Extension func(httphead.Option) bool 287 288 // ExtensionCustorm allow user to parse Sec-WebSocket-Extensions header manually. 289 // Note that returned options should be valid until Upgrade returns. 290 // If ExtensionCustom is set, it used instead of Extension function. 291 ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool) 292 293 // Header is an optional HandshakeHeader instance that could be used to 294 // write additional headers to the handshake response. 295 // 296 // It used instead of any key-value mappings to avoid allocations in user 297 // land. 298 // 299 // Note that if present, it will be written in any result of handshake. 300 Header HandshakeHeader 301 302 // OnRequest is a callback that will be called after request line 303 // successful parsing. 304 // 305 // The arguments are only valid until the callback returns. 306 // 307 // If returned error is non-nil then connection is rejected and response is 308 // sent with appropriate HTTP error code and body set to error message. 309 // 310 // RejectConnectionError could be used to get more control on response. 311 OnRequest func(uri []byte) error 312 313 // OnHost is a callback that will be called after "Host" header successful 314 // parsing. 315 // 316 // It is separated from OnHeader callback because the Host header must be 317 // present in each request since HTTP/1.1. Thus Host header is non-optional 318 // and required for every WebSocket handshake. 319 // 320 // The arguments are only valid until the callback returns. 321 // 322 // If returned error is non-nil then connection is rejected and response is 323 // sent with appropriate HTTP error code and body set to error message. 324 // 325 // RejectConnectionError could be used to get more control on response. 326 OnHost func(host []byte) error 327 328 // OnHeader is a callback that will be called after successful parsing of 329 // header, that is not used during WebSocket handshake procedure. That is, 330 // it will be called with non-websocket headers, which could be relevant 331 // for application-level logic. 332 // 333 // The arguments are only valid until the callback returns. 334 // 335 // If returned error is non-nil then connection is rejected and response is 336 // sent with appropriate HTTP error code and body set to error message. 337 // 338 // RejectConnectionError could be used to get more control on response. 339 OnHeader func(key, value []byte) error 340 341 // OnBeforeUpgrade is a callback that will be called before sending 342 // successful upgrade response. 343 // 344 // Setting OnBeforeUpgrade allows user to make final application-level 345 // checks and decide whether this connection is allowed to successfully 346 // upgrade to WebSocket. 347 // 348 // It must return non-nil either HandshakeHeader or error and never both. 349 // 350 // If returned error is non-nil then connection is rejected and response is 351 // sent with appropriate HTTP error code and body set to error message. 352 // 353 // RejectConnectionError could be used to get more control on response. 354 OnBeforeUpgrade func() (header HandshakeHeader, err error) 355 } 356 357 // Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn 358 // as connection with incoming HTTP Upgrade request. 359 // 360 // It is a caller responsibility to manage i/o timeouts on conn. 361 // 362 // Non-nil error means that request for the WebSocket upgrade is invalid or 363 // malformed and usually connection should be closed. 364 // Even when error is non-nil Upgrade will write appropriate response into 365 // connection in compliance with RFC. 366 func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { 367 // headerSeen constants helps to report whether or not some header was seen 368 // during reading request bytes. 369 const ( 370 headerSeenHost = 1 << iota 371 headerSeenUpgrade 372 headerSeenConnection 373 headerSeenSecVersion 374 headerSeenSecKey 375 376 // headerSeenAll is the value that we expect to receive at the end of 377 // headers read/parse loop. 378 headerSeenAll = 0 | 379 headerSeenHost | 380 headerSeenUpgrade | 381 headerSeenConnection | 382 headerSeenSecVersion | 383 headerSeenSecKey 384 ) 385 386 // Prepare I/O buffers. 387 // TODO(ezoic): make it configurable. 388 br := pbufio.GetReader(conn, 389 nonZero(u.ReadBufferSize, DefaultServerReadBufferSize), 390 ) 391 bw := pbufio.GetWriter(conn, 392 nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize), 393 ) 394 defer func() { 395 pbufio.PutReader(br) 396 pbufio.PutWriter(bw) 397 }() 398 399 // Read HTTP request line like "GET /ws HTTP/1.1". 400 rl, err := readLine(br) 401 if err != nil { 402 return 403 } 404 // Parse request line data like HTTP version, uri and method. 405 req, err := httpParseRequestLine(rl) 406 if err != nil { 407 return 408 } 409 410 // Prepare stack-based handshake header list. 411 header := handshakeHeader{ 412 0: u.Header, 413 } 414 415 // Parse and check HTTP request. 416 // As RFC6455 says: 417 // The client's opening handshake consists of the following parts. If the 418 // server, while reading the handshake, finds that the client did not 419 // send a handshake that matches the description below (note that as per 420 // [RFC2616], the order of the header fields is not important), including 421 // but not limited to any violations of the ABNF grammar specified for 422 // the components of the handshake, the server MUST stop processing the 423 // client's handshake and return an HTTP response with an appropriate 424 // error code (such as 400 Bad Request). 425 // 426 // See https://tools.ietf.org/html/rfc6455#section-4.2.1 427 428 // An HTTP/1.1 or higher GET request, including a "Request-URI". 429 // 430 // Even if RFC says "1.1 or higher" without mentioning the part of the 431 // version, we apply it only to minor part. 432 switch { 433 case req.major != 1 || req.minor < 1: 434 // Abort processing the whole request because we do not even know how 435 // to actually parse it. 436 err = ErrHandshakeBadProtocol 437 438 case btsToString(req.method) != http.MethodGet: 439 err = ErrHandshakeBadMethod 440 441 default: 442 if onRequest := u.OnRequest; onRequest != nil { 443 err = onRequest(req.uri) 444 } 445 } 446 // Start headers read/parse loop. 447 var ( 448 // headerSeen reports which header was seen by setting corresponding 449 // bit on. 450 headerSeen byte 451 452 nonce = make([]byte, nonceSize) 453 ) 454 for err == nil { 455 line, e := readLine(br) 456 if e != nil { 457 return hs, e 458 } 459 if len(line) == 0 { 460 // Blank line, no more lines to read. 461 break 462 } 463 464 k, v, ok := httpParseHeaderLine(line) 465 if !ok { 466 err = ErrMalformedRequest 467 break 468 } 469 470 switch btsToString(k) { 471 case headerHostCanonical: 472 headerSeen |= headerSeenHost 473 if onHost := u.OnHost; onHost != nil { 474 err = onHost(v) 475 } 476 477 case headerUpgradeCanonical: 478 headerSeen |= headerSeenUpgrade 479 if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { 480 err = ErrHandshakeBadUpgrade 481 } 482 483 case headerConnectionCanonical: 484 headerSeen |= headerSeenConnection 485 if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) { 486 err = ErrHandshakeBadConnection 487 } 488 489 case headerSecVersionCanonical: 490 headerSeen |= headerSeenSecVersion 491 if !bytes.Equal(v, specHeaderValueSecVersion) { 492 err = ErrHandshakeUpgradeRequired 493 } 494 495 case headerSecKeyCanonical: 496 headerSeen |= headerSeenSecKey 497 if len(v) != nonceSize { 498 err = ErrHandshakeBadSecKey 499 } else { 500 copy(nonce[:], v) 501 } 502 503 case headerSecProtocolCanonical: 504 if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) { 505 var ok bool 506 if custom != nil { 507 hs.Protocol, ok = custom(v) 508 } else { 509 hs.Protocol, ok = btsSelectProtocol(v, check) 510 } 511 if !ok { 512 err = ErrMalformedRequest 513 } 514 } 515 516 case headerSecExtensionsCanonical: 517 if custom, check := u.ExtensionCustom, u.Extension; custom != nil || check != nil { 518 var ok bool 519 if custom != nil { 520 hs.Extensions, ok = custom(v, hs.Extensions) 521 } else { 522 hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check) 523 } 524 if !ok { 525 err = ErrMalformedRequest 526 } 527 } 528 529 default: 530 if onHeader := u.OnHeader; onHeader != nil { 531 err = onHeader(k, v) 532 } 533 } 534 } 535 switch { 536 case err == nil && headerSeen != headerSeenAll: 537 switch { 538 case headerSeen&headerSeenHost == 0: 539 // As RFC2616 says: 540 // A client MUST include a Host header field in all HTTP/1.1 541 // request messages. If the requested URI does not include an 542 // Internet host name for the service being requested, then the 543 // Host header field MUST be given with an empty value. An 544 // HTTP/1.1 proxy MUST ensure that any request message it 545 // forwards does contain an appropriate Host header field that 546 // identifies the service being requested by the proxy. All 547 // Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad 548 // Request) status code to any HTTP/1.1 request message which 549 // lacks a Host header field. 550 err = ErrHandshakeBadHost 551 case headerSeen&headerSeenUpgrade == 0: 552 err = ErrHandshakeBadUpgrade 553 case headerSeen&headerSeenConnection == 0: 554 err = ErrHandshakeBadConnection 555 case headerSeen&headerSeenSecVersion == 0: 556 // In case of empty or not present version we do not send 426 status, 557 // because it does not meet the ABNF rules of RFC6455: 558 // 559 // version = DIGIT | (NZDIGIT DIGIT) | 560 // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT) 561 // ; Limited to 0-255 range, with no leading zeros 562 // 563 // That is, if version is really invalid – we sent 426 status as above, if it 564 // not present – it is 400. 565 err = ErrHandshakeBadSecVersion 566 case headerSeen&headerSeenSecKey == 0: 567 err = ErrHandshakeBadSecKey 568 default: 569 panic("unknown headers state") 570 } 571 572 case err == nil && u.OnBeforeUpgrade != nil: 573 header[1], err = u.OnBeforeUpgrade() 574 } 575 if err != nil { 576 var code int 577 if rej, ok := err.(*rejectConnectionError); ok { 578 code = rej.code 579 header[1] = rej.header 580 } 581 if code == 0 { 582 code = http.StatusInternalServerError 583 } 584 httpWriteResponseError(bw, err, code, header.WriteTo) 585 // Do not store Flush() error to not override already existing one. 586 bw.Flush() 587 return 588 } 589 590 httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo) 591 err = bw.Flush() 592 593 return 594 } 595 596 type handshakeHeader [2]HandshakeHeader 597 598 func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) { 599 for i := 0; i < len(hs) && err == nil; i++ { 600 if h := hs[i]; h != nil { 601 var m int64 602 m, err = h.WriteTo(w) 603 n += m 604 } 605 } 606 return n, err 607 }