github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/server_test.go (about) 1 package ws 2 3 import ( 4 "bufio" 5 "bytes" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "math/rand" 10 "net" 11 "net/http" 12 "net/http/httptest" 13 "net/http/httputil" 14 "reflect" 15 "sort" 16 "strconv" 17 "strings" 18 "sync/atomic" 19 "testing" 20 _ "unsafe" // for go:linkname 21 22 "github.com/ezoic/httphead" 23 ) 24 25 // TODO(ezoic): upgradeGenericCase with methods like configureUpgrader, 26 // configureHTTPUpgrader. 27 type upgradeCase struct { 28 label string 29 30 protocol func(string) bool 31 extension func(httphead.Option) bool 32 onRequest func(u []byte) error 33 onHost func(h []byte) error 34 onHeader func(k, v []byte) error 35 36 nonce []byte 37 removeSecKey bool 38 badSecKey bool 39 secKeyHeader string 40 41 req *http.Request 42 res *http.Response 43 hs Handshake 44 err error 45 } 46 47 var upgradeCases = []upgradeCase{ 48 { 49 label: "base", 50 nonce: mustMakeNonce(), 51 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 52 headerUpgrade: []string{"websocket"}, 53 headerConnection: []string{"Upgrade"}, 54 headerSecVersion: []string{"13"}, 55 }), 56 res: mustMakeResponse(101, http.Header{ 57 headerUpgrade: []string{"websocket"}, 58 headerConnection: []string{"Upgrade"}, 59 }), 60 }, 61 { 62 label: "base_canonical", 63 nonce: mustMakeNonce(), 64 secKeyHeader: headerSecKeyCanonical, 65 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 66 headerUpgrade: []string{"websocket"}, 67 headerConnection: []string{"Upgrade"}, 68 headerSecVersionCanonical: []string{"13"}, 69 }), 70 res: mustMakeResponse(101, http.Header{ 71 headerUpgrade: []string{"websocket"}, 72 headerConnection: []string{"Upgrade"}, 73 }), 74 }, 75 { 76 label: "lowercase_headers", 77 nonce: mustMakeNonce(), 78 secKeyHeader: strings.ToLower(headerSecKey), 79 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 80 strings.ToLower(headerUpgrade): []string{"websocket"}, 81 strings.ToLower(headerConnection): []string{"Upgrade"}, 82 strings.ToLower(headerSecVersion): []string{"13"}, 83 }), 84 res: mustMakeResponse(101, http.Header{ 85 headerUpgrade: []string{"websocket"}, 86 headerConnection: []string{"Upgrade"}, 87 }), 88 }, 89 { 90 label: "uppercase", 91 protocol: func(sub string) bool { return true }, 92 nonce: mustMakeNonce(), 93 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 94 headerUpgrade: []string{"WEBSOCKET"}, 95 headerConnection: []string{"UPGRADE"}, 96 headerSecVersion: []string{"13"}, 97 }), 98 res: mustMakeResponse(101, http.Header{ 99 headerUpgrade: []string{"websocket"}, 100 headerConnection: []string{"Upgrade"}, 101 }), 102 }, 103 { 104 label: "subproto", 105 protocol: SelectFromSlice([]string{"b", "d"}), 106 nonce: mustMakeNonce(), 107 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 108 headerUpgrade: []string{"websocket"}, 109 headerConnection: []string{"Upgrade"}, 110 headerSecVersion: []string{"13"}, 111 headerSecProtocol: []string{"a", "b", "c", "d"}, 112 }), 113 res: mustMakeResponse(101, http.Header{ 114 headerUpgrade: []string{"websocket"}, 115 headerConnection: []string{"Upgrade"}, 116 headerSecProtocol: []string{"b"}, 117 }), 118 hs: Handshake{Protocol: "b"}, 119 }, 120 { 121 label: "subproto_lowercase_headers", 122 protocol: SelectFromSlice([]string{"b", "d"}), 123 nonce: mustMakeNonce(), 124 secKeyHeader: strings.ToLower(headerSecKey), 125 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 126 strings.ToLower(headerUpgrade): []string{"websocket"}, 127 strings.ToLower(headerConnection): []string{"Upgrade"}, 128 strings.ToLower(headerSecVersion): []string{"13"}, 129 strings.ToLower(headerSecProtocol): []string{"a", "b", "c", "d"}, 130 }), 131 res: mustMakeResponse(101, http.Header{ 132 headerUpgrade: []string{"websocket"}, 133 headerConnection: []string{"Upgrade"}, 134 headerSecProtocol: []string{"b"}, 135 }), 136 hs: Handshake{Protocol: "b"}, 137 }, 138 { 139 label: "subproto_comma", 140 protocol: SelectFromSlice([]string{"b", "d"}), 141 nonce: mustMakeNonce(), 142 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 143 headerUpgrade: []string{"websocket"}, 144 headerConnection: []string{"Upgrade"}, 145 headerSecVersion: []string{"13"}, 146 headerSecProtocol: []string{"a, b, c, d"}, 147 }), 148 res: mustMakeResponse(101, http.Header{ 149 headerUpgrade: []string{"websocket"}, 150 headerConnection: []string{"Upgrade"}, 151 headerSecProtocol: []string{"b"}, 152 }), 153 hs: Handshake{Protocol: "b"}, 154 }, 155 { 156 extension: func(opt httphead.Option) bool { 157 switch string(opt.Name) { 158 case "b", "d": 159 return true 160 default: 161 return false 162 } 163 }, 164 nonce: mustMakeNonce(), 165 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 166 headerUpgrade: []string{"websocket"}, 167 headerConnection: []string{"Upgrade"}, 168 headerSecVersion: []string{"13"}, 169 headerSecExtensions: []string{"a;foo=1", "b;bar=2", "c", "d;baz=3"}, 170 }), 171 res: mustMakeResponse(101, http.Header{ 172 headerUpgrade: []string{"websocket"}, 173 headerConnection: []string{"Upgrade"}, 174 headerSecExtensions: []string{"b;bar=2,d;baz=3"}, 175 }), 176 hs: Handshake{ 177 Extensions: []httphead.Option{ 178 httphead.NewOption("b", map[string]string{ 179 "bar": "2", 180 }), 181 httphead.NewOption("d", map[string]string{ 182 "baz": "3", 183 }), 184 }, 185 }, 186 }, 187 188 // Error cases. 189 // ------------ 190 191 { 192 label: "bad_http_method", 193 nonce: mustMakeNonce(), 194 req: mustMakeRequest("POST", "ws://example.org", http.Header{ 195 headerUpgrade: []string{"websocket"}, 196 headerConnection: []string{"Upgrade"}, 197 headerSecVersion: []string{"13"}, 198 }), 199 res: mustMakeErrResponse(405, ErrHandshakeBadMethod, nil), 200 err: ErrHandshakeBadMethod, 201 }, 202 { 203 label: "bad_http_proto", 204 nonce: mustMakeNonce(), 205 req: setProto(1, 0, mustMakeRequest("GET", "ws://example.org", http.Header{ 206 headerUpgrade: []string{"websocket"}, 207 headerConnection: []string{"Upgrade"}, 208 headerSecVersion: []string{"13"}, 209 })), 210 res: mustMakeErrResponse(505, ErrHandshakeBadProtocol, nil), 211 err: ErrHandshakeBadProtocol, 212 }, 213 { 214 label: "bad_host", 215 nonce: mustMakeNonce(), 216 req: withoutHeader("Host", mustMakeRequest("GET", "ws://example.org", http.Header{ 217 headerUpgrade: []string{"websocket"}, 218 headerConnection: []string{"Upgrade"}, 219 headerSecVersion: []string{"13"}, 220 })), 221 res: mustMakeErrResponse(400, ErrHandshakeBadHost, nil), 222 err: ErrHandshakeBadHost, 223 }, 224 { 225 label: "bad_upgrade", 226 nonce: mustMakeNonce(), 227 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 228 headerConnection: []string{"Upgrade"}, 229 headerSecVersion: []string{"13"}, 230 }), 231 res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil), 232 err: ErrHandshakeBadUpgrade, 233 }, 234 { 235 label: "bad_upgrade", 236 nonce: mustMakeNonce(), 237 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 238 "X-Custom-Header": []string{"value"}, 239 headerConnection: []string{"Upgrade"}, 240 headerSecVersion: []string{"13"}, 241 }), 242 243 onRequest: func([]byte) error { return nil }, 244 onHost: func([]byte) error { return nil }, 245 onHeader: func(k, v []byte) error { return nil }, 246 247 res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil), 248 err: ErrHandshakeBadUpgrade, 249 }, 250 { 251 label: "bad_upgrade", 252 nonce: mustMakeNonce(), 253 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 254 headerUpgrade: []string{"not-websocket"}, 255 headerConnection: []string{"Upgrade"}, 256 headerSecVersion: []string{"13"}, 257 }), 258 res: mustMakeErrResponse(400, ErrHandshakeBadUpgrade, nil), 259 err: ErrHandshakeBadUpgrade, 260 }, 261 { 262 label: "bad_connection", 263 nonce: mustMakeNonce(), 264 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 265 headerUpgrade: []string{"websocket"}, 266 headerSecVersion: []string{"13"}, 267 }), 268 res: mustMakeErrResponse(400, ErrHandshakeBadConnection, nil), 269 err: ErrHandshakeBadConnection, 270 }, 271 { 272 label: "bad_connection", 273 nonce: mustMakeNonce(), 274 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 275 headerUpgrade: []string{"websocket"}, 276 headerConnection: []string{"not-upgrade"}, 277 headerSecVersion: []string{"13"}, 278 }), 279 res: mustMakeErrResponse(400, ErrHandshakeBadConnection, nil), 280 err: ErrHandshakeBadConnection, 281 }, 282 { 283 label: "bad_sec_version_x", 284 nonce: mustMakeNonce(), 285 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 286 headerUpgrade: []string{"websocket"}, 287 headerConnection: []string{"Upgrade"}, 288 }), 289 res: mustMakeErrResponse(400, ErrHandshakeBadSecVersion, nil), 290 err: ErrHandshakeBadSecVersion, 291 }, 292 { 293 label: "bad_sec_version", 294 nonce: mustMakeNonce(), 295 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 296 headerUpgrade: []string{"websocket"}, 297 headerConnection: []string{"upgrade"}, 298 headerSecVersion: []string{"15"}, 299 }), 300 res: mustMakeErrResponse(426, ErrHandshakeBadSecVersion, http.Header{ 301 headerSecVersion: []string{"13"}, 302 }), 303 err: ErrHandshakeUpgradeRequired, 304 }, 305 { 306 label: "bad_sec_key", 307 nonce: mustMakeNonce(), 308 removeSecKey: true, 309 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 310 headerUpgrade: []string{"websocket"}, 311 headerConnection: []string{"Upgrade"}, 312 headerSecVersion: []string{"13"}, 313 }), 314 res: mustMakeErrResponse(400, ErrHandshakeBadSecKey, nil), 315 err: ErrHandshakeBadSecKey, 316 }, 317 { 318 label: "bad_sec_key", 319 nonce: mustMakeNonce(), 320 badSecKey: true, 321 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 322 headerUpgrade: []string{"websocket"}, 323 headerConnection: []string{"Upgrade"}, 324 headerSecVersion: []string{"13"}, 325 }), 326 res: mustMakeErrResponse(400, ErrHandshakeBadSecKey, nil), 327 err: ErrHandshakeBadSecKey, 328 }, 329 { 330 label: "bad_ws_extension", 331 nonce: mustMakeNonce(), 332 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 333 headerUpgrade: []string{"websocket"}, 334 headerConnection: []string{"Upgrade"}, 335 headerSecVersion: []string{"13"}, 336 headerSecExtensions: []string{"=["}, 337 }), 338 extension: func(opt httphead.Option) bool { 339 return false 340 }, 341 res: mustMakeErrResponse(400, ErrMalformedRequest, nil), 342 err: ErrMalformedRequest, 343 }, 344 { 345 label: "bad_subprotocol", 346 nonce: mustMakeNonce(), 347 req: mustMakeRequest("GET", "ws://example.org", http.Header{ 348 headerUpgrade: []string{"websocket"}, 349 headerConnection: []string{"Upgrade"}, 350 headerSecVersion: []string{"13"}, 351 headerSecProtocol: []string{"=["}, 352 }), 353 protocol: func(string) bool { 354 return false 355 }, 356 res: mustMakeErrResponse(400, ErrMalformedRequest, nil), 357 err: ErrMalformedRequest, 358 }, 359 } 360 361 func TestHTTPUpgrader(t *testing.T) { 362 for _, test := range upgradeCases { 363 t.Run(test.label, func(t *testing.T) { 364 if !test.removeSecKey { 365 nonce := test.nonce 366 if test.badSecKey { 367 nonce = nonce[:nonceSize-1] 368 } 369 if test.secKeyHeader == "" { 370 test.secKeyHeader = headerSecKey 371 } 372 test.req.Header[test.secKeyHeader] = []string{string(nonce)} 373 } 374 if test.err == nil { 375 test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))} 376 } 377 378 // Need to emulate http server read request for truth test. 379 // 380 // We use dumpRequest here because test.req.Write is always send 381 // http/1.1 proto version, that does not fits all our testing 382 // cases. 383 reqBytes := dumpRequest(test.req) 384 req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqBytes))) 385 if err != nil { 386 t.Fatal(err) 387 } 388 389 res := newRecorder() 390 391 u := HTTPUpgrader{ 392 Protocol: test.protocol, 393 Extension: test.extension, 394 } 395 _, _, hs, err := u.Upgrade(req, res) 396 if test.err != err { 397 t.Errorf( 398 "expected error to be '%v', got '%v';\non request:\n====\n%s\n====", 399 test.err, err, dumpRequest(req), 400 ) 401 return 402 } 403 404 actRespBts := sortHeaders(res.Bytes()) 405 expRespBts := sortHeaders(dumpResponse(test.res)) 406 if !bytes.Equal(actRespBts, expRespBts) { 407 t.Errorf( 408 "unexpected http response:\n---- act:\n%s\n---- want:\n%s\n==== on request:\n%s\n====", 409 actRespBts, expRespBts, dumpRequest(test.req), 410 ) 411 return 412 } 413 414 if act, exp := hs.Protocol, test.hs.Protocol; act != exp { 415 t.Errorf("handshake protocol is %q want %q", act, exp) 416 } 417 if act, exp := len(hs.Extensions), len(test.hs.Extensions); act != exp { 418 t.Errorf("handshake got %d extensions; want %d", act, exp) 419 } else { 420 for i := 0; i < act; i++ { 421 if act, exp := hs.Extensions[i], test.hs.Extensions[i]; !act.Equal(exp) { 422 t.Errorf("handshake %d-th extension is %s; want %s", i, act, exp) 423 } 424 } 425 } 426 }) 427 } 428 } 429 430 func TestUpgrader(t *testing.T) { 431 for _, test := range upgradeCases { 432 t.Run(test.label, func(t *testing.T) { 433 if !test.removeSecKey { 434 nonce := test.nonce[:] 435 if test.badSecKey { 436 nonce = nonce[:nonceSize-1] 437 } 438 test.req.Header[headerSecKey] = []string{string(nonce)} 439 } 440 if test.err == nil { 441 test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))} 442 } 443 444 u := Upgrader{ 445 Protocol: func(p []byte) bool { 446 return test.protocol(string(p)) 447 }, 448 Extension: func(e httphead.Option) bool { 449 return test.extension(e) 450 }, 451 OnHeader: test.onHeader, 452 OnRequest: test.onRequest, 453 } 454 455 // We use dumpRequest here because test.req.Write is always send 456 // http/1.1 proto version, that does not fits all our testing 457 // cases. 458 reqBytes := dumpRequest(test.req) 459 conn := bytes.NewBuffer(reqBytes) 460 461 hs, err := u.Upgrade(conn) 462 if test.err != err { 463 464 t.Errorf("expected error to be '%v', got '%v'", test.err, err) 465 return 466 } 467 468 actRespBts := sortHeaders(conn.Bytes()) 469 expRespBts := sortHeaders(dumpResponse(test.res)) 470 if !bytes.Equal(actRespBts, expRespBts) { 471 t.Errorf( 472 "unexpected http response:\n---- act:\n%s\n---- want:\n%s\n==== on request:\n%s\n====", 473 actRespBts, expRespBts, dumpRequest(test.req), 474 ) 475 return 476 } 477 478 if act, exp := hs.Protocol, test.hs.Protocol; act != exp { 479 t.Errorf("handshake protocol is %q want %q", act, exp) 480 } 481 if act, exp := len(hs.Extensions), len(test.hs.Extensions); act != exp { 482 t.Errorf("handshake got %d extensions; want %d", act, exp) 483 } else { 484 for i := 0; i < act; i++ { 485 if act, exp := hs.Extensions[i], test.hs.Extensions[i]; !act.Equal(exp) { 486 t.Errorf("handshake %d-th extension is %s; want %s", i, act, exp) 487 } 488 } 489 } 490 }) 491 } 492 } 493 494 func BenchmarkHTTPUpgrader(b *testing.B) { 495 for _, bench := range upgradeCases { 496 bench.req.Header.Set(headerSecKey, string(bench.nonce[:])) 497 498 u := HTTPUpgrader{ 499 Protocol: bench.protocol, 500 Extension: bench.extension, 501 } 502 503 b.Run(bench.label, func(b *testing.B) { 504 res := make([]http.ResponseWriter, b.N) 505 for i := 0; i < b.N; i++ { 506 res[i] = newRecorder() 507 } 508 509 i := new(int64) 510 511 b.ResetTimer() 512 b.ReportAllocs() 513 b.RunParallel(func(pb *testing.PB) { 514 for pb.Next() { 515 w := res[atomic.AddInt64(i, 1)-1] 516 u.Upgrade(bench.req, w) 517 } 518 }) 519 }) 520 } 521 } 522 523 func BenchmarkUpgrader(b *testing.B) { 524 for _, bench := range upgradeCases { 525 bench.req.Header.Set(headerSecKey, string(bench.nonce[:])) 526 527 u := Upgrader{ 528 Protocol: func(p []byte) bool { 529 return bench.protocol(btsToString(p)) 530 }, 531 Extension: func(e httphead.Option) bool { 532 return bench.extension(e) 533 }, 534 } 535 536 reqBytes := dumpRequest(bench.req) 537 538 type benchReadWriter struct { 539 io.Reader 540 io.Writer 541 } 542 543 b.Run(bench.label, func(b *testing.B) { 544 conn := make([]io.ReadWriter, b.N) 545 for i := 0; i < b.N; i++ { 546 conn[i] = benchReadWriter{bytes.NewReader(reqBytes), ioutil.Discard} 547 } 548 549 i := new(int64) 550 551 b.ResetTimer() 552 b.ReportAllocs() 553 b.RunParallel(func(pb *testing.PB) { 554 for pb.Next() { 555 c := conn[atomic.AddInt64(i, 1)-1] 556 u.Upgrade(c) 557 } 558 }) 559 }) 560 } 561 } 562 563 func TestHttpStrSelectProtocol(t *testing.T) { 564 for i, test := range []struct { 565 header string 566 }{ 567 {"jsonrpc, soap, grpc"}, 568 } { 569 t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { 570 exp := strings.Split(test.header, ",") 571 for i, p := range exp { 572 exp[i] = strings.TrimSpace(p) 573 } 574 575 var calls []string 576 strSelectProtocol(test.header, func(s string) bool { 577 calls = append(calls, s) 578 return false 579 }) 580 581 if !reflect.DeepEqual(calls, exp) { 582 t.Errorf("selectProtocol(%q, fn); called fn with %v; want %v", test.header, calls, exp) 583 } 584 }) 585 } 586 } 587 588 func BenchmarkSelectProtocol(b *testing.B) { 589 for _, bench := range []struct { 590 label string 591 header string 592 acceptStr func(string) bool 593 acceptBts func([]byte) bool 594 }{ 595 { 596 label: "never accept", 597 header: "jsonrpc, soap, grpc", 598 acceptStr: func(s string) bool { 599 return len(s)%2 == 2 // never ok 600 }, 601 acceptBts: func(v []byte) bool { 602 return len(v)%2 == 2 // never ok 603 }, 604 }, 605 { 606 label: "from slice", 607 header: "a, b, c, d, e, f, g", 608 acceptStr: SelectFromSlice([]string{"g", "f", "e", "d"}), 609 }, 610 { 611 label: "uniq 1024 from slise", 612 header: strings.Join(randProtocols(1024, 16), ", "), 613 acceptStr: SelectFromSlice(randProtocols(1024, 17)), 614 }, 615 } { 616 b.Run(fmt.Sprintf("String/%s", bench.label), func(b *testing.B) { 617 for i := 0; i < b.N; i++ { 618 strSelectProtocol(bench.header, bench.acceptStr) 619 } 620 }) 621 if bench.acceptBts != nil { 622 b.Run(fmt.Sprintf("Bytes/%s", bench.label), func(b *testing.B) { 623 h := []byte(bench.header) 624 b.StartTimer() 625 626 for i := 0; i < b.N; i++ { 627 btsSelectProtocol(h, bench.acceptBts) 628 } 629 }) 630 } 631 } 632 } 633 634 func randProtocols(n, m int) []string { 635 ret := make([]string, n) 636 bts := make([]byte, m) 637 uniq := map[string]bool{} 638 for i := 0; i < n; i++ { 639 for { 640 for j := 0; j < m; j++ { 641 bts[j] = byte(rand.Intn('x'-'a') + 'a') 642 } 643 str := string(bts) 644 if _, has := uniq[str]; !has { 645 ret[i] = str 646 break 647 } 648 } 649 } 650 return ret 651 } 652 653 func dumpRequest(req *http.Request) []byte { 654 bts, err := httputil.DumpRequest(req, true) 655 if err != nil { 656 panic(err) 657 } 658 return bts 659 } 660 661 func dumpResponse(res *http.Response) []byte { 662 if !res.Close { 663 for _, v := range res.Header[headerConnection] { 664 if v == "close" { 665 res.Close = true 666 break 667 } 668 } 669 } 670 bts, err := httputil.DumpResponse(res, true) 671 if err != nil { 672 panic(err) 673 } 674 if !res.Close { 675 bts = bytes.Replace(bts, []byte("Connection: close\r\n"), nil, -1) 676 } 677 678 return bts 679 } 680 681 type headersBytes [][]byte 682 683 func (h headersBytes) Len() int { return len(h) } 684 func (h headersBytes) Swap(i, j int) { h[i], h[j] = h[j], h[i] } 685 func (h headersBytes) Less(i, j int) bool { return bytes.Compare(h[i], h[j]) == -1 } 686 687 func maskHeader(bts []byte, key, mask string) []byte { 688 lines := bytes.Split(bts, []byte("\r\n")) 689 for i, line := range lines { 690 pair := bytes.Split(line, []byte(": ")) 691 if string(pair[0]) == key { 692 lines[i] = []byte(key + ": " + mask) 693 } 694 } 695 return bytes.Join(lines, []byte("\r\n")) 696 } 697 698 func sortHeaders(bts []byte) []byte { 699 lines := bytes.Split(bts, []byte("\r\n")) 700 if len(lines) <= 1 { 701 return bts 702 } 703 sort.Sort(headersBytes(lines[1 : len(lines)-2])) 704 return bytes.Join(lines, []byte("\r\n")) 705 } 706 707 //go:linkname httpPutBufioReader net/http.putBufioReader 708 func httpPutBufioReader(*bufio.Reader) 709 710 //go:linkname httpPutBufioWriter net/http.putBufioWriter 711 func httpPutBufioWriter(*bufio.Writer) 712 713 //go:linkname httpNewBufioReader net/http.newBufioReader 714 func httpNewBufioReader(io.Reader) *bufio.Reader 715 716 //go:linkname httpNewBufioWriterSize net/http.newBufioWriterSize 717 func httpNewBufioWriterSize(io.Writer, int) *bufio.Writer 718 719 type recorder struct { 720 *httptest.ResponseRecorder 721 hijacked bool 722 conn func(*bytes.Buffer) net.Conn 723 } 724 725 func newRecorder() *recorder { 726 return &recorder{ 727 ResponseRecorder: httptest.NewRecorder(), 728 } 729 } 730 731 func (r *recorder) Bytes() []byte { 732 if r.hijacked { 733 return r.ResponseRecorder.Body.Bytes() 734 } 735 736 // TODO(ezoic): remove this when support for go 1.7 will end. 737 resp := r.Result() 738 cs := strings.TrimSpace(resp.Header.Get("Content-Length")) 739 if n, err := strconv.ParseInt(cs, 10, 64); err == nil { 740 resp.ContentLength = n 741 } else { 742 resp.ContentLength = -1 743 } 744 745 return dumpResponse(resp) 746 } 747 748 func (r *recorder) Hijack() (conn net.Conn, brw *bufio.ReadWriter, err error) { 749 if r.hijacked { 750 err = fmt.Errorf("already hijacked") 751 return 752 } 753 754 r.hijacked = true 755 756 var buf *bytes.Buffer 757 if r.ResponseRecorder != nil { 758 buf = r.ResponseRecorder.Body 759 } 760 761 if r.conn != nil { 762 conn = r.conn(buf) 763 } else { 764 conn = stubConn{ 765 read: buf.Read, 766 write: buf.Write, 767 close: func() error { return nil }, 768 } 769 } 770 771 // Use httpNewBufio* linked functions here to make 772 // benchmark more closer to real life usage. 773 br := httpNewBufioReader(conn) 774 bw := httpNewBufioWriterSize(conn, 4<<10) 775 776 brw = bufio.NewReadWriter(br, bw) 777 778 return 779 } 780 781 func mustMakeRequest(method, url string, headers http.Header) *http.Request { 782 req, err := http.NewRequest(method, url, nil) 783 if err != nil { 784 panic(err) 785 } 786 req.Header = headers 787 return req 788 } 789 790 func setProto(major, minor int, req *http.Request) *http.Request { 791 req.ProtoMajor = major 792 req.ProtoMinor = minor 793 return req 794 } 795 796 func withoutHeader(header string, req *http.Request) *http.Request { 797 if strings.EqualFold(header, "Host") { 798 req.URL.Host = "" 799 req.Host = "" 800 } else { 801 delete(req.Header, header) 802 } 803 return req 804 } 805 806 func mustMakeResponse(code int, headers http.Header) *http.Response { 807 res := &http.Response{ 808 StatusCode: code, 809 Status: http.StatusText(code), 810 Header: headers, 811 ProtoMajor: 1, 812 ProtoMinor: 1, 813 ContentLength: -1, 814 } 815 return res 816 } 817 818 func mustMakeErrResponse(code int, err error, headers http.Header) *http.Response { 819 // Body text. 820 body := err.Error() 821 822 res := &http.Response{ 823 StatusCode: code, 824 Status: http.StatusText(code), 825 Header: http.Header{ 826 "Content-Type": []string{"text/plain; charset=utf-8"}, 827 }, 828 ProtoMajor: 1, 829 ProtoMinor: 1, 830 ContentLength: int64(len(body)), 831 } 832 res.Body = ioutil.NopCloser( 833 strings.NewReader(body), 834 ) 835 for k, v := range headers { 836 res.Header[k] = v 837 } 838 return res 839 } 840 841 func mustMakeNonce() (ret []byte) { 842 ret = make([]byte, nonceSize) 843 initNonce(ret) 844 return 845 }