github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/server_test.go (about) 1 package ws 2 3 import ( 4 "bufio" 5 "bytes" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "math/rand" 10 "net" 11 "net/http" 12 "net/http/httptest" 13 "net/http/httputil" 14 "reflect" 15 "sort" 16 "strconv" 17 "strings" 18 "sync/atomic" 19 "testing" 20 _ "unsafe" // for go:linkname 21 22 "github.com/gobwas/httphead" 23 ) 24 25 // TODO(gobwas): upgradeGenericCase with methods like configureUpgrader, 26 // configureHTTPUpgrader. 27 type upgradeCase struct { 28 label string 29 30 protocol func(string) bool 31 negotiate func(httphead.Option) (httphead.Option, error) 32 onRequest func(u []byte) error 33 onHost func(h []byte) error 34 onHeader func(k, v []byte) error 35 36 nonce []byte 37 removeSecKey bool 38 badSecKey bool 39 secKeyHeader string 40 41 req *http.Request 42 res *http.Response 43 hs Handshake 44 err error 45 } 46 47 var upgradeCases = []upgradeCase{ 48 { 49 label: "base", 50 nonce: mustMakeNonce(), 51 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 52 headerUpgrade: []string{"websocket"}, 53 headerConnection: []string{"Upgrade"}, 54 headerSecVersion: []string{"13"}, 55 }), 56 res: mustMakeResponse(101, http.Header{ 57 headerUpgrade: []string{"websocket"}, 58 headerConnection: []string{"Upgrade"}, 59 }), 60 }, 61 { 62 label: "base_canonical", 63 nonce: mustMakeNonce(), 64 secKeyHeader: headerSecKeyCanonical, 65 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 66 headerUpgrade: []string{"websocket"}, 67 headerConnection: []string{"Upgrade"}, 68 headerSecVersionCanonical: []string{"13"}, 69 }), 70 res: mustMakeResponse(101, http.Header{ 71 headerUpgrade: []string{"websocket"}, 72 headerConnection: []string{"Upgrade"}, 73 }), 74 }, 75 { 76 label: "lowercase_headers", 77 nonce: mustMakeNonce(), 78 secKeyHeader: strings.ToLower(headerSecKey), 79 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 80 strings.ToLower(headerUpgrade): []string{"websocket"}, 81 strings.ToLower(headerConnection): []string{"Upgrade"}, 82 strings.ToLower(headerSecVersion): []string{"13"}, 83 }), 84 res: mustMakeResponse(101, http.Header{ 85 headerUpgrade: []string{"websocket"}, 86 headerConnection: []string{"Upgrade"}, 87 }), 88 }, 89 { 90 label: "uppercase", 91 protocol: func(sub string) bool { return true }, 92 nonce: mustMakeNonce(), 93 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 94 headerUpgrade: []string{"WEBSOCKET"}, 95 headerConnection: []string{"UPGRADE"}, 96 headerSecVersion: []string{"13"}, 97 }), 98 res: mustMakeResponse(101, http.Header{ 99 headerUpgrade: []string{"websocket"}, 100 headerConnection: []string{"Upgrade"}, 101 }), 102 }, 103 { 104 label: "subproto", 105 protocol: SelectFromSlice([]string{"b", "d"}), 106 nonce: mustMakeNonce(), 107 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 108 headerUpgrade: []string{"websocket"}, 109 headerConnection: []string{"Upgrade"}, 110 headerSecVersion: []string{"13"}, 111 headerSecProtocol: []string{"a", "b", "c", "d"}, 112 }), 113 res: mustMakeResponse(101, http.Header{ 114 headerUpgrade: []string{"websocket"}, 115 headerConnection: []string{"Upgrade"}, 116 headerSecProtocol: []string{"b"}, 117 }), 118 hs: Handshake{Protocol: "b"}, 119 }, 120 { 121 label: "subproto_lowercase_headers", 122 protocol: SelectFromSlice([]string{"b", "d"}), 123 nonce: mustMakeNonce(), 124 secKeyHeader: strings.ToLower(headerSecKey), 125 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 126 strings.ToLower(headerUpgrade): []string{"websocket"}, 127 strings.ToLower(headerConnection): []string{"Upgrade"}, 128 strings.ToLower(headerSecVersion): []string{"13"}, 129 strings.ToLower(headerSecProtocol): []string{"a", "b", "c", "d"}, 130 }), 131 res: mustMakeResponse(101, http.Header{ 132 headerUpgrade: []string{"websocket"}, 133 headerConnection: []string{"Upgrade"}, 134 headerSecProtocol: []string{"b"}, 135 }), 136 hs: Handshake{Protocol: "b"}, 137 }, 138 { 139 label: "subproto_comma", 140 protocol: SelectFromSlice([]string{"b", "d"}), 141 nonce: mustMakeNonce(), 142 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 143 headerUpgrade: []string{"websocket"}, 144 headerConnection: []string{"Upgrade"}, 145 headerSecVersion: []string{"13"}, 146 headerSecProtocol: []string{"a, b, c, d"}, 147 }), 148 res: mustMakeResponse(101, http.Header{ 149 headerUpgrade: []string{"websocket"}, 150 headerConnection: []string{"Upgrade"}, 151 headerSecProtocol: []string{"b"}, 152 }), 153 hs: Handshake{Protocol: "b"}, 154 }, 155 { 156 negotiate: func(opt httphead.Option) (ret httphead.Option, err error) { 157 switch string(opt.Name) { 158 case "b", "d": 159 return opt.Clone(), nil 160 default: 161 return ret, nil 162 } 163 }, 164 nonce: mustMakeNonce(), 165 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 166 headerUpgrade: []string{"websocket"}, 167 headerConnection: []string{"Upgrade"}, 168 headerSecVersion: []string{"13"}, 169 headerSecExtensions: []string{"a;foo=1", "b;bar=2", "c", "d;baz=3"}, 170 }), 171 res: mustMakeResponse(101, http.Header{ 172 headerUpgrade: []string{"websocket"}, 173 headerConnection: []string{"Upgrade"}, 174 headerSecExtensions: []string{"b;bar=2,d;baz=3"}, 175 }), 176 hs: Handshake{ 177 Extensions: []httphead.Option{ 178 httphead.NewOption("b", map[string]string{ 179 "bar": "2", 180 }), 181 httphead.NewOption("d", map[string]string{ 182 "baz": "3", 183 }), 184 }, 185 }, 186 }, 187 188 // Error cases. 189 // ------------ 190 191 { 192 label: "bad_http_method", 193 nonce: mustMakeNonce(), 194 req: mustMakeRequest("POST", "ws://example.org", http.Header{ 195 headerUpgrade: []string{"websocket"}, 196 headerConnection: []string{"Upgrade"}, 197 headerSecVersion: []string{"13"}, 198 }), 199 res: mustMakeErrResponse(405, ErrHandshakeBadMethod, nil), 200 err: ErrHandshakeBadMethod, 201 }, 202 { 203 label: "bad_http_proto", 204 nonce: mustMakeNonce(), 205 req: setProto(1, 0, mustMakeRequest("GET", "ws://example.org", http.Header{ 206 headerUpgrade: []string{"websocket"}, 207 headerConnection: []string{"Upgrade"}, 208 headerSecVersion: []string{"13"}, 209 })), 210 res: mustMakeErrResponse(505, ErrHandshakeBadProtocol, nil), 211 err: ErrHandshakeBadProtocol, 212 }, 213 { 214 label: "bad_host", 215 nonce: mustMakeNonce(), 216 req: withoutHeader("Host", mustMakeRequest("GET", "ws://example.org", http.Header{ 217 headerUpgrade: []string{"websocket"}, 218 headerConnection: []string{"Upgrade"}, 219 headerSecVersion: []string{"13"}, 220 })), 221 res: mustMakeErrResponse(400, ErrHandshakeBadHost, nil), 222 err: ErrHandshakeBadHost, 223 }, 224 { 225 label: "bad_upgrade", 226 nonce: mustMakeNonce(), 227 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 228 headerConnection: []string{"Upgrade"}, 229 headerSecVersion: []string{"13"}, 230 }), 231 res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil), 232 err: ErrHandshakeBadUpgrade, 233 }, 234 { 235 label: "bad_upgrade", 236 nonce: mustMakeNonce(), 237 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 238 "X-Custom-Header": []string{"value"}, 239 headerConnection: []string{"Upgrade"}, 240 headerSecVersion: []string{"13"}, 241 }), 242 243 onRequest: func([]byte) error { return nil }, 244 onHost: func([]byte) error { return nil }, 245 onHeader: func(k, v []byte) error { return nil }, 246 247 res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil), 248 err: ErrHandshakeBadUpgrade, 249 }, 250 { 251 label: "bad_upgrade", 252 nonce: mustMakeNonce(), 253 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 254 headerUpgrade: []string{"not-websocket"}, 255 headerConnection: []string{"Upgrade"}, 256 headerSecVersion: []string{"13"}, 257 }), 258 res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil), 259 err: ErrHandshakeBadUpgrade, 260 }, 261 { 262 label: "bad_connection", 263 nonce: mustMakeNonce(), 264 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 265 headerUpgrade: []string{"websocket"}, 266 headerSecVersion: []string{"13"}, 267 }), 268 res: mustMakeErrResponse(400, ErrHandshakeBadConnection, nil), 269 err: ErrHandshakeBadConnection, 270 }, 271 { 272 label: "bad_connection", 273 nonce: mustMakeNonce(), 274 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 275 headerUpgrade: []string{"websocket"}, 276 headerConnection: []string{"not-upgrade"}, 277 headerSecVersion: []string{"13"}, 278 }), 279 res: mustMakeErrResponse(400, ErrHandshakeBadConnection, nil), 280 err: ErrHandshakeBadConnection, 281 }, 282 { 283 label: "bad_sec_version_x", 284 nonce: mustMakeNonce(), 285 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 286 headerUpgrade: []string{"websocket"}, 287 headerConnection: []string{"Upgrade"}, 288 }), 289 res: mustMakeErrResponse(400, ErrHandshakeBadSecVersion, nil), 290 err: ErrHandshakeBadSecVersion, 291 }, 292 { 293 label: "bad_sec_version", 294 nonce: mustMakeNonce(), 295 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 296 headerUpgrade: []string{"websocket"}, 297 headerConnection: []string{"upgrade"}, 298 headerSecVersion: []string{"15"}, 299 }), 300 res: mustMakeErrResponse(426, ErrHandshakeBadSecVersion, http.Header{ 301 headerSecVersion: []string{"13"}, 302 }), 303 err: ErrHandshakeUpgradeRequired, 304 }, 305 { 306 label: "bad_sec_key", 307 nonce: mustMakeNonce(), 308 removeSecKey: true, 309 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 310 headerUpgrade: []string{"websocket"}, 311 headerConnection: []string{"Upgrade"}, 312 headerSecVersion: []string{"13"}, 313 }), 314 res: mustMakeErrResponse(400, ErrHandshakeBadSecKey, nil), 315 err: ErrHandshakeBadSecKey, 316 }, 317 { 318 label: "bad_sec_key", 319 nonce: mustMakeNonce(), 320 badSecKey: true, 321 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 322 headerUpgrade: []string{"websocket"}, 323 headerConnection: []string{"Upgrade"}, 324 headerSecVersion: []string{"13"}, 325 }), 326 res: mustMakeErrResponse(400, ErrHandshakeBadSecKey, nil), 327 err: ErrHandshakeBadSecKey, 328 }, 329 { 330 label: "bad_ws_extension", 331 nonce: mustMakeNonce(), 332 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 333 headerUpgrade: []string{"websocket"}, 334 headerConnection: []string{"Upgrade"}, 335 headerSecVersion: []string{"13"}, 336 headerSecExtensions: []string{"=["}, 337 }), 338 negotiate: func(opt httphead.Option) (ret httphead.Option, err error) { 339 return ret, nil 340 }, 341 res: mustMakeErrResponse(400, ErrMalformedRequest, nil), 342 err: ErrMalformedRequest, 343 }, 344 { 345 label: "bad_subprotocol", 346 nonce: mustMakeNonce(), 347 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 348 headerUpgrade: []string{"websocket"}, 349 headerConnection: []string{"Upgrade"}, 350 headerSecVersion: []string{"13"}, 351 headerSecProtocol: []string{"=["}, 352 }), 353 protocol: func(string) bool { 354 return false 355 }, 356 res: mustMakeErrResponse(400, ErrMalformedRequest, nil), 357 err: ErrMalformedRequest, 358 }, 359 } 360 361 func TestHTTPUpgrader(t *testing.T) { 362 for _, test := range upgradeCases { 363 t.Run(test.label, func(t *testing.T) { 364 if !test.removeSecKey { 365 nonce := test.nonce 366 if test.badSecKey { 367 nonce = nonce[:nonceSize-1] 368 } 369 if test.secKeyHeader == "" { 370 test.secKeyHeader = headerSecKey 371 } 372 test.req.Header[test.secKeyHeader] = []string{string(nonce)} 373 } 374 if test.err == nil { 375 test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))} 376 } 377 378 // Need to emulate http server read request for truth test. 379 // 380 // We use dumpRequest here because test.req.Write is always send 381 // http/1.1 proto version, that does not fits all our testing 382 // cases. 383 reqBytes := dumpRequest(test.req) 384 req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqBytes))) 385 if err != nil { 386 t.Fatal(err) 387 } 388 389 res := newRecorder() 390 391 u := HTTPUpgrader{ 392 Protocol: test.protocol, 393 Negotiate: test.negotiate, 394 } 395 _, _, hs, err := u.Upgrade(req, res) 396 if test.err != err { 397 t.Errorf( 398 "expected error to be '%v', got '%v';\non request:\n====\n%s\n====", 399 test.err, err, dumpRequest(req), 400 ) 401 return 402 } 403 404 actRespBts := sortHeaders(res.Bytes()) 405 expRespBts := sortHeaders(dumpResponse(test.res)) 406 if !bytes.Equal(actRespBts, expRespBts) { 407 t.Errorf( 408 "unexpected http response:\n---- act:\n%s\n---- want:\n%s\n==== on request:\n%s\n====", 409 actRespBts, expRespBts, dumpRequest(test.req), 410 ) 411 return 412 } 413 414 if act, exp := hs.Protocol, test.hs.Protocol; act != exp { 415 t.Errorf("handshake protocol is %q want %q", act, exp) 416 } 417 if act, exp := len(hs.Extensions), len(test.hs.Extensions); act != exp { 418 t.Errorf("handshake got %d extensions; want %d", act, exp) 419 } else { 420 for i := 0; i < act; i++ { 421 if act, exp := hs.Extensions[i], test.hs.Extensions[i]; !act.Equal(exp) { 422 t.Errorf("handshake %d-th extension is %s; want %s", i, act, exp) 423 } 424 } 425 } 426 }) 427 } 428 } 429 430 func TestUpgrader(t *testing.T) { 431 for _, test := range upgradeCases { 432 t.Run(test.label, func(t *testing.T) { 433 if !test.removeSecKey { 434 nonce := test.nonce[:] 435 if test.badSecKey { 436 nonce = nonce[:nonceSize-1] 437 } 438 test.req.Header[headerSecKey] = []string{string(nonce)} 439 } 440 if test.err == nil { 441 test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))} 442 } 443 444 u := Upgrader{ 445 Protocol: func(p []byte) bool { 446 return test.protocol(string(p)) 447 }, 448 Negotiate: test.negotiate, 449 OnHeader: test.onHeader, 450 OnRequest: test.onRequest, 451 } 452 453 // We use dumpRequest here because test.req.Write is always send 454 // http/1.1 proto version, that does not fits all our testing 455 // cases. 456 reqBytes := dumpRequest(test.req) 457 conn := bytes.NewBuffer(reqBytes) 458 459 hs, err := u.Upgrade(conn) 460 if test.err != err { 461 462 t.Errorf("expected error to be '%v', got '%v'", test.err, err) 463 return 464 } 465 466 actRespBts := sortHeaders(conn.Bytes()) 467 expRespBts := sortHeaders(dumpResponse(test.res)) 468 if !bytes.Equal(actRespBts, expRespBts) { 469 t.Errorf( 470 "unexpected http response:\n---- act:\n%s\n---- want:\n%s\n==== on request:\n%s\n====", 471 actRespBts, expRespBts, dumpRequest(test.req), 472 ) 473 return 474 } 475 476 if act, exp := hs.Protocol, test.hs.Protocol; act != exp { 477 t.Errorf("handshake protocol is %q want %q", act, exp) 478 } 479 if act, exp := len(hs.Extensions), len(test.hs.Extensions); act != exp { 480 t.Errorf("handshake got %d extensions; want %d", act, exp) 481 } else { 482 for i := 0; i < act; i++ { 483 if act, exp := hs.Extensions[i], test.hs.Extensions[i]; !act.Equal(exp) { 484 t.Errorf("handshake %d-th extension is %s; want %s", i, act, exp) 485 } 486 } 487 } 488 }) 489 } 490 } 491 492 func BenchmarkHTTPUpgrader(b *testing.B) { 493 for _, bench := range upgradeCases { 494 bench.req.Header.Set(headerSecKey, string(bench.nonce[:])) 495 496 u := HTTPUpgrader{ 497 Protocol: bench.protocol, 498 Negotiate: bench.negotiate, 499 } 500 501 b.Run(bench.label, func(b *testing.B) { 502 res := make([]http.ResponseWriter, b.N) 503 for i := 0; i < b.N; i++ { 504 res[i] = newRecorder() 505 } 506 507 i := new(int64) 508 509 b.ResetTimer() 510 b.ReportAllocs() 511 b.RunParallel(func(pb *testing.PB) { 512 for pb.Next() { 513 w := res[atomic.AddInt64(i, 1)-1] 514 u.Upgrade(bench.req, w) 515 } 516 }) 517 }) 518 } 519 } 520 521 func BenchmarkUpgrader(b *testing.B) { 522 for _, bench := range upgradeCases { 523 bench.req.Header.Set(headerSecKey, string(bench.nonce[:])) 524 525 u := Upgrader{ 526 Protocol: func(p []byte) bool { 527 return bench.protocol(btsToString(p)) 528 }, 529 Negotiate: bench.negotiate, 530 } 531 532 reqBytes := dumpRequest(bench.req) 533 534 type benchReadWriter struct { 535 io.Reader 536 io.Writer 537 } 538 539 b.Run(bench.label, func(b *testing.B) { 540 conn := make([]io.ReadWriter, b.N) 541 for i := 0; i < b.N; i++ { 542 conn[i] = benchReadWriter{bytes.NewReader(reqBytes), ioutil.Discard} 543 } 544 545 i := new(int64) 546 547 b.ResetTimer() 548 b.ReportAllocs() 549 b.RunParallel(func(pb *testing.PB) { 550 for pb.Next() { 551 c := conn[atomic.AddInt64(i, 1)-1] 552 u.Upgrade(c) 553 } 554 }) 555 }) 556 } 557 } 558 559 func TestHttpStrSelectProtocol(t *testing.T) { 560 for i, test := range []struct { 561 header string 562 }{ 563 {"jsonrpc, soap, grpc"}, 564 } { 565 t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { 566 exp := strings.Split(test.header, ",") 567 for i, p := range exp { 568 exp[i] = strings.TrimSpace(p) 569 } 570 571 var calls []string 572 strSelectProtocol(test.header, func(s string) bool { 573 calls = append(calls, s) 574 return false 575 }) 576 577 if !reflect.DeepEqual(calls, exp) { 578 t.Errorf("selectProtocol(%q, fn); called fn with %v; want %v", test.header, calls, exp) 579 } 580 }) 581 } 582 } 583 584 func BenchmarkSelectProtocol(b *testing.B) { 585 for _, bench := range []struct { 586 label string 587 header string 588 acceptStr func(string) bool 589 acceptBts func([]byte) bool 590 }{ 591 { 592 label: "never accept", 593 header: "jsonrpc, soap, grpc", 594 acceptStr: func(s string) bool { 595 return len(s)%2 == 2 // never ok 596 }, 597 acceptBts: func(v []byte) bool { 598 return len(v)%2 == 2 // never ok 599 }, 600 }, 601 { 602 label: "from slice", 603 header: "a, b, c, d, e, f, g", 604 acceptStr: SelectFromSlice([]string{"g", "f", "e", "d"}), 605 }, 606 { 607 label: "uniq 1024 from slise", 608 header: strings.Join(randProtocols(1024, 16), ", "), 609 acceptStr: SelectFromSlice(randProtocols(1024, 17)), 610 }, 611 } { 612 b.Run(fmt.Sprintf("String/%s", bench.label), func(b *testing.B) { 613 for i := 0; i < b.N; i++ { 614 strSelectProtocol(bench.header, bench.acceptStr) 615 } 616 }) 617 if bench.acceptBts != nil { 618 b.Run(fmt.Sprintf("Bytes/%s", bench.label), func(b *testing.B) { 619 h := []byte(bench.header) 620 b.StartTimer() 621 622 for i := 0; i < b.N; i++ { 623 btsSelectProtocol(h, bench.acceptBts) 624 } 625 }) 626 } 627 } 628 } 629 630 func randProtocols(n, m int) []string { 631 ret := make([]string, n) 632 bts := make([]byte, m) 633 uniq := map[string]bool{} 634 for i := 0; i < n; i++ { 635 for { 636 for j := 0; j < m; j++ { 637 bts[j] = byte(rand.Intn('x'-'a') + 'a') 638 } 639 str := string(bts) 640 if _, has := uniq[str]; !has { 641 ret[i] = str 642 break 643 } 644 } 645 } 646 return ret 647 } 648 649 func dumpRequest(req *http.Request) []byte { 650 bts, err := httputil.DumpRequest(req, true) 651 if err != nil { 652 panic(err) 653 } 654 return bts 655 } 656 657 func dumpResponse(res *http.Response) []byte { 658 if !res.Close { 659 for _, v := range res.Header[headerConnection] { 660 if v == "close" { 661 res.Close = true 662 break 663 } 664 } 665 } 666 bts, err := httputil.DumpResponse(res, true) 667 if err != nil { 668 panic(err) 669 } 670 if !res.Close { 671 bts = bytes.Replace(bts, []byte("Connection: close\r\n"), nil, -1) 672 } 673 674 return bts 675 } 676 677 type headersBytes [][]byte 678 679 func (h headersBytes) Len() int { return len(h) } 680 func (h headersBytes) Swap(i, j int) { h[i], h[j] = h[j], h[i] } 681 func (h headersBytes) Less(i, j int) bool { return bytes.Compare(h[i], h[j]) == -1 } 682 683 func maskHeader(bts []byte, key, mask string) []byte { 684 lines := bytes.Split(bts, []byte("\r\n")) 685 for i, line := range lines { 686 pair := bytes.Split(line, []byte(": ")) 687 if string(pair[0]) == key { 688 lines[i] = []byte(key + ": " + mask) 689 } 690 } 691 return bytes.Join(lines, []byte("\r\n")) 692 } 693 694 func sortHeaders(bts []byte) []byte { 695 lines := bytes.Split(bts, []byte("\r\n")) 696 if len(lines) <= 1 { 697 return bts 698 } 699 sort.Sort(headersBytes(lines[1 : len(lines)-2])) 700 return bytes.Join(lines, []byte("\r\n")) 701 } 702 703 //go:linkname httpPutBufioReader net/http.putBufioReader 704 func httpPutBufioReader(*bufio.Reader) 705 706 //go:linkname httpPutBufioWriter net/http.putBufioWriter 707 func httpPutBufioWriter(*bufio.Writer) 708 709 //go:linkname httpNewBufioReader net/http.newBufioReader 710 func httpNewBufioReader(io.Reader) *bufio.Reader 711 712 //go:linkname httpNewBufioWriterSize net/http.newBufioWriterSize 713 func httpNewBufioWriterSize(io.Writer, int) *bufio.Writer 714 715 type recorder struct { 716 *httptest.ResponseRecorder 717 hijacked bool 718 conn func(*bytes.Buffer) net.Conn 719 } 720 721 func newRecorder() *recorder { 722 return &recorder{ 723 ResponseRecorder: httptest.NewRecorder(), 724 } 725 } 726 727 func (r *recorder) Bytes() []byte { 728 if r.hijacked { 729 return r.ResponseRecorder.Body.Bytes() 730 } 731 732 // TODO(gobwas): remove this when support for go 1.7 will end. 733 resp := r.Result() 734 cs := strings.TrimSpace(resp.Header.Get("Content-Length")) 735 if n, err := strconv.ParseInt(cs, 10, 64); err == nil { 736 resp.ContentLength = n 737 } else { 738 resp.ContentLength = -1 739 } 740 741 return dumpResponse(resp) 742 } 743 744 func (r *recorder) Hijack() (conn net.Conn, brw *bufio.ReadWriter, err error) { 745 if r.hijacked { 746 err = fmt.Errorf("already hijacked") 747 return 748 } 749 750 r.hijacked = true 751 752 var buf *bytes.Buffer 753 if r.ResponseRecorder != nil { 754 buf = r.ResponseRecorder.Body 755 } 756 757 if r.conn != nil { 758 conn = r.conn(buf) 759 } else { 760 conn = stubConn{ 761 read: buf.Read, 762 write: buf.Write, 763 close: func() error { return nil }, 764 } 765 } 766 767 // Use httpNewBufio* linked functions here to make 768 // benchmark more closer to real life usage. 769 br := httpNewBufioReader(conn) 770 bw := httpNewBufioWriterSize(conn, 4<<10) 771 772 brw = bufio.NewReadWriter(br, bw) 773 774 return 775 } 776 777 func mustMakeRequest(method, url string, headers http.Header) *http.Request { 778 req, err := http.NewRequest(method, url, nil) 779 if err != nil { 780 panic(err) 781 } 782 req.Header = headers 783 return req 784 } 785 786 func setProto(major, minor int, req *http.Request) *http.Request { 787 req.ProtoMajor = major 788 req.ProtoMinor = minor 789 return req 790 } 791 792 func withoutHeader(header string, req *http.Request) *http.Request { 793 if strings.EqualFold(header, "Host") { 794 req.URL.Host = "" 795 req.Host = "" 796 } else { 797 delete(req.Header, header) 798 } 799 return req 800 } 801 802 func mustMakeResponse(code int, headers http.Header) *http.Response { 803 res := &http.Response{ 804 StatusCode: code, 805 Status: http.StatusText(code), 806 Header: headers, 807 ProtoMajor: 1, 808 ProtoMinor: 1, 809 ContentLength: -1, 810 } 811 return res 812 } 813 814 func mustMakeErrResponse(code int, err error, headers http.Header) *http.Response { 815 // Body text. 816 body := err.Error() 817 818 res := &http.Response{ 819 StatusCode: code, 820 Status: http.StatusText(code), 821 Header: http.Header{ 822 "Content-Type": []string{"text/plain; charset=utf-8"}, 823 }, 824 ProtoMajor: 1, 825 ProtoMinor: 1, 826 ContentLength: int64(len(body)), 827 } 828 res.Body = ioutil.NopCloser( 829 strings.NewReader(body), 830 ) 831 for k, v := range headers { 832 res.Header[k] = v 833 } 834 return res 835 } 836 837 func mustMakeNonce() (ret []byte) { 838 ret = make([]byte, nonceSize) 839 initNonce(ret) 840 return 841 }