github.com/google/martian/v3@v3.3.3/proxy_test.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package martian 16 17 import ( 18 "bufio" 19 "bytes" 20 "crypto/tls" 21 "crypto/x509" 22 "errors" 23 "fmt" 24 "io" 25 "io/ioutil" 26 "net" 27 "net/http" 28 "net/url" 29 "os" 30 "strings" 31 "testing" 32 "time" 33 34 "github.com/google/martian/v3/log" 35 "github.com/google/martian/v3/martiantest" 36 "github.com/google/martian/v3/mitm" 37 "github.com/google/martian/v3/proxyutil" 38 ) 39 40 type tempError struct{} 41 42 func (e *tempError) Error() string { return "temporary" } 43 func (e *tempError) Timeout() bool { return true } 44 func (e *tempError) Temporary() bool { return true } 45 46 type timeoutListener struct { 47 net.Listener 48 errCount int 49 err error 50 } 51 52 func newTimeoutListener(l net.Listener, errCount int) net.Listener { 53 return &timeoutListener{ 54 Listener: l, 55 errCount: errCount, 56 err: &tempError{}, 57 } 58 } 59 60 func (l *timeoutListener) Accept() (net.Conn, error) { 61 if l.errCount > 0 { 62 l.errCount-- 63 return nil, l.err 64 } 65 66 return l.Listener.Accept() 67 } 68 69 func TestIntegrationTemporaryTimeout(t *testing.T) { 70 t.Parallel() 71 72 l, err := net.Listen("tcp", "[::]:0") 73 if err != nil { 74 t.Fatalf("net.Listen(): got %v, want no error", err) 75 } 76 77 p := NewProxy() 78 defer p.Close() 79 80 tr := martiantest.NewTransport() 81 p.SetRoundTripper(tr) 82 p.SetTimeout(200 * time.Millisecond) 83 84 // Start the proxy with a listener that will return a temporary error on 85 // Accept() three times. 86 go p.Serve(newTimeoutListener(l, 3)) 87 88 conn, err := net.Dial("tcp", l.Addr().String()) 89 if err != nil { 90 t.Fatalf("net.Dial(): got %v, want no error", err) 91 } 92 defer conn.Close() 93 94 req, err := http.NewRequest("GET", "http://example.com", nil) 95 if err != nil { 96 t.Fatalf("http.NewRequest(): got %v, want no error", err) 97 } 98 req.Header.Set("Connection", "close") 99 100 // GET http://example.com/ HTTP/1.1 101 // Host: example.com 102 if err := req.WriteProxy(conn); err != nil { 103 t.Fatalf("req.WriteProxy(): got %v, want no error", err) 104 } 105 106 res, err := http.ReadResponse(bufio.NewReader(conn), req) 107 if err != nil { 108 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 109 } 110 defer res.Body.Close() 111 112 if got, want := res.StatusCode, 200; got != want { 113 t.Errorf("res.StatusCode: got %d, want %d", got, want) 114 } 115 } 116 117 func TestIntegrationHTTP(t *testing.T) { 118 t.Parallel() 119 120 l, err := net.Listen("tcp", "[::]:0") 121 if err != nil { 122 t.Fatalf("net.Listen(): got %v, want no error", err) 123 } 124 125 p := NewProxy() 126 defer p.Close() 127 128 p.SetRequestModifier(nil) 129 p.SetResponseModifier(nil) 130 131 tr := martiantest.NewTransport() 132 p.SetRoundTripper(tr) 133 p.SetTimeout(200 * time.Millisecond) 134 135 tm := martiantest.NewModifier() 136 137 tm.RequestFunc(func(req *http.Request) { 138 ctx := NewContext(req) 139 ctx.Set("martian.test", "true") 140 }) 141 142 tm.ResponseFunc(func(res *http.Response) { 143 ctx := NewContext(res.Request) 144 v, _ := ctx.Get("martian.test") 145 146 res.Header.Set("Martian-Test", v.(string)) 147 }) 148 149 p.SetRequestModifier(tm) 150 p.SetResponseModifier(tm) 151 152 go p.Serve(l) 153 154 conn, err := net.Dial("tcp", l.Addr().String()) 155 if err != nil { 156 t.Fatalf("net.Dial(): got %v, want no error", err) 157 } 158 defer conn.Close() 159 160 req, err := http.NewRequest("GET", "http://example.com", nil) 161 if err != nil { 162 t.Fatalf("http.NewRequest(): got %v, want no error", err) 163 } 164 165 // GET http://example.com/ HTTP/1.1 166 // Host: example.com 167 if err := req.WriteProxy(conn); err != nil { 168 t.Fatalf("req.WriteProxy(): got %v, want no error", err) 169 } 170 171 res, err := http.ReadResponse(bufio.NewReader(conn), req) 172 if err != nil { 173 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 174 } 175 176 if got, want := res.StatusCode, 200; got != want { 177 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 178 } 179 180 if got, want := res.Header.Get("Martian-Test"), "true"; got != want { 181 t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) 182 } 183 } 184 185 func TestIntegrationHTTP100Continue(t *testing.T) { 186 t.Parallel() 187 188 l, err := net.Listen("tcp", "[::]:0") 189 if err != nil { 190 t.Fatalf("net.Listen(): got %v, want no error", err) 191 } 192 193 p := NewProxy() 194 defer p.Close() 195 196 p.SetTimeout(2 * time.Second) 197 198 sl, err := net.Listen("tcp", "[::]:0") 199 if err != nil { 200 t.Fatalf("net.Listen(): got %v, want no error", err) 201 } 202 203 go func() { 204 conn, err := sl.Accept() 205 if err != nil { 206 log.Errorf("proxy_test: failed to accept connection: %v", err) 207 return 208 } 209 defer conn.Close() 210 211 log.Infof("proxy_test: accepted connection: %s", conn.RemoteAddr()) 212 213 req, err := http.ReadRequest(bufio.NewReader(conn)) 214 if err != nil { 215 log.Errorf("proxy_test: failed to read request: %v", err) 216 return 217 } 218 219 if req.Header.Get("Expect") == "100-continue" { 220 log.Infof("proxy_test: received 100-continue request") 221 222 conn.Write([]byte("HTTP/1.1 100 Continue\r\n\r\n")) 223 224 log.Infof("proxy_test: sent 100-continue response") 225 } else { 226 log.Infof("proxy_test: received non 100-continue request") 227 228 res := proxyutil.NewResponse(417, nil, req) 229 res.Header.Set("Connection", "close") 230 res.Write(conn) 231 return 232 } 233 234 res := proxyutil.NewResponse(200, req.Body, req) 235 res.Header.Set("Connection", "close") 236 res.Write(conn) 237 238 log.Infof("proxy_test: sent 200 response") 239 }() 240 241 tm := martiantest.NewModifier() 242 p.SetRequestModifier(tm) 243 p.SetResponseModifier(tm) 244 245 go p.Serve(l) 246 247 conn, err := net.Dial("tcp", l.Addr().String()) 248 if err != nil { 249 t.Fatalf("net.Dial(): got %v, want no error", err) 250 } 251 defer conn.Close() 252 253 host := sl.Addr().String() 254 raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+ 255 "Host: %s\r\n"+ 256 "Content-Length: 12\r\n"+ 257 "Expect: 100-continue\r\n\r\n", host, host) 258 259 if _, err := conn.Write([]byte(raw)); err != nil { 260 t.Fatalf("conn.Write(headers): got %v, want no error", err) 261 } 262 263 go func() { 264 select { 265 case <-time.After(time.Second): 266 conn.Write([]byte("body content")) 267 } 268 }() 269 270 res, err := http.ReadResponse(bufio.NewReader(conn), nil) 271 if err != nil { 272 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 273 } 274 defer res.Body.Close() 275 276 if got, want := res.StatusCode, 200; got != want { 277 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 278 } 279 280 got, err := ioutil.ReadAll(res.Body) 281 if err != nil { 282 t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) 283 } 284 285 if want := []byte("body content"); !bytes.Equal(got, want) { 286 t.Errorf("res.Body: got %q, want %q", got, want) 287 } 288 289 if !tm.RequestModified() { 290 t.Error("tm.RequestModified(): got false, want true") 291 } 292 if !tm.ResponseModified() { 293 t.Error("tm.ResponseModified(): got false, want true") 294 } 295 } 296 297 func TestIntegrationHTTPDownstreamProxy(t *testing.T) { 298 t.Parallel() 299 300 // Start first proxy to use as downstream. 301 dl, err := net.Listen("tcp", "[::]:0") 302 if err != nil { 303 t.Fatalf("net.Listen(): got %v, want no error", err) 304 } 305 306 downstream := NewProxy() 307 defer downstream.Close() 308 309 dtr := martiantest.NewTransport() 310 dtr.Respond(299) 311 downstream.SetRoundTripper(dtr) 312 downstream.SetTimeout(600 * time.Millisecond) 313 314 go downstream.Serve(dl) 315 316 // Start second proxy as upstream proxy, will write to downstream proxy. 317 ul, err := net.Listen("tcp", "[::]:0") 318 if err != nil { 319 t.Fatalf("net.Listen(): got %v, want no error", err) 320 } 321 322 upstream := NewProxy() 323 defer upstream.Close() 324 325 // Set upstream proxy's downstream proxy to the host:port of the first proxy. 326 upstream.SetDownstreamProxy(&url.URL{ 327 Host: dl.Addr().String(), 328 }) 329 upstream.SetTimeout(600 * time.Millisecond) 330 331 go upstream.Serve(ul) 332 333 // Open connection to upstream proxy. 334 conn, err := net.Dial("tcp", ul.Addr().String()) 335 if err != nil { 336 t.Fatalf("net.Dial(): got %v, want no error", err) 337 } 338 defer conn.Close() 339 340 req, err := http.NewRequest("GET", "http://example.com", nil) 341 if err != nil { 342 t.Fatalf("http.NewRequest(): got %v, want no error", err) 343 } 344 345 // GET http://example.com/ HTTP/1.1 346 // Host: example.com 347 if err := req.WriteProxy(conn); err != nil { 348 t.Fatalf("req.WriteProxy(): got %v, want no error", err) 349 } 350 351 // Response from downstream proxy. 352 res, err := http.ReadResponse(bufio.NewReader(conn), req) 353 if err != nil { 354 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 355 } 356 357 if got, want := res.StatusCode, 299; got != want { 358 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 359 } 360 } 361 362 func TestIntegrationHTTPDownstreamProxyError(t *testing.T) { 363 t.Parallel() 364 365 l, err := net.Listen("tcp", "[::]:0") 366 if err != nil { 367 t.Fatalf("net.Listen(): got %v, want no error", err) 368 } 369 370 p := NewProxy() 371 defer p.Close() 372 373 // Set proxy's downstream proxy to invalid host:port to force failure. 374 p.SetDownstreamProxy(&url.URL{ 375 Host: "[::]:0", 376 }) 377 p.SetTimeout(600 * time.Millisecond) 378 379 tm := martiantest.NewModifier() 380 reserr := errors.New("response error") 381 tm.ResponseError(reserr) 382 383 p.SetResponseModifier(tm) 384 385 go p.Serve(l) 386 387 // Open connection to upstream proxy. 388 conn, err := net.Dial("tcp", l.Addr().String()) 389 if err != nil { 390 t.Fatalf("net.Dial(): got %v, want no error", err) 391 } 392 defer conn.Close() 393 394 req, err := http.NewRequest("CONNECT", "//example.com:443", nil) 395 if err != nil { 396 t.Fatalf("http.NewRequest(): got %v, want no error", err) 397 } 398 399 // CONNECT example.com:443 HTTP/1.1 400 // Host: example.com 401 if err := req.Write(conn); err != nil { 402 t.Fatalf("req.Write(): got %v, want no error", err) 403 } 404 405 // Response from upstream proxy, assuming downstream proxy failed to CONNECT. 406 res, err := http.ReadResponse(bufio.NewReader(conn), req) 407 if err != nil { 408 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 409 } 410 411 if got, want := res.StatusCode, 502; got != want { 412 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 413 } 414 if got, want := res.Header["Warning"][1], reserr.Error(); !strings.Contains(got, want) { 415 t.Errorf("res.Header.get(%q): got %q, want to contain %q", "Warning", got, want) 416 } 417 } 418 419 func TestIntegrationTLSHandshakeErrorCallback(t *testing.T) { 420 t.Parallel() 421 422 l, err := net.Listen("tcp", "[::]:0") 423 if err != nil { 424 t.Fatalf("net.Listen(): got %v, want no error", err) 425 } 426 427 p := NewProxy() 428 defer p.Close() 429 430 // Test TLS server. 431 ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour) 432 if err != nil { 433 t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) 434 } 435 mc, err := mitm.NewConfig(ca, priv) 436 if err != nil { 437 t.Fatalf("mitm.NewConfig(): got %v, want no error", err) 438 } 439 440 var herr error 441 mc.SetHandshakeErrorCallback(func(_ *http.Request, err error) { herr = fmt.Errorf("handshake error") }) 442 p.SetMITM(mc) 443 444 tl, err := net.Listen("tcp", "[::]:0") 445 if err != nil { 446 t.Fatalf("tls.Listen(): got %v, want no error", err) 447 } 448 tl = tls.NewListener(tl, mc.TLS()) 449 450 go http.Serve(tl, http.HandlerFunc( 451 func(rw http.ResponseWriter, req *http.Request) { 452 rw.WriteHeader(200) 453 })) 454 455 tm := martiantest.NewModifier() 456 457 // Force the CONNECT request to dial the local TLS server. 458 tm.RequestFunc(func(req *http.Request) { 459 req.URL.Host = tl.Addr().String() 460 }) 461 462 go p.Serve(l) 463 464 conn, err := net.Dial("tcp", l.Addr().String()) 465 if err != nil { 466 t.Fatalf("net.Dial(): got %v, want no error", err) 467 } 468 defer conn.Close() 469 470 req, err := http.NewRequest("CONNECT", "//example.com:443", nil) 471 if err != nil { 472 t.Fatalf("http.NewRequest(): got %v, want no error", err) 473 } 474 475 // CONNECT example.com:443 HTTP/1.1 476 // Host: example.com 477 // 478 // Rewritten to CONNECT to host:port in CONNECT request modifier. 479 if err := req.Write(conn); err != nil { 480 t.Fatalf("req.Write(): got %v, want no error", err) 481 } 482 483 // CONNECT response after establishing tunnel. 484 if _, err := http.ReadResponse(bufio.NewReader(conn), req); err != nil { 485 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 486 } 487 488 tlsconn := tls.Client(conn, &tls.Config{ 489 ServerName: "example.com", 490 // Client has no cert so it will get "x509: certificate signed by unknown authority" from the 491 // handshake and send "remote error: bad certificate" to the server. 492 RootCAs: x509.NewCertPool(), 493 }) 494 defer tlsconn.Close() 495 496 req, err = http.NewRequest("GET", "https://example.com", nil) 497 if err != nil { 498 t.Fatalf("http.NewRequest(): got %v, want no error", err) 499 } 500 req.Header.Set("Connection", "close") 501 502 if got, want := req.Write(tlsconn), "x509: certificate signed by unknown authority"; !strings.Contains(got.Error(), want) { 503 t.Fatalf("Got incorrect error from Client Handshake(), got: %v, want: %v", got, want) 504 } 505 506 // TODO: herr is not being asserted against. It should be pushed on to a channel 507 // of err, and the assertion should pull off of it and assert. That design resulted in the test 508 // hanging for unknown reasons. 509 t.Skip("skipping assertion of handshake error callback error due to mysterious deadlock") 510 if got, want := herr, "remote error: bad certificate"; !strings.Contains(got.Error(), want) { 511 t.Fatalf("Got incorrect error from Server Handshake(), got: %v, want: %v", got, want) 512 } 513 } 514 515 func TestIntegrationConnect(t *testing.T) { 516 t.Parallel() 517 518 l, err := net.Listen("tcp", "[::]:0") 519 if err != nil { 520 t.Fatalf("net.Listen(): got %v, want no error", err) 521 } 522 523 p := NewProxy() 524 defer p.Close() 525 526 // Test TLS server. 527 ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour) 528 if err != nil { 529 t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) 530 } 531 mc, err := mitm.NewConfig(ca, priv) 532 if err != nil { 533 t.Fatalf("mitm.NewConfig(): got %v, want no error", err) 534 } 535 536 tl, err := net.Listen("tcp", "[::]:0") 537 if err != nil { 538 t.Fatalf("tls.Listen(): got %v, want no error", err) 539 } 540 tl = tls.NewListener(tl, mc.TLS()) 541 542 go http.Serve(tl, http.HandlerFunc( 543 func(rw http.ResponseWriter, req *http.Request) { 544 rw.WriteHeader(299) 545 })) 546 547 tm := martiantest.NewModifier() 548 reqerr := errors.New("request error") 549 reserr := errors.New("response error") 550 551 // Force the CONNECT request to dial the local TLS server. 552 tm.RequestFunc(func(req *http.Request) { 553 req.URL.Host = tl.Addr().String() 554 }) 555 556 tm.RequestError(reqerr) 557 tm.ResponseError(reserr) 558 559 p.SetRequestModifier(tm) 560 p.SetResponseModifier(tm) 561 562 go p.Serve(l) 563 564 conn, err := net.Dial("tcp", l.Addr().String()) 565 if err != nil { 566 t.Fatalf("net.Dial(): got %v, want no error", err) 567 } 568 defer conn.Close() 569 570 req, err := http.NewRequest("CONNECT", "//example.com:443", nil) 571 if err != nil { 572 t.Fatalf("http.NewRequest(): got %v, want no error", err) 573 } 574 575 // CONNECT example.com:443 HTTP/1.1 576 // Host: example.com 577 // 578 // Rewritten to CONNECT to host:port in CONNECT request modifier. 579 if err := req.Write(conn); err != nil { 580 t.Fatalf("req.Write(): got %v, want no error", err) 581 } 582 583 // CONNECT response after establishing tunnel. 584 res, err := http.ReadResponse(bufio.NewReader(conn), req) 585 if err != nil { 586 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 587 } 588 589 if got, want := res.StatusCode, 200; got != want { 590 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 591 } 592 593 if !tm.RequestModified() { 594 t.Error("tm.RequestModified(): got false, want true") 595 } 596 if !tm.ResponseModified() { 597 t.Error("tm.ResponseModified(): got false, want true") 598 } 599 if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { 600 t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) 601 } 602 603 roots := x509.NewCertPool() 604 roots.AddCert(ca) 605 606 tlsconn := tls.Client(conn, &tls.Config{ 607 ServerName: "example.com", 608 RootCAs: roots, 609 }) 610 defer tlsconn.Close() 611 612 req, err = http.NewRequest("GET", "https://example.com", nil) 613 if err != nil { 614 t.Fatalf("http.NewRequest(): got %v, want no error", err) 615 } 616 req.Header.Set("Connection", "close") 617 618 // GET / HTTP/1.1 619 // Host: example.com 620 // Connection: close 621 if err := req.Write(tlsconn); err != nil { 622 t.Fatalf("req.Write(): got %v, want no error", err) 623 } 624 625 res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) 626 if err != nil { 627 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 628 } 629 defer res.Body.Close() 630 631 if got, want := res.StatusCode, 299; got != want { 632 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 633 } 634 if got, want := res.Header.Get("Warning"), reserr.Error(); strings.Contains(got, want) { 635 t.Errorf("res.Header.Get(%q): got %s, want to not contain %s", "Warning", got, want) 636 } 637 } 638 639 func TestIntegrationConnectDownstreamProxy(t *testing.T) { 640 t.Parallel() 641 642 // Start first proxy to use as downstream. 643 dl, err := net.Listen("tcp", "[::]:0") 644 if err != nil { 645 t.Fatalf("net.Listen(): got %v, want no error", err) 646 } 647 648 downstream := NewProxy() 649 defer downstream.Close() 650 651 dtr := martiantest.NewTransport() 652 dtr.Respond(299) 653 downstream.SetRoundTripper(dtr) 654 655 ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) 656 if err != nil { 657 t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) 658 } 659 660 mc, err := mitm.NewConfig(ca, priv) 661 if err != nil { 662 t.Fatalf("mitm.NewConfig(): got %v, want no error", err) 663 } 664 downstream.SetMITM(mc) 665 666 go downstream.Serve(dl) 667 668 // Start second proxy as upstream proxy, will CONNECT to downstream proxy. 669 ul, err := net.Listen("tcp", "[::]:0") 670 if err != nil { 671 t.Fatalf("net.Listen(): got %v, want no error", err) 672 } 673 674 upstream := NewProxy() 675 defer upstream.Close() 676 677 // Set upstream proxy's downstream proxy to the host:port of the first proxy. 678 upstream.SetDownstreamProxy(&url.URL{ 679 Host: dl.Addr().String(), 680 }) 681 682 go upstream.Serve(ul) 683 684 // Open connection to upstream proxy. 685 conn, err := net.Dial("tcp", ul.Addr().String()) 686 if err != nil { 687 t.Fatalf("net.Dial(): got %v, want no error", err) 688 } 689 defer conn.Close() 690 691 req, err := http.NewRequest("CONNECT", "//example.com:443", nil) 692 if err != nil { 693 t.Fatalf("http.NewRequest(): got %v, want no error", err) 694 } 695 696 // CONNECT example.com:443 HTTP/1.1 697 // Host: example.com 698 if err := req.Write(conn); err != nil { 699 t.Fatalf("req.Write(): got %v, want no error", err) 700 } 701 702 // Response from downstream proxy starting MITM. 703 res, err := http.ReadResponse(bufio.NewReader(conn), req) 704 if err != nil { 705 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 706 } 707 708 if got, want := res.StatusCode, 200; got != want { 709 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 710 } 711 712 roots := x509.NewCertPool() 713 roots.AddCert(ca) 714 715 tlsconn := tls.Client(conn, &tls.Config{ 716 // Validate the hostname. 717 ServerName: "example.com", 718 // The certificate will have been MITM'd, verify using the MITM CA 719 // certificate. 720 RootCAs: roots, 721 }) 722 defer tlsconn.Close() 723 724 req, err = http.NewRequest("GET", "https://example.com", nil) 725 if err != nil { 726 t.Fatalf("http.NewRequest(): got %v, want no error", err) 727 } 728 729 // GET / HTTP/1.1 730 // Host: example.com 731 if err := req.Write(tlsconn); err != nil { 732 t.Fatalf("req.Write(): got %v, want no error", err) 733 } 734 735 // Response from MITM in downstream proxy. 736 res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) 737 if err != nil { 738 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 739 } 740 defer res.Body.Close() 741 742 if got, want := res.StatusCode, 299; got != want { 743 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 744 } 745 } 746 747 func TestIntegrationMITM(t *testing.T) { 748 t.Parallel() 749 750 l, err := net.Listen("tcp", "[::]:0") 751 if err != nil { 752 t.Fatalf("net.Listen(): got %v, want no error", err) 753 } 754 755 p := NewProxy() 756 defer p.Close() 757 758 tr := martiantest.NewTransport() 759 tr.Func(func(req *http.Request) (*http.Response, error) { 760 res := proxyutil.NewResponse(200, nil, req) 761 res.Header.Set("Request-Scheme", req.URL.Scheme) 762 763 return res, nil 764 }) 765 766 p.SetRoundTripper(tr) 767 p.SetTimeout(600 * time.Millisecond) 768 769 ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) 770 if err != nil { 771 t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) 772 } 773 774 mc, err := mitm.NewConfig(ca, priv) 775 if err != nil { 776 t.Fatalf("mitm.NewConfig(): got %v, want no error", err) 777 } 778 p.SetMITM(mc) 779 780 tm := martiantest.NewModifier() 781 reqerr := errors.New("request error") 782 reserr := errors.New("response error") 783 tm.RequestError(reqerr) 784 tm.ResponseError(reserr) 785 786 p.SetRequestModifier(tm) 787 p.SetResponseModifier(tm) 788 789 go p.Serve(l) 790 791 conn, err := net.Dial("tcp", l.Addr().String()) 792 if err != nil { 793 t.Fatalf("net.Dial(): got %v, want no error", err) 794 } 795 defer conn.Close() 796 797 req, err := http.NewRequest("CONNECT", "//example.com:443", nil) 798 if err != nil { 799 t.Fatalf("http.NewRequest(): got %v, want no error", err) 800 } 801 802 // CONNECT example.com:443 HTTP/1.1 803 // Host: example.com 804 if err := req.Write(conn); err != nil { 805 t.Fatalf("req.Write(): got %v, want no error", err) 806 } 807 808 // Response MITM'd from proxy. 809 res, err := http.ReadResponse(bufio.NewReader(conn), req) 810 if err != nil { 811 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 812 } 813 if got, want := res.StatusCode, 200; got != want { 814 815 t.Errorf("res.StatusCode: got %d, want %d", got, want) 816 } 817 if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { 818 t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) 819 } 820 821 roots := x509.NewCertPool() 822 roots.AddCert(ca) 823 824 tlsconn := tls.Client(conn, &tls.Config{ 825 ServerName: "example.com", 826 RootCAs: roots, 827 }) 828 defer tlsconn.Close() 829 830 req, err = http.NewRequest("GET", "https://example.com", nil) 831 if err != nil { 832 t.Fatalf("http.NewRequest(): got %v, want no error", err) 833 } 834 835 // GET / HTTP/1.1 836 // Host: example.com 837 if err := req.Write(tlsconn); err != nil { 838 t.Fatalf("req.Write(): got %v, want no error", err) 839 } 840 841 // Response from MITM proxy. 842 res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) 843 if err != nil { 844 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 845 } 846 defer res.Body.Close() 847 848 if got, want := res.StatusCode, 200; got != want { 849 t.Errorf("res.StatusCode: got %d, want %d", got, want) 850 } 851 if got, want := res.Header.Get("Request-Scheme"), "https"; got != want { 852 t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want) 853 } 854 if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { 855 t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) 856 } 857 } 858 859 func TestIntegrationTransparentHTTP(t *testing.T) { 860 t.Parallel() 861 862 l, err := net.Listen("tcp", "[::]:0") 863 if err != nil { 864 t.Fatalf("net.Listen(): got %v, want no error", err) 865 } 866 867 p := NewProxy() 868 defer p.Close() 869 870 tr := martiantest.NewTransport() 871 p.SetRoundTripper(tr) 872 873 if got, want := p.GetRoundTripper(), tr; got != want { 874 t.Errorf("proxy.GetRoundTripper: got %v, want %v", got, want) 875 } 876 877 p.SetTimeout(200 * time.Millisecond) 878 879 tm := martiantest.NewModifier() 880 p.SetRequestModifier(tm) 881 p.SetResponseModifier(tm) 882 883 go p.Serve(l) 884 885 conn, err := net.Dial("tcp", l.Addr().String()) 886 if err != nil { 887 t.Fatalf("net.Dial(): got %v, want no error", err) 888 } 889 defer conn.Close() 890 891 req, err := http.NewRequest("GET", "http://example.com", nil) 892 if err != nil { 893 t.Fatalf("http.NewRequest(): got %v, want no error", err) 894 } 895 896 // GET / HTTP/1.1 897 // Host: www.example.com 898 if err := req.Write(conn); err != nil { 899 t.Fatalf("req.Write(): got %v, want no error", err) 900 } 901 902 res, err := http.ReadResponse(bufio.NewReader(conn), req) 903 if err != nil { 904 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 905 } 906 907 if got, want := res.StatusCode, 200; got != want { 908 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 909 } 910 911 if !tm.RequestModified() { 912 t.Error("tm.RequestModified(): got false, want true") 913 } 914 if !tm.ResponseModified() { 915 t.Error("tm.ResponseModified(): got false, want true") 916 } 917 } 918 919 func TestIntegrationTransparentMITM(t *testing.T) { 920 t.Parallel() 921 922 ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) 923 if err != nil { 924 t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) 925 } 926 927 mc, err := mitm.NewConfig(ca, priv) 928 if err != nil { 929 t.Fatalf("mitm.NewConfig(): got %v, want no error", err) 930 } 931 932 // Start TLS listener with config that will generate certificates based on 933 // SNI from connection. 934 // 935 // BUG: tls.Listen will not accept a tls.Config where Certificates is empty, 936 // even though it is supported by tls.Server when GetCertificate is not nil. 937 l, err := net.Listen("tcp", "[::]:0") 938 if err != nil { 939 t.Fatalf("net.Listen(): got %v, want no error", err) 940 } 941 l = tls.NewListener(l, mc.TLS()) 942 943 p := NewProxy() 944 defer p.Close() 945 946 tr := martiantest.NewTransport() 947 tr.Func(func(req *http.Request) (*http.Response, error) { 948 res := proxyutil.NewResponse(200, nil, req) 949 res.Header.Set("Request-Scheme", req.URL.Scheme) 950 951 return res, nil 952 }) 953 954 p.SetRoundTripper(tr) 955 956 tm := martiantest.NewModifier() 957 p.SetRequestModifier(tm) 958 p.SetResponseModifier(tm) 959 960 go p.Serve(l) 961 962 roots := x509.NewCertPool() 963 roots.AddCert(ca) 964 965 tlsconn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{ 966 // Verify the hostname is example.com. 967 ServerName: "example.com", 968 // The certificate will have been generated during MITM, so we need to 969 // verify it with the generated CA certificate. 970 RootCAs: roots, 971 }) 972 if err != nil { 973 t.Fatalf("tls.Dial(): got %v, want no error", err) 974 } 975 defer tlsconn.Close() 976 977 req, err := http.NewRequest("GET", "https://example.com", nil) 978 if err != nil { 979 t.Fatalf("http.NewRequest(): got %v, want no error", err) 980 } 981 982 // Write Encrypted request directly, no CONNECT. 983 // GET / HTTP/1.1 984 // Host: example.com 985 if err := req.Write(tlsconn); err != nil { 986 t.Fatalf("req.Write(): got %v, want no error", err) 987 } 988 989 res, err := http.ReadResponse(bufio.NewReader(tlsconn), req) 990 if err != nil { 991 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 992 } 993 defer res.Body.Close() 994 995 if got, want := res.StatusCode, 200; got != want { 996 t.Fatalf("res.StatusCode: got %d, want %d", got, want) 997 } 998 if got, want := res.Header.Get("Request-Scheme"), "https"; got != want { 999 t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want) 1000 } 1001 1002 if !tm.RequestModified() { 1003 t.Errorf("tm.RequestModified(): got false, want true") 1004 } 1005 if !tm.ResponseModified() { 1006 t.Errorf("tm.ResponseModified(): got false, want true") 1007 } 1008 } 1009 1010 func TestIntegrationFailedRoundTrip(t *testing.T) { 1011 t.Parallel() 1012 1013 l, err := net.Listen("tcp", "[::]:0") 1014 if err != nil { 1015 t.Fatalf("net.Listen(): got %v, want no error", err) 1016 } 1017 1018 p := NewProxy() 1019 defer p.Close() 1020 1021 tr := martiantest.NewTransport() 1022 trerr := errors.New("round trip error") 1023 tr.RespondError(trerr) 1024 p.SetRoundTripper(tr) 1025 p.SetTimeout(200 * time.Millisecond) 1026 1027 go p.Serve(l) 1028 1029 conn, err := net.Dial("tcp", l.Addr().String()) 1030 if err != nil { 1031 t.Fatalf("net.Dial(): got %v, want no error", err) 1032 } 1033 defer conn.Close() 1034 1035 req, err := http.NewRequest("GET", "http://example.com", nil) 1036 if err != nil { 1037 t.Fatalf("http.NewRequest(): got %v, want no error", err) 1038 } 1039 1040 // GET http://example.com/ HTTP/1.1 1041 // Host: example.com 1042 if err := req.WriteProxy(conn); err != nil { 1043 t.Fatalf("req.WriteProxy(): got %v, want no error", err) 1044 } 1045 1046 // Response from failed round trip. 1047 res, err := http.ReadResponse(bufio.NewReader(conn), req) 1048 if err != nil { 1049 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 1050 } 1051 defer res.Body.Close() 1052 1053 if got, want := res.StatusCode, 502; got != want { 1054 t.Errorf("res.StatusCode: got %d, want %d", got, want) 1055 } 1056 1057 if got, want := res.Header.Get("Warning"), trerr.Error(); !strings.Contains(got, want) { 1058 t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) 1059 } 1060 } 1061 1062 func TestIntegrationSkipRoundTrip(t *testing.T) { 1063 t.Parallel() 1064 1065 l, err := net.Listen("tcp", "[::]:0") 1066 if err != nil { 1067 t.Fatalf("net.Listen(): got %v, want no error", err) 1068 } 1069 1070 p := NewProxy() 1071 defer p.Close() 1072 1073 // Transport will be skipped, no 500. 1074 tr := martiantest.NewTransport() 1075 tr.Respond(500) 1076 p.SetRoundTripper(tr) 1077 p.SetTimeout(200 * time.Millisecond) 1078 1079 tm := martiantest.NewModifier() 1080 tm.RequestFunc(func(req *http.Request) { 1081 ctx := NewContext(req) 1082 ctx.SkipRoundTrip() 1083 }) 1084 p.SetRequestModifier(tm) 1085 1086 go p.Serve(l) 1087 1088 conn, err := net.Dial("tcp", l.Addr().String()) 1089 if err != nil { 1090 t.Fatalf("net.Dial(): got %v, want no error", err) 1091 } 1092 defer conn.Close() 1093 1094 req, err := http.NewRequest("GET", "http://example.com", nil) 1095 if err != nil { 1096 t.Fatalf("http.NewRequest(): got %v, want no error", err) 1097 } 1098 1099 // GET http://example.com/ HTTP/1.1 1100 // Host: example.com 1101 if err := req.WriteProxy(conn); err != nil { 1102 t.Fatalf("req.WriteProxy(): got %v, want no error", err) 1103 } 1104 1105 // Response from skipped round trip. 1106 res, err := http.ReadResponse(bufio.NewReader(conn), req) 1107 if err != nil { 1108 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 1109 } 1110 defer res.Body.Close() 1111 1112 if got, want := res.StatusCode, 200; got != want { 1113 t.Errorf("res.StatusCode: got %d, want %d", got, want) 1114 } 1115 } 1116 1117 func TestHTTPThroughConnectWithMITM(t *testing.T) { 1118 t.Parallel() 1119 1120 l, err := net.Listen("tcp", "[::]:0") 1121 if err != nil { 1122 t.Fatalf("net.Listen(): got %v, want no error", err) 1123 } 1124 1125 p := NewProxy() 1126 defer p.Close() 1127 1128 tm := martiantest.NewModifier() 1129 tm.RequestFunc(func(req *http.Request) { 1130 ctx := NewContext(req) 1131 ctx.SkipRoundTrip() 1132 1133 if req.Method != "GET" && req.Method != "CONNECT" { 1134 t.Errorf("unexpected method on request handler: %v", req.Method) 1135 } 1136 }) 1137 p.SetRequestModifier(tm) 1138 1139 ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) 1140 if err != nil { 1141 t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) 1142 } 1143 1144 mc, err := mitm.NewConfig(ca, priv) 1145 if err != nil { 1146 t.Fatalf("mitm.NewConfig(): got %v, want no error", err) 1147 } 1148 p.SetMITM(mc) 1149 1150 go p.Serve(l) 1151 1152 conn, err := net.Dial("tcp", l.Addr().String()) 1153 if err != nil { 1154 t.Fatalf("net.Dial(): got %v, want no error", err) 1155 } 1156 defer conn.Close() 1157 1158 req, err := http.NewRequest("CONNECT", "//example.com:80", nil) 1159 if err != nil { 1160 t.Fatalf("http.NewRequest(): got %v, want no error", err) 1161 } 1162 1163 // CONNECT example.com:80 HTTP/1.1 1164 // Host: example.com 1165 if err := req.Write(conn); err != nil { 1166 t.Fatalf("req.Write(): got %v, want no error", err) 1167 } 1168 1169 // Response skipped round trip. 1170 res, err := http.ReadResponse(bufio.NewReader(conn), req) 1171 if err != nil { 1172 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 1173 } 1174 res.Body.Close() 1175 1176 if got, want := res.StatusCode, 200; got != want { 1177 t.Errorf("res.StatusCode: got %d, want %d", got, want) 1178 } 1179 1180 req, err = http.NewRequest("GET", "http://example.com", nil) 1181 if err != nil { 1182 t.Fatalf("http.NewRequest(): got %v, want no error", err) 1183 } 1184 1185 // GET http://example.com/ HTTP/1.1 1186 // Host: example.com 1187 if err := req.WriteProxy(conn); err != nil { 1188 t.Fatalf("req.WriteProxy(): got %v, want no error", err) 1189 } 1190 1191 // Response from skipped round trip. 1192 res, err = http.ReadResponse(bufio.NewReader(conn), req) 1193 if err != nil { 1194 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 1195 } 1196 res.Body.Close() 1197 1198 if got, want := res.StatusCode, 200; got != want { 1199 t.Errorf("res.StatusCode: got %d, want %d", got, want) 1200 } 1201 1202 req, err = http.NewRequest("GET", "http://example.com", nil) 1203 if err != nil { 1204 t.Fatalf("http.NewRequest(): got %v, want no error", err) 1205 } 1206 1207 // GET http://example.com/ HTTP/1.1 1208 // Host: example.com 1209 if err := req.WriteProxy(conn); err != nil { 1210 t.Fatalf("req.WriteProxy(): got %v, want no error", err) 1211 } 1212 1213 // Response from skipped round trip. 1214 res, err = http.ReadResponse(bufio.NewReader(conn), req) 1215 if err != nil { 1216 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 1217 } 1218 res.Body.Close() 1219 1220 if got, want := res.StatusCode, 200; got != want { 1221 t.Errorf("res.StatusCode: got %d, want %d", got, want) 1222 } 1223 } 1224 1225 func TestServerClosesConnection(t *testing.T) { 1226 t.Parallel() 1227 1228 dstl, err := net.Listen("tcp", "[::]:0") 1229 if err != nil { 1230 t.Fatalf("Failed to create http listener: %v", err) 1231 } 1232 defer dstl.Close() 1233 1234 go func() { 1235 t.Logf("Waiting for server side connection") 1236 conn, err := dstl.Accept() 1237 if err != nil { 1238 t.Fatalf("Got error while accepting connection on destination listener: %v", err) 1239 } 1240 t.Logf("Accepted server side connection") 1241 1242 buf := make([]byte, 16384) 1243 if _, err := conn.Read(buf); err != nil { 1244 t.Fatalf("Error reading: %v", err) 1245 } 1246 1247 _, err = conn.Write([]byte("HTTP/1.1 301 MOVED PERMANENTLY\r\n" + 1248 "Server: \r\n" + 1249 "Date: \r\n" + 1250 "Referer: \r\n" + 1251 "Location: http://www.foo.com/\r\n" + 1252 "Content-type: text/html\r\n" + 1253 "Connection: close\r\n\r\n")) 1254 if err != nil { 1255 t.Fatalf("Got error while writting to connection on destination listener: %v", err) 1256 } 1257 conn.Close() 1258 }() 1259 1260 l, err := net.Listen("tcp", "[::]:0") 1261 if err != nil { 1262 t.Fatalf("net.Listen(): got %v, want no error", err) 1263 } 1264 1265 ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) 1266 if err != nil { 1267 t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) 1268 } 1269 1270 mc, err := mitm.NewConfig(ca, priv) 1271 if err != nil { 1272 t.Fatalf("mitm.NewConfig(): got %v, want no error", err) 1273 } 1274 p := NewProxy() 1275 p.SetMITM(mc) 1276 defer p.Close() 1277 1278 // Start the proxy with a listener that will return a temporary error on 1279 // Accept() three times. 1280 go p.Serve(newTimeoutListener(l, 3)) 1281 1282 conn, err := net.Dial("tcp", l.Addr().String()) 1283 if err != nil { 1284 t.Fatalf("net.Dial(): got %v, want no error", err) 1285 } 1286 defer conn.Close() 1287 1288 req, err := http.NewRequest("CONNECT", fmt.Sprintf("//%s", dstl.Addr().String()), nil) 1289 if err != nil { 1290 t.Fatalf("http.NewRequest(): got %v, want no error", err) 1291 } 1292 1293 // CONNECT example.com:443 HTTP/1.1 1294 // Host: example.com 1295 if err := req.Write(conn); err != nil { 1296 t.Fatalf("req.Write(): got %v, want no error", err) 1297 } 1298 res, err := http.ReadResponse(bufio.NewReader(conn), req) 1299 if err != nil { 1300 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 1301 } 1302 res.Body.Close() 1303 1304 _, err = conn.Write([]byte("GET / HTTP/1.1\r\n" + 1305 "User-Agent: curl/7.35.0\r\n" + 1306 fmt.Sprintf("Host: %s\r\n", dstl.Addr()) + 1307 "Accept: */*\r\n\r\n")) 1308 if err != nil { 1309 t.Fatalf("Error while writing GET request: %v", err) 1310 } 1311 1312 res, err = http.ReadResponse(bufio.NewReader(io.TeeReader(conn, os.Stderr)), req) 1313 if err != nil { 1314 t.Fatalf("http.ReadResponse(): got %v, want no error", err) 1315 } 1316 _, err = ioutil.ReadAll(res.Body) 1317 if err != nil { 1318 t.Fatalf("error while ReadAll: %v", err) 1319 } 1320 defer res.Body.Close() 1321 } 1322 1323 // TestRacyClose checks that creating a proxy, serving from it, and closing 1324 // it in rapid succession doesn't result in race warnings. 1325 // See https://github.com/google/martian/issues/286. 1326 func TestRacyClose(t *testing.T) { 1327 t.Parallel() 1328 1329 log.SetLevel(log.Silent) // avoid "failed to accept" messages because we close l 1330 openAndConnect := func() { 1331 l, err := net.Listen("tcp", "[::]:0") 1332 if err != nil { 1333 t.Fatalf("net.Listen(): got %v, want no error", err) 1334 } 1335 defer l.Close() // to make p.Serve exit 1336 1337 p := NewProxy() 1338 go p.Serve(l) 1339 defer p.Close() 1340 1341 conn, err := net.Dial("tcp", l.Addr().String()) 1342 if err != nil { 1343 t.Fatalf("net.Dial(): got %v, want no error", err) 1344 } 1345 defer conn.Close() 1346 } 1347 1348 // Repeat a bunch of times to make failures more repeatable. 1349 for i := 0; i < 100; i++ { 1350 openAndConnect() 1351 } 1352 }