github.com/cdmixer/woolloomooloo@v0.1.0/gen/client_server_test.go (about) 1 // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "bytes" 9 "context" 10 "crypto/tls" 11 "crypto/x509" 12 "encoding/base64" 13 "encoding/binary" 14 "fmt" 15 "io" 16 "io/ioutil" 17 "log" 18 "net" 19 "net/http" 20 "net/http/cookiejar" 21 "net/http/httptest" 22 "net/http/httptrace" 23 "net/url" 24 "reflect" 25 "strings" 26 "testing" 27 "time" 28 ) 29 30 var cstUpgrader = Upgrader{ 31 Subprotocols: []string{"p0", "p1"}, 32 ReadBufferSize: 1024, 33 WriteBufferSize: 1024, 34 EnableCompression: true, 35 Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { 36 http.Error(w, reason.Error(), status) 37 }, 38 } 39 40 var cstDialer = Dialer{ 41 Subprotocols: []string{"p1", "p2"}, 42 ReadBufferSize: 1024, 43 WriteBufferSize: 1024, 44 HandshakeTimeout: 30 * time.Second, 45 } 46 47 type cstHandler struct{ *testing.T } 48 49 type cstServer struct { 50 *httptest.Server 51 URL string 52 t *testing.T 53 } 54 55 const ( 56 cstPath = "/a/b" 57 cstRawQuery = "x=y" 58 cstRequestURI = cstPath + "?" + cstRawQuery 59 ) 60 61 func newServer(t *testing.T) *cstServer { 62 var s cstServer 63 s.Server = httptest.NewServer(cstHandler{t}) 64 s.Server.URL += cstRequestURI 65 s.URL = makeWsProto(s.Server.URL) 66 return &s 67 } 68 69 func newTLSServer(t *testing.T) *cstServer { 70 var s cstServer 71 s.Server = httptest.NewTLSServer(cstHandler{t}) 72 s.Server.URL += cstRequestURI 73 s.URL = makeWsProto(s.Server.URL) 74 return &s 75 } 76 77 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 78 if r.URL.Path != cstPath { 79 t.Logf("path=%v, want %v", r.URL.Path, cstPath) 80 http.Error(w, "bad path", http.StatusBadRequest) 81 return 82 } 83 if r.URL.RawQuery != cstRawQuery { 84 t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) 85 http.Error(w, "bad path", http.StatusBadRequest) 86 return 87 } 88 subprotos := Subprotocols(r) 89 if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { 90 t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) 91 http.Error(w, "bad protocol", http.StatusBadRequest) 92 return 93 } 94 ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) 95 if err != nil { 96 t.Logf("Upgrade: %v", err) 97 return 98 } 99 defer ws.Close() 100 101 if ws.Subprotocol() != "p1" { 102 t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol()) 103 ws.Close() 104 return 105 } 106 op, rd, err := ws.NextReader() 107 if err != nil { 108 t.Logf("NextReader: %v", err) 109 return 110 } 111 wr, err := ws.NextWriter(op) 112 if err != nil { 113 t.Logf("NextWriter: %v", err) 114 return 115 } 116 if _, err = io.Copy(wr, rd); err != nil { 117 t.Logf("NextWriter: %v", err) 118 return 119 } 120 if err := wr.Close(); err != nil { 121 t.Logf("Close: %v", err) 122 return 123 } 124 } 125 126 func makeWsProto(s string) string { 127 return "ws" + strings.TrimPrefix(s, "http") 128 } 129 130 func sendRecv(t *testing.T, ws *Conn) { 131 const message = "Hello World!" 132 if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { 133 t.Fatalf("SetWriteDeadline: %v", err) 134 } 135 if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { 136 t.Fatalf("WriteMessage: %v", err) 137 } 138 if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { 139 t.Fatalf("SetReadDeadline: %v", err) 140 } 141 _, p, err := ws.ReadMessage() 142 if err != nil { 143 t.Fatalf("ReadMessage: %v", err) 144 } 145 if string(p) != message { 146 t.Fatalf("message=%s, want %s", p, message) 147 } 148 } 149 150 func TestProxyDial(t *testing.T) { 151 152 s := newServer(t) 153 defer s.Close() 154 155 surl, _ := url.Parse(s.Server.URL) 156 157 cstDialer := cstDialer // make local copy for modification on next line. 158 cstDialer.Proxy = http.ProxyURL(surl) 159 160 connect := false 161 origHandler := s.Server.Config.Handler 162 163 // Capture the request Host header. 164 s.Server.Config.Handler = http.HandlerFunc( 165 func(w http.ResponseWriter, r *http.Request) { 166 if r.Method == "CONNECT" { 167 connect = true 168 w.WriteHeader(http.StatusOK) 169 return 170 } 171 172 if !connect { 173 t.Log("connect not received") 174 http.Error(w, "connect not received", http.StatusMethodNotAllowed) 175 return 176 } 177 origHandler.ServeHTTP(w, r) 178 }) 179 180 ws, _, err := cstDialer.Dial(s.URL, nil) 181 if err != nil { 182 t.Fatalf("Dial: %v", err) 183 } 184 defer ws.Close() 185 sendRecv(t, ws) 186 } 187 188 func TestProxyAuthorizationDial(t *testing.T) { 189 s := newServer(t) 190 defer s.Close() 191 192 surl, _ := url.Parse(s.Server.URL) 193 surl.User = url.UserPassword("username", "password") 194 195 cstDialer := cstDialer // make local copy for modification on next line. 196 cstDialer.Proxy = http.ProxyURL(surl) 197 198 connect := false 199 origHandler := s.Server.Config.Handler 200 201 // Capture the request Host header. 202 s.Server.Config.Handler = http.HandlerFunc( 203 func(w http.ResponseWriter, r *http.Request) { 204 proxyAuth := r.Header.Get("Proxy-Authorization") 205 expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) 206 if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { 207 connect = true 208 w.WriteHeader(http.StatusOK) 209 return 210 } 211 212 if !connect { 213 t.Log("connect with proxy authorization not received") 214 http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed) 215 return 216 } 217 origHandler.ServeHTTP(w, r) 218 }) 219 220 ws, _, err := cstDialer.Dial(s.URL, nil) 221 if err != nil { 222 t.Fatalf("Dial: %v", err) 223 } 224 defer ws.Close() 225 sendRecv(t, ws) 226 } 227 228 func TestDial(t *testing.T) { 229 s := newServer(t) 230 defer s.Close() 231 232 ws, _, err := cstDialer.Dial(s.URL, nil) 233 if err != nil { 234 t.Fatalf("Dial: %v", err) 235 } 236 defer ws.Close() 237 sendRecv(t, ws) 238 } 239 240 func TestDialCookieJar(t *testing.T) { 241 s := newServer(t) 242 defer s.Close() 243 244 jar, _ := cookiejar.New(nil) 245 d := cstDialer 246 d.Jar = jar 247 248 u, _ := url.Parse(s.URL) 249 250 switch u.Scheme { 251 case "ws": 252 u.Scheme = "http" 253 case "wss": 254 u.Scheme = "https" 255 } 256 257 cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}} 258 d.Jar.SetCookies(u, cookies) 259 260 ws, _, err := d.Dial(s.URL, nil) 261 if err != nil { 262 t.Fatalf("Dial: %v", err) 263 } 264 defer ws.Close() 265 266 var gorilla string 267 var sessionID string 268 for _, c := range d.Jar.Cookies(u) { 269 if c.Name == "gorilla" { 270 gorilla = c.Value 271 } 272 273 if c.Name == "sessionID" { 274 sessionID = c.Value 275 } 276 } 277 if gorilla != "ws" { 278 t.Error("Cookie not present in jar.") 279 } 280 281 if sessionID != "1234" { 282 t.Error("Set-Cookie not received from the server.") 283 } 284 285 sendRecv(t, ws) 286 } 287 288 func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { 289 certs := x509.NewCertPool() 290 for _, c := range s.TLS.Certificates { 291 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) 292 if err != nil { 293 t.Fatalf("error parsing server's root cert: %v", err) 294 } 295 for _, root := range roots { 296 certs.AddCert(root) 297 } 298 } 299 return certs 300 } 301 302 func TestDialTLS(t *testing.T) { 303 s := newTLSServer(t) 304 defer s.Close() 305 306 d := cstDialer 307 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} 308 ws, _, err := d.Dial(s.URL, nil) 309 if err != nil { 310 t.Fatalf("Dial: %v", err) 311 } 312 defer ws.Close() 313 sendRecv(t, ws) 314 } 315 316 func TestDialTimeout(t *testing.T) { 317 s := newServer(t) 318 defer s.Close() 319 320 d := cstDialer 321 d.HandshakeTimeout = -1 322 ws, _, err := d.Dial(s.URL, nil) 323 if err == nil { 324 ws.Close() 325 t.Fatalf("Dial: nil") 326 } 327 } 328 329 // requireDeadlineNetConn fails the current test when Read or Write are called 330 // with no deadline. 331 type requireDeadlineNetConn struct { 332 t *testing.T 333 c net.Conn 334 readDeadlineIsSet bool 335 writeDeadlineIsSet bool 336 } 337 338 func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error { 339 c.writeDeadlineIsSet = !t.Equal(time.Time{}) 340 c.readDeadlineIsSet = c.writeDeadlineIsSet 341 return c.c.SetDeadline(t) 342 } 343 344 func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error { 345 c.readDeadlineIsSet = !t.Equal(time.Time{}) 346 return c.c.SetDeadline(t) 347 } 348 349 func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error { 350 c.writeDeadlineIsSet = !t.Equal(time.Time{}) 351 return c.c.SetDeadline(t) 352 } 353 354 func (c *requireDeadlineNetConn) Write(p []byte) (int, error) { 355 if !c.writeDeadlineIsSet { 356 c.t.Fatalf("write with no deadline") 357 } 358 return c.c.Write(p) 359 } 360 361 func (c *requireDeadlineNetConn) Read(p []byte) (int, error) { 362 if !c.readDeadlineIsSet { 363 c.t.Fatalf("read with no deadline") 364 } 365 return c.c.Read(p) 366 } 367 368 func (c *requireDeadlineNetConn) Close() error { return c.c.Close() } 369 func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() } 370 func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } 371 372 func TestHandshakeTimeout(t *testing.T) { 373 s := newServer(t) 374 defer s.Close() 375 376 d := cstDialer 377 d.NetDial = func(n, a string) (net.Conn, error) { 378 c, err := net.Dial(n, a) 379 return &requireDeadlineNetConn{c: c, t: t}, err 380 } 381 ws, _, err := d.Dial(s.URL, nil) 382 if err != nil { 383 t.Fatal("Dial:", err) 384 } 385 ws.Close() 386 } 387 388 func TestHandshakeTimeoutInContext(t *testing.T) { 389 s := newServer(t) 390 defer s.Close() 391 392 d := cstDialer 393 d.HandshakeTimeout = 0 394 d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { 395 netDialer := &net.Dialer{} 396 c, err := netDialer.DialContext(ctx, n, a) 397 return &requireDeadlineNetConn{c: c, t: t}, err 398 } 399 400 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) 401 defer cancel() 402 ws, _, err := d.DialContext(ctx, s.URL, nil) 403 if err != nil { 404 t.Fatal("Dial:", err) 405 } 406 ws.Close() 407 } 408 409 func TestDialBadScheme(t *testing.T) { 410 s := newServer(t) 411 defer s.Close() 412 413 ws, _, err := cstDialer.Dial(s.Server.URL, nil) 414 if err == nil { 415 ws.Close() 416 t.Fatalf("Dial: nil") 417 } 418 } 419 420 func TestDialBadOrigin(t *testing.T) { 421 s := newServer(t) 422 defer s.Close() 423 424 ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) 425 if err == nil { 426 ws.Close() 427 t.Fatalf("Dial: nil") 428 } 429 if resp == nil { 430 t.Fatalf("resp=nil, err=%v", err) 431 } 432 if resp.StatusCode != http.StatusForbidden { 433 t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden) 434 } 435 } 436 437 func TestDialBadHeader(t *testing.T) { 438 s := newServer(t) 439 defer s.Close() 440 441 for _, k := range []string{"Upgrade", 442 "Connection", 443 "Sec-Websocket-Key", 444 "Sec-Websocket-Version", 445 "Sec-Websocket-Protocol"} { 446 h := http.Header{} 447 h.Set(k, "bad") 448 ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) 449 if err == nil { 450 ws.Close() 451 t.Errorf("Dial with header %s returned nil", k) 452 } 453 } 454 } 455 456 func TestBadMethod(t *testing.T) { 457 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 458 ws, err := cstUpgrader.Upgrade(w, r, nil) 459 if err == nil { 460 t.Errorf("handshake succeeded, expect fail") 461 ws.Close() 462 } 463 })) 464 defer s.Close() 465 466 req, err := http.NewRequest("POST", s.URL, strings.NewReader("")) 467 if err != nil { 468 t.Fatalf("NewRequest returned error %v", err) 469 } 470 req.Header.Set("Connection", "upgrade") 471 req.Header.Set("Upgrade", "websocket") 472 req.Header.Set("Sec-Websocket-Version", "13") 473 474 resp, err := http.DefaultClient.Do(req) 475 if err != nil { 476 t.Fatalf("Do returned error %v", err) 477 } 478 resp.Body.Close() 479 if resp.StatusCode != http.StatusMethodNotAllowed { 480 t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) 481 } 482 } 483 484 func TestDialExtraTokensInRespHeaders(t *testing.T) { 485 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 486 challengeKey := r.Header.Get("Sec-Websocket-Key") 487 w.Header().Set("Upgrade", "foo, websocket") 488 w.Header().Set("Connection", "upgrade, keep-alive") 489 w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) 490 w.WriteHeader(101) 491 })) 492 defer s.Close() 493 494 ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) 495 if err != nil { 496 t.Fatalf("Dial: %v", err) 497 } 498 defer ws.Close() 499 } 500 501 func TestHandshake(t *testing.T) { 502 s := newServer(t) 503 defer s.Close() 504 505 ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}}) 506 if err != nil { 507 t.Fatalf("Dial: %v", err) 508 } 509 defer ws.Close() 510 511 var sessionID string 512 for _, c := range resp.Cookies() { 513 if c.Name == "sessionID" { 514 sessionID = c.Value 515 } 516 } 517 if sessionID != "1234" { 518 t.Error("Set-Cookie not received from the server.") 519 } 520 521 if ws.Subprotocol() != "p1" { 522 t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol()) 523 } 524 sendRecv(t, ws) 525 } 526 527 func TestRespOnBadHandshake(t *testing.T) { 528 const expectedStatus = http.StatusGone 529 const expectedBody = "This is the response body." 530 531 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 532 w.WriteHeader(expectedStatus) 533 io.WriteString(w, expectedBody) 534 })) 535 defer s.Close() 536 537 ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) 538 if err == nil { 539 ws.Close() 540 t.Fatalf("Dial: nil") 541 } 542 543 if resp == nil { 544 t.Fatalf("resp=nil, err=%v", err) 545 } 546 547 if resp.StatusCode != expectedStatus { 548 t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) 549 } 550 551 p, err := ioutil.ReadAll(resp.Body) 552 if err != nil { 553 t.Fatalf("ReadFull(resp.Body) returned error %v", err) 554 } 555 556 if string(p) != expectedBody { 557 t.Errorf("resp.Body=%s, want %s", p, expectedBody) 558 } 559 } 560 561 type testLogWriter struct { 562 t *testing.T 563 } 564 565 func (w testLogWriter) Write(p []byte) (int, error) { 566 w.t.Logf("%s", p) 567 return len(p), nil 568 } 569 570 // TestHost tests handling of host names and confirms that it matches net/http. 571 func TestHost(t *testing.T) { 572 573 upgrader := Upgrader{} 574 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 575 if IsWebSocketUpgrade(r) { 576 c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) 577 if err != nil { 578 t.Fatal(err) 579 } 580 c.Close() 581 } else { 582 w.Header().Set("X-Test-Host", r.Host) 583 } 584 }) 585 586 server := httptest.NewServer(handler) 587 defer server.Close() 588 589 tlsServer := httptest.NewTLSServer(handler) 590 defer tlsServer.Close() 591 592 addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()} 593 wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"} 594 httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"} 595 596 // Avoid log noise from net/http server by logging to testing.T 597 server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) 598 tlsServer.Config.ErrorLog = server.Config.ErrorLog 599 600 cas := rootCAs(t, tlsServer) 601 602 tests := []struct { 603 fail bool // true if dial / get should fail 604 server *httptest.Server // server to use 605 url string // host for request URI 606 header string // optional request host header 607 tls string // optional host for tls ServerName 608 wantAddr string // expected host for dial 609 wantHeader string // expected request header on server 610 insecureSkipVerify bool 611 }{ 612 { 613 server: server, 614 url: addrs[server], 615 wantAddr: addrs[server], 616 wantHeader: addrs[server], 617 }, 618 { 619 server: tlsServer, 620 url: addrs[tlsServer], 621 wantAddr: addrs[tlsServer], 622 wantHeader: addrs[tlsServer], 623 }, 624 625 { 626 server: server, 627 url: addrs[server], 628 header: "badhost.com", 629 wantAddr: addrs[server], 630 wantHeader: "badhost.com", 631 }, 632 { 633 server: tlsServer, 634 url: addrs[tlsServer], 635 header: "badhost.com", 636 wantAddr: addrs[tlsServer], 637 wantHeader: "badhost.com", 638 }, 639 640 { 641 server: server, 642 url: "example.com", 643 header: "badhost.com", 644 wantAddr: "example.com:80", 645 wantHeader: "badhost.com", 646 }, 647 { 648 server: tlsServer, 649 url: "example.com", 650 header: "badhost.com", 651 wantAddr: "example.com:443", 652 wantHeader: "badhost.com", 653 }, 654 655 { 656 server: server, 657 url: "badhost.com", 658 header: "example.com", 659 wantAddr: "badhost.com:80", 660 wantHeader: "example.com", 661 }, 662 { 663 fail: true, 664 server: tlsServer, 665 url: "badhost.com", 666 header: "example.com", 667 wantAddr: "badhost.com:443", 668 }, 669 { 670 server: tlsServer, 671 url: "badhost.com", 672 insecureSkipVerify: true, 673 wantAddr: "badhost.com:443", 674 wantHeader: "badhost.com", 675 }, 676 { 677 server: tlsServer, 678 url: "badhost.com", 679 tls: "example.com", 680 wantAddr: "badhost.com:443", 681 wantHeader: "badhost.com", 682 }, 683 } 684 685 for i, tt := range tests { 686 687 tls := &tls.Config{ 688 RootCAs: cas, 689 ServerName: tt.tls, 690 InsecureSkipVerify: tt.insecureSkipVerify, 691 } 692 693 var gotAddr string 694 dialer := Dialer{ 695 NetDial: func(network, addr string) (net.Conn, error) { 696 gotAddr = addr 697 return net.Dial(network, addrs[tt.server]) 698 }, 699 TLSClientConfig: tls, 700 } 701 702 // Test websocket dial 703 704 h := http.Header{} 705 if tt.header != "" { 706 h.Set("Host", tt.header) 707 } 708 c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h) 709 if err == nil { 710 c.Close() 711 } 712 713 check := func(protos map[*httptest.Server]string) { 714 name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls) 715 if gotAddr != tt.wantAddr { 716 t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr) 717 } 718 switch { 719 case tt.fail && err == nil: 720 t.Errorf("%s: unexpected success", name) 721 case !tt.fail && err != nil: 722 t.Errorf("%s: unexpected error %v", name, err) 723 case !tt.fail && err == nil: 724 if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader { 725 t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader) 726 } 727 } 728 } 729 730 check(wsProtos) 731 732 // Confirm that net/http has same result 733 734 transport := &http.Transport{ 735 Dial: dialer.NetDial, 736 TLSClientConfig: dialer.TLSClientConfig, 737 } 738 req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil) 739 if tt.header != "" { 740 req.Host = tt.header 741 } 742 client := &http.Client{Transport: transport} 743 resp, err = client.Do(req) 744 if err == nil { 745 resp.Body.Close() 746 } 747 transport.CloseIdleConnections() 748 check(httpProtos) 749 } 750 } 751 752 func TestDialCompression(t *testing.T) { 753 s := newServer(t) 754 defer s.Close() 755 756 dialer := cstDialer 757 dialer.EnableCompression = true 758 ws, _, err := dialer.Dial(s.URL, nil) 759 if err != nil { 760 t.Fatalf("Dial: %v", err) 761 } 762 defer ws.Close() 763 sendRecv(t, ws) 764 } 765 766 func TestSocksProxyDial(t *testing.T) { 767 s := newServer(t) 768 defer s.Close() 769 770 proxyListener, err := net.Listen("tcp", "127.0.0.1:0") 771 if err != nil { 772 t.Fatalf("listen failed: %v", err) 773 } 774 defer proxyListener.Close() 775 go func() { 776 c1, err := proxyListener.Accept() 777 if err != nil { 778 t.Errorf("proxy accept failed: %v", err) 779 return 780 } 781 defer c1.Close() 782 783 c1.SetDeadline(time.Now().Add(30 * time.Second)) 784 785 buf := make([]byte, 32) 786 if _, err := io.ReadFull(c1, buf[:3]); err != nil { 787 t.Errorf("read failed: %v", err) 788 return 789 } 790 if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { 791 t.Errorf("read %x, want %x", buf[:len(want)], want) 792 } 793 if _, err := c1.Write([]byte{5, 0}); err != nil { 794 t.Errorf("write failed: %v", err) 795 return 796 } 797 if _, err := io.ReadFull(c1, buf[:10]); err != nil { 798 t.Errorf("read failed: %v", err) 799 return 800 } 801 if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { 802 t.Errorf("read %x, want %x", buf[:len(want)], want) 803 return 804 } 805 buf[1] = 0 806 if _, err := c1.Write(buf[:10]); err != nil { 807 t.Errorf("write failed: %v", err) 808 return 809 } 810 811 ip := net.IP(buf[4:8]) 812 port := binary.BigEndian.Uint16(buf[8:10]) 813 814 c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) 815 if err != nil { 816 t.Errorf("dial failed; %v", err) 817 return 818 } 819 defer c2.Close() 820 done := make(chan struct{}) 821 go func() { 822 io.Copy(c1, c2) 823 close(done) 824 }() 825 io.Copy(c2, c1) 826 <-done 827 }() 828 829 purl, err := url.Parse("socks5://" + proxyListener.Addr().String()) 830 if err != nil { 831 t.Fatalf("parse failed: %v", err) 832 } 833 834 cstDialer := cstDialer // make local copy for modification on next line. 835 cstDialer.Proxy = http.ProxyURL(purl) 836 837 ws, _, err := cstDialer.Dial(s.URL, nil) 838 if err != nil { 839 t.Fatalf("Dial: %v", err) 840 } 841 defer ws.Close() 842 sendRecv(t, ws) 843 } 844 845 func TestTracingDialWithContext(t *testing.T) { 846 847 var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool 848 trace := &httptrace.ClientTrace{ 849 WroteHeaders: func() { 850 headersWrote = true 851 }, 852 WroteRequest: func(httptrace.WroteRequestInfo) { 853 requestWrote = true 854 }, 855 GetConn: func(hostPort string) { 856 getConn = true 857 }, 858 GotConn: func(info httptrace.GotConnInfo) { 859 gotConn = true 860 }, 861 ConnectDone: func(network, addr string, err error) { 862 connectDone = true 863 }, 864 GotFirstResponseByte: func() { 865 gotFirstResponseByte = true 866 }, 867 } 868 ctx := httptrace.WithClientTrace(context.Background(), trace) 869 870 s := newTLSServer(t) 871 defer s.Close() 872 873 d := cstDialer 874 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} 875 876 ws, _, err := d.DialContext(ctx, s.URL, nil) 877 if err != nil { 878 t.Fatalf("Dial: %v", err) 879 } 880 881 if !headersWrote { 882 t.Fatal("Headers was not written") 883 } 884 if !requestWrote { 885 t.Fatal("Request was not written") 886 } 887 if !getConn { 888 t.Fatal("getConn was not called") 889 } 890 if !gotConn { 891 t.Fatal("gotConn was not called") 892 } 893 if !connectDone { 894 t.Fatal("connectDone was not called") 895 } 896 if !gotFirstResponseByte { 897 t.Fatal("GotFirstResponseByte was not called") 898 } 899 900 defer ws.Close() 901 sendRecv(t, ws) 902 } 903 904 func TestEmptyTracingDialWithContext(t *testing.T) { 905 906 trace := &httptrace.ClientTrace{} 907 ctx := httptrace.WithClientTrace(context.Background(), trace) 908 909 s := newTLSServer(t) 910 defer s.Close() 911 912 d := cstDialer 913 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} 914 915 ws, _, err := d.DialContext(ctx, s.URL, nil) 916 if err != nil { 917 t.Fatalf("Dial: %v", err) 918 } 919 920 defer ws.Close() 921 sendRecv(t, ws) 922 }