github.com/miolini/go@v0.0.0-20160405192216-fca68c8cb408/src/net/http/clientserver_test.go (about)

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
     6  
     7  package http_test
     8  
     9  import (
    10  	"bytes"
    11  	"compress/gzip"
    12  	"crypto/tls"
    13  	"fmt"
    14  	"io"
    15  	"io/ioutil"
    16  	"log"
    17  	"net"
    18  	. "net/http"
    19  	"net/http/httptest"
    20  	"net/url"
    21  	"os"
    22  	"reflect"
    23  	"runtime"
    24  	"sort"
    25  	"strings"
    26  	"sync"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  )
    31  
    32  type clientServerTest struct {
    33  	t  *testing.T
    34  	h2 bool
    35  	h  Handler
    36  	ts *httptest.Server
    37  	tr *Transport
    38  	c  *Client
    39  }
    40  
    41  func (t *clientServerTest) close() {
    42  	t.tr.CloseIdleConnections()
    43  	t.ts.Close()
    44  }
    45  
    46  const (
    47  	h1Mode = false
    48  	h2Mode = true
    49  )
    50  
    51  func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
    52  	cst := &clientServerTest{
    53  		t:  t,
    54  		h2: h2,
    55  		h:  h,
    56  		tr: &Transport{},
    57  	}
    58  	cst.c = &Client{Transport: cst.tr}
    59  
    60  	for _, opt := range opts {
    61  		switch opt := opt.(type) {
    62  		case func(*Transport):
    63  			opt(cst.tr)
    64  		default:
    65  			t.Fatalf("unhandled option type %T", opt)
    66  		}
    67  	}
    68  
    69  	if !h2 {
    70  		cst.ts = httptest.NewServer(h)
    71  		return cst
    72  	}
    73  	cst.ts = httptest.NewUnstartedServer(h)
    74  	ExportHttp2ConfigureServer(cst.ts.Config, nil)
    75  	cst.ts.TLS = cst.ts.Config.TLSConfig
    76  	cst.ts.StartTLS()
    77  
    78  	cst.tr.TLSClientConfig = &tls.Config{
    79  		InsecureSkipVerify: true,
    80  	}
    81  	if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	return cst
    85  }
    86  
    87  // Testing the newClientServerTest helper itself.
    88  func TestNewClientServerTest(t *testing.T) {
    89  	var got struct {
    90  		sync.Mutex
    91  		log []string
    92  	}
    93  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
    94  		got.Lock()
    95  		defer got.Unlock()
    96  		got.log = append(got.log, r.Proto)
    97  	})
    98  	for _, v := range [2]bool{false, true} {
    99  		cst := newClientServerTest(t, v, h)
   100  		if _, err := cst.c.Head(cst.ts.URL); err != nil {
   101  			t.Fatal(err)
   102  		}
   103  		cst.close()
   104  	}
   105  	got.Lock() // no need to unlock
   106  	if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) {
   107  		t.Errorf("got %q; want %q", got.log, want)
   108  	}
   109  }
   110  
   111  func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) }
   112  func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) }
   113  
   114  func testChunkedResponseHeaders(t *testing.T, h2 bool) {
   115  	defer afterTest(t)
   116  	log.SetOutput(ioutil.Discard) // is noisy otherwise
   117  	defer log.SetOutput(os.Stderr)
   118  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   119  		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
   120  		w.(Flusher).Flush()
   121  		fmt.Fprintf(w, "I am a chunked response.")
   122  	}))
   123  	defer cst.close()
   124  
   125  	res, err := cst.c.Get(cst.ts.URL)
   126  	if err != nil {
   127  		t.Fatalf("Get error: %v", err)
   128  	}
   129  	defer res.Body.Close()
   130  	if g, e := res.ContentLength, int64(-1); g != e {
   131  		t.Errorf("expected ContentLength of %d; got %d", e, g)
   132  	}
   133  	wantTE := []string{"chunked"}
   134  	if h2 {
   135  		wantTE = nil
   136  	}
   137  	if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
   138  		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
   139  	}
   140  	if got, haveCL := res.Header["Content-Length"]; haveCL {
   141  		t.Errorf("Unexpected Content-Length: %q", got)
   142  	}
   143  }
   144  
   145  type reqFunc func(c *Client, url string) (*Response, error)
   146  
   147  // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
   148  // against each other.
   149  type h12Compare struct {
   150  	Handler       func(ResponseWriter, *Request)    // required
   151  	ReqFunc       reqFunc                           // optional
   152  	CheckResponse func(proto string, res *Response) // optional
   153  	Opts          []interface{}
   154  }
   155  
   156  func (tt h12Compare) reqFunc() reqFunc {
   157  	if tt.ReqFunc == nil {
   158  		return (*Client).Get
   159  	}
   160  	return tt.ReqFunc
   161  }
   162  
   163  func (tt h12Compare) run(t *testing.T) {
   164  	cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
   165  	defer cst1.close()
   166  	cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
   167  	defer cst2.close()
   168  
   169  	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
   170  	if err != nil {
   171  		t.Errorf("HTTP/1 request: %v", err)
   172  		return
   173  	}
   174  	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
   175  	if err != nil {
   176  		t.Errorf("HTTP/2 request: %v", err)
   177  		return
   178  	}
   179  	tt.normalizeRes(t, res1, "HTTP/1.1")
   180  	tt.normalizeRes(t, res2, "HTTP/2.0")
   181  	res1body, res2body := res1.Body, res2.Body
   182  
   183  	eres1 := mostlyCopy(res1)
   184  	eres2 := mostlyCopy(res2)
   185  	if !reflect.DeepEqual(eres1, eres2) {
   186  		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
   187  			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
   188  	}
   189  	if !reflect.DeepEqual(res1body, res2body) {
   190  		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
   191  	}
   192  	if fn := tt.CheckResponse; fn != nil {
   193  		res1.Body, res2.Body = res1body, res2body
   194  		fn("HTTP/1.1", res1)
   195  		fn("HTTP/2.0", res2)
   196  	}
   197  }
   198  
   199  func mostlyCopy(r *Response) *Response {
   200  	c := *r
   201  	c.Body = nil
   202  	c.TransferEncoding = nil
   203  	c.TLS = nil
   204  	c.Request = nil
   205  	return &c
   206  }
   207  
   208  type slurpResult struct {
   209  	io.ReadCloser
   210  	body []byte
   211  	err  error
   212  }
   213  
   214  func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
   215  
   216  func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
   217  	if res.Proto == wantProto {
   218  		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
   219  	} else {
   220  		t.Errorf("got %q response; want %q", res.Proto, wantProto)
   221  	}
   222  	slurp, err := ioutil.ReadAll(res.Body)
   223  	res.Body.Close()
   224  	res.Body = slurpResult{
   225  		ReadCloser: ioutil.NopCloser(bytes.NewReader(slurp)),
   226  		body:       slurp,
   227  		err:        err,
   228  	}
   229  	for i, v := range res.Header["Date"] {
   230  		res.Header["Date"][i] = strings.Repeat("x", len(v))
   231  	}
   232  	if res.Request == nil {
   233  		t.Errorf("for %s, no request", wantProto)
   234  	}
   235  	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
   236  		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
   237  	}
   238  }
   239  
   240  // Issue 13532
   241  func TestH12_HeadContentLengthNoBody(t *testing.T) {
   242  	h12Compare{
   243  		ReqFunc: (*Client).Head,
   244  		Handler: func(w ResponseWriter, r *Request) {
   245  		},
   246  	}.run(t)
   247  }
   248  
   249  func TestH12_HeadContentLengthSmallBody(t *testing.T) {
   250  	h12Compare{
   251  		ReqFunc: (*Client).Head,
   252  		Handler: func(w ResponseWriter, r *Request) {
   253  			io.WriteString(w, "small")
   254  		},
   255  	}.run(t)
   256  }
   257  
   258  func TestH12_HeadContentLengthLargeBody(t *testing.T) {
   259  	h12Compare{
   260  		ReqFunc: (*Client).Head,
   261  		Handler: func(w ResponseWriter, r *Request) {
   262  			chunk := strings.Repeat("x", 512<<10)
   263  			for i := 0; i < 10; i++ {
   264  				io.WriteString(w, chunk)
   265  			}
   266  		},
   267  	}.run(t)
   268  }
   269  
   270  func TestH12_200NoBody(t *testing.T) {
   271  	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
   272  }
   273  
   274  func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
   275  func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
   276  func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
   277  
   278  func testH12_noBody(t *testing.T, status int) {
   279  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   280  		w.WriteHeader(status)
   281  	}}.run(t)
   282  }
   283  
   284  func TestH12_SmallBody(t *testing.T) {
   285  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   286  		io.WriteString(w, "small body")
   287  	}}.run(t)
   288  }
   289  
   290  func TestH12_ExplicitContentLength(t *testing.T) {
   291  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   292  		w.Header().Set("Content-Length", "3")
   293  		io.WriteString(w, "foo")
   294  	}}.run(t)
   295  }
   296  
   297  func TestH12_FlushBeforeBody(t *testing.T) {
   298  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   299  		w.(Flusher).Flush()
   300  		io.WriteString(w, "foo")
   301  	}}.run(t)
   302  }
   303  
   304  func TestH12_FlushMidBody(t *testing.T) {
   305  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   306  		io.WriteString(w, "foo")
   307  		w.(Flusher).Flush()
   308  		io.WriteString(w, "bar")
   309  	}}.run(t)
   310  }
   311  
   312  func TestH12_Head_ExplicitLen(t *testing.T) {
   313  	h12Compare{
   314  		ReqFunc: (*Client).Head,
   315  		Handler: func(w ResponseWriter, r *Request) {
   316  			if r.Method != "HEAD" {
   317  				t.Errorf("unexpected method %q", r.Method)
   318  			}
   319  			w.Header().Set("Content-Length", "1235")
   320  		},
   321  	}.run(t)
   322  }
   323  
   324  func TestH12_Head_ImplicitLen(t *testing.T) {
   325  	h12Compare{
   326  		ReqFunc: (*Client).Head,
   327  		Handler: func(w ResponseWriter, r *Request) {
   328  			if r.Method != "HEAD" {
   329  				t.Errorf("unexpected method %q", r.Method)
   330  			}
   331  			io.WriteString(w, "foo")
   332  		},
   333  	}.run(t)
   334  }
   335  
   336  func TestH12_HandlerWritesTooLittle(t *testing.T) {
   337  	h12Compare{
   338  		Handler: func(w ResponseWriter, r *Request) {
   339  			w.Header().Set("Content-Length", "3")
   340  			io.WriteString(w, "12") // one byte short
   341  		},
   342  		CheckResponse: func(proto string, res *Response) {
   343  			sr, ok := res.Body.(slurpResult)
   344  			if !ok {
   345  				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
   346  				return
   347  			}
   348  			if sr.err != io.ErrUnexpectedEOF {
   349  				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
   350  			}
   351  			if string(sr.body) != "12" {
   352  				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
   353  			}
   354  		},
   355  	}.run(t)
   356  }
   357  
   358  // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
   359  // writing more than they declared. This test does not test whether
   360  // the transport deals with too much data, though, since the server
   361  // doesn't make it possible to send bogus data. For those tests, see
   362  // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
   363  // (for HTTP/2).
   364  func TestH12_HandlerWritesTooMuch(t *testing.T) {
   365  	h12Compare{
   366  		Handler: func(w ResponseWriter, r *Request) {
   367  			w.Header().Set("Content-Length", "3")
   368  			w.(Flusher).Flush()
   369  			io.WriteString(w, "123")
   370  			w.(Flusher).Flush()
   371  			n, err := io.WriteString(w, "x") // too many
   372  			if n > 0 || err == nil {
   373  				t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
   374  			}
   375  		},
   376  	}.run(t)
   377  }
   378  
   379  // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
   380  // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
   381  func TestH12_AutoGzip(t *testing.T) {
   382  	h12Compare{
   383  		Handler: func(w ResponseWriter, r *Request) {
   384  			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
   385  				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
   386  			}
   387  			w.Header().Set("Content-Encoding", "gzip")
   388  			gz := gzip.NewWriter(w)
   389  			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
   390  			gz.Close()
   391  		},
   392  	}.run(t)
   393  }
   394  
   395  func TestH12_AutoGzip_Disabled(t *testing.T) {
   396  	h12Compare{
   397  		Opts: []interface{}{
   398  			func(tr *Transport) { tr.DisableCompression = true },
   399  		},
   400  		Handler: func(w ResponseWriter, r *Request) {
   401  			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
   402  			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
   403  				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
   404  			}
   405  		},
   406  	}.run(t)
   407  }
   408  
   409  // Test304Responses verifies that 304s don't declare that they're
   410  // chunking in their response headers and aren't allowed to produce
   411  // output.
   412  func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) }
   413  func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) }
   414  
   415  func test304Responses(t *testing.T, h2 bool) {
   416  	defer afterTest(t)
   417  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   418  		w.WriteHeader(StatusNotModified)
   419  		_, err := w.Write([]byte("illegal body"))
   420  		if err != ErrBodyNotAllowed {
   421  			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
   422  		}
   423  	}))
   424  	defer cst.close()
   425  	res, err := cst.c.Get(cst.ts.URL)
   426  	if err != nil {
   427  		t.Fatal(err)
   428  	}
   429  	if len(res.TransferEncoding) > 0 {
   430  		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
   431  	}
   432  	body, err := ioutil.ReadAll(res.Body)
   433  	if err != nil {
   434  		t.Error(err)
   435  	}
   436  	if len(body) > 0 {
   437  		t.Errorf("got unexpected body %q", string(body))
   438  	}
   439  }
   440  
   441  func TestH12_ServerEmptyContentLength(t *testing.T) {
   442  	h12Compare{
   443  		Handler: func(w ResponseWriter, r *Request) {
   444  			w.Header()["Content-Type"] = []string{""}
   445  			io.WriteString(w, "<html><body>hi</body></html>")
   446  		},
   447  	}.run(t)
   448  }
   449  
   450  func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
   451  	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
   452  }
   453  
   454  func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
   455  	h12requestContentLength(t, func() io.Reader { return strings.NewReader("") }, 0)
   456  }
   457  
   458  func TestH12_RequestContentLength_Unknown(t *testing.T) {
   459  	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
   460  }
   461  
   462  func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
   463  	h12Compare{
   464  		Handler: func(w ResponseWriter, r *Request) {
   465  			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
   466  			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
   467  		},
   468  		ReqFunc: func(c *Client, url string) (*Response, error) {
   469  			return c.Post(url, "text/plain", bodyfn())
   470  		},
   471  		CheckResponse: func(proto string, res *Response) {
   472  			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
   473  				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
   474  			}
   475  		},
   476  	}.run(t)
   477  }
   478  
   479  // Tests that closing the Request.Cancel channel also while still
   480  // reading the response body. Issue 13159.
   481  func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) }
   482  func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) }
   483  func testCancelRequestMidBody(t *testing.T, h2 bool) {
   484  	defer afterTest(t)
   485  	unblock := make(chan bool)
   486  	didFlush := make(chan bool, 1)
   487  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   488  		io.WriteString(w, "Hello")
   489  		w.(Flusher).Flush()
   490  		didFlush <- true
   491  		<-unblock
   492  		io.WriteString(w, ", world.")
   493  	}))
   494  	defer cst.close()
   495  	defer close(unblock)
   496  
   497  	req, _ := NewRequest("GET", cst.ts.URL, nil)
   498  	cancel := make(chan struct{})
   499  	req.Cancel = cancel
   500  
   501  	res, err := cst.c.Do(req)
   502  	if err != nil {
   503  		t.Fatal(err)
   504  	}
   505  	defer res.Body.Close()
   506  	<-didFlush
   507  
   508  	// Read a bit before we cancel. (Issue 13626)
   509  	// We should have "Hello" at least sitting there.
   510  	firstRead := make([]byte, 10)
   511  	n, err := res.Body.Read(firstRead)
   512  	if err != nil {
   513  		t.Fatal(err)
   514  	}
   515  	firstRead = firstRead[:n]
   516  
   517  	close(cancel)
   518  
   519  	rest, err := ioutil.ReadAll(res.Body)
   520  	all := string(firstRead) + string(rest)
   521  	if all != "Hello" {
   522  		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
   523  	}
   524  	if !reflect.DeepEqual(err, ExportErrRequestCanceled) {
   525  		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
   526  	}
   527  }
   528  
   529  // Tests that clients can send trailers to a server and that the server can read them.
   530  func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) }
   531  func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) }
   532  
   533  func testTrailersClientToServer(t *testing.T, h2 bool) {
   534  	defer afterTest(t)
   535  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   536  		var decl []string
   537  		for k := range r.Trailer {
   538  			decl = append(decl, k)
   539  		}
   540  		sort.Strings(decl)
   541  
   542  		slurp, err := ioutil.ReadAll(r.Body)
   543  		if err != nil {
   544  			t.Errorf("Server reading request body: %v", err)
   545  		}
   546  		if string(slurp) != "foo" {
   547  			t.Errorf("Server read request body %q; want foo", slurp)
   548  		}
   549  		if r.Trailer == nil {
   550  			io.WriteString(w, "nil Trailer")
   551  		} else {
   552  			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
   553  				decl,
   554  				r.Trailer.Get("Client-Trailer-A"),
   555  				r.Trailer.Get("Client-Trailer-B"))
   556  		}
   557  	}))
   558  	defer cst.close()
   559  
   560  	var req *Request
   561  	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
   562  		eofReaderFunc(func() {
   563  			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
   564  		}),
   565  		strings.NewReader("foo"),
   566  		eofReaderFunc(func() {
   567  			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
   568  		}),
   569  	))
   570  	req.Trailer = Header{
   571  		"Client-Trailer-A": nil, //  to be set later
   572  		"Client-Trailer-B": nil, //  to be set later
   573  	}
   574  	req.ContentLength = -1
   575  	res, err := cst.c.Do(req)
   576  	if err != nil {
   577  		t.Fatal(err)
   578  	}
   579  	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
   580  		t.Error(err)
   581  	}
   582  }
   583  
   584  // Tests that servers send trailers to a client and that the client can read them.
   585  func TestTrailersServerToClient_h1(t *testing.T)       { testTrailersServerToClient(t, h1Mode, false) }
   586  func TestTrailersServerToClient_h2(t *testing.T)       { testTrailersServerToClient(t, h2Mode, false) }
   587  func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) }
   588  func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) }
   589  
   590  func testTrailersServerToClient(t *testing.T, h2, flush bool) {
   591  	defer afterTest(t)
   592  	const body = "Some body"
   593  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   594  		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
   595  		w.Header().Add("Trailer", "Server-Trailer-C")
   596  
   597  		io.WriteString(w, body)
   598  		if flush {
   599  			w.(Flusher).Flush()
   600  		}
   601  
   602  		// How handlers set Trailers: declare it ahead of time
   603  		// with the Trailer header, and then mutate the
   604  		// Header() of those values later, after the response
   605  		// has been written (we wrote to w above).
   606  		w.Header().Set("Server-Trailer-A", "valuea")
   607  		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
   608  		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
   609  	}))
   610  	defer cst.close()
   611  
   612  	res, err := cst.c.Get(cst.ts.URL)
   613  	if err != nil {
   614  		t.Fatal(err)
   615  	}
   616  
   617  	wantHeader := Header{
   618  		"Content-Type": {"text/plain; charset=utf-8"},
   619  	}
   620  	wantLen := -1
   621  	if h2 && !flush {
   622  		// In HTTP/1.1, any use of trailers forces HTTP/1.1
   623  		// chunking and a flush at the first write. That's
   624  		// unnecessary with HTTP/2's framing, so the server
   625  		// is able to calculate the length while still sending
   626  		// trailers afterwards.
   627  		wantLen = len(body)
   628  		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
   629  	}
   630  	if res.ContentLength != int64(wantLen) {
   631  		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
   632  	}
   633  
   634  	delete(res.Header, "Date") // irrelevant for test
   635  	if !reflect.DeepEqual(res.Header, wantHeader) {
   636  		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
   637  	}
   638  
   639  	if got, want := res.Trailer, (Header{
   640  		"Server-Trailer-A": nil,
   641  		"Server-Trailer-B": nil,
   642  		"Server-Trailer-C": nil,
   643  	}); !reflect.DeepEqual(got, want) {
   644  		t.Errorf("Trailer before body read = %v; want %v", got, want)
   645  	}
   646  
   647  	if err := wantBody(res, nil, body); err != nil {
   648  		t.Fatal(err)
   649  	}
   650  
   651  	if got, want := res.Trailer, (Header{
   652  		"Server-Trailer-A": {"valuea"},
   653  		"Server-Trailer-B": nil,
   654  		"Server-Trailer-C": {"valuec"},
   655  	}); !reflect.DeepEqual(got, want) {
   656  		t.Errorf("Trailer after body read = %v; want %v", got, want)
   657  	}
   658  }
   659  
   660  // Don't allow a Body.Read after Body.Close. Issue 13648.
   661  func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) }
   662  func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) }
   663  
   664  func testResponseBodyReadAfterClose(t *testing.T, h2 bool) {
   665  	defer afterTest(t)
   666  	const body = "Some body"
   667  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   668  		io.WriteString(w, body)
   669  	}))
   670  	defer cst.close()
   671  	res, err := cst.c.Get(cst.ts.URL)
   672  	if err != nil {
   673  		t.Fatal(err)
   674  	}
   675  	res.Body.Close()
   676  	data, err := ioutil.ReadAll(res.Body)
   677  	if len(data) != 0 || err == nil {
   678  		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
   679  	}
   680  }
   681  
   682  func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) }
   683  func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) }
   684  func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) {
   685  	defer afterTest(t)
   686  	const reqBody = "some request body"
   687  	const resBody = "some response body"
   688  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   689  		var wg sync.WaitGroup
   690  		wg.Add(2)
   691  		didRead := make(chan bool, 1)
   692  		// Read in one goroutine.
   693  		go func() {
   694  			defer wg.Done()
   695  			data, err := ioutil.ReadAll(r.Body)
   696  			if string(data) != reqBody {
   697  				t.Errorf("Handler read %q; want %q", data, reqBody)
   698  			}
   699  			if err != nil {
   700  				t.Errorf("Handler Read: %v", err)
   701  			}
   702  			didRead <- true
   703  		}()
   704  		// Write in another goroutine.
   705  		go func() {
   706  			defer wg.Done()
   707  			if !h2 {
   708  				// our HTTP/1 implementation intentionally
   709  				// doesn't permit writes during read (mostly
   710  				// due to it being undefined); if that is ever
   711  				// relaxed, change this.
   712  				<-didRead
   713  			}
   714  			io.WriteString(w, resBody)
   715  		}()
   716  		wg.Wait()
   717  	}))
   718  	defer cst.close()
   719  	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
   720  	req.Header.Add("Expect", "100-continue") // just to complicate things
   721  	res, err := cst.c.Do(req)
   722  	if err != nil {
   723  		t.Fatal(err)
   724  	}
   725  	data, err := ioutil.ReadAll(res.Body)
   726  	defer res.Body.Close()
   727  	if err != nil {
   728  		t.Fatal(err)
   729  	}
   730  	if string(data) != resBody {
   731  		t.Errorf("read %q; want %q", data, resBody)
   732  	}
   733  }
   734  
   735  func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) }
   736  func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) }
   737  func testConnectRequest(t *testing.T, h2 bool) {
   738  	defer afterTest(t)
   739  	gotc := make(chan *Request, 1)
   740  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   741  		gotc <- r
   742  	}))
   743  	defer cst.close()
   744  
   745  	u, err := url.Parse(cst.ts.URL)
   746  	if err != nil {
   747  		t.Fatal(err)
   748  	}
   749  
   750  	tests := []struct {
   751  		req  *Request
   752  		want string
   753  	}{
   754  		{
   755  			req: &Request{
   756  				Method: "CONNECT",
   757  				Header: Header{},
   758  				URL:    u,
   759  			},
   760  			want: u.Host,
   761  		},
   762  		{
   763  			req: &Request{
   764  				Method: "CONNECT",
   765  				Header: Header{},
   766  				URL:    u,
   767  				Host:   "example.com:123",
   768  			},
   769  			want: "example.com:123",
   770  		},
   771  	}
   772  
   773  	for i, tt := range tests {
   774  		res, err := cst.c.Do(tt.req)
   775  		if err != nil {
   776  			t.Errorf("%d. RoundTrip = %v", i, err)
   777  			continue
   778  		}
   779  		res.Body.Close()
   780  		req := <-gotc
   781  		if req.Method != "CONNECT" {
   782  			t.Errorf("method = %q; want CONNECT", req.Method)
   783  		}
   784  		if req.Host != tt.want {
   785  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
   786  		}
   787  		if req.URL.Host != tt.want {
   788  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
   789  		}
   790  	}
   791  }
   792  
   793  func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) }
   794  func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) }
   795  func testTransportUserAgent(t *testing.T, h2 bool) {
   796  	defer afterTest(t)
   797  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   798  		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
   799  	}))
   800  	defer cst.close()
   801  
   802  	either := func(a, b string) string {
   803  		if h2 {
   804  			return b
   805  		}
   806  		return a
   807  	}
   808  
   809  	tests := []struct {
   810  		setup func(*Request)
   811  		want  string
   812  	}{
   813  		{
   814  			func(r *Request) {},
   815  			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
   816  		},
   817  		{
   818  			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
   819  			`["foo/1.2.3"]`,
   820  		},
   821  		{
   822  			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
   823  			`["single"]`,
   824  		},
   825  		{
   826  			func(r *Request) { r.Header.Set("User-Agent", "") },
   827  			`[]`,
   828  		},
   829  		{
   830  			func(r *Request) { r.Header["User-Agent"] = nil },
   831  			`[]`,
   832  		},
   833  	}
   834  	for i, tt := range tests {
   835  		req, _ := NewRequest("GET", cst.ts.URL, nil)
   836  		tt.setup(req)
   837  		res, err := cst.c.Do(req)
   838  		if err != nil {
   839  			t.Errorf("%d. RoundTrip = %v", i, err)
   840  			continue
   841  		}
   842  		slurp, err := ioutil.ReadAll(res.Body)
   843  		res.Body.Close()
   844  		if err != nil {
   845  			t.Errorf("%d. read body = %v", i, err)
   846  			continue
   847  		}
   848  		if string(slurp) != tt.want {
   849  			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
   850  		}
   851  	}
   852  }
   853  
   854  func TestStarRequestFoo_h1(t *testing.T)     { testStarRequest(t, "FOO", h1Mode) }
   855  func TestStarRequestFoo_h2(t *testing.T)     { testStarRequest(t, "FOO", h2Mode) }
   856  func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) }
   857  func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) }
   858  func testStarRequest(t *testing.T, method string, h2 bool) {
   859  	defer afterTest(t)
   860  	gotc := make(chan *Request, 1)
   861  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   862  		w.Header().Set("foo", "bar")
   863  		gotc <- r
   864  		w.(Flusher).Flush()
   865  	}))
   866  	defer cst.close()
   867  
   868  	u, err := url.Parse(cst.ts.URL)
   869  	if err != nil {
   870  		t.Fatal(err)
   871  	}
   872  	u.Path = "*"
   873  
   874  	req := &Request{
   875  		Method: method,
   876  		Header: Header{},
   877  		URL:    u,
   878  	}
   879  
   880  	res, err := cst.c.Do(req)
   881  	if err != nil {
   882  		t.Fatalf("RoundTrip = %v", err)
   883  	}
   884  	res.Body.Close()
   885  
   886  	wantFoo := "bar"
   887  	wantLen := int64(-1)
   888  	if method == "OPTIONS" {
   889  		wantFoo = ""
   890  		wantLen = 0
   891  	}
   892  	if res.StatusCode != 200 {
   893  		t.Errorf("status code = %v; want %d", res.Status, 200)
   894  	}
   895  	if res.ContentLength != wantLen {
   896  		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
   897  	}
   898  	if got := res.Header.Get("foo"); got != wantFoo {
   899  		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
   900  	}
   901  	select {
   902  	case req = <-gotc:
   903  	default:
   904  		req = nil
   905  	}
   906  	if req == nil {
   907  		if method != "OPTIONS" {
   908  			t.Fatalf("handler never got request")
   909  		}
   910  		return
   911  	}
   912  	if req.Method != method {
   913  		t.Errorf("method = %q; want %q", req.Method, method)
   914  	}
   915  	if req.URL.Path != "*" {
   916  		t.Errorf("URL.Path = %q; want *", req.URL.Path)
   917  	}
   918  	if req.RequestURI != "*" {
   919  		t.Errorf("RequestURI = %q; want *", req.RequestURI)
   920  	}
   921  }
   922  
   923  // Issue 13957
   924  func TestTransportDiscardsUnneededConns(t *testing.T) {
   925  	defer afterTest(t)
   926  	cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   927  		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
   928  	}))
   929  	defer cst.close()
   930  
   931  	var numOpen, numClose int32 // atomic
   932  
   933  	tlsConfig := &tls.Config{InsecureSkipVerify: true}
   934  	tr := &Transport{
   935  		TLSClientConfig: tlsConfig,
   936  		DialTLS: func(_, addr string) (net.Conn, error) {
   937  			time.Sleep(10 * time.Millisecond)
   938  			rc, err := net.Dial("tcp", addr)
   939  			if err != nil {
   940  				return nil, err
   941  			}
   942  			atomic.AddInt32(&numOpen, 1)
   943  			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
   944  			return tls.Client(c, tlsConfig), nil
   945  		},
   946  	}
   947  	if err := ExportHttp2ConfigureTransport(tr); err != nil {
   948  		t.Fatal(err)
   949  	}
   950  	defer tr.CloseIdleConnections()
   951  
   952  	c := &Client{Transport: tr}
   953  
   954  	const N = 10
   955  	gotBody := make(chan string, N)
   956  	var wg sync.WaitGroup
   957  	for i := 0; i < N; i++ {
   958  		wg.Add(1)
   959  		go func() {
   960  			defer wg.Done()
   961  			resp, err := c.Get(cst.ts.URL)
   962  			if err != nil {
   963  				t.Errorf("Get: %v", err)
   964  				return
   965  			}
   966  			defer resp.Body.Close()
   967  			slurp, err := ioutil.ReadAll(resp.Body)
   968  			if err != nil {
   969  				t.Error(err)
   970  			}
   971  			gotBody <- string(slurp)
   972  		}()
   973  	}
   974  	wg.Wait()
   975  	close(gotBody)
   976  
   977  	var last string
   978  	for got := range gotBody {
   979  		if last == "" {
   980  			last = got
   981  			continue
   982  		}
   983  		if got != last {
   984  			t.Errorf("Response body changed: %q -> %q", last, got)
   985  		}
   986  	}
   987  
   988  	var open, close int32
   989  	for i := 0; i < 150; i++ {
   990  		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
   991  		if open < 1 {
   992  			t.Fatalf("open = %d; want at least", open)
   993  		}
   994  		if close == open-1 {
   995  			// Success
   996  			return
   997  		}
   998  		time.Sleep(10 * time.Millisecond)
   999  	}
  1000  	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
  1001  }
  1002  
  1003  // tests that Transport doesn't retain a pointer to the provided request.
  1004  func TestTransportGCRequest_Body_h1(t *testing.T)   { testTransportGCRequest(t, h1Mode, true) }
  1005  func TestTransportGCRequest_Body_h2(t *testing.T)   { testTransportGCRequest(t, h2Mode, true) }
  1006  func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) }
  1007  func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) }
  1008  func testTransportGCRequest(t *testing.T, h2, body bool) {
  1009  	defer afterTest(t)
  1010  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1011  		ioutil.ReadAll(r.Body)
  1012  		if body {
  1013  			io.WriteString(w, "Hello.")
  1014  		}
  1015  	}))
  1016  	defer cst.close()
  1017  
  1018  	didGC := make(chan struct{})
  1019  	(func() {
  1020  		body := strings.NewReader("some body")
  1021  		req, _ := NewRequest("POST", cst.ts.URL, body)
  1022  		runtime.SetFinalizer(req, func(*Request) { close(didGC) })
  1023  		res, err := cst.c.Do(req)
  1024  		if err != nil {
  1025  			t.Fatal(err)
  1026  		}
  1027  		if _, err := ioutil.ReadAll(res.Body); err != nil {
  1028  			t.Fatal(err)
  1029  		}
  1030  		if err := res.Body.Close(); err != nil {
  1031  			t.Fatal(err)
  1032  		}
  1033  	})()
  1034  	timeout := time.NewTimer(5 * time.Second)
  1035  	defer timeout.Stop()
  1036  	for {
  1037  		select {
  1038  		case <-didGC:
  1039  			return
  1040  		case <-time.After(100 * time.Millisecond):
  1041  			runtime.GC()
  1042  		case <-timeout.C:
  1043  			t.Fatal("never saw GC of request")
  1044  		}
  1045  	}
  1046  }
  1047  
  1048  func TestTransportRejectsInvalidHeaders_h1(t *testing.T) {
  1049  	testTransportRejectsInvalidHeaders(t, h1Mode)
  1050  }
  1051  func TestTransportRejectsInvalidHeaders_h2(t *testing.T) {
  1052  	testTransportRejectsInvalidHeaders(t, h2Mode)
  1053  }
  1054  func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) {
  1055  	defer afterTest(t)
  1056  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1057  		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
  1058  	}))
  1059  	defer cst.close()
  1060  	cst.tr.DisableKeepAlives = true
  1061  
  1062  	tests := []struct {
  1063  		key, val string
  1064  		ok       bool
  1065  	}{
  1066  		{"Foo", "capital-key", true}, // verify h2 allows capital keys
  1067  		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
  1068  		{"Foo", "two\nlines", false}, // \n byte in value not allowed
  1069  		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
  1070  		{"A space", "v", false},      // spaces in keys not allowed
  1071  		{"имя", "v", false},          // key must be ascii
  1072  		{"name", "валю", true},       // value may be non-ascii
  1073  		{"", "v", false},             // key must be non-empty
  1074  		{"k", "", true},              // value may be empty
  1075  	}
  1076  	for _, tt := range tests {
  1077  		dialedc := make(chan bool, 1)
  1078  		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  1079  			dialedc <- true
  1080  			return net.Dial(netw, addr)
  1081  		}
  1082  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1083  		req.Header[tt.key] = []string{tt.val}
  1084  		res, err := cst.c.Do(req)
  1085  		var body []byte
  1086  		if err == nil {
  1087  			body, _ = ioutil.ReadAll(res.Body)
  1088  			res.Body.Close()
  1089  		}
  1090  		var dialed bool
  1091  		select {
  1092  		case <-dialedc:
  1093  			dialed = true
  1094  		default:
  1095  		}
  1096  
  1097  		if !tt.ok && dialed {
  1098  			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
  1099  		} else if (err == nil) != tt.ok {
  1100  			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
  1101  		}
  1102  	}
  1103  }
  1104  
  1105  // Tests that we support bogus under-100 HTTP statuses, because we historically
  1106  // have. This might change at some point, but not yet in Go 1.6.
  1107  func TestBogusStatusWorks_h1(t *testing.T) { testBogusStatusWorks(t, h1Mode) }
  1108  func TestBogusStatusWorks_h2(t *testing.T) { testBogusStatusWorks(t, h2Mode) }
  1109  func testBogusStatusWorks(t *testing.T, h2 bool) {
  1110  	defer afterTest(t)
  1111  	const code = 7
  1112  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1113  		w.WriteHeader(code)
  1114  	}))
  1115  	defer cst.close()
  1116  
  1117  	res, err := cst.c.Get(cst.ts.URL)
  1118  	if err != nil {
  1119  		t.Fatal(err)
  1120  	}
  1121  	if res.StatusCode != code {
  1122  		t.Errorf("StatusCode = %d; want %d", res.StatusCode, code)
  1123  	}
  1124  }
  1125  
  1126  type noteCloseConn struct {
  1127  	net.Conn
  1128  	closeFunc func()
  1129  }
  1130  
  1131  func (x noteCloseConn) Close() error {
  1132  	x.closeFunc()
  1133  	return x.Conn.Close()
  1134  }