github.com/google/martian/v3@v3.3.3/proxy_test.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package martian
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"crypto/tls"
    21  	"crypto/x509"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"io/ioutil"
    26  	"net"
    27  	"net/http"
    28  	"net/url"
    29  	"os"
    30  	"strings"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/google/martian/v3/log"
    35  	"github.com/google/martian/v3/martiantest"
    36  	"github.com/google/martian/v3/mitm"
    37  	"github.com/google/martian/v3/proxyutil"
    38  )
    39  
    40  type tempError struct{}
    41  
    42  func (e *tempError) Error() string   { return "temporary" }
    43  func (e *tempError) Timeout() bool   { return true }
    44  func (e *tempError) Temporary() bool { return true }
    45  
    46  type timeoutListener struct {
    47  	net.Listener
    48  	errCount int
    49  	err      error
    50  }
    51  
    52  func newTimeoutListener(l net.Listener, errCount int) net.Listener {
    53  	return &timeoutListener{
    54  		Listener: l,
    55  		errCount: errCount,
    56  		err:      &tempError{},
    57  	}
    58  }
    59  
    60  func (l *timeoutListener) Accept() (net.Conn, error) {
    61  	if l.errCount > 0 {
    62  		l.errCount--
    63  		return nil, l.err
    64  	}
    65  
    66  	return l.Listener.Accept()
    67  }
    68  
    69  func TestIntegrationTemporaryTimeout(t *testing.T) {
    70  	t.Parallel()
    71  
    72  	l, err := net.Listen("tcp", "[::]:0")
    73  	if err != nil {
    74  		t.Fatalf("net.Listen(): got %v, want no error", err)
    75  	}
    76  
    77  	p := NewProxy()
    78  	defer p.Close()
    79  
    80  	tr := martiantest.NewTransport()
    81  	p.SetRoundTripper(tr)
    82  	p.SetTimeout(200 * time.Millisecond)
    83  
    84  	// Start the proxy with a listener that will return a temporary error on
    85  	// Accept() three times.
    86  	go p.Serve(newTimeoutListener(l, 3))
    87  
    88  	conn, err := net.Dial("tcp", l.Addr().String())
    89  	if err != nil {
    90  		t.Fatalf("net.Dial(): got %v, want no error", err)
    91  	}
    92  	defer conn.Close()
    93  
    94  	req, err := http.NewRequest("GET", "http://example.com", nil)
    95  	if err != nil {
    96  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
    97  	}
    98  	req.Header.Set("Connection", "close")
    99  
   100  	// GET http://example.com/ HTTP/1.1
   101  	// Host: example.com
   102  	if err := req.WriteProxy(conn); err != nil {
   103  		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   104  	}
   105  
   106  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
   107  	if err != nil {
   108  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   109  	}
   110  	defer res.Body.Close()
   111  
   112  	if got, want := res.StatusCode, 200; got != want {
   113  		t.Errorf("res.StatusCode: got %d, want %d", got, want)
   114  	}
   115  }
   116  
   117  func TestIntegrationHTTP(t *testing.T) {
   118  	t.Parallel()
   119  
   120  	l, err := net.Listen("tcp", "[::]:0")
   121  	if err != nil {
   122  		t.Fatalf("net.Listen(): got %v, want no error", err)
   123  	}
   124  
   125  	p := NewProxy()
   126  	defer p.Close()
   127  
   128  	p.SetRequestModifier(nil)
   129  	p.SetResponseModifier(nil)
   130  
   131  	tr := martiantest.NewTransport()
   132  	p.SetRoundTripper(tr)
   133  	p.SetTimeout(200 * time.Millisecond)
   134  
   135  	tm := martiantest.NewModifier()
   136  
   137  	tm.RequestFunc(func(req *http.Request) {
   138  		ctx := NewContext(req)
   139  		ctx.Set("martian.test", "true")
   140  	})
   141  
   142  	tm.ResponseFunc(func(res *http.Response) {
   143  		ctx := NewContext(res.Request)
   144  		v, _ := ctx.Get("martian.test")
   145  
   146  		res.Header.Set("Martian-Test", v.(string))
   147  	})
   148  
   149  	p.SetRequestModifier(tm)
   150  	p.SetResponseModifier(tm)
   151  
   152  	go p.Serve(l)
   153  
   154  	conn, err := net.Dial("tcp", l.Addr().String())
   155  	if err != nil {
   156  		t.Fatalf("net.Dial(): got %v, want no error", err)
   157  	}
   158  	defer conn.Close()
   159  
   160  	req, err := http.NewRequest("GET", "http://example.com", nil)
   161  	if err != nil {
   162  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   163  	}
   164  
   165  	// GET http://example.com/ HTTP/1.1
   166  	// Host: example.com
   167  	if err := req.WriteProxy(conn); err != nil {
   168  		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   169  	}
   170  
   171  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
   172  	if err != nil {
   173  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   174  	}
   175  
   176  	if got, want := res.StatusCode, 200; got != want {
   177  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   178  	}
   179  
   180  	if got, want := res.Header.Get("Martian-Test"), "true"; got != want {
   181  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want)
   182  	}
   183  }
   184  
   185  func TestIntegrationHTTP100Continue(t *testing.T) {
   186  	t.Parallel()
   187  
   188  	l, err := net.Listen("tcp", "[::]:0")
   189  	if err != nil {
   190  		t.Fatalf("net.Listen(): got %v, want no error", err)
   191  	}
   192  
   193  	p := NewProxy()
   194  	defer p.Close()
   195  
   196  	p.SetTimeout(2 * time.Second)
   197  
   198  	sl, err := net.Listen("tcp", "[::]:0")
   199  	if err != nil {
   200  		t.Fatalf("net.Listen(): got %v, want no error", err)
   201  	}
   202  
   203  	go func() {
   204  		conn, err := sl.Accept()
   205  		if err != nil {
   206  			log.Errorf("proxy_test: failed to accept connection: %v", err)
   207  			return
   208  		}
   209  		defer conn.Close()
   210  
   211  		log.Infof("proxy_test: accepted connection: %s", conn.RemoteAddr())
   212  
   213  		req, err := http.ReadRequest(bufio.NewReader(conn))
   214  		if err != nil {
   215  			log.Errorf("proxy_test: failed to read request: %v", err)
   216  			return
   217  		}
   218  
   219  		if req.Header.Get("Expect") == "100-continue" {
   220  			log.Infof("proxy_test: received 100-continue request")
   221  
   222  			conn.Write([]byte("HTTP/1.1 100 Continue\r\n\r\n"))
   223  
   224  			log.Infof("proxy_test: sent 100-continue response")
   225  		} else {
   226  			log.Infof("proxy_test: received non 100-continue request")
   227  
   228  			res := proxyutil.NewResponse(417, nil, req)
   229  			res.Header.Set("Connection", "close")
   230  			res.Write(conn)
   231  			return
   232  		}
   233  
   234  		res := proxyutil.NewResponse(200, req.Body, req)
   235  		res.Header.Set("Connection", "close")
   236  		res.Write(conn)
   237  
   238  		log.Infof("proxy_test: sent 200 response")
   239  	}()
   240  
   241  	tm := martiantest.NewModifier()
   242  	p.SetRequestModifier(tm)
   243  	p.SetResponseModifier(tm)
   244  
   245  	go p.Serve(l)
   246  
   247  	conn, err := net.Dial("tcp", l.Addr().String())
   248  	if err != nil {
   249  		t.Fatalf("net.Dial(): got %v, want no error", err)
   250  	}
   251  	defer conn.Close()
   252  
   253  	host := sl.Addr().String()
   254  	raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+
   255  		"Host: %s\r\n"+
   256  		"Content-Length: 12\r\n"+
   257  		"Expect: 100-continue\r\n\r\n", host, host)
   258  
   259  	if _, err := conn.Write([]byte(raw)); err != nil {
   260  		t.Fatalf("conn.Write(headers): got %v, want no error", err)
   261  	}
   262  
   263  	go func() {
   264  		select {
   265  		case <-time.After(time.Second):
   266  			conn.Write([]byte("body content"))
   267  		}
   268  	}()
   269  
   270  	res, err := http.ReadResponse(bufio.NewReader(conn), nil)
   271  	if err != nil {
   272  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   273  	}
   274  	defer res.Body.Close()
   275  
   276  	if got, want := res.StatusCode, 200; got != want {
   277  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   278  	}
   279  
   280  	got, err := ioutil.ReadAll(res.Body)
   281  	if err != nil {
   282  		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
   283  	}
   284  
   285  	if want := []byte("body content"); !bytes.Equal(got, want) {
   286  		t.Errorf("res.Body: got %q, want %q", got, want)
   287  	}
   288  
   289  	if !tm.RequestModified() {
   290  		t.Error("tm.RequestModified(): got false, want true")
   291  	}
   292  	if !tm.ResponseModified() {
   293  		t.Error("tm.ResponseModified(): got false, want true")
   294  	}
   295  }
   296  
   297  func TestIntegrationHTTPDownstreamProxy(t *testing.T) {
   298  	t.Parallel()
   299  
   300  	// Start first proxy to use as downstream.
   301  	dl, err := net.Listen("tcp", "[::]:0")
   302  	if err != nil {
   303  		t.Fatalf("net.Listen(): got %v, want no error", err)
   304  	}
   305  
   306  	downstream := NewProxy()
   307  	defer downstream.Close()
   308  
   309  	dtr := martiantest.NewTransport()
   310  	dtr.Respond(299)
   311  	downstream.SetRoundTripper(dtr)
   312  	downstream.SetTimeout(600 * time.Millisecond)
   313  
   314  	go downstream.Serve(dl)
   315  
   316  	// Start second proxy as upstream proxy, will write to downstream proxy.
   317  	ul, err := net.Listen("tcp", "[::]:0")
   318  	if err != nil {
   319  		t.Fatalf("net.Listen(): got %v, want no error", err)
   320  	}
   321  
   322  	upstream := NewProxy()
   323  	defer upstream.Close()
   324  
   325  	// Set upstream proxy's downstream proxy to the host:port of the first proxy.
   326  	upstream.SetDownstreamProxy(&url.URL{
   327  		Host: dl.Addr().String(),
   328  	})
   329  	upstream.SetTimeout(600 * time.Millisecond)
   330  
   331  	go upstream.Serve(ul)
   332  
   333  	// Open connection to upstream proxy.
   334  	conn, err := net.Dial("tcp", ul.Addr().String())
   335  	if err != nil {
   336  		t.Fatalf("net.Dial(): got %v, want no error", err)
   337  	}
   338  	defer conn.Close()
   339  
   340  	req, err := http.NewRequest("GET", "http://example.com", nil)
   341  	if err != nil {
   342  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   343  	}
   344  
   345  	// GET http://example.com/ HTTP/1.1
   346  	// Host: example.com
   347  	if err := req.WriteProxy(conn); err != nil {
   348  		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   349  	}
   350  
   351  	// Response from downstream proxy.
   352  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
   353  	if err != nil {
   354  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   355  	}
   356  
   357  	if got, want := res.StatusCode, 299; got != want {
   358  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   359  	}
   360  }
   361  
   362  func TestIntegrationHTTPDownstreamProxyError(t *testing.T) {
   363  	t.Parallel()
   364  
   365  	l, err := net.Listen("tcp", "[::]:0")
   366  	if err != nil {
   367  		t.Fatalf("net.Listen(): got %v, want no error", err)
   368  	}
   369  
   370  	p := NewProxy()
   371  	defer p.Close()
   372  
   373  	// Set proxy's downstream proxy to invalid host:port to force failure.
   374  	p.SetDownstreamProxy(&url.URL{
   375  		Host: "[::]:0",
   376  	})
   377  	p.SetTimeout(600 * time.Millisecond)
   378  
   379  	tm := martiantest.NewModifier()
   380  	reserr := errors.New("response error")
   381  	tm.ResponseError(reserr)
   382  
   383  	p.SetResponseModifier(tm)
   384  
   385  	go p.Serve(l)
   386  
   387  	// Open connection to upstream proxy.
   388  	conn, err := net.Dial("tcp", l.Addr().String())
   389  	if err != nil {
   390  		t.Fatalf("net.Dial(): got %v, want no error", err)
   391  	}
   392  	defer conn.Close()
   393  
   394  	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
   395  	if err != nil {
   396  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   397  	}
   398  
   399  	// CONNECT example.com:443 HTTP/1.1
   400  	// Host: example.com
   401  	if err := req.Write(conn); err != nil {
   402  		t.Fatalf("req.Write(): got %v, want no error", err)
   403  	}
   404  
   405  	// Response from upstream proxy, assuming downstream proxy failed to CONNECT.
   406  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
   407  	if err != nil {
   408  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   409  	}
   410  
   411  	if got, want := res.StatusCode, 502; got != want {
   412  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   413  	}
   414  	if got, want := res.Header["Warning"][1], reserr.Error(); !strings.Contains(got, want) {
   415  		t.Errorf("res.Header.get(%q): got %q, want to contain %q", "Warning", got, want)
   416  	}
   417  }
   418  
   419  func TestIntegrationTLSHandshakeErrorCallback(t *testing.T) {
   420  	t.Parallel()
   421  
   422  	l, err := net.Listen("tcp", "[::]:0")
   423  	if err != nil {
   424  		t.Fatalf("net.Listen(): got %v, want no error", err)
   425  	}
   426  
   427  	p := NewProxy()
   428  	defer p.Close()
   429  
   430  	// Test TLS server.
   431  	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour)
   432  	if err != nil {
   433  		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
   434  	}
   435  	mc, err := mitm.NewConfig(ca, priv)
   436  	if err != nil {
   437  		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
   438  	}
   439  
   440  	var herr error
   441  	mc.SetHandshakeErrorCallback(func(_ *http.Request, err error) { herr = fmt.Errorf("handshake error") })
   442  	p.SetMITM(mc)
   443  
   444  	tl, err := net.Listen("tcp", "[::]:0")
   445  	if err != nil {
   446  		t.Fatalf("tls.Listen(): got %v, want no error", err)
   447  	}
   448  	tl = tls.NewListener(tl, mc.TLS())
   449  
   450  	go http.Serve(tl, http.HandlerFunc(
   451  		func(rw http.ResponseWriter, req *http.Request) {
   452  			rw.WriteHeader(200)
   453  		}))
   454  
   455  	tm := martiantest.NewModifier()
   456  
   457  	// Force the CONNECT request to dial the local TLS server.
   458  	tm.RequestFunc(func(req *http.Request) {
   459  		req.URL.Host = tl.Addr().String()
   460  	})
   461  
   462  	go p.Serve(l)
   463  
   464  	conn, err := net.Dial("tcp", l.Addr().String())
   465  	if err != nil {
   466  		t.Fatalf("net.Dial(): got %v, want no error", err)
   467  	}
   468  	defer conn.Close()
   469  
   470  	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
   471  	if err != nil {
   472  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   473  	}
   474  
   475  	// CONNECT example.com:443 HTTP/1.1
   476  	// Host: example.com
   477  	//
   478  	// Rewritten to CONNECT to host:port in CONNECT request modifier.
   479  	if err := req.Write(conn); err != nil {
   480  		t.Fatalf("req.Write(): got %v, want no error", err)
   481  	}
   482  
   483  	// CONNECT response after establishing tunnel.
   484  	if _, err := http.ReadResponse(bufio.NewReader(conn), req); err != nil {
   485  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   486  	}
   487  
   488  	tlsconn := tls.Client(conn, &tls.Config{
   489  		ServerName: "example.com",
   490  		// Client has no cert so it will get "x509: certificate signed by unknown authority" from the
   491  		// handshake and send "remote error: bad certificate" to the server.
   492  		RootCAs: x509.NewCertPool(),
   493  	})
   494  	defer tlsconn.Close()
   495  
   496  	req, err = http.NewRequest("GET", "https://example.com", nil)
   497  	if err != nil {
   498  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   499  	}
   500  	req.Header.Set("Connection", "close")
   501  
   502  	if got, want := req.Write(tlsconn), "x509: certificate signed by unknown authority"; !strings.Contains(got.Error(), want) {
   503  		t.Fatalf("Got incorrect error from Client Handshake(), got: %v, want: %v", got, want)
   504  	}
   505  
   506  	// TODO: herr is not being asserted against. It should be pushed on to a channel
   507  	// of err, and the assertion should pull off of it and assert. That design resulted in the test
   508  	// hanging for unknown reasons.
   509  	t.Skip("skipping assertion of handshake error callback error due to mysterious deadlock")
   510  	if got, want := herr, "remote error: bad certificate"; !strings.Contains(got.Error(), want) {
   511  		t.Fatalf("Got incorrect error from Server Handshake(), got: %v, want: %v", got, want)
   512  	}
   513  }
   514  
   515  func TestIntegrationConnect(t *testing.T) {
   516  	t.Parallel()
   517  
   518  	l, err := net.Listen("tcp", "[::]:0")
   519  	if err != nil {
   520  		t.Fatalf("net.Listen(): got %v, want no error", err)
   521  	}
   522  
   523  	p := NewProxy()
   524  	defer p.Close()
   525  
   526  	// Test TLS server.
   527  	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour)
   528  	if err != nil {
   529  		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
   530  	}
   531  	mc, err := mitm.NewConfig(ca, priv)
   532  	if err != nil {
   533  		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
   534  	}
   535  
   536  	tl, err := net.Listen("tcp", "[::]:0")
   537  	if err != nil {
   538  		t.Fatalf("tls.Listen(): got %v, want no error", err)
   539  	}
   540  	tl = tls.NewListener(tl, mc.TLS())
   541  
   542  	go http.Serve(tl, http.HandlerFunc(
   543  		func(rw http.ResponseWriter, req *http.Request) {
   544  			rw.WriteHeader(299)
   545  		}))
   546  
   547  	tm := martiantest.NewModifier()
   548  	reqerr := errors.New("request error")
   549  	reserr := errors.New("response error")
   550  
   551  	// Force the CONNECT request to dial the local TLS server.
   552  	tm.RequestFunc(func(req *http.Request) {
   553  		req.URL.Host = tl.Addr().String()
   554  	})
   555  
   556  	tm.RequestError(reqerr)
   557  	tm.ResponseError(reserr)
   558  
   559  	p.SetRequestModifier(tm)
   560  	p.SetResponseModifier(tm)
   561  
   562  	go p.Serve(l)
   563  
   564  	conn, err := net.Dial("tcp", l.Addr().String())
   565  	if err != nil {
   566  		t.Fatalf("net.Dial(): got %v, want no error", err)
   567  	}
   568  	defer conn.Close()
   569  
   570  	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
   571  	if err != nil {
   572  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   573  	}
   574  
   575  	// CONNECT example.com:443 HTTP/1.1
   576  	// Host: example.com
   577  	//
   578  	// Rewritten to CONNECT to host:port in CONNECT request modifier.
   579  	if err := req.Write(conn); err != nil {
   580  		t.Fatalf("req.Write(): got %v, want no error", err)
   581  	}
   582  
   583  	// CONNECT response after establishing tunnel.
   584  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
   585  	if err != nil {
   586  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   587  	}
   588  
   589  	if got, want := res.StatusCode, 200; got != want {
   590  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   591  	}
   592  
   593  	if !tm.RequestModified() {
   594  		t.Error("tm.RequestModified(): got false, want true")
   595  	}
   596  	if !tm.ResponseModified() {
   597  		t.Error("tm.ResponseModified(): got false, want true")
   598  	}
   599  	if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
   600  		t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
   601  	}
   602  
   603  	roots := x509.NewCertPool()
   604  	roots.AddCert(ca)
   605  
   606  	tlsconn := tls.Client(conn, &tls.Config{
   607  		ServerName: "example.com",
   608  		RootCAs:    roots,
   609  	})
   610  	defer tlsconn.Close()
   611  
   612  	req, err = http.NewRequest("GET", "https://example.com", nil)
   613  	if err != nil {
   614  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   615  	}
   616  	req.Header.Set("Connection", "close")
   617  
   618  	// GET / HTTP/1.1
   619  	// Host: example.com
   620  	// Connection: close
   621  	if err := req.Write(tlsconn); err != nil {
   622  		t.Fatalf("req.Write(): got %v, want no error", err)
   623  	}
   624  
   625  	res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
   626  	if err != nil {
   627  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   628  	}
   629  	defer res.Body.Close()
   630  
   631  	if got, want := res.StatusCode, 299; got != want {
   632  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   633  	}
   634  	if got, want := res.Header.Get("Warning"), reserr.Error(); strings.Contains(got, want) {
   635  		t.Errorf("res.Header.Get(%q): got %s, want to not contain %s", "Warning", got, want)
   636  	}
   637  }
   638  
   639  func TestIntegrationConnectDownstreamProxy(t *testing.T) {
   640  	t.Parallel()
   641  
   642  	// Start first proxy to use as downstream.
   643  	dl, err := net.Listen("tcp", "[::]:0")
   644  	if err != nil {
   645  		t.Fatalf("net.Listen(): got %v, want no error", err)
   646  	}
   647  
   648  	downstream := NewProxy()
   649  	defer downstream.Close()
   650  
   651  	dtr := martiantest.NewTransport()
   652  	dtr.Respond(299)
   653  	downstream.SetRoundTripper(dtr)
   654  
   655  	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
   656  	if err != nil {
   657  		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
   658  	}
   659  
   660  	mc, err := mitm.NewConfig(ca, priv)
   661  	if err != nil {
   662  		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
   663  	}
   664  	downstream.SetMITM(mc)
   665  
   666  	go downstream.Serve(dl)
   667  
   668  	// Start second proxy as upstream proxy, will CONNECT to downstream proxy.
   669  	ul, err := net.Listen("tcp", "[::]:0")
   670  	if err != nil {
   671  		t.Fatalf("net.Listen(): got %v, want no error", err)
   672  	}
   673  
   674  	upstream := NewProxy()
   675  	defer upstream.Close()
   676  
   677  	// Set upstream proxy's downstream proxy to the host:port of the first proxy.
   678  	upstream.SetDownstreamProxy(&url.URL{
   679  		Host: dl.Addr().String(),
   680  	})
   681  
   682  	go upstream.Serve(ul)
   683  
   684  	// Open connection to upstream proxy.
   685  	conn, err := net.Dial("tcp", ul.Addr().String())
   686  	if err != nil {
   687  		t.Fatalf("net.Dial(): got %v, want no error", err)
   688  	}
   689  	defer conn.Close()
   690  
   691  	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
   692  	if err != nil {
   693  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   694  	}
   695  
   696  	// CONNECT example.com:443 HTTP/1.1
   697  	// Host: example.com
   698  	if err := req.Write(conn); err != nil {
   699  		t.Fatalf("req.Write(): got %v, want no error", err)
   700  	}
   701  
   702  	// Response from downstream proxy starting MITM.
   703  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
   704  	if err != nil {
   705  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   706  	}
   707  
   708  	if got, want := res.StatusCode, 200; got != want {
   709  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   710  	}
   711  
   712  	roots := x509.NewCertPool()
   713  	roots.AddCert(ca)
   714  
   715  	tlsconn := tls.Client(conn, &tls.Config{
   716  		// Validate the hostname.
   717  		ServerName: "example.com",
   718  		// The certificate will have been MITM'd, verify using the MITM CA
   719  		// certificate.
   720  		RootCAs: roots,
   721  	})
   722  	defer tlsconn.Close()
   723  
   724  	req, err = http.NewRequest("GET", "https://example.com", nil)
   725  	if err != nil {
   726  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   727  	}
   728  
   729  	// GET / HTTP/1.1
   730  	// Host: example.com
   731  	if err := req.Write(tlsconn); err != nil {
   732  		t.Fatalf("req.Write(): got %v, want no error", err)
   733  	}
   734  
   735  	// Response from MITM in downstream proxy.
   736  	res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
   737  	if err != nil {
   738  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   739  	}
   740  	defer res.Body.Close()
   741  
   742  	if got, want := res.StatusCode, 299; got != want {
   743  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   744  	}
   745  }
   746  
   747  func TestIntegrationMITM(t *testing.T) {
   748  	t.Parallel()
   749  
   750  	l, err := net.Listen("tcp", "[::]:0")
   751  	if err != nil {
   752  		t.Fatalf("net.Listen(): got %v, want no error", err)
   753  	}
   754  
   755  	p := NewProxy()
   756  	defer p.Close()
   757  
   758  	tr := martiantest.NewTransport()
   759  	tr.Func(func(req *http.Request) (*http.Response, error) {
   760  		res := proxyutil.NewResponse(200, nil, req)
   761  		res.Header.Set("Request-Scheme", req.URL.Scheme)
   762  
   763  		return res, nil
   764  	})
   765  
   766  	p.SetRoundTripper(tr)
   767  	p.SetTimeout(600 * time.Millisecond)
   768  
   769  	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
   770  	if err != nil {
   771  		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
   772  	}
   773  
   774  	mc, err := mitm.NewConfig(ca, priv)
   775  	if err != nil {
   776  		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
   777  	}
   778  	p.SetMITM(mc)
   779  
   780  	tm := martiantest.NewModifier()
   781  	reqerr := errors.New("request error")
   782  	reserr := errors.New("response error")
   783  	tm.RequestError(reqerr)
   784  	tm.ResponseError(reserr)
   785  
   786  	p.SetRequestModifier(tm)
   787  	p.SetResponseModifier(tm)
   788  
   789  	go p.Serve(l)
   790  
   791  	conn, err := net.Dial("tcp", l.Addr().String())
   792  	if err != nil {
   793  		t.Fatalf("net.Dial(): got %v, want no error", err)
   794  	}
   795  	defer conn.Close()
   796  
   797  	req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
   798  	if err != nil {
   799  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   800  	}
   801  
   802  	// CONNECT example.com:443 HTTP/1.1
   803  	// Host: example.com
   804  	if err := req.Write(conn); err != nil {
   805  		t.Fatalf("req.Write(): got %v, want no error", err)
   806  	}
   807  
   808  	// Response MITM'd from proxy.
   809  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
   810  	if err != nil {
   811  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   812  	}
   813  	if got, want := res.StatusCode, 200; got != want {
   814  
   815  		t.Errorf("res.StatusCode: got %d, want %d", got, want)
   816  	}
   817  	if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
   818  		t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
   819  	}
   820  
   821  	roots := x509.NewCertPool()
   822  	roots.AddCert(ca)
   823  
   824  	tlsconn := tls.Client(conn, &tls.Config{
   825  		ServerName: "example.com",
   826  		RootCAs:    roots,
   827  	})
   828  	defer tlsconn.Close()
   829  
   830  	req, err = http.NewRequest("GET", "https://example.com", nil)
   831  	if err != nil {
   832  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   833  	}
   834  
   835  	// GET / HTTP/1.1
   836  	// Host: example.com
   837  	if err := req.Write(tlsconn); err != nil {
   838  		t.Fatalf("req.Write(): got %v, want no error", err)
   839  	}
   840  
   841  	// Response from MITM proxy.
   842  	res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
   843  	if err != nil {
   844  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   845  	}
   846  	defer res.Body.Close()
   847  
   848  	if got, want := res.StatusCode, 200; got != want {
   849  		t.Errorf("res.StatusCode: got %d, want %d", got, want)
   850  	}
   851  	if got, want := res.Header.Get("Request-Scheme"), "https"; got != want {
   852  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want)
   853  	}
   854  	if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
   855  		t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
   856  	}
   857  }
   858  
   859  func TestIntegrationTransparentHTTP(t *testing.T) {
   860  	t.Parallel()
   861  
   862  	l, err := net.Listen("tcp", "[::]:0")
   863  	if err != nil {
   864  		t.Fatalf("net.Listen(): got %v, want no error", err)
   865  	}
   866  
   867  	p := NewProxy()
   868  	defer p.Close()
   869  
   870  	tr := martiantest.NewTransport()
   871  	p.SetRoundTripper(tr)
   872  
   873  	if got, want := p.GetRoundTripper(), tr; got != want {
   874  		t.Errorf("proxy.GetRoundTripper: got %v, want %v", got, want)
   875  	}
   876  
   877  	p.SetTimeout(200 * time.Millisecond)
   878  
   879  	tm := martiantest.NewModifier()
   880  	p.SetRequestModifier(tm)
   881  	p.SetResponseModifier(tm)
   882  
   883  	go p.Serve(l)
   884  
   885  	conn, err := net.Dial("tcp", l.Addr().String())
   886  	if err != nil {
   887  		t.Fatalf("net.Dial(): got %v, want no error", err)
   888  	}
   889  	defer conn.Close()
   890  
   891  	req, err := http.NewRequest("GET", "http://example.com", nil)
   892  	if err != nil {
   893  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   894  	}
   895  
   896  	// GET / HTTP/1.1
   897  	// Host: www.example.com
   898  	if err := req.Write(conn); err != nil {
   899  		t.Fatalf("req.Write(): got %v, want no error", err)
   900  	}
   901  
   902  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
   903  	if err != nil {
   904  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   905  	}
   906  
   907  	if got, want := res.StatusCode, 200; got != want {
   908  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   909  	}
   910  
   911  	if !tm.RequestModified() {
   912  		t.Error("tm.RequestModified(): got false, want true")
   913  	}
   914  	if !tm.ResponseModified() {
   915  		t.Error("tm.ResponseModified(): got false, want true")
   916  	}
   917  }
   918  
   919  func TestIntegrationTransparentMITM(t *testing.T) {
   920  	t.Parallel()
   921  
   922  	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
   923  	if err != nil {
   924  		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
   925  	}
   926  
   927  	mc, err := mitm.NewConfig(ca, priv)
   928  	if err != nil {
   929  		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
   930  	}
   931  
   932  	// Start TLS listener with config that will generate certificates based on
   933  	// SNI from connection.
   934  	//
   935  	// BUG: tls.Listen will not accept a tls.Config where Certificates is empty,
   936  	// even though it is supported by tls.Server when GetCertificate is not nil.
   937  	l, err := net.Listen("tcp", "[::]:0")
   938  	if err != nil {
   939  		t.Fatalf("net.Listen(): got %v, want no error", err)
   940  	}
   941  	l = tls.NewListener(l, mc.TLS())
   942  
   943  	p := NewProxy()
   944  	defer p.Close()
   945  
   946  	tr := martiantest.NewTransport()
   947  	tr.Func(func(req *http.Request) (*http.Response, error) {
   948  		res := proxyutil.NewResponse(200, nil, req)
   949  		res.Header.Set("Request-Scheme", req.URL.Scheme)
   950  
   951  		return res, nil
   952  	})
   953  
   954  	p.SetRoundTripper(tr)
   955  
   956  	tm := martiantest.NewModifier()
   957  	p.SetRequestModifier(tm)
   958  	p.SetResponseModifier(tm)
   959  
   960  	go p.Serve(l)
   961  
   962  	roots := x509.NewCertPool()
   963  	roots.AddCert(ca)
   964  
   965  	tlsconn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{
   966  		// Verify the hostname is example.com.
   967  		ServerName: "example.com",
   968  		// The certificate will have been generated during MITM, so we need to
   969  		// verify it with the generated CA certificate.
   970  		RootCAs: roots,
   971  	})
   972  	if err != nil {
   973  		t.Fatalf("tls.Dial(): got %v, want no error", err)
   974  	}
   975  	defer tlsconn.Close()
   976  
   977  	req, err := http.NewRequest("GET", "https://example.com", nil)
   978  	if err != nil {
   979  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
   980  	}
   981  
   982  	// Write Encrypted request directly, no CONNECT.
   983  	// GET / HTTP/1.1
   984  	// Host: example.com
   985  	if err := req.Write(tlsconn); err != nil {
   986  		t.Fatalf("req.Write(): got %v, want no error", err)
   987  	}
   988  
   989  	res, err := http.ReadResponse(bufio.NewReader(tlsconn), req)
   990  	if err != nil {
   991  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   992  	}
   993  	defer res.Body.Close()
   994  
   995  	if got, want := res.StatusCode, 200; got != want {
   996  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   997  	}
   998  	if got, want := res.Header.Get("Request-Scheme"), "https"; got != want {
   999  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want)
  1000  	}
  1001  
  1002  	if !tm.RequestModified() {
  1003  		t.Errorf("tm.RequestModified(): got false, want true")
  1004  	}
  1005  	if !tm.ResponseModified() {
  1006  		t.Errorf("tm.ResponseModified(): got false, want true")
  1007  	}
  1008  }
  1009  
  1010  func TestIntegrationFailedRoundTrip(t *testing.T) {
  1011  	t.Parallel()
  1012  
  1013  	l, err := net.Listen("tcp", "[::]:0")
  1014  	if err != nil {
  1015  		t.Fatalf("net.Listen(): got %v, want no error", err)
  1016  	}
  1017  
  1018  	p := NewProxy()
  1019  	defer p.Close()
  1020  
  1021  	tr := martiantest.NewTransport()
  1022  	trerr := errors.New("round trip error")
  1023  	tr.RespondError(trerr)
  1024  	p.SetRoundTripper(tr)
  1025  	p.SetTimeout(200 * time.Millisecond)
  1026  
  1027  	go p.Serve(l)
  1028  
  1029  	conn, err := net.Dial("tcp", l.Addr().String())
  1030  	if err != nil {
  1031  		t.Fatalf("net.Dial(): got %v, want no error", err)
  1032  	}
  1033  	defer conn.Close()
  1034  
  1035  	req, err := http.NewRequest("GET", "http://example.com", nil)
  1036  	if err != nil {
  1037  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
  1038  	}
  1039  
  1040  	// GET http://example.com/ HTTP/1.1
  1041  	// Host: example.com
  1042  	if err := req.WriteProxy(conn); err != nil {
  1043  		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  1044  	}
  1045  
  1046  	// Response from failed round trip.
  1047  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
  1048  	if err != nil {
  1049  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1050  	}
  1051  	defer res.Body.Close()
  1052  
  1053  	if got, want := res.StatusCode, 502; got != want {
  1054  		t.Errorf("res.StatusCode: got %d, want %d", got, want)
  1055  	}
  1056  
  1057  	if got, want := res.Header.Get("Warning"), trerr.Error(); !strings.Contains(got, want) {
  1058  		t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
  1059  	}
  1060  }
  1061  
  1062  func TestIntegrationSkipRoundTrip(t *testing.T) {
  1063  	t.Parallel()
  1064  
  1065  	l, err := net.Listen("tcp", "[::]:0")
  1066  	if err != nil {
  1067  		t.Fatalf("net.Listen(): got %v, want no error", err)
  1068  	}
  1069  
  1070  	p := NewProxy()
  1071  	defer p.Close()
  1072  
  1073  	// Transport will be skipped, no 500.
  1074  	tr := martiantest.NewTransport()
  1075  	tr.Respond(500)
  1076  	p.SetRoundTripper(tr)
  1077  	p.SetTimeout(200 * time.Millisecond)
  1078  
  1079  	tm := martiantest.NewModifier()
  1080  	tm.RequestFunc(func(req *http.Request) {
  1081  		ctx := NewContext(req)
  1082  		ctx.SkipRoundTrip()
  1083  	})
  1084  	p.SetRequestModifier(tm)
  1085  
  1086  	go p.Serve(l)
  1087  
  1088  	conn, err := net.Dial("tcp", l.Addr().String())
  1089  	if err != nil {
  1090  		t.Fatalf("net.Dial(): got %v, want no error", err)
  1091  	}
  1092  	defer conn.Close()
  1093  
  1094  	req, err := http.NewRequest("GET", "http://example.com", nil)
  1095  	if err != nil {
  1096  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
  1097  	}
  1098  
  1099  	// GET http://example.com/ HTTP/1.1
  1100  	// Host: example.com
  1101  	if err := req.WriteProxy(conn); err != nil {
  1102  		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  1103  	}
  1104  
  1105  	// Response from skipped round trip.
  1106  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
  1107  	if err != nil {
  1108  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1109  	}
  1110  	defer res.Body.Close()
  1111  
  1112  	if got, want := res.StatusCode, 200; got != want {
  1113  		t.Errorf("res.StatusCode: got %d, want %d", got, want)
  1114  	}
  1115  }
  1116  
  1117  func TestHTTPThroughConnectWithMITM(t *testing.T) {
  1118  	t.Parallel()
  1119  
  1120  	l, err := net.Listen("tcp", "[::]:0")
  1121  	if err != nil {
  1122  		t.Fatalf("net.Listen(): got %v, want no error", err)
  1123  	}
  1124  
  1125  	p := NewProxy()
  1126  	defer p.Close()
  1127  
  1128  	tm := martiantest.NewModifier()
  1129  	tm.RequestFunc(func(req *http.Request) {
  1130  		ctx := NewContext(req)
  1131  		ctx.SkipRoundTrip()
  1132  
  1133  		if req.Method != "GET" && req.Method != "CONNECT" {
  1134  			t.Errorf("unexpected method on request handler: %v", req.Method)
  1135  		}
  1136  	})
  1137  	p.SetRequestModifier(tm)
  1138  
  1139  	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
  1140  	if err != nil {
  1141  		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  1142  	}
  1143  
  1144  	mc, err := mitm.NewConfig(ca, priv)
  1145  	if err != nil {
  1146  		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  1147  	}
  1148  	p.SetMITM(mc)
  1149  
  1150  	go p.Serve(l)
  1151  
  1152  	conn, err := net.Dial("tcp", l.Addr().String())
  1153  	if err != nil {
  1154  		t.Fatalf("net.Dial(): got %v, want no error", err)
  1155  	}
  1156  	defer conn.Close()
  1157  
  1158  	req, err := http.NewRequest("CONNECT", "//example.com:80", nil)
  1159  	if err != nil {
  1160  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
  1161  	}
  1162  
  1163  	// CONNECT example.com:80 HTTP/1.1
  1164  	// Host: example.com
  1165  	if err := req.Write(conn); err != nil {
  1166  		t.Fatalf("req.Write(): got %v, want no error", err)
  1167  	}
  1168  
  1169  	// Response skipped round trip.
  1170  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
  1171  	if err != nil {
  1172  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1173  	}
  1174  	res.Body.Close()
  1175  
  1176  	if got, want := res.StatusCode, 200; got != want {
  1177  		t.Errorf("res.StatusCode: got %d, want %d", got, want)
  1178  	}
  1179  
  1180  	req, err = http.NewRequest("GET", "http://example.com", nil)
  1181  	if err != nil {
  1182  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
  1183  	}
  1184  
  1185  	// GET http://example.com/ HTTP/1.1
  1186  	// Host: example.com
  1187  	if err := req.WriteProxy(conn); err != nil {
  1188  		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  1189  	}
  1190  
  1191  	// Response from skipped round trip.
  1192  	res, err = http.ReadResponse(bufio.NewReader(conn), req)
  1193  	if err != nil {
  1194  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1195  	}
  1196  	res.Body.Close()
  1197  
  1198  	if got, want := res.StatusCode, 200; got != want {
  1199  		t.Errorf("res.StatusCode: got %d, want %d", got, want)
  1200  	}
  1201  
  1202  	req, err = http.NewRequest("GET", "http://example.com", nil)
  1203  	if err != nil {
  1204  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
  1205  	}
  1206  
  1207  	// GET http://example.com/ HTTP/1.1
  1208  	// Host: example.com
  1209  	if err := req.WriteProxy(conn); err != nil {
  1210  		t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  1211  	}
  1212  
  1213  	// Response from skipped round trip.
  1214  	res, err = http.ReadResponse(bufio.NewReader(conn), req)
  1215  	if err != nil {
  1216  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1217  	}
  1218  	res.Body.Close()
  1219  
  1220  	if got, want := res.StatusCode, 200; got != want {
  1221  		t.Errorf("res.StatusCode: got %d, want %d", got, want)
  1222  	}
  1223  }
  1224  
  1225  func TestServerClosesConnection(t *testing.T) {
  1226  	t.Parallel()
  1227  
  1228  	dstl, err := net.Listen("tcp", "[::]:0")
  1229  	if err != nil {
  1230  		t.Fatalf("Failed to create http listener: %v", err)
  1231  	}
  1232  	defer dstl.Close()
  1233  
  1234  	go func() {
  1235  		t.Logf("Waiting for server side connection")
  1236  		conn, err := dstl.Accept()
  1237  		if err != nil {
  1238  			t.Fatalf("Got error while accepting connection on destination listener: %v", err)
  1239  		}
  1240  		t.Logf("Accepted server side connection")
  1241  
  1242  		buf := make([]byte, 16384)
  1243  		if _, err := conn.Read(buf); err != nil {
  1244  			t.Fatalf("Error reading: %v", err)
  1245  		}
  1246  
  1247  		_, err = conn.Write([]byte("HTTP/1.1 301 MOVED PERMANENTLY\r\n" +
  1248  			"Server:  \r\n" +
  1249  			"Date:  \r\n" +
  1250  			"Referer:  \r\n" +
  1251  			"Location: http://www.foo.com/\r\n" +
  1252  			"Content-type: text/html\r\n" +
  1253  			"Connection: close\r\n\r\n"))
  1254  		if err != nil {
  1255  			t.Fatalf("Got error while writting to connection on destination listener: %v", err)
  1256  		}
  1257  		conn.Close()
  1258  	}()
  1259  
  1260  	l, err := net.Listen("tcp", "[::]:0")
  1261  	if err != nil {
  1262  		t.Fatalf("net.Listen(): got %v, want no error", err)
  1263  	}
  1264  
  1265  	ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
  1266  	if err != nil {
  1267  		t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  1268  	}
  1269  
  1270  	mc, err := mitm.NewConfig(ca, priv)
  1271  	if err != nil {
  1272  		t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  1273  	}
  1274  	p := NewProxy()
  1275  	p.SetMITM(mc)
  1276  	defer p.Close()
  1277  
  1278  	// Start the proxy with a listener that will return a temporary error on
  1279  	// Accept() three times.
  1280  	go p.Serve(newTimeoutListener(l, 3))
  1281  
  1282  	conn, err := net.Dial("tcp", l.Addr().String())
  1283  	if err != nil {
  1284  		t.Fatalf("net.Dial(): got %v, want no error", err)
  1285  	}
  1286  	defer conn.Close()
  1287  
  1288  	req, err := http.NewRequest("CONNECT", fmt.Sprintf("//%s", dstl.Addr().String()), nil)
  1289  	if err != nil {
  1290  		t.Fatalf("http.NewRequest(): got %v, want no error", err)
  1291  	}
  1292  
  1293  	// CONNECT example.com:443 HTTP/1.1
  1294  	// Host: example.com
  1295  	if err := req.Write(conn); err != nil {
  1296  		t.Fatalf("req.Write(): got %v, want no error", err)
  1297  	}
  1298  	res, err := http.ReadResponse(bufio.NewReader(conn), req)
  1299  	if err != nil {
  1300  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1301  	}
  1302  	res.Body.Close()
  1303  
  1304  	_, err = conn.Write([]byte("GET / HTTP/1.1\r\n" +
  1305  		"User-Agent: curl/7.35.0\r\n" +
  1306  		fmt.Sprintf("Host: %s\r\n", dstl.Addr()) +
  1307  		"Accept: */*\r\n\r\n"))
  1308  	if err != nil {
  1309  		t.Fatalf("Error while writing GET request: %v", err)
  1310  	}
  1311  
  1312  	res, err = http.ReadResponse(bufio.NewReader(io.TeeReader(conn, os.Stderr)), req)
  1313  	if err != nil {
  1314  		t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1315  	}
  1316  	_, err = ioutil.ReadAll(res.Body)
  1317  	if err != nil {
  1318  		t.Fatalf("error while ReadAll: %v", err)
  1319  	}
  1320  	defer res.Body.Close()
  1321  }
  1322  
  1323  // TestRacyClose checks that creating a proxy, serving from it, and closing
  1324  // it in rapid succession doesn't result in race warnings.
  1325  // See https://github.com/google/martian/issues/286.
  1326  func TestRacyClose(t *testing.T) {
  1327  	t.Parallel()
  1328  
  1329  	log.SetLevel(log.Silent) // avoid "failed to accept" messages because we close l
  1330  	openAndConnect := func() {
  1331  		l, err := net.Listen("tcp", "[::]:0")
  1332  		if err != nil {
  1333  			t.Fatalf("net.Listen(): got %v, want no error", err)
  1334  		}
  1335  		defer l.Close() // to make p.Serve exit
  1336  
  1337  		p := NewProxy()
  1338  		go p.Serve(l)
  1339  		defer p.Close()
  1340  
  1341  		conn, err := net.Dial("tcp", l.Addr().String())
  1342  		if err != nil {
  1343  			t.Fatalf("net.Dial(): got %v, want no error", err)
  1344  		}
  1345  		defer conn.Close()
  1346  	}
  1347  
  1348  	// Repeat a bunch of times to make failures more repeatable.
  1349  	for i := 0; i < 100; i++ {
  1350  		openAndConnect()
  1351  	}
  1352  }