github.com/useflyent/fhttp@v0.0.0-20211004035111-333f430cfbbf/http2/transport_test.go (about)

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http2
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"context"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"encoding/pem"
    14  	"errors"
    15  	"flag"
    16  	"fmt"
    17  	"io"
    18  	"io/ioutil"
    19  	"log"
    20  	"math/rand"
    21  	"net"
    22  	"net/textproto"
    23  	"net/url"
    24  	"os"
    25  	"reflect"
    26  	"runtime"
    27  	"sort"
    28  	"strconv"
    29  	"strings"
    30  	"sync"
    31  	"sync/atomic"
    32  	"testing"
    33  	"time"
    34  
    35  	http "github.com/useflyent/fhttp"
    36  	"github.com/useflyent/fhttp/http2/hpack"
    37  	"github.com/useflyent/fhttp/httptest"
    38  	"github.com/useflyent/fhttp/httptrace"
    39  )
    40  
    41  var (
    42  	extNet        = flag.Bool("extnet", false, "do external network tests")
    43  	transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
    44  	insecure      = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove?
    45  )
    46  
    47  var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
    48  
    49  var canceledCtx context.Context
    50  
    51  func init() {
    52  	ctx, cancel := context.WithCancel(context.Background())
    53  	cancel()
    54  	canceledCtx = ctx
    55  }
    56  
    57  func TestTransportExternal(t *testing.T) {
    58  	if !*extNet {
    59  		t.Skip("skipping external network test")
    60  	}
    61  	req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
    62  	rt := &Transport{TLSClientConfig: tlsConfigInsecure}
    63  	res, err := rt.RoundTrip(req)
    64  	if err != nil {
    65  		t.Fatalf("%v", err)
    66  	}
    67  	res.Write(os.Stdout)
    68  }
    69  
    70  type fakeTLSConn struct {
    71  	net.Conn
    72  }
    73  
    74  func (c *fakeTLSConn) ConnectionState() tls.ConnectionState {
    75  	return tls.ConnectionState{
    76  		Version:     tls.VersionTLS12,
    77  		CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
    78  	}
    79  }
    80  
    81  func startH2cServer(t *testing.T) net.Listener {
    82  	h2Server := &Server{}
    83  	l := newLocalListener(t)
    84  	go func() {
    85  		conn, err := l.Accept()
    86  		if err != nil {
    87  			t.Error(err)
    88  			return
    89  		}
    90  		h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    91  			fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil)
    92  		})})
    93  	}()
    94  	return l
    95  }
    96  
    97  func TestTransportH2c(t *testing.T) {
    98  	l := startH2cServer(t)
    99  	defer l.Close()
   100  	req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil)
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	var gotConnCnt int32
   105  	trace := &httptrace.ClientTrace{
   106  		GotConn: func(connInfo httptrace.GotConnInfo) {
   107  			if !connInfo.Reused {
   108  				atomic.AddInt32(&gotConnCnt, 1)
   109  			}
   110  		},
   111  	}
   112  	req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   113  	tr := &Transport{
   114  		AllowHTTP: true,
   115  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
   116  			return net.Dial(network, addr)
   117  		},
   118  	}
   119  	res, err := tr.RoundTrip(req)
   120  	if err != nil {
   121  		t.Fatal(err)
   122  	}
   123  	if res.ProtoMajor != 2 {
   124  		t.Fatal("proto not h2c")
   125  	}
   126  	body, err := ioutil.ReadAll(res.Body)
   127  	if err != nil {
   128  		t.Fatal(err)
   129  	}
   130  	if got, want := string(body), "Hello, /foobar, http: true"; got != want {
   131  		t.Fatalf("response got %v, want %v", got, want)
   132  	}
   133  	if got, want := gotConnCnt, int32(1); got != want {
   134  		t.Errorf("Too many got connections: %d", gotConnCnt)
   135  	}
   136  }
   137  
   138  func TestTransport(t *testing.T) {
   139  	const body = "sup"
   140  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   141  		io.WriteString(w, body)
   142  	}, optOnlyServer)
   143  	defer st.Close()
   144  
   145  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   146  	defer tr.CloseIdleConnections()
   147  
   148  	u, err := url.Parse(st.ts.URL)
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  	for i, m := range []string{"GET", ""} {
   153  		req := &http.Request{
   154  			Method: m,
   155  			URL:    u,
   156  		}
   157  		res, err := tr.RoundTrip(req)
   158  		if err != nil {
   159  			t.Fatalf("%d: %s", i, err)
   160  		}
   161  
   162  		t.Logf("%d: Got res: %+v", i, res)
   163  		if g, w := res.StatusCode, 200; g != w {
   164  			t.Errorf("%d: StatusCode = %v; want %v", i, g, w)
   165  		}
   166  		if g, w := res.Status, "200 OK"; g != w {
   167  			t.Errorf("%d: Status = %q; want %q", i, g, w)
   168  		}
   169  		wantHeader := http.Header{
   170  			"Content-Length": []string{"3"},
   171  			"Content-Type":   []string{"text/plain; charset=utf-8"},
   172  			"Date":           []string{"XXX"}, // see cleanDate
   173  		}
   174  		cleanDate(res)
   175  		if !reflect.DeepEqual(res.Header, wantHeader) {
   176  			t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader)
   177  		}
   178  		if res.Request != req {
   179  			t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req)
   180  		}
   181  		if res.TLS == nil {
   182  			t.Errorf("%d: Response.TLS = nil; want non-nil", i)
   183  		}
   184  		slurp, err := ioutil.ReadAll(res.Body)
   185  		if err != nil {
   186  			t.Errorf("%d: Body read: %v", i, err)
   187  		} else if string(slurp) != body {
   188  			t.Errorf("%d: Body = %q; want %q", i, slurp, body)
   189  		}
   190  		res.Body.Close()
   191  	}
   192  }
   193  
   194  func onSameConn(t *testing.T, modReq func(*http.Request)) bool {
   195  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   196  		io.WriteString(w, r.RemoteAddr)
   197  	}, optOnlyServer, func(c net.Conn, st http.ConnState) {
   198  		t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
   199  	})
   200  	defer st.Close()
   201  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   202  	defer tr.CloseIdleConnections()
   203  	get := func() string {
   204  		req, err := http.NewRequest("GET", st.ts.URL, nil)
   205  		if err != nil {
   206  			t.Fatal(err)
   207  		}
   208  		modReq(req)
   209  		res, err := tr.RoundTrip(req)
   210  		if err != nil {
   211  			t.Fatal(err)
   212  		}
   213  		defer res.Body.Close()
   214  		slurp, err := ioutil.ReadAll(res.Body)
   215  		if err != nil {
   216  			t.Fatalf("Body read: %v", err)
   217  		}
   218  		addr := strings.TrimSpace(string(slurp))
   219  		if addr == "" {
   220  			t.Fatalf("didn't get an addr in response")
   221  		}
   222  		return addr
   223  	}
   224  	first := get()
   225  	second := get()
   226  	return first == second
   227  }
   228  
   229  func TestTransportReusesConns(t *testing.T) {
   230  	if !onSameConn(t, func(*http.Request) {}) {
   231  		t.Errorf("first and second responses were on different connections")
   232  	}
   233  }
   234  
   235  func TestTransportReusesConn_RequestClose(t *testing.T) {
   236  	if onSameConn(t, func(r *http.Request) { r.Close = true }) {
   237  		t.Errorf("first and second responses were not on different connections")
   238  	}
   239  }
   240  
   241  func TestTransportReusesConn_ConnClose(t *testing.T) {
   242  	if onSameConn(t, func(r *http.Request) { r.Header.Set("Connection", "close") }) {
   243  		t.Errorf("first and second responses were not on different connections")
   244  	}
   245  }
   246  
   247  // Tests that the Transport only keeps one pending dial open per destination address.
   248  // https://golang.org/issue/13397
   249  func TestTransportGroupsPendingDials(t *testing.T) {
   250  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   251  		io.WriteString(w, r.RemoteAddr)
   252  	}, optOnlyServer)
   253  	defer st.Close()
   254  	tr := &Transport{
   255  		TLSClientConfig: tlsConfigInsecure,
   256  	}
   257  	defer tr.CloseIdleConnections()
   258  	var (
   259  		mu    sync.Mutex
   260  		dials = map[string]int{}
   261  	)
   262  	var gotConnCnt int32
   263  	trace := &httptrace.ClientTrace{
   264  		GotConn: func(connInfo httptrace.GotConnInfo) {
   265  			if !connInfo.Reused {
   266  				atomic.AddInt32(&gotConnCnt, 1)
   267  			}
   268  		},
   269  	}
   270  	var wg sync.WaitGroup
   271  	for i := 0; i < 10; i++ {
   272  		wg.Add(1)
   273  		go func() {
   274  			defer wg.Done()
   275  			req, err := http.NewRequest("GET", st.ts.URL, nil)
   276  			if err != nil {
   277  				t.Error(err)
   278  				return
   279  			}
   280  			req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   281  			res, err := tr.RoundTrip(req)
   282  			if err != nil {
   283  				t.Error(err)
   284  				return
   285  			}
   286  			defer res.Body.Close()
   287  			slurp, err := ioutil.ReadAll(res.Body)
   288  			if err != nil {
   289  				t.Errorf("Body read: %v", err)
   290  			}
   291  			addr := strings.TrimSpace(string(slurp))
   292  			if addr == "" {
   293  				t.Errorf("didn't get an addr in response")
   294  			}
   295  			mu.Lock()
   296  			dials[addr]++
   297  			mu.Unlock()
   298  		}()
   299  	}
   300  	wg.Wait()
   301  	if len(dials) != 1 {
   302  		t.Errorf("saw %d dials; want 1: %v", len(dials), dials)
   303  	}
   304  	tr.CloseIdleConnections()
   305  	if err := retry(50, 10*time.Millisecond, func() error {
   306  		cp, ok := tr.connPool().(*clientConnPool)
   307  		if !ok {
   308  			return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
   309  		}
   310  		cp.mu.Lock()
   311  		defer cp.mu.Unlock()
   312  		if len(cp.dialing) != 0 {
   313  			return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
   314  		}
   315  		if len(cp.conns) != 0 {
   316  			return fmt.Errorf("conns = %v; want empty", cp.conns)
   317  		}
   318  		if len(cp.keys) != 0 {
   319  			return fmt.Errorf("keys = %v; want empty", cp.keys)
   320  		}
   321  		return nil
   322  	}); err != nil {
   323  		t.Errorf("State of pool after CloseIdleConnections: %v", err)
   324  	}
   325  	if got, want := gotConnCnt, int32(1); got != want {
   326  		t.Errorf("Too many got connections: %d", gotConnCnt)
   327  	}
   328  }
   329  
   330  func retry(tries int, delay time.Duration, fn func() error) error {
   331  	var err error
   332  	for i := 0; i < tries; i++ {
   333  		err = fn()
   334  		if err == nil {
   335  			return nil
   336  		}
   337  		time.Sleep(delay)
   338  	}
   339  	return err
   340  }
   341  
   342  func TestTransportAbortClosesPipes(t *testing.T) {
   343  	shutdown := make(chan struct{})
   344  	st := newServerTester(t,
   345  		func(w http.ResponseWriter, r *http.Request) {
   346  			w.(http.Flusher).Flush()
   347  			<-shutdown
   348  		},
   349  		optOnlyServer,
   350  	)
   351  	defer st.Close()
   352  	defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
   353  
   354  	errCh := make(chan error)
   355  	go func() {
   356  		defer close(errCh)
   357  		tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   358  		req, err := http.NewRequest("GET", st.ts.URL, nil)
   359  		if err != nil {
   360  			errCh <- err
   361  			return
   362  		}
   363  		res, err := tr.RoundTrip(req)
   364  		if err != nil {
   365  			errCh <- err
   366  			return
   367  		}
   368  		defer res.Body.Close()
   369  		st.closeConn()
   370  		_, err = ioutil.ReadAll(res.Body)
   371  		if err == nil {
   372  			errCh <- errors.New("expected error from res.Body.Read")
   373  			return
   374  		}
   375  	}()
   376  
   377  	select {
   378  	case err := <-errCh:
   379  		if err != nil {
   380  			t.Fatal(err)
   381  		}
   382  	// deadlock? that's a bug.
   383  	case <-time.After(3 * time.Second):
   384  		t.Fatal("timeout")
   385  	}
   386  }
   387  
   388  // TODO: merge this with TestTransportBody to make TestTransportRequest? This
   389  // could be a table-driven test with extra goodies.
   390  func TestTransportPath(t *testing.T) {
   391  	gotc := make(chan *url.URL, 1)
   392  	st := newServerTester(t,
   393  		func(w http.ResponseWriter, r *http.Request) {
   394  			gotc <- r.URL
   395  		},
   396  		optOnlyServer,
   397  	)
   398  	defer st.Close()
   399  
   400  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   401  	defer tr.CloseIdleConnections()
   402  	const (
   403  		path  = "/testpath"
   404  		query = "q=1"
   405  	)
   406  	surl := st.ts.URL + path + "?" + query
   407  	req, err := http.NewRequest("POST", surl, nil)
   408  	if err != nil {
   409  		t.Fatal(err)
   410  	}
   411  	c := &http.Client{Transport: tr}
   412  	res, err := c.Do(req)
   413  	if err != nil {
   414  		t.Fatal(err)
   415  	}
   416  	defer res.Body.Close()
   417  	got := <-gotc
   418  	if got.Path != path {
   419  		t.Errorf("Read Path = %q; want %q", got.Path, path)
   420  	}
   421  	if got.RawQuery != query {
   422  		t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
   423  	}
   424  }
   425  
   426  func randString(n int) string {
   427  	rnd := rand.New(rand.NewSource(int64(n)))
   428  	b := make([]byte, n)
   429  	for i := range b {
   430  		b[i] = byte(rnd.Intn(256))
   431  	}
   432  	return string(b)
   433  }
   434  
   435  type panicReader struct{}
   436  
   437  func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") }
   438  func (panicReader) Close() error             { panic("unexpected Close") }
   439  
   440  func TestActualContentLength(t *testing.T) {
   441  	tests := []struct {
   442  		req  *http.Request
   443  		want int64
   444  	}{
   445  		// Verify we don't read from Body:
   446  		0: {
   447  			req:  &http.Request{Body: panicReader{}},
   448  			want: -1,
   449  		},
   450  		// nil Body means 0, regardless of ContentLength:
   451  		1: {
   452  			req:  &http.Request{Body: nil, ContentLength: 5},
   453  			want: 0,
   454  		},
   455  		// ContentLength is used if set.
   456  		2: {
   457  			req:  &http.Request{Body: panicReader{}, ContentLength: 5},
   458  			want: 5,
   459  		},
   460  		// http.NoBody means 0, not -1.
   461  		3: {
   462  			req:  &http.Request{Body: http.NoBody},
   463  			want: 0,
   464  		},
   465  	}
   466  	for i, tt := range tests {
   467  		got := actualContentLength(tt.req)
   468  		if got != tt.want {
   469  			t.Errorf("test[%d]: got %d; want %d", i, got, tt.want)
   470  		}
   471  	}
   472  }
   473  
   474  func TestTransportBody(t *testing.T) {
   475  	bodyTests := []struct {
   476  		body         string
   477  		noContentLen bool
   478  	}{
   479  		{body: "some message"},
   480  		{body: "some message", noContentLen: true},
   481  		{body: strings.Repeat("a", 1<<20), noContentLen: true},
   482  		{body: strings.Repeat("a", 1<<20)},
   483  		{body: randString(16<<10 - 1)},
   484  		{body: randString(16 << 10)},
   485  		{body: randString(16<<10 + 1)},
   486  		{body: randString(512<<10 - 1)},
   487  		{body: randString(512 << 10)},
   488  		{body: randString(512<<10 + 1)},
   489  		{body: randString(1<<20 - 1)},
   490  		{body: randString(1 << 20)},
   491  		{body: randString(1<<20 + 2)},
   492  	}
   493  
   494  	type reqInfo struct {
   495  		req   *http.Request
   496  		slurp []byte
   497  		err   error
   498  	}
   499  	gotc := make(chan reqInfo, 1)
   500  	st := newServerTester(t,
   501  		func(w http.ResponseWriter, r *http.Request) {
   502  			slurp, err := ioutil.ReadAll(r.Body)
   503  			if err != nil {
   504  				gotc <- reqInfo{err: err}
   505  			} else {
   506  				gotc <- reqInfo{req: r, slurp: slurp}
   507  			}
   508  		},
   509  		optOnlyServer,
   510  	)
   511  	defer st.Close()
   512  
   513  	for i, tt := range bodyTests {
   514  		tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   515  		defer tr.CloseIdleConnections()
   516  
   517  		var body io.Reader = strings.NewReader(tt.body)
   518  		if tt.noContentLen {
   519  			body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods
   520  		}
   521  		req, err := http.NewRequest("POST", st.ts.URL, body)
   522  		if err != nil {
   523  			t.Fatalf("#%d: %v", i, err)
   524  		}
   525  		c := &http.Client{Transport: tr}
   526  		res, err := c.Do(req)
   527  		if err != nil {
   528  			t.Fatalf("#%d: %v", i, err)
   529  		}
   530  		defer res.Body.Close()
   531  		ri := <-gotc
   532  		if ri.err != nil {
   533  			t.Errorf("#%d: read error: %v", i, ri.err)
   534  			continue
   535  		}
   536  		if got := string(ri.slurp); got != tt.body {
   537  			t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
   538  		}
   539  		wantLen := int64(len(tt.body))
   540  		if tt.noContentLen && tt.body != "" {
   541  			wantLen = -1
   542  		}
   543  		if ri.req.ContentLength != wantLen {
   544  			t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
   545  		}
   546  	}
   547  }
   548  
   549  func shortString(v string) string {
   550  	const maxLen = 100
   551  	if len(v) <= maxLen {
   552  		return v
   553  	}
   554  	return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
   555  }
   556  
   557  func TestTransportDialTLS(t *testing.T) {
   558  	var mu sync.Mutex // guards following
   559  	var gotReq, didDial bool
   560  
   561  	ts := newServerTester(t,
   562  		func(w http.ResponseWriter, r *http.Request) {
   563  			mu.Lock()
   564  			gotReq = true
   565  			mu.Unlock()
   566  		},
   567  		optOnlyServer,
   568  	)
   569  	defer ts.Close()
   570  	tr := &Transport{
   571  		DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
   572  			mu.Lock()
   573  			didDial = true
   574  			mu.Unlock()
   575  			cfg.InsecureSkipVerify = true
   576  			c, err := tls.Dial(netw, addr, cfg)
   577  			if err != nil {
   578  				return nil, err
   579  			}
   580  			return c, c.Handshake()
   581  		},
   582  	}
   583  	defer tr.CloseIdleConnections()
   584  	client := &http.Client{Transport: tr}
   585  	res, err := client.Get(ts.ts.URL)
   586  	if err != nil {
   587  		t.Fatal(err)
   588  	}
   589  	res.Body.Close()
   590  	mu.Lock()
   591  	if !gotReq {
   592  		t.Error("didn't get request")
   593  	}
   594  	if !didDial {
   595  		t.Error("didn't use dial hook")
   596  	}
   597  }
   598  
   599  func TestConfigureTransport(t *testing.T) {
   600  	t1 := &http.Transport{}
   601  	err := ConfigureTransport(t1)
   602  	if err != nil {
   603  		t.Fatal(err)
   604  	}
   605  	if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) {
   606  		// Laziness, to avoid buildtags.
   607  		t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got)
   608  	}
   609  	wantNextProtos := []string{"h2", "http/1.1"}
   610  	if t1.TLSClientConfig == nil {
   611  		t.Errorf("nil t1.TLSClientConfig")
   612  	} else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) {
   613  		t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos)
   614  	}
   615  	if err := ConfigureTransport(t1); err == nil {
   616  		t.Error("unexpected success on second call to ConfigureTransport")
   617  	}
   618  
   619  	// And does it work?
   620  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   621  		io.WriteString(w, r.Proto)
   622  	}, optOnlyServer)
   623  	defer st.Close()
   624  
   625  	t1.TLSClientConfig.InsecureSkipVerify = true
   626  	c := &http.Client{Transport: t1}
   627  	res, err := c.Get(st.ts.URL)
   628  	if err != nil {
   629  		t.Fatal(err)
   630  	}
   631  	slurp, err := ioutil.ReadAll(res.Body)
   632  	if err != nil {
   633  		t.Fatal(err)
   634  	}
   635  	if got, want := string(slurp), "HTTP/2.0"; got != want {
   636  		t.Errorf("body = %q; want %q", got, want)
   637  	}
   638  }
   639  
   640  type capitalizeReader struct {
   641  	r io.Reader
   642  }
   643  
   644  func (cr capitalizeReader) Read(p []byte) (n int, err error) {
   645  	n, err = cr.r.Read(p)
   646  	for i, b := range p[:n] {
   647  		if b >= 'a' && b <= 'z' {
   648  			p[i] = b - ('a' - 'A')
   649  		}
   650  	}
   651  	return
   652  }
   653  
   654  type flushWriter struct {
   655  	w io.Writer
   656  }
   657  
   658  func (fw flushWriter) Write(p []byte) (n int, err error) {
   659  	n, err = fw.w.Write(p)
   660  	if f, ok := fw.w.(http.Flusher); ok {
   661  		f.Flush()
   662  	}
   663  	return
   664  }
   665  
   666  type clientTester struct {
   667  	t      *testing.T
   668  	tr     *Transport
   669  	sc, cc net.Conn // server and client conn
   670  	fr     *Framer  // server's framer
   671  	client func() error
   672  	server func() error
   673  }
   674  
   675  func newClientTester(t *testing.T) *clientTester {
   676  	var dialOnce struct {
   677  		sync.Mutex
   678  		dialed bool
   679  	}
   680  	ct := &clientTester{
   681  		t: t,
   682  	}
   683  	ct.tr = &Transport{
   684  		TLSClientConfig: tlsConfigInsecure,
   685  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
   686  			dialOnce.Lock()
   687  			defer dialOnce.Unlock()
   688  			if dialOnce.dialed {
   689  				return nil, errors.New("only one dial allowed in test mode")
   690  			}
   691  			dialOnce.dialed = true
   692  			return ct.cc, nil
   693  		},
   694  	}
   695  
   696  	ln := newLocalListener(t)
   697  	cc, err := net.Dial("tcp", ln.Addr().String())
   698  	if err != nil {
   699  		t.Fatal(err)
   700  
   701  	}
   702  	sc, err := ln.Accept()
   703  	if err != nil {
   704  		t.Fatal(err)
   705  	}
   706  	ln.Close()
   707  	ct.cc = cc
   708  	ct.sc = sc
   709  	ct.fr = NewFramer(sc, sc)
   710  	return ct
   711  }
   712  
   713  func newLocalListener(t *testing.T) net.Listener {
   714  	ln, err := net.Listen("tcp4", "127.0.0.1:0")
   715  	if err == nil {
   716  		return ln
   717  	}
   718  	ln, err = net.Listen("tcp6", "[::1]:0")
   719  	if err != nil {
   720  		t.Fatal(err)
   721  	}
   722  	return ln
   723  }
   724  
   725  func (ct *clientTester) greet(settings ...Setting) {
   726  	buf := make([]byte, len(ClientPreface))
   727  	_, err := io.ReadFull(ct.sc, buf)
   728  	if err != nil {
   729  		ct.t.Fatalf("reading client preface: %v", err)
   730  	}
   731  	f, err := ct.fr.ReadFrame()
   732  	if err != nil {
   733  		ct.t.Fatalf("Reading client settings frame: %v", err)
   734  	}
   735  	if sf, ok := f.(*SettingsFrame); !ok {
   736  		ct.t.Fatalf("Wanted client settings frame; got %v", f)
   737  		_ = sf // stash it away?
   738  	}
   739  	if err := ct.fr.WriteSettings(settings...); err != nil {
   740  		ct.t.Fatal(err)
   741  	}
   742  	if err := ct.fr.WriteSettingsAck(); err != nil {
   743  		ct.t.Fatal(err)
   744  	}
   745  }
   746  
   747  func (ct *clientTester) readNonSettingsFrame() (Frame, error) {
   748  	for {
   749  		f, err := ct.fr.ReadFrame()
   750  		if err != nil {
   751  			return nil, err
   752  		}
   753  		if _, ok := f.(*SettingsFrame); ok {
   754  			continue
   755  		}
   756  		return f, nil
   757  	}
   758  }
   759  
   760  func (ct *clientTester) cleanup() {
   761  	ct.tr.CloseIdleConnections()
   762  
   763  	// close both connections, ignore the error if its already closed
   764  	ct.sc.Close()
   765  	ct.cc.Close()
   766  }
   767  
   768  func (ct *clientTester) run() {
   769  	var errOnce sync.Once
   770  	var wg sync.WaitGroup
   771  
   772  	run := func(which string, fn func() error) {
   773  		defer wg.Done()
   774  		if err := fn(); err != nil {
   775  			errOnce.Do(func() {
   776  				ct.t.Errorf("%s: %v", which, err)
   777  				ct.cleanup()
   778  			})
   779  		}
   780  	}
   781  
   782  	wg.Add(2)
   783  	go run("client", ct.client)
   784  	go run("server", ct.server)
   785  	wg.Wait()
   786  
   787  	errOnce.Do(ct.cleanup) // clean up if no error
   788  }
   789  
   790  func (ct *clientTester) readFrame() (Frame, error) {
   791  	return readFrameTimeout(ct.fr, 2*time.Second)
   792  }
   793  
   794  func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
   795  	for {
   796  		f, err := ct.readFrame()
   797  		if err != nil {
   798  			return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
   799  		}
   800  		switch f.(type) {
   801  		case *WindowUpdateFrame, *SettingsFrame:
   802  			continue
   803  		}
   804  		hf, ok := f.(*HeadersFrame)
   805  		if !ok {
   806  			return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
   807  		}
   808  		return hf, nil
   809  	}
   810  }
   811  
   812  type countingReader struct {
   813  	n *int64
   814  }
   815  
   816  func (r countingReader) Read(p []byte) (n int, err error) {
   817  	for i := range p {
   818  		p[i] = byte(i)
   819  	}
   820  	atomic.AddInt64(r.n, int64(len(p)))
   821  	return len(p), err
   822  }
   823  
   824  func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
   825  func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
   826  
   827  func testTransportReqBodyAfterResponse(t *testing.T, status int) {
   828  	const bodySize = 10 << 20
   829  	clientDone := make(chan struct{})
   830  	ct := newClientTester(t)
   831  	ct.client = func() error {
   832  		defer ct.cc.(*net.TCPConn).CloseWrite()
   833  		if runtime.GOOS == "plan9" {
   834  			// CloseWrite not supported on Plan 9; Issue 17906
   835  			defer ct.cc.(*net.TCPConn).Close()
   836  		}
   837  		defer close(clientDone)
   838  
   839  		var n int64 // atomic
   840  		req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize))
   841  		if err != nil {
   842  			return err
   843  		}
   844  		res, err := ct.tr.RoundTrip(req)
   845  		if err != nil {
   846  			return fmt.Errorf("RoundTrip: %v", err)
   847  		}
   848  		defer res.Body.Close()
   849  		if res.StatusCode != status {
   850  			return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
   851  		}
   852  		slurp, err := ioutil.ReadAll(res.Body)
   853  		if err != nil {
   854  			return fmt.Errorf("Slurp: %v", err)
   855  		}
   856  		if len(slurp) > 0 {
   857  			return fmt.Errorf("unexpected body: %q", slurp)
   858  		}
   859  		if status == 200 {
   860  			if got := atomic.LoadInt64(&n); got != bodySize {
   861  				return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
   862  			}
   863  		} else {
   864  			if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize {
   865  				return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
   866  			}
   867  		}
   868  		return nil
   869  	}
   870  	ct.server = func() error {
   871  		ct.greet()
   872  		var buf bytes.Buffer
   873  		enc := hpack.NewEncoder(&buf)
   874  		var dataRecv int64
   875  		var closed bool
   876  		for {
   877  			f, err := ct.fr.ReadFrame()
   878  			if err != nil {
   879  				select {
   880  				case <-clientDone:
   881  					// If the client's done, it
   882  					// will have reported any
   883  					// errors on its side.
   884  					return nil
   885  				default:
   886  					return err
   887  				}
   888  			}
   889  			//println(fmt.Sprintf("server got frame: %v", f))
   890  			switch f := f.(type) {
   891  			case *WindowUpdateFrame, *SettingsFrame:
   892  			case *HeadersFrame:
   893  				if !f.HeadersEnded() {
   894  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
   895  				}
   896  				if f.StreamEnded() {
   897  					return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
   898  				}
   899  			case *DataFrame:
   900  				dataLen := len(f.Data())
   901  				if dataLen > 0 {
   902  					if dataRecv == 0 {
   903  						enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
   904  						ct.fr.WriteHeaders(HeadersFrameParam{
   905  							StreamID:      f.StreamID,
   906  							EndHeaders:    true,
   907  							EndStream:     false,
   908  							BlockFragment: buf.Bytes(),
   909  						})
   910  					}
   911  					if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
   912  						return err
   913  					}
   914  					if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
   915  						return err
   916  					}
   917  				}
   918  				dataRecv += int64(dataLen)
   919  
   920  				if !closed && ((status != 200 && dataRecv > 0) ||
   921  					(status == 200 && dataRecv == bodySize)) {
   922  					closed = true
   923  					if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
   924  						return err
   925  					}
   926  				}
   927  			default:
   928  				return fmt.Errorf("Unexpected client frame %v", f)
   929  			}
   930  		}
   931  	}
   932  	ct.run()
   933  }
   934  
   935  // See golang.org/issue/13444
   936  func TestTransportFullDuplex(t *testing.T) {
   937  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   938  		w.WriteHeader(200) // redundant but for clarity
   939  		w.(http.Flusher).Flush()
   940  		io.Copy(flushWriter{w}, capitalizeReader{r.Body})
   941  		fmt.Fprintf(w, "bye.\n")
   942  	}, optOnlyServer)
   943  	defer st.Close()
   944  
   945  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   946  	defer tr.CloseIdleConnections()
   947  	c := &http.Client{Transport: tr}
   948  
   949  	pr, pw := io.Pipe()
   950  	req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
   951  	if err != nil {
   952  		t.Fatal(err)
   953  	}
   954  	req.ContentLength = -1
   955  	res, err := c.Do(req)
   956  	if err != nil {
   957  		t.Fatal(err)
   958  	}
   959  	defer res.Body.Close()
   960  	if res.StatusCode != 200 {
   961  		t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
   962  	}
   963  	bs := bufio.NewScanner(res.Body)
   964  	want := func(v string) {
   965  		if !bs.Scan() {
   966  			t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
   967  		}
   968  	}
   969  	write := func(v string) {
   970  		_, err := io.WriteString(pw, v)
   971  		if err != nil {
   972  			t.Fatalf("pipe write: %v", err)
   973  		}
   974  	}
   975  	write("foo\n")
   976  	want("FOO")
   977  	write("bar\n")
   978  	want("BAR")
   979  	pw.Close()
   980  	want("bye.")
   981  	if err := bs.Err(); err != nil {
   982  		t.Fatal(err)
   983  	}
   984  }
   985  
   986  func TestTransportConnectRequest(t *testing.T) {
   987  	gotc := make(chan *http.Request, 1)
   988  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   989  		gotc <- r
   990  	}, optOnlyServer)
   991  	defer st.Close()
   992  
   993  	u, err := url.Parse(st.ts.URL)
   994  	if err != nil {
   995  		t.Fatal(err)
   996  	}
   997  
   998  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
   999  	defer tr.CloseIdleConnections()
  1000  	c := &http.Client{Transport: tr}
  1001  
  1002  	tests := []struct {
  1003  		req  *http.Request
  1004  		want string
  1005  	}{
  1006  		{
  1007  			req: &http.Request{
  1008  				Method: "CONNECT",
  1009  				Header: http.Header{},
  1010  				URL:    u,
  1011  			},
  1012  			want: u.Host,
  1013  		},
  1014  		{
  1015  			req: &http.Request{
  1016  				Method: "CONNECT",
  1017  				Header: http.Header{},
  1018  				URL:    u,
  1019  				Host:   "example.com:123",
  1020  			},
  1021  			want: "example.com:123",
  1022  		},
  1023  	}
  1024  
  1025  	for i, tt := range tests {
  1026  		res, err := c.Do(tt.req)
  1027  		if err != nil {
  1028  			t.Errorf("%d. RoundTrip = %v", i, err)
  1029  			continue
  1030  		}
  1031  		res.Body.Close()
  1032  		req := <-gotc
  1033  		if req.Method != "CONNECT" {
  1034  			t.Errorf("method = %q; want CONNECT", req.Method)
  1035  		}
  1036  		if req.Host != tt.want {
  1037  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
  1038  		}
  1039  		if req.URL.Host != tt.want {
  1040  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
  1041  		}
  1042  	}
  1043  }
  1044  
  1045  type headerType int
  1046  
  1047  const (
  1048  	noHeader headerType = iota // omitted
  1049  	oneHeader
  1050  	splitHeader // broken into continuation on purpose
  1051  )
  1052  
  1053  const (
  1054  	f0 = noHeader
  1055  	f1 = oneHeader
  1056  	f2 = splitHeader
  1057  	d0 = false
  1058  	d1 = true
  1059  )
  1060  
  1061  // Test all 36 combinations of response frame orders:
  1062  //    (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) }
  1063  // Generated by http://play.golang.org/p/SScqYKJYXd
  1064  func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
  1065  func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
  1066  func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
  1067  func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
  1068  func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
  1069  func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
  1070  func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
  1071  func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
  1072  func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
  1073  func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
  1074  func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
  1075  func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
  1076  func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
  1077  func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
  1078  func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
  1079  func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
  1080  func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
  1081  func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
  1082  func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
  1083  func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
  1084  func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
  1085  func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
  1086  func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
  1087  func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
  1088  func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
  1089  func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
  1090  func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
  1091  func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
  1092  func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
  1093  func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
  1094  func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
  1095  func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
  1096  func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
  1097  func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
  1098  func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
  1099  func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
  1100  
  1101  func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
  1102  	const reqBody = "some request body"
  1103  	const resBody = "some response body"
  1104  
  1105  	if resHeader == noHeader {
  1106  		// TODO: test 100-continue followed by immediate
  1107  		// server stream reset, without headers in the middle?
  1108  		panic("invalid combination")
  1109  	}
  1110  
  1111  	ct := newClientTester(t)
  1112  	ct.client = func() error {
  1113  		req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
  1114  		if expect100Continue != noHeader {
  1115  			req.Header.Set("Expect", "100-continue")
  1116  		}
  1117  		res, err := ct.tr.RoundTrip(req)
  1118  		if err != nil {
  1119  			return fmt.Errorf("RoundTrip: %v", err)
  1120  		}
  1121  		defer res.Body.Close()
  1122  		if res.StatusCode != 200 {
  1123  			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  1124  		}
  1125  		slurp, err := ioutil.ReadAll(res.Body)
  1126  		if err != nil {
  1127  			return fmt.Errorf("Slurp: %v", err)
  1128  		}
  1129  		wantBody := resBody
  1130  		if !withData {
  1131  			wantBody = ""
  1132  		}
  1133  		if string(slurp) != wantBody {
  1134  			return fmt.Errorf("body = %q; want %q", slurp, wantBody)
  1135  		}
  1136  		if trailers == noHeader {
  1137  			if len(res.Trailer) > 0 {
  1138  				t.Errorf("Trailer = %v; want none", res.Trailer)
  1139  			}
  1140  		} else {
  1141  			want := http.Header{"Some-Trailer": {"some-value"}}
  1142  			if !reflect.DeepEqual(res.Trailer, want) {
  1143  				t.Errorf("Trailer = %v; want %v", res.Trailer, want)
  1144  			}
  1145  		}
  1146  		return nil
  1147  	}
  1148  	ct.server = func() error {
  1149  		ct.greet()
  1150  		var buf bytes.Buffer
  1151  		enc := hpack.NewEncoder(&buf)
  1152  
  1153  		for {
  1154  			f, err := ct.fr.ReadFrame()
  1155  			if err != nil {
  1156  				return err
  1157  			}
  1158  			endStream := false
  1159  			send := func(mode headerType) {
  1160  				hbf := buf.Bytes()
  1161  				switch mode {
  1162  				case oneHeader:
  1163  					ct.fr.WriteHeaders(HeadersFrameParam{
  1164  						StreamID:      f.Header().StreamID,
  1165  						EndHeaders:    true,
  1166  						EndStream:     endStream,
  1167  						BlockFragment: hbf,
  1168  					})
  1169  				case splitHeader:
  1170  					if len(hbf) < 2 {
  1171  						panic("too small")
  1172  					}
  1173  					ct.fr.WriteHeaders(HeadersFrameParam{
  1174  						StreamID:      f.Header().StreamID,
  1175  						EndHeaders:    false,
  1176  						EndStream:     endStream,
  1177  						BlockFragment: hbf[:1],
  1178  					})
  1179  					ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:])
  1180  				default:
  1181  					panic("bogus mode")
  1182  				}
  1183  			}
  1184  			switch f := f.(type) {
  1185  			case *WindowUpdateFrame, *SettingsFrame:
  1186  			case *DataFrame:
  1187  				if !f.StreamEnded() {
  1188  					// No need to send flow control tokens. The test request body is tiny.
  1189  					continue
  1190  				}
  1191  				// Response headers (1+ frames; 1 or 2 in this test, but never 0)
  1192  				{
  1193  					buf.Reset()
  1194  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1195  					enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
  1196  					enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
  1197  					if trailers != noHeader {
  1198  						enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
  1199  					}
  1200  					endStream = withData == false && trailers == noHeader
  1201  					send(resHeader)
  1202  				}
  1203  				if withData {
  1204  					endStream = trailers == noHeader
  1205  					ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
  1206  				}
  1207  				if trailers != noHeader {
  1208  					endStream = true
  1209  					buf.Reset()
  1210  					enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
  1211  					send(trailers)
  1212  				}
  1213  				if endStream {
  1214  					return nil
  1215  				}
  1216  			case *HeadersFrame:
  1217  				if expect100Continue != noHeader {
  1218  					buf.Reset()
  1219  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
  1220  					send(expect100Continue)
  1221  				}
  1222  			}
  1223  		}
  1224  	}
  1225  	ct.run()
  1226  }
  1227  
  1228  // Issue 26189, Issue 17739: ignore unknown 1xx responses
  1229  func TestTransportUnknown1xx(t *testing.T) {
  1230  	var buf bytes.Buffer
  1231  	defer func() { got1xxFuncForTests = nil }()
  1232  	got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error {
  1233  		fmt.Fprintf(&buf, "code=%d header=%v\n", code, header)
  1234  		return nil
  1235  	}
  1236  
  1237  	ct := newClientTester(t)
  1238  	ct.client = func() error {
  1239  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1240  		res, err := ct.tr.RoundTrip(req)
  1241  		if err != nil {
  1242  			return fmt.Errorf("RoundTrip: %v", err)
  1243  		}
  1244  		defer res.Body.Close()
  1245  		if res.StatusCode != 204 {
  1246  			return fmt.Errorf("status code = %v; want 204", res.StatusCode)
  1247  		}
  1248  		want := `code=110 header=map[Foo-Bar:[110]]
  1249  code=111 header=map[Foo-Bar:[111]]
  1250  code=112 header=map[Foo-Bar:[112]]
  1251  code=113 header=map[Foo-Bar:[113]]
  1252  code=114 header=map[Foo-Bar:[114]]
  1253  `
  1254  		if got := buf.String(); got != want {
  1255  			t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
  1256  		}
  1257  		return nil
  1258  	}
  1259  	ct.server = func() error {
  1260  		ct.greet()
  1261  		var buf bytes.Buffer
  1262  		enc := hpack.NewEncoder(&buf)
  1263  
  1264  		for {
  1265  			f, err := ct.fr.ReadFrame()
  1266  			if err != nil {
  1267  				return err
  1268  			}
  1269  			switch f := f.(type) {
  1270  			case *WindowUpdateFrame, *SettingsFrame:
  1271  			case *HeadersFrame:
  1272  				for i := 110; i <= 114; i++ {
  1273  					buf.Reset()
  1274  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)})
  1275  					enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)})
  1276  					ct.fr.WriteHeaders(HeadersFrameParam{
  1277  						StreamID:      f.StreamID,
  1278  						EndHeaders:    true,
  1279  						EndStream:     false,
  1280  						BlockFragment: buf.Bytes(),
  1281  					})
  1282  				}
  1283  				buf.Reset()
  1284  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
  1285  				ct.fr.WriteHeaders(HeadersFrameParam{
  1286  					StreamID:      f.StreamID,
  1287  					EndHeaders:    true,
  1288  					EndStream:     false,
  1289  					BlockFragment: buf.Bytes(),
  1290  				})
  1291  				return nil
  1292  			}
  1293  		}
  1294  	}
  1295  	ct.run()
  1296  
  1297  }
  1298  
  1299  func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
  1300  	ct := newClientTester(t)
  1301  	ct.client = func() error {
  1302  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1303  		res, err := ct.tr.RoundTrip(req)
  1304  		if err != nil {
  1305  			return fmt.Errorf("RoundTrip: %v", err)
  1306  		}
  1307  		defer res.Body.Close()
  1308  		if res.StatusCode != 200 {
  1309  			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  1310  		}
  1311  		slurp, err := ioutil.ReadAll(res.Body)
  1312  		if err != nil {
  1313  			return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil)
  1314  		}
  1315  		if len(slurp) > 0 {
  1316  			return fmt.Errorf("body = %q; want nothing", slurp)
  1317  		}
  1318  		if _, ok := res.Trailer["Some-Trailer"]; !ok {
  1319  			return fmt.Errorf("expected Some-Trailer")
  1320  		}
  1321  		return nil
  1322  	}
  1323  	ct.server = func() error {
  1324  		ct.greet()
  1325  
  1326  		var n int
  1327  		var hf *HeadersFrame
  1328  		for hf == nil && n < 10 {
  1329  			f, err := ct.fr.ReadFrame()
  1330  			if err != nil {
  1331  				return err
  1332  			}
  1333  			hf, _ = f.(*HeadersFrame)
  1334  			n++
  1335  		}
  1336  
  1337  		var buf bytes.Buffer
  1338  		enc := hpack.NewEncoder(&buf)
  1339  
  1340  		// send headers without Trailer header
  1341  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1342  		ct.fr.WriteHeaders(HeadersFrameParam{
  1343  			StreamID:      hf.StreamID,
  1344  			EndHeaders:    true,
  1345  			EndStream:     false,
  1346  			BlockFragment: buf.Bytes(),
  1347  		})
  1348  
  1349  		// send trailers
  1350  		buf.Reset()
  1351  		enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"})
  1352  		ct.fr.WriteHeaders(HeadersFrameParam{
  1353  			StreamID:      hf.StreamID,
  1354  			EndHeaders:    true,
  1355  			EndStream:     true,
  1356  			BlockFragment: buf.Bytes(),
  1357  		})
  1358  		return nil
  1359  	}
  1360  	ct.run()
  1361  }
  1362  
  1363  func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
  1364  	testTransportInvalidTrailer_Pseudo(t, oneHeader)
  1365  }
  1366  func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
  1367  	testTransportInvalidTrailer_Pseudo(t, splitHeader)
  1368  }
  1369  func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
  1370  	testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
  1371  		enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
  1372  		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
  1373  	})
  1374  }
  1375  
  1376  func TestTransportInvalidTrailer_Capital1(t *testing.T) {
  1377  	testTransportInvalidTrailer_Capital(t, oneHeader)
  1378  }
  1379  func TestTransportInvalidTrailer_Capital2(t *testing.T) {
  1380  	testTransportInvalidTrailer_Capital(t, splitHeader)
  1381  }
  1382  func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
  1383  	testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
  1384  		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
  1385  		enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
  1386  	})
  1387  }
  1388  func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
  1389  	testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
  1390  		enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
  1391  	})
  1392  }
  1393  func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
  1394  	testInvalidTrailer(t, oneHeader, headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) {
  1395  		enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
  1396  	})
  1397  }
  1398  
  1399  func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) {
  1400  	ct := newClientTester(t)
  1401  	ct.client = func() error {
  1402  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1403  		res, err := ct.tr.RoundTrip(req)
  1404  		if err != nil {
  1405  			return fmt.Errorf("RoundTrip: %v", err)
  1406  		}
  1407  		defer res.Body.Close()
  1408  		if res.StatusCode != 200 {
  1409  			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  1410  		}
  1411  		slurp, err := ioutil.ReadAll(res.Body)
  1412  		se, ok := err.(StreamError)
  1413  		if !ok || se.Cause != wantErr {
  1414  			return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
  1415  		}
  1416  		if len(slurp) > 0 {
  1417  			return fmt.Errorf("body = %q; want nothing", slurp)
  1418  		}
  1419  		return nil
  1420  	}
  1421  	ct.server = func() error {
  1422  		ct.greet()
  1423  		var buf bytes.Buffer
  1424  		enc := hpack.NewEncoder(&buf)
  1425  
  1426  		for {
  1427  			f, err := ct.fr.ReadFrame()
  1428  			if err != nil {
  1429  				return err
  1430  			}
  1431  			switch f := f.(type) {
  1432  			case *HeadersFrame:
  1433  				var endStream bool
  1434  				send := func(mode headerType) {
  1435  					hbf := buf.Bytes()
  1436  					switch mode {
  1437  					case oneHeader:
  1438  						ct.fr.WriteHeaders(HeadersFrameParam{
  1439  							StreamID:      f.StreamID,
  1440  							EndHeaders:    true,
  1441  							EndStream:     endStream,
  1442  							BlockFragment: hbf,
  1443  						})
  1444  					case splitHeader:
  1445  						if len(hbf) < 2 {
  1446  							panic("too small")
  1447  						}
  1448  						ct.fr.WriteHeaders(HeadersFrameParam{
  1449  							StreamID:      f.StreamID,
  1450  							EndHeaders:    false,
  1451  							EndStream:     endStream,
  1452  							BlockFragment: hbf[:1],
  1453  						})
  1454  						ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
  1455  					default:
  1456  						panic("bogus mode")
  1457  					}
  1458  				}
  1459  				// Response headers (1+ frames; 1 or 2 in this test, but never 0)
  1460  				{
  1461  					buf.Reset()
  1462  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1463  					enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"})
  1464  					endStream = false
  1465  					send(oneHeader)
  1466  				}
  1467  				// Trailers:
  1468  				{
  1469  					endStream = true
  1470  					buf.Reset()
  1471  					writeTrailer(enc)
  1472  					send(trailers)
  1473  				}
  1474  				return nil
  1475  			}
  1476  		}
  1477  	}
  1478  	ct.run()
  1479  }
  1480  
  1481  // headerListSize returns the HTTP2 header list size of h.
  1482  //   http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
  1483  //   http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock
  1484  func headerListSize(h http.Header) (size uint32) {
  1485  	for k, vv := range h {
  1486  		for _, v := range vv {
  1487  			hf := hpack.HeaderField{Name: k, Value: v}
  1488  			size += hf.Size()
  1489  		}
  1490  	}
  1491  	return size
  1492  }
  1493  
  1494  // padHeaders adds data to an http.Header until headerListSize(h) ==
  1495  // limit. Due to the way header list sizes are calculated, padHeaders
  1496  // cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will
  1497  // call t.Fatal if asked to do so. PadHeaders first reserves enough
  1498  // space for an empty "Pad-Headers" key, then adds as many copies of
  1499  // filler as possible. Any remaining bytes necessary to push the
  1500  // header list size up to limit are added to h["Pad-Headers"].
  1501  func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) {
  1502  	if limit > 0xffffffff {
  1503  		t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
  1504  	}
  1505  	hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
  1506  	minPadding := uint64(hf.Size())
  1507  	size := uint64(headerListSize(h))
  1508  
  1509  	minlimit := size + minPadding
  1510  	if limit < minlimit {
  1511  		t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
  1512  	}
  1513  
  1514  	// Use a fixed-width format for name so that fieldSize
  1515  	// remains constant.
  1516  	nameFmt := "Pad-Headers-%06d"
  1517  	hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
  1518  	fieldSize := uint64(hf.Size())
  1519  
  1520  	// Add as many complete filler values as possible, leaving
  1521  	// room for at least one empty "Pad-Headers" key.
  1522  	limit = limit - minPadding
  1523  	for i := 0; size+fieldSize < limit; i++ {
  1524  		name := fmt.Sprintf(nameFmt, i)
  1525  		h.Add(name, filler)
  1526  		size += fieldSize
  1527  	}
  1528  
  1529  	// Add enough bytes to reach limit.
  1530  	remain := limit - size
  1531  	lastValue := strings.Repeat("*", int(remain))
  1532  	h.Add("Pad-Headers", lastValue)
  1533  }
  1534  
  1535  func TestPadHeaders(t *testing.T) {
  1536  	check := func(h http.Header, limit uint32, fillerLen int) {
  1537  		if h == nil {
  1538  			h = make(http.Header)
  1539  		}
  1540  		filler := strings.Repeat("f", fillerLen)
  1541  		padHeaders(t, h, uint64(limit), filler)
  1542  		gotSize := headerListSize(h)
  1543  		if gotSize != limit {
  1544  			t.Errorf("Got size = %v; want %v", gotSize, limit)
  1545  		}
  1546  	}
  1547  	// Try all possible combinations for small fillerLen and limit.
  1548  	hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
  1549  	minLimit := hf.Size()
  1550  	for limit := minLimit; limit <= 128; limit++ {
  1551  		for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
  1552  			check(nil, limit, fillerLen)
  1553  		}
  1554  	}
  1555  
  1556  	// Try a few tests with larger limits, plus cumulative
  1557  	// tests. Since these tests are cumulative, tests[i+1].limit
  1558  	// must be >= tests[i].limit + minLimit. See the comment on
  1559  	// padHeaders for more info on why the limit arg has this
  1560  	// restriction.
  1561  	tests := []struct {
  1562  		fillerLen int
  1563  		limit     uint32
  1564  	}{
  1565  		{
  1566  			fillerLen: 64,
  1567  			limit:     1024,
  1568  		},
  1569  		{
  1570  			fillerLen: 1024,
  1571  			limit:     1286,
  1572  		},
  1573  		{
  1574  			fillerLen: 256,
  1575  			limit:     2048,
  1576  		},
  1577  		{
  1578  			fillerLen: 1024,
  1579  			limit:     10 * 1024,
  1580  		},
  1581  		{
  1582  			fillerLen: 1023,
  1583  			limit:     11 * 1024,
  1584  		},
  1585  	}
  1586  	h := make(http.Header)
  1587  	for _, tc := range tests {
  1588  		check(nil, tc.limit, tc.fillerLen)
  1589  		check(h, tc.limit, tc.fillerLen)
  1590  	}
  1591  }
  1592  
  1593  func TestTransportChecksRequestHeaderListSize(t *testing.T) {
  1594  	st := newServerTester(t,
  1595  		func(w http.ResponseWriter, r *http.Request) {
  1596  			// Consume body & force client to send
  1597  			// trailers before writing response.
  1598  			// ioutil.ReadAll returns non-nil err for
  1599  			// requests that attempt to send greater than
  1600  			// maxHeaderListSize bytes of trailers, since
  1601  			// those requests generate a stream reset.
  1602  			ioutil.ReadAll(r.Body)
  1603  			r.Body.Close()
  1604  		},
  1605  		func(ts *httptest.Server) {
  1606  			ts.Config.MaxHeaderBytes = 16 << 10
  1607  		},
  1608  		optOnlyServer,
  1609  		optQuiet,
  1610  	)
  1611  	defer st.Close()
  1612  
  1613  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  1614  	defer tr.CloseIdleConnections()
  1615  
  1616  	checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
  1617  		res, err := tr.RoundTrip(req)
  1618  		if err != wantErr {
  1619  			if res != nil {
  1620  				res.Body.Close()
  1621  			}
  1622  			t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
  1623  			return
  1624  		}
  1625  		if err == nil {
  1626  			if res == nil {
  1627  				t.Errorf("%v: response nil; want non-nil.", desc)
  1628  				return
  1629  			}
  1630  			defer res.Body.Close()
  1631  			if res.StatusCode != http.StatusOK {
  1632  				t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK)
  1633  			}
  1634  			return
  1635  		}
  1636  		if res != nil {
  1637  			t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err)
  1638  		}
  1639  	}
  1640  	headerListSizeForRequest := func(req *http.Request) (size uint64) {
  1641  		contentLen := actualContentLength(req)
  1642  		trailers, err := commaSeparatedTrailers(req)
  1643  		if err != nil {
  1644  			t.Fatalf("headerListSizeForRequest: %v", err)
  1645  		}
  1646  		cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
  1647  		cc.henc = hpack.NewEncoder(&cc.hbuf)
  1648  		cc.mu.Lock()
  1649  		hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
  1650  		cc.mu.Unlock()
  1651  		if err != nil {
  1652  			t.Fatalf("headerListSizeForRequest: %v", err)
  1653  		}
  1654  		hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
  1655  			size += uint64(hf.Size())
  1656  		})
  1657  		if len(hdrs) > 0 {
  1658  			if _, err := hpackDec.Write(hdrs); err != nil {
  1659  				t.Fatalf("headerListSizeForRequest: %v", err)
  1660  			}
  1661  		}
  1662  		return size
  1663  	}
  1664  	// Create a new Request for each test, rather than reusing the
  1665  	// same Request, to avoid a race when modifying req.Headers.
  1666  	// See https://github.com/golang/go/issues/21316
  1667  	newRequest := func() *http.Request {
  1668  		// Body must be non-nil to enable writing trailers.
  1669  		body := strings.NewReader("hello")
  1670  		req, err := http.NewRequest("POST", st.ts.URL, body)
  1671  		if err != nil {
  1672  			t.Fatalf("newRequest: NewRequest: %v", err)
  1673  		}
  1674  		return req
  1675  	}
  1676  
  1677  	// Make an arbitrary request to ensure we get the server's
  1678  	// settings frame and initialize peerMaxHeaderListSize.
  1679  	req := newRequest()
  1680  	checkRoundTrip(req, nil, "Initial request")
  1681  
  1682  	// Get the ClientConn associated with the request and validate
  1683  	// peerMaxHeaderListSize.
  1684  	addr := authorityAddr(req.URL.Scheme, req.URL.Host)
  1685  	cc, err := tr.connPool().GetClientConn(req, addr)
  1686  	if err != nil {
  1687  		t.Fatalf("GetClientConn: %v", err)
  1688  	}
  1689  	cc.mu.Lock()
  1690  	peerSize := cc.peerMaxHeaderListSize
  1691  	cc.mu.Unlock()
  1692  	st.scMu.Lock()
  1693  	wantSize := uint64(st.sc.maxHeaderListSize())
  1694  	st.scMu.Unlock()
  1695  	if peerSize != wantSize {
  1696  		t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize)
  1697  	}
  1698  
  1699  	// Sanity check peerSize. (*serverConn) maxHeaderListSize adds
  1700  	// 320 bytes of padding.
  1701  	wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320
  1702  	if peerSize != wantHeaderBytes {
  1703  		t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes)
  1704  	}
  1705  
  1706  	// Pad headers & trailers, but stay under peerSize.
  1707  	req = newRequest()
  1708  	req.Header = make(http.Header)
  1709  	req.Trailer = make(http.Header)
  1710  	filler := strings.Repeat("*", 1024)
  1711  	padHeaders(t, req.Trailer, peerSize, filler)
  1712  	// cc.encodeHeaders adds some default headers to the request,
  1713  	// so we need to leave room for those.
  1714  	defaultBytes := headerListSizeForRequest(req)
  1715  	padHeaders(t, req.Header, peerSize-defaultBytes, filler)
  1716  	checkRoundTrip(req, nil, "Headers & Trailers under limit")
  1717  
  1718  	// Add enough header bytes to push us over peerSize.
  1719  	req = newRequest()
  1720  	req.Header = make(http.Header)
  1721  	padHeaders(t, req.Header, peerSize, filler)
  1722  	checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit")
  1723  
  1724  	// Push trailers over the limit.
  1725  	req = newRequest()
  1726  	req.Trailer = make(http.Header)
  1727  	padHeaders(t, req.Trailer, peerSize+1, filler)
  1728  	checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit")
  1729  
  1730  	// Send headers with a single large value.
  1731  	req = newRequest()
  1732  	filler = strings.Repeat("*", int(peerSize))
  1733  	req.Header = make(http.Header)
  1734  	req.Header.Set("Big", filler)
  1735  	checkRoundTrip(req, errRequestHeaderListSize, "Single large header")
  1736  
  1737  	// Send trailers with a single large value.
  1738  	req = newRequest()
  1739  	req.Trailer = make(http.Header)
  1740  	req.Trailer.Set("Big", filler)
  1741  	checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer")
  1742  }
  1743  
  1744  func TestTransportChecksResponseHeaderListSize(t *testing.T) {
  1745  	ct := newClientTester(t)
  1746  	ct.client = func() error {
  1747  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1748  		res, err := ct.tr.RoundTrip(req)
  1749  		if err != errResponseHeaderListSize {
  1750  			if res != nil {
  1751  				res.Body.Close()
  1752  			}
  1753  			size := int64(0)
  1754  			for k, vv := range res.Header {
  1755  				for _, v := range vv {
  1756  					size += int64(len(k)) + int64(len(v)) + 32
  1757  				}
  1758  			}
  1759  			return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
  1760  		}
  1761  		return nil
  1762  	}
  1763  	ct.server = func() error {
  1764  		ct.greet()
  1765  		var buf bytes.Buffer
  1766  		enc := hpack.NewEncoder(&buf)
  1767  
  1768  		for {
  1769  			f, err := ct.fr.ReadFrame()
  1770  			if err != nil {
  1771  				return err
  1772  			}
  1773  			switch f := f.(type) {
  1774  			case *HeadersFrame:
  1775  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1776  				large := strings.Repeat("a", 1<<10)
  1777  				for i := 0; i < 5042; i++ {
  1778  					enc.WriteField(hpack.HeaderField{Name: large, Value: large})
  1779  				}
  1780  				if size, want := buf.Len(), 6329; size != want {
  1781  					// Note: this number might change if
  1782  					// our hpack implementation
  1783  					// changes. That's fine. This is
  1784  					// just a sanity check that our
  1785  					// response can fit in a single
  1786  					// header block fragment frame.
  1787  					return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
  1788  				}
  1789  				ct.fr.WriteHeaders(HeadersFrameParam{
  1790  					StreamID:      f.StreamID,
  1791  					EndHeaders:    true,
  1792  					EndStream:     true,
  1793  					BlockFragment: buf.Bytes(),
  1794  				})
  1795  				return nil
  1796  			}
  1797  		}
  1798  	}
  1799  	ct.run()
  1800  }
  1801  
  1802  func TestTransportCookieHeaderSplit(t *testing.T) {
  1803  	ct := newClientTester(t)
  1804  	ct.client = func() error {
  1805  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  1806  		req.Header.Add("Cookie", "a=b;c=d;  e=f;")
  1807  		req.Header.Add("Cookie", "e=f;g=h; ")
  1808  		req.Header.Add("Cookie", "i=j")
  1809  		_, err := ct.tr.RoundTrip(req)
  1810  		return err
  1811  	}
  1812  	ct.server = func() error {
  1813  		ct.greet()
  1814  		for {
  1815  			f, err := ct.fr.ReadFrame()
  1816  			if err != nil {
  1817  				return err
  1818  			}
  1819  			switch f := f.(type) {
  1820  			case *HeadersFrame:
  1821  				dec := hpack.NewDecoder(initialHeaderTableSize, nil)
  1822  				hfs, err := dec.DecodeFull(f.HeaderBlockFragment())
  1823  				if err != nil {
  1824  					return err
  1825  				}
  1826  				got := []string{}
  1827  				want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}
  1828  				for _, hf := range hfs {
  1829  					if hf.Name == "cookie" {
  1830  						got = append(got, hf.Value)
  1831  					}
  1832  				}
  1833  				if !reflect.DeepEqual(got, want) {
  1834  					t.Errorf("Cookies = %#v, want %#v", got, want)
  1835  				}
  1836  
  1837  				var buf bytes.Buffer
  1838  				enc := hpack.NewEncoder(&buf)
  1839  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  1840  				ct.fr.WriteHeaders(HeadersFrameParam{
  1841  					StreamID:      f.StreamID,
  1842  					EndHeaders:    true,
  1843  					EndStream:     true,
  1844  					BlockFragment: buf.Bytes(),
  1845  				})
  1846  				return nil
  1847  			}
  1848  		}
  1849  	}
  1850  	ct.run()
  1851  }
  1852  
  1853  // Test that the Transport returns a typed error from Response.Body.Read calls
  1854  // when the server sends an error. (here we use a panic, since that should generate
  1855  // a stream error, but others like cancel should be similar)
  1856  func TestTransportBodyReadErrorType(t *testing.T) {
  1857  	doPanic := make(chan bool, 1)
  1858  	st := newServerTester(t,
  1859  		func(w http.ResponseWriter, r *http.Request) {
  1860  			w.(http.Flusher).Flush() // force headers out
  1861  			<-doPanic
  1862  			panic("boom")
  1863  		},
  1864  		optOnlyServer,
  1865  		optQuiet,
  1866  	)
  1867  	defer st.Close()
  1868  
  1869  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  1870  	defer tr.CloseIdleConnections()
  1871  	c := &http.Client{Transport: tr}
  1872  
  1873  	res, err := c.Get(st.ts.URL)
  1874  	if err != nil {
  1875  		t.Fatal(err)
  1876  	}
  1877  	defer res.Body.Close()
  1878  	doPanic <- true
  1879  	buf := make([]byte, 100)
  1880  	n, err := res.Body.Read(buf)
  1881  	want := StreamError{StreamID: 0x1, Code: 0x2}
  1882  	if !reflect.DeepEqual(want, err) {
  1883  		t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
  1884  	}
  1885  }
  1886  
  1887  // golang.org/issue/13924
  1888  // This used to fail after many iterations, especially with -race:
  1889  // go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race
  1890  func TestTransportDoubleCloseOnWriteError(t *testing.T) {
  1891  	var (
  1892  		mu   sync.Mutex
  1893  		conn net.Conn // to close if set
  1894  	)
  1895  
  1896  	st := newServerTester(t,
  1897  		func(w http.ResponseWriter, r *http.Request) {
  1898  			mu.Lock()
  1899  			defer mu.Unlock()
  1900  			if conn != nil {
  1901  				conn.Close()
  1902  			}
  1903  		},
  1904  		optOnlyServer,
  1905  	)
  1906  	defer st.Close()
  1907  
  1908  	tr := &Transport{
  1909  		TLSClientConfig: tlsConfigInsecure,
  1910  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  1911  			tc, err := tls.Dial(network, addr, cfg)
  1912  			if err != nil {
  1913  				return nil, err
  1914  			}
  1915  			mu.Lock()
  1916  			defer mu.Unlock()
  1917  			conn = tc
  1918  			return tc, nil
  1919  		},
  1920  	}
  1921  	defer tr.CloseIdleConnections()
  1922  	c := &http.Client{Transport: tr}
  1923  	c.Get(st.ts.URL)
  1924  }
  1925  
  1926  // Test that the http1 Transport.DisableKeepAlives option is respected
  1927  // and connections are closed as soon as idle.
  1928  // See golang.org/issue/14008
  1929  func TestTransportDisableKeepAlives(t *testing.T) {
  1930  	st := newServerTester(t,
  1931  		func(w http.ResponseWriter, r *http.Request) {
  1932  			io.WriteString(w, "hi")
  1933  		},
  1934  		optOnlyServer,
  1935  	)
  1936  	defer st.Close()
  1937  
  1938  	connClosed := make(chan struct{}) // closed on tls.Conn.Close
  1939  	tr := &Transport{
  1940  		t1: &http.Transport{
  1941  			DisableKeepAlives: true,
  1942  		},
  1943  		TLSClientConfig: tlsConfigInsecure,
  1944  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  1945  			tc, err := tls.Dial(network, addr, cfg)
  1946  			if err != nil {
  1947  				return nil, err
  1948  			}
  1949  			return &noteCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
  1950  		},
  1951  	}
  1952  	c := &http.Client{Transport: tr}
  1953  	res, err := c.Get(st.ts.URL)
  1954  	if err != nil {
  1955  		t.Fatal(err)
  1956  	}
  1957  	if _, err := ioutil.ReadAll(res.Body); err != nil {
  1958  		t.Fatal(err)
  1959  	}
  1960  	defer res.Body.Close()
  1961  
  1962  	select {
  1963  	case <-connClosed:
  1964  	case <-time.After(1 * time.Second):
  1965  		t.Errorf("timeout")
  1966  	}
  1967  
  1968  }
  1969  
  1970  // Test concurrent requests with Transport.DisableKeepAlives. We can share connections,
  1971  // but when things are totally idle, it still needs to close.
  1972  func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
  1973  	const D = 25 * time.Millisecond
  1974  	st := newServerTester(t,
  1975  		func(w http.ResponseWriter, r *http.Request) {
  1976  			time.Sleep(D)
  1977  			io.WriteString(w, "hi")
  1978  		},
  1979  		optOnlyServer,
  1980  	)
  1981  	defer st.Close()
  1982  
  1983  	var dials int32
  1984  	var conns sync.WaitGroup
  1985  	tr := &Transport{
  1986  		t1: &http.Transport{
  1987  			DisableKeepAlives: true,
  1988  		},
  1989  		TLSClientConfig: tlsConfigInsecure,
  1990  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  1991  			tc, err := tls.Dial(network, addr, cfg)
  1992  			if err != nil {
  1993  				return nil, err
  1994  			}
  1995  			atomic.AddInt32(&dials, 1)
  1996  			conns.Add(1)
  1997  			return &noteCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
  1998  		},
  1999  	}
  2000  	c := &http.Client{Transport: tr}
  2001  	var reqs sync.WaitGroup
  2002  	const N = 20
  2003  	for i := 0; i < N; i++ {
  2004  		reqs.Add(1)
  2005  		if i == N-1 {
  2006  			// For the final request, try to make all the
  2007  			// others close. This isn't verified in the
  2008  			// count, other than the Log statement, since
  2009  			// it's so timing dependent. This test is
  2010  			// really to make sure we don't interrupt a
  2011  			// valid request.
  2012  			time.Sleep(D * 2)
  2013  		}
  2014  		go func() {
  2015  			defer reqs.Done()
  2016  			res, err := c.Get(st.ts.URL)
  2017  			if err != nil {
  2018  				t.Error(err)
  2019  				return
  2020  			}
  2021  			if _, err := ioutil.ReadAll(res.Body); err != nil {
  2022  				t.Error(err)
  2023  				return
  2024  			}
  2025  			res.Body.Close()
  2026  		}()
  2027  	}
  2028  	reqs.Wait()
  2029  	conns.Wait()
  2030  	t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
  2031  }
  2032  
  2033  type noteCloseConn struct {
  2034  	net.Conn
  2035  	onceClose sync.Once
  2036  	closefn   func()
  2037  }
  2038  
  2039  func (c *noteCloseConn) Close() error {
  2040  	c.onceClose.Do(c.closefn)
  2041  	return c.Conn.Close()
  2042  }
  2043  
  2044  func isTimeout(err error) bool {
  2045  	switch err := err.(type) {
  2046  	case nil:
  2047  		return false
  2048  	case *url.Error:
  2049  		return isTimeout(err.Err)
  2050  	case net.Error:
  2051  		return err.Timeout()
  2052  	}
  2053  	return false
  2054  }
  2055  
  2056  // Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent.
  2057  func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
  2058  	testTransportResponseHeaderTimeout(t, false)
  2059  }
  2060  func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
  2061  	testTransportResponseHeaderTimeout(t, true)
  2062  }
  2063  
  2064  func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
  2065  	ct := newClientTester(t)
  2066  	ct.tr.t1 = &http.Transport{
  2067  		ResponseHeaderTimeout: 5 * time.Millisecond,
  2068  	}
  2069  	ct.client = func() error {
  2070  		c := &http.Client{Transport: ct.tr}
  2071  		var err error
  2072  		var n int64
  2073  		const bodySize = 4 << 20
  2074  		if body {
  2075  			_, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
  2076  		} else {
  2077  			_, err = c.Get("https://dummy.tld/")
  2078  		}
  2079  		if !isTimeout(err) {
  2080  			t.Errorf("client expected timeout error; got %#v", err)
  2081  		}
  2082  		if body && n != bodySize {
  2083  			t.Errorf("only read %d bytes of body; want %d", n, bodySize)
  2084  		}
  2085  		return nil
  2086  	}
  2087  	ct.server = func() error {
  2088  		ct.greet()
  2089  		for {
  2090  			f, err := ct.fr.ReadFrame()
  2091  			if err != nil {
  2092  				t.Logf("ReadFrame: %v", err)
  2093  				return nil
  2094  			}
  2095  			switch f := f.(type) {
  2096  			case *DataFrame:
  2097  				dataLen := len(f.Data())
  2098  				if dataLen > 0 {
  2099  					if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
  2100  						return err
  2101  					}
  2102  					if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
  2103  						return err
  2104  					}
  2105  				}
  2106  			case *RSTStreamFrame:
  2107  				if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
  2108  					return nil
  2109  				}
  2110  			}
  2111  		}
  2112  	}
  2113  	ct.run()
  2114  }
  2115  
  2116  func TestTransportDisableCompression(t *testing.T) {
  2117  	const body = "sup"
  2118  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2119  		want := http.Header{
  2120  			"User-Agent": []string{"Go-http-client/2.0"},
  2121  		}
  2122  		if !reflect.DeepEqual(r.Header, want) {
  2123  			t.Errorf("request headers = %v; want %v", r.Header, want)
  2124  		}
  2125  	}, optOnlyServer)
  2126  	defer st.Close()
  2127  
  2128  	tr := &Transport{
  2129  		TLSClientConfig: tlsConfigInsecure,
  2130  		t1: &http.Transport{
  2131  			DisableCompression: true,
  2132  		},
  2133  	}
  2134  	defer tr.CloseIdleConnections()
  2135  
  2136  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  2137  	if err != nil {
  2138  		t.Fatal(err)
  2139  	}
  2140  	res, err := tr.RoundTrip(req)
  2141  	if err != nil {
  2142  		t.Fatal(err)
  2143  	}
  2144  	defer res.Body.Close()
  2145  }
  2146  
  2147  // RFC 7540 section 8.1.2.2
  2148  func TestTransportRejectsConnHeaders(t *testing.T) {
  2149  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2150  		var got []string
  2151  		for k := range r.Header {
  2152  			got = append(got, k)
  2153  		}
  2154  		sort.Strings(got)
  2155  		w.Header().Set("Got-Header", strings.Join(got, ","))
  2156  	}, optOnlyServer)
  2157  	defer st.Close()
  2158  
  2159  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2160  	defer tr.CloseIdleConnections()
  2161  
  2162  	tests := []struct {
  2163  		key   string
  2164  		value []string
  2165  		want  string
  2166  	}{
  2167  		{
  2168  			key:   "Upgrade",
  2169  			value: []string{"anything"},
  2170  			want:  "ERROR: http2: invalid Upgrade request header: [\"anything\"]",
  2171  		},
  2172  		{
  2173  			key:   "Connection",
  2174  			value: []string{"foo"},
  2175  			want:  "ERROR: http2: invalid Connection request header: [\"foo\"]",
  2176  		},
  2177  		{
  2178  			key:   "Connection",
  2179  			value: []string{"close"},
  2180  			want:  "Accept-Encoding,User-Agent",
  2181  		},
  2182  		{
  2183  			key:   "Connection",
  2184  			value: []string{"CLoSe"},
  2185  			want:  "Accept-Encoding,User-Agent",
  2186  		},
  2187  		{
  2188  			key:   "Connection",
  2189  			value: []string{"close", "something-else"},
  2190  			want:  "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]",
  2191  		},
  2192  		{
  2193  			key:   "Connection",
  2194  			value: []string{"keep-alive"},
  2195  			want:  "Accept-Encoding,User-Agent",
  2196  		},
  2197  		{
  2198  			key:   "Connection",
  2199  			value: []string{"Keep-ALIVE"},
  2200  			want:  "Accept-Encoding,User-Agent",
  2201  		},
  2202  		{
  2203  			key:   "Proxy-Connection", // just deleted and ignored
  2204  			value: []string{"keep-alive"},
  2205  			want:  "Accept-Encoding,User-Agent",
  2206  		},
  2207  		{
  2208  			key:   "Transfer-Encoding",
  2209  			value: []string{""},
  2210  			want:  "Accept-Encoding,User-Agent",
  2211  		},
  2212  		{
  2213  			key:   "Transfer-Encoding",
  2214  			value: []string{"foo"},
  2215  			want:  "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]",
  2216  		},
  2217  		{
  2218  			key:   "Transfer-Encoding",
  2219  			value: []string{"chunked"},
  2220  			want:  "Accept-Encoding,User-Agent",
  2221  		},
  2222  		{
  2223  			key:   "Transfer-Encoding",
  2224  			value: []string{"chunked", "other"},
  2225  			want:  "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]",
  2226  		},
  2227  		{
  2228  			key:   "Content-Length",
  2229  			value: []string{"123"},
  2230  			want:  "Accept-Encoding,User-Agent",
  2231  		},
  2232  		{
  2233  			key:   "Keep-Alive",
  2234  			value: []string{"doop"},
  2235  			want:  "Accept-Encoding,User-Agent",
  2236  		},
  2237  	}
  2238  
  2239  	for _, tt := range tests {
  2240  		req, _ := http.NewRequest("GET", st.ts.URL, nil)
  2241  		req.Header[tt.key] = tt.value
  2242  		res, err := tr.RoundTrip(req)
  2243  		var got string
  2244  		if err != nil {
  2245  			got = fmt.Sprintf("ERROR: %v", err)
  2246  		} else {
  2247  			got = res.Header.Get("Got-Header")
  2248  			res.Body.Close()
  2249  		}
  2250  		if got != tt.want {
  2251  			t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
  2252  		}
  2253  	}
  2254  }
  2255  
  2256  // Reject content-length headers containing a sign.
  2257  // See https://golang.org/issue/39017
  2258  func TestTransportRejectsContentLengthWithSign(t *testing.T) {
  2259  	tests := []struct {
  2260  		name   string
  2261  		cl     []string
  2262  		wantCL string
  2263  	}{
  2264  		{
  2265  			name:   "proper content-length",
  2266  			cl:     []string{"3"},
  2267  			wantCL: "3",
  2268  		},
  2269  		{
  2270  			name:   "ignore cl with plus sign",
  2271  			cl:     []string{"+3"},
  2272  			wantCL: "",
  2273  		},
  2274  		{
  2275  			name:   "ignore cl with minus sign",
  2276  			cl:     []string{"-3"},
  2277  			wantCL: "",
  2278  		},
  2279  		{
  2280  			name:   "max int64, for safe uint64->int64 conversion",
  2281  			cl:     []string{"9223372036854775807"},
  2282  			wantCL: "9223372036854775807",
  2283  		},
  2284  		{
  2285  			name:   "overflows int64, so ignored",
  2286  			cl:     []string{"9223372036854775808"},
  2287  			wantCL: "",
  2288  		},
  2289  	}
  2290  
  2291  	for _, tt := range tests {
  2292  		tt := tt
  2293  		t.Run(tt.name, func(t *testing.T) {
  2294  			st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2295  				w.Header().Set("Content-Length", tt.cl[0])
  2296  			}, optOnlyServer)
  2297  			defer st.Close()
  2298  			tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2299  			defer tr.CloseIdleConnections()
  2300  
  2301  			req, _ := http.NewRequest("HEAD", st.ts.URL, nil)
  2302  			res, err := tr.RoundTrip(req)
  2303  
  2304  			var got string
  2305  			if err != nil {
  2306  				got = fmt.Sprintf("ERROR: %v", err)
  2307  			} else {
  2308  				got = res.Header.Get("Content-Length")
  2309  				res.Body.Close()
  2310  			}
  2311  
  2312  			if got != tt.wantCL {
  2313  				t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL)
  2314  			}
  2315  		})
  2316  	}
  2317  }
  2318  
  2319  // golang.org/issue/14048
  2320  func TestTransportFailsOnInvalidHeaders(t *testing.T) {
  2321  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2322  		var got []string
  2323  		for k := range r.Header {
  2324  			got = append(got, k)
  2325  		}
  2326  		sort.Strings(got)
  2327  		w.Header().Set("Got-Header", strings.Join(got, ","))
  2328  	}, optOnlyServer)
  2329  	defer st.Close()
  2330  
  2331  	tests := [...]struct {
  2332  		h       http.Header
  2333  		wantErr string
  2334  	}{
  2335  		0: {
  2336  			h:       http.Header{"with space": {"foo"}},
  2337  			wantErr: `invalid HTTP header name "with space"`,
  2338  		},
  2339  		1: {
  2340  			h:       http.Header{"name": {"Брэд"}},
  2341  			wantErr: "", // okay
  2342  		},
  2343  		2: {
  2344  			h:       http.Header{"имя": {"Brad"}},
  2345  			wantErr: `invalid HTTP header name "имя"`,
  2346  		},
  2347  		3: {
  2348  			h:       http.Header{"foo": {"foo\x01bar"}},
  2349  			wantErr: `invalid HTTP header value "foo\x01bar" for header "foo"`,
  2350  		},
  2351  	}
  2352  
  2353  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2354  	defer tr.CloseIdleConnections()
  2355  
  2356  	for i, tt := range tests {
  2357  		req, _ := http.NewRequest("GET", st.ts.URL, nil)
  2358  		req.Header = tt.h
  2359  		res, err := tr.RoundTrip(req)
  2360  		var bad bool
  2361  		if tt.wantErr == "" {
  2362  			if err != nil {
  2363  				bad = true
  2364  				t.Errorf("case %d: error = %v; want no error", i, err)
  2365  			}
  2366  		} else {
  2367  			if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
  2368  				bad = true
  2369  				t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr)
  2370  			}
  2371  		}
  2372  		if err == nil {
  2373  			if bad {
  2374  				t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header"))
  2375  			}
  2376  			res.Body.Close()
  2377  		}
  2378  	}
  2379  }
  2380  
  2381  func TestTransportNewTLSConfig(t *testing.T) {
  2382  	tests := [...]struct {
  2383  		conf *tls.Config
  2384  		host string
  2385  		want *tls.Config
  2386  	}{
  2387  		// Normal case.
  2388  		0: {
  2389  			conf: nil,
  2390  			host: "foo.com",
  2391  			want: &tls.Config{
  2392  				ServerName: "foo.com",
  2393  				NextProtos: []string{NextProtoTLS},
  2394  			},
  2395  		},
  2396  
  2397  		// User-provided name (bar.com) takes precedence:
  2398  		1: {
  2399  			conf: &tls.Config{
  2400  				ServerName: "bar.com",
  2401  			},
  2402  			host: "foo.com",
  2403  			want: &tls.Config{
  2404  				ServerName: "bar.com",
  2405  				NextProtos: []string{NextProtoTLS},
  2406  			},
  2407  		},
  2408  
  2409  		// NextProto is prepended:
  2410  		2: {
  2411  			conf: &tls.Config{
  2412  				NextProtos: []string{"foo", "bar"},
  2413  			},
  2414  			host: "example.com",
  2415  			want: &tls.Config{
  2416  				ServerName: "example.com",
  2417  				NextProtos: []string{NextProtoTLS, "foo", "bar"},
  2418  			},
  2419  		},
  2420  
  2421  		// NextProto is not duplicated:
  2422  		3: {
  2423  			conf: &tls.Config{
  2424  				NextProtos: []string{"foo", "bar", NextProtoTLS},
  2425  			},
  2426  			host: "example.com",
  2427  			want: &tls.Config{
  2428  				ServerName: "example.com",
  2429  				NextProtos: []string{"foo", "bar", NextProtoTLS},
  2430  			},
  2431  		},
  2432  	}
  2433  	for i, tt := range tests {
  2434  		// Ignore the session ticket keys part, which ends up populating
  2435  		// unexported fields in the Config:
  2436  		if tt.conf != nil {
  2437  			tt.conf.SessionTicketsDisabled = true
  2438  		}
  2439  
  2440  		tr := &Transport{TLSClientConfig: tt.conf}
  2441  		got := tr.newTLSConfig(tt.host)
  2442  
  2443  		got.SessionTicketsDisabled = false
  2444  
  2445  		if !reflect.DeepEqual(got, tt.want) {
  2446  			t.Errorf("%d. got %#v; want %#v", i, got, tt.want)
  2447  		}
  2448  	}
  2449  }
  2450  
  2451  // The Google GFE responds to HEAD requests with a HEADERS frame
  2452  // without END_STREAM, followed by a 0-length DATA frame with
  2453  // END_STREAM. Make sure we don't get confused by that. (We did.)
  2454  func TestTransportReadHeadResponse(t *testing.T) {
  2455  	ct := newClientTester(t)
  2456  	clientDone := make(chan struct{})
  2457  	ct.client = func() error {
  2458  		defer close(clientDone)
  2459  		req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
  2460  		res, err := ct.tr.RoundTrip(req)
  2461  		if err != nil {
  2462  			return err
  2463  		}
  2464  		if res.ContentLength != 123 {
  2465  			return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength)
  2466  		}
  2467  		slurp, err := ioutil.ReadAll(res.Body)
  2468  		if err != nil {
  2469  			return fmt.Errorf("ReadAll: %v", err)
  2470  		}
  2471  		if len(slurp) > 0 {
  2472  			return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
  2473  		}
  2474  		return nil
  2475  	}
  2476  	ct.server = func() error {
  2477  		ct.greet()
  2478  		for {
  2479  			f, err := ct.fr.ReadFrame()
  2480  			if err != nil {
  2481  				t.Logf("ReadFrame: %v", err)
  2482  				return nil
  2483  			}
  2484  			hf, ok := f.(*HeadersFrame)
  2485  			if !ok {
  2486  				continue
  2487  			}
  2488  			var buf bytes.Buffer
  2489  			enc := hpack.NewEncoder(&buf)
  2490  			enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2491  			enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
  2492  			ct.fr.WriteHeaders(HeadersFrameParam{
  2493  				StreamID:      hf.StreamID,
  2494  				EndHeaders:    true,
  2495  				EndStream:     false, // as the GFE does
  2496  				BlockFragment: buf.Bytes(),
  2497  			})
  2498  			ct.fr.WriteData(hf.StreamID, true, nil)
  2499  
  2500  			<-clientDone
  2501  			return nil
  2502  		}
  2503  	}
  2504  	ct.run()
  2505  }
  2506  
  2507  func TestTransportReadHeadResponseWithBody(t *testing.T) {
  2508  	// This test use not valid response format.
  2509  	// Discarding logger output to not spam tests output.
  2510  	log.SetOutput(ioutil.Discard)
  2511  	defer log.SetOutput(os.Stderr)
  2512  
  2513  	response := "redirecting to /elsewhere"
  2514  	ct := newClientTester(t)
  2515  	clientDone := make(chan struct{})
  2516  	ct.client = func() error {
  2517  		defer close(clientDone)
  2518  		req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
  2519  		res, err := ct.tr.RoundTrip(req)
  2520  		if err != nil {
  2521  			return err
  2522  		}
  2523  		if res.ContentLength != int64(len(response)) {
  2524  			return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response))
  2525  		}
  2526  		slurp, err := ioutil.ReadAll(res.Body)
  2527  		if err != nil {
  2528  			return fmt.Errorf("ReadAll: %v", err)
  2529  		}
  2530  		if len(slurp) > 0 {
  2531  			return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
  2532  		}
  2533  		return nil
  2534  	}
  2535  	ct.server = func() error {
  2536  		ct.greet()
  2537  		for {
  2538  			f, err := ct.fr.ReadFrame()
  2539  			if err != nil {
  2540  				t.Logf("ReadFrame: %v", err)
  2541  				return nil
  2542  			}
  2543  			hf, ok := f.(*HeadersFrame)
  2544  			if !ok {
  2545  				continue
  2546  			}
  2547  			var buf bytes.Buffer
  2548  			enc := hpack.NewEncoder(&buf)
  2549  			enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2550  			enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))})
  2551  			ct.fr.WriteHeaders(HeadersFrameParam{
  2552  				StreamID:      hf.StreamID,
  2553  				EndHeaders:    true,
  2554  				EndStream:     false,
  2555  				BlockFragment: buf.Bytes(),
  2556  			})
  2557  			ct.fr.WriteData(hf.StreamID, true, []byte(response))
  2558  
  2559  			<-clientDone
  2560  			return nil
  2561  		}
  2562  	}
  2563  	ct.run()
  2564  }
  2565  
  2566  type neverEnding byte
  2567  
  2568  func (b neverEnding) Read(p []byte) (int, error) {
  2569  	for i := range p {
  2570  		p[i] = byte(b)
  2571  	}
  2572  	return len(p), nil
  2573  }
  2574  
  2575  // golang.org/issue/15425: test that a handler closing the request
  2576  // body doesn't terminate the stream to the peer. (It just stops
  2577  // readability from the handler's side, and eventually the client
  2578  // runs out of flow control tokens)
  2579  func TestTransportHandlerBodyClose(t *testing.T) {
  2580  	const bodySize = 10 << 20
  2581  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2582  		r.Body.Close()
  2583  		io.Copy(w, io.LimitReader(neverEnding('A'), bodySize))
  2584  	}, optOnlyServer)
  2585  	defer st.Close()
  2586  
  2587  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2588  	defer tr.CloseIdleConnections()
  2589  
  2590  	g0 := runtime.NumGoroutine()
  2591  
  2592  	const numReq = 10
  2593  	for i := 0; i < numReq; i++ {
  2594  		req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
  2595  		if err != nil {
  2596  			t.Fatal(err)
  2597  		}
  2598  		res, err := tr.RoundTrip(req)
  2599  		if err != nil {
  2600  			t.Fatal(err)
  2601  		}
  2602  		n, err := io.Copy(ioutil.Discard, res.Body)
  2603  		res.Body.Close()
  2604  		if n != bodySize || err != nil {
  2605  			t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize)
  2606  		}
  2607  	}
  2608  	tr.CloseIdleConnections()
  2609  
  2610  	if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool {
  2611  		gd := runtime.NumGoroutine() - g0
  2612  		return gd < numReq/2
  2613  	}) {
  2614  		t.Errorf("appeared to leak goroutines")
  2615  	}
  2616  }
  2617  
  2618  // https://golang.org/issue/15930
  2619  func TestTransportFlowControl(t *testing.T) {
  2620  	const bufLen = 64 << 10
  2621  	var total int64 = 100 << 20 // 100MB
  2622  	if testing.Short() {
  2623  		total = 10 << 20
  2624  	}
  2625  
  2626  	var wrote int64 // updated atomically
  2627  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  2628  		b := make([]byte, bufLen)
  2629  		for wrote < total {
  2630  			n, err := w.Write(b)
  2631  			atomic.AddInt64(&wrote, int64(n))
  2632  			if err != nil {
  2633  				t.Errorf("ResponseWriter.Write error: %v", err)
  2634  				break
  2635  			}
  2636  			w.(http.Flusher).Flush()
  2637  		}
  2638  	}, optOnlyServer)
  2639  
  2640  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  2641  	defer tr.CloseIdleConnections()
  2642  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  2643  	if err != nil {
  2644  		t.Fatal("NewRequest error:", err)
  2645  	}
  2646  	resp, err := tr.RoundTrip(req)
  2647  	if err != nil {
  2648  		t.Fatal("RoundTrip error:", err)
  2649  	}
  2650  	defer resp.Body.Close()
  2651  
  2652  	var read int64
  2653  	b := make([]byte, bufLen)
  2654  	for {
  2655  		n, err := resp.Body.Read(b)
  2656  		if err == io.EOF {
  2657  			break
  2658  		}
  2659  		if err != nil {
  2660  			t.Fatal("Read error:", err)
  2661  		}
  2662  		read += int64(n)
  2663  
  2664  		const max = transportDefaultStreamFlow
  2665  		if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max {
  2666  			t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read)
  2667  		}
  2668  
  2669  		// Let the server get ahead of the client.
  2670  		time.Sleep(1 * time.Millisecond)
  2671  	}
  2672  }
  2673  
  2674  // golang.org/issue/14627 -- if the server sends a GOAWAY frame, make
  2675  // the Transport remember it and return it back to users (via
  2676  // RoundTrip or request body reads) if needed (e.g. if the server
  2677  // proceeds to close the TCP connection before the client gets its
  2678  // response)
  2679  func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
  2680  	testTransportUsesGoAwayDebugError(t, false)
  2681  }
  2682  
  2683  func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
  2684  	testTransportUsesGoAwayDebugError(t, true)
  2685  }
  2686  
  2687  func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
  2688  	ct := newClientTester(t)
  2689  	clientDone := make(chan struct{})
  2690  
  2691  	const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
  2692  	const goAwayDebugData = "some debug data"
  2693  
  2694  	ct.client = func() error {
  2695  		defer close(clientDone)
  2696  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  2697  		res, err := ct.tr.RoundTrip(req)
  2698  		if failMidBody {
  2699  			if err != nil {
  2700  				return fmt.Errorf("unexpected client RoundTrip error: %v", err)
  2701  			}
  2702  			_, err = io.Copy(ioutil.Discard, res.Body)
  2703  			res.Body.Close()
  2704  		}
  2705  		want := GoAwayError{
  2706  			LastStreamID: 5,
  2707  			ErrCode:      goAwayErrCode,
  2708  			DebugData:    goAwayDebugData,
  2709  		}
  2710  		if !reflect.DeepEqual(err, want) {
  2711  			t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want)
  2712  		}
  2713  		return nil
  2714  	}
  2715  	ct.server = func() error {
  2716  		ct.greet()
  2717  		for {
  2718  			f, err := ct.fr.ReadFrame()
  2719  			if err != nil {
  2720  				t.Logf("ReadFrame: %v", err)
  2721  				return nil
  2722  			}
  2723  			hf, ok := f.(*HeadersFrame)
  2724  			if !ok {
  2725  				continue
  2726  			}
  2727  			if failMidBody {
  2728  				var buf bytes.Buffer
  2729  				enc := hpack.NewEncoder(&buf)
  2730  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2731  				enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
  2732  				ct.fr.WriteHeaders(HeadersFrameParam{
  2733  					StreamID:      hf.StreamID,
  2734  					EndHeaders:    true,
  2735  					EndStream:     false,
  2736  					BlockFragment: buf.Bytes(),
  2737  				})
  2738  			}
  2739  			// Write two GOAWAY frames, to test that the Transport takes
  2740  			// the interesting parts of both.
  2741  			ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
  2742  			ct.fr.WriteGoAway(5, goAwayErrCode, nil)
  2743  			ct.sc.(*net.TCPConn).CloseWrite()
  2744  			if runtime.GOOS == "plan9" {
  2745  				// CloseWrite not supported on Plan 9; Issue 17906
  2746  				ct.sc.(*net.TCPConn).Close()
  2747  			}
  2748  			<-clientDone
  2749  			return nil
  2750  		}
  2751  	}
  2752  	ct.run()
  2753  }
  2754  
  2755  func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
  2756  	ct := newClientTester(t)
  2757  
  2758  	clientClosed := make(chan struct{})
  2759  	serverWroteFirstByte := make(chan struct{})
  2760  
  2761  	ct.client = func() error {
  2762  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  2763  		res, err := ct.tr.RoundTrip(req)
  2764  		if err != nil {
  2765  			return err
  2766  		}
  2767  		<-serverWroteFirstByte
  2768  
  2769  		if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
  2770  			return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
  2771  		}
  2772  		res.Body.Close() // leaving 4999 bytes unread
  2773  		close(clientClosed)
  2774  
  2775  		return nil
  2776  	}
  2777  	ct.server = func() error {
  2778  		ct.greet()
  2779  
  2780  		var hf *HeadersFrame
  2781  		for {
  2782  			f, err := ct.fr.ReadFrame()
  2783  			if err != nil {
  2784  				return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
  2785  			}
  2786  			switch f.(type) {
  2787  			case *WindowUpdateFrame, *SettingsFrame:
  2788  				continue
  2789  			}
  2790  			var ok bool
  2791  			hf, ok = f.(*HeadersFrame)
  2792  			if !ok {
  2793  				return fmt.Errorf("Got %T; want HeadersFrame", f)
  2794  			}
  2795  			break
  2796  		}
  2797  
  2798  		var buf bytes.Buffer
  2799  		enc := hpack.NewEncoder(&buf)
  2800  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2801  		enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
  2802  		ct.fr.WriteHeaders(HeadersFrameParam{
  2803  			StreamID:      hf.StreamID,
  2804  			EndHeaders:    true,
  2805  			EndStream:     false,
  2806  			BlockFragment: buf.Bytes(),
  2807  		})
  2808  
  2809  		// Two cases:
  2810  		// - Send one DATA frame with 5000 bytes.
  2811  		// - Send two DATA frames with 1 and 4999 bytes each.
  2812  		//
  2813  		// In both cases, the client should consume one byte of data,
  2814  		// refund that byte, then refund the following 4999 bytes.
  2815  		//
  2816  		// In the second case, the server waits for the client connection to
  2817  		// close before seconding the second DATA frame. This tests the case
  2818  		// where the client receives a DATA frame after it has reset the stream.
  2819  		if oneDataFrame {
  2820  			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000))
  2821  			close(serverWroteFirstByte)
  2822  			<-clientClosed
  2823  		} else {
  2824  			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1))
  2825  			close(serverWroteFirstByte)
  2826  			<-clientClosed
  2827  			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
  2828  		}
  2829  
  2830  		waitingFor := "RSTStreamFrame"
  2831  		for {
  2832  			f, err := ct.fr.ReadFrame()
  2833  			if err != nil {
  2834  				return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err)
  2835  			}
  2836  			if _, ok := f.(*SettingsFrame); ok {
  2837  				continue
  2838  			}
  2839  			switch waitingFor {
  2840  			case "RSTStreamFrame":
  2841  				if rf, ok := f.(*RSTStreamFrame); !ok || rf.ErrCode != ErrCodeCancel {
  2842  					return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
  2843  				}
  2844  				waitingFor = "WindowUpdateFrame"
  2845  			case "WindowUpdateFrame":
  2846  				if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != 4999 {
  2847  					return fmt.Errorf("Expected WindowUpdateFrame for 4999 bytes; got %v", summarizeFrame(f))
  2848  				}
  2849  				return nil
  2850  			}
  2851  		}
  2852  	}
  2853  	ct.run()
  2854  }
  2855  
  2856  // See golang.org/issue/16481
  2857  func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) {
  2858  	testTransportReturnsUnusedFlowControl(t, true)
  2859  }
  2860  
  2861  // See golang.org/issue/20469
  2862  func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
  2863  	testTransportReturnsUnusedFlowControl(t, false)
  2864  }
  2865  
  2866  // Issue 16612: adjust flow control on open streams when transport
  2867  // receives SETTINGS with INITIAL_WINDOW_SIZE from server.
  2868  func TestTransportAdjustsFlowControl(t *testing.T) {
  2869  	ct := newClientTester(t)
  2870  	clientDone := make(chan struct{})
  2871  
  2872  	const bodySize = 1 << 20
  2873  
  2874  	ct.client = func() error {
  2875  		defer ct.cc.(*net.TCPConn).CloseWrite()
  2876  		if runtime.GOOS == "plan9" {
  2877  			// CloseWrite not supported on Plan 9; Issue 17906
  2878  			defer ct.cc.(*net.TCPConn).Close()
  2879  		}
  2880  		defer close(clientDone)
  2881  
  2882  		req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
  2883  		res, err := ct.tr.RoundTrip(req)
  2884  		if err != nil {
  2885  			return err
  2886  		}
  2887  		res.Body.Close()
  2888  		return nil
  2889  	}
  2890  	ct.server = func() error {
  2891  		_, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface)))
  2892  		if err != nil {
  2893  			return fmt.Errorf("reading client preface: %v", err)
  2894  		}
  2895  
  2896  		var gotBytes int64
  2897  		var sentSettings bool
  2898  		for {
  2899  			f, err := ct.fr.ReadFrame()
  2900  			if err != nil {
  2901  				select {
  2902  				case <-clientDone:
  2903  					return nil
  2904  				default:
  2905  					return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
  2906  				}
  2907  			}
  2908  			switch f := f.(type) {
  2909  			case *DataFrame:
  2910  				gotBytes += int64(len(f.Data()))
  2911  				// After we've got half the client's
  2912  				// initial flow control window's worth
  2913  				// of request body data, give it just
  2914  				// enough flow control to finish.
  2915  				if gotBytes >= initialWindowSize/2 && !sentSettings {
  2916  					sentSettings = true
  2917  
  2918  					ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
  2919  					ct.fr.WriteWindowUpdate(0, bodySize)
  2920  					ct.fr.WriteSettingsAck()
  2921  				}
  2922  
  2923  				if f.StreamEnded() {
  2924  					var buf bytes.Buffer
  2925  					enc := hpack.NewEncoder(&buf)
  2926  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2927  					ct.fr.WriteHeaders(HeadersFrameParam{
  2928  						StreamID:      f.StreamID,
  2929  						EndHeaders:    true,
  2930  						EndStream:     true,
  2931  						BlockFragment: buf.Bytes(),
  2932  					})
  2933  				}
  2934  			}
  2935  		}
  2936  	}
  2937  	ct.run()
  2938  }
  2939  
  2940  // See golang.org/issue/16556
  2941  func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
  2942  	ct := newClientTester(t)
  2943  
  2944  	unblockClient := make(chan bool, 1)
  2945  
  2946  	ct.client = func() error {
  2947  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  2948  		res, err := ct.tr.RoundTrip(req)
  2949  		if err != nil {
  2950  			return err
  2951  		}
  2952  		defer res.Body.Close()
  2953  		<-unblockClient
  2954  		return nil
  2955  	}
  2956  	ct.server = func() error {
  2957  		ct.greet()
  2958  
  2959  		var hf *HeadersFrame
  2960  		for {
  2961  			f, err := ct.fr.ReadFrame()
  2962  			if err != nil {
  2963  				return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
  2964  			}
  2965  			switch f.(type) {
  2966  			case *WindowUpdateFrame, *SettingsFrame:
  2967  				continue
  2968  			}
  2969  			var ok bool
  2970  			hf, ok = f.(*HeadersFrame)
  2971  			if !ok {
  2972  				return fmt.Errorf("Got %T; want HeadersFrame", f)
  2973  			}
  2974  			break
  2975  		}
  2976  
  2977  		var buf bytes.Buffer
  2978  		enc := hpack.NewEncoder(&buf)
  2979  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  2980  		enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
  2981  		ct.fr.WriteHeaders(HeadersFrameParam{
  2982  			StreamID:      hf.StreamID,
  2983  			EndHeaders:    true,
  2984  			EndStream:     false,
  2985  			BlockFragment: buf.Bytes(),
  2986  		})
  2987  		pad := make([]byte, 5)
  2988  		ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream
  2989  
  2990  		f, err := ct.readNonSettingsFrame()
  2991  		if err != nil {
  2992  			return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err)
  2993  		}
  2994  		wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding
  2995  		if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 {
  2996  			return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f))
  2997  		}
  2998  
  2999  		f, err = ct.readNonSettingsFrame()
  3000  		if err != nil {
  3001  			return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err)
  3002  		}
  3003  		if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 {
  3004  			return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f))
  3005  		}
  3006  		unblockClient <- true
  3007  		return nil
  3008  	}
  3009  	ct.run()
  3010  }
  3011  
  3012  // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
  3013  // StreamError as a result of the response HEADERS
  3014  func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
  3015  	ct := newClientTester(t)
  3016  
  3017  	ct.client = func() error {
  3018  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3019  		res, err := ct.tr.RoundTrip(req)
  3020  		if err == nil {
  3021  			res.Body.Close()
  3022  			return errors.New("unexpected successful GET")
  3023  		}
  3024  		want := StreamError{1, ErrCodeProtocol, headerFieldNameError("  content-type")}
  3025  		if !reflect.DeepEqual(want, err) {
  3026  			t.Errorf("RoundTrip error = %#v; want %#v", err, want)
  3027  		}
  3028  		return nil
  3029  	}
  3030  	ct.server = func() error {
  3031  		ct.greet()
  3032  
  3033  		hf, err := ct.firstHeaders()
  3034  		if err != nil {
  3035  			return err
  3036  		}
  3037  
  3038  		var buf bytes.Buffer
  3039  		enc := hpack.NewEncoder(&buf)
  3040  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  3041  		enc.WriteField(hpack.HeaderField{Name: "  content-type", Value: "bogus"}) // bogus spaces
  3042  		ct.fr.WriteHeaders(HeadersFrameParam{
  3043  			StreamID:      hf.StreamID,
  3044  			EndHeaders:    true,
  3045  			EndStream:     false,
  3046  			BlockFragment: buf.Bytes(),
  3047  		})
  3048  
  3049  		for {
  3050  			fr, err := ct.readFrame()
  3051  			if err != nil {
  3052  				return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
  3053  			}
  3054  			if _, ok := fr.(*SettingsFrame); ok {
  3055  				continue
  3056  			}
  3057  			if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
  3058  				t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
  3059  			}
  3060  			break
  3061  		}
  3062  
  3063  		return nil
  3064  	}
  3065  	ct.run()
  3066  }
  3067  
  3068  // byteAndEOFReader returns is in an io.Reader which reads one byte
  3069  // (the underlying byte) and io.EOF at once in its Read call.
  3070  type byteAndEOFReader byte
  3071  
  3072  func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
  3073  	if len(p) == 0 {
  3074  		panic("unexpected useless call")
  3075  	}
  3076  	p[0] = byte(b)
  3077  	return 1, io.EOF
  3078  }
  3079  
  3080  // Issue 16788: the Transport had a regression where it started
  3081  // sending a spurious DATA frame with a duplicate END_STREAM bit after
  3082  // the request body writer goroutine had already read an EOF from the
  3083  // Request.Body and included the END_STREAM on a data-carrying DATA
  3084  // frame.
  3085  //
  3086  // Notably, to trigger this, the requests need to use a Request.Body
  3087  // which returns (non-0, io.EOF) and also needs to set the ContentLength
  3088  // explicitly.
  3089  func TestTransportBodyDoubleEndStream(t *testing.T) {
  3090  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  3091  		// Nothing.
  3092  	}, optOnlyServer)
  3093  	defer st.Close()
  3094  
  3095  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  3096  	defer tr.CloseIdleConnections()
  3097  
  3098  	for i := 0; i < 2; i++ {
  3099  		req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a'))
  3100  		req.ContentLength = 1
  3101  		res, err := tr.RoundTrip(req)
  3102  		if err != nil {
  3103  			t.Fatalf("failure on req %d: %v", i+1, err)
  3104  		}
  3105  		defer res.Body.Close()
  3106  	}
  3107  }
  3108  
  3109  // golang.org/issue/16847, golang.org/issue/19103
  3110  func TestTransportRequestPathPseudo(t *testing.T) {
  3111  	type result struct {
  3112  		path string
  3113  		err  string
  3114  	}
  3115  	tests := []struct {
  3116  		req  *http.Request
  3117  		want result
  3118  	}{
  3119  		0: {
  3120  			req: &http.Request{
  3121  				Method: "GET",
  3122  				URL: &url.URL{
  3123  					Host: "foo.com",
  3124  					Path: "/foo",
  3125  				},
  3126  			},
  3127  			want: result{path: "/foo"},
  3128  		},
  3129  		// In Go 1.7, we accepted paths of "//foo".
  3130  		// In Go 1.8, we rejected it (issue 16847).
  3131  		// In Go 1.9, we accepted it again (issue 19103).
  3132  		1: {
  3133  			req: &http.Request{
  3134  				Method: "GET",
  3135  				URL: &url.URL{
  3136  					Host: "foo.com",
  3137  					Path: "//foo",
  3138  				},
  3139  			},
  3140  			want: result{path: "//foo"},
  3141  		},
  3142  
  3143  		// Opaque with //$Matching_Hostname/path
  3144  		2: {
  3145  			req: &http.Request{
  3146  				Method: "GET",
  3147  				URL: &url.URL{
  3148  					Scheme: "https",
  3149  					Opaque: "//foo.com/path",
  3150  					Host:   "foo.com",
  3151  					Path:   "/ignored",
  3152  				},
  3153  			},
  3154  			want: result{path: "/path"},
  3155  		},
  3156  
  3157  		// Opaque with some other Request.Host instead:
  3158  		3: {
  3159  			req: &http.Request{
  3160  				Method: "GET",
  3161  				Host:   "bar.com",
  3162  				URL: &url.URL{
  3163  					Scheme: "https",
  3164  					Opaque: "//bar.com/path",
  3165  					Host:   "foo.com",
  3166  					Path:   "/ignored",
  3167  				},
  3168  			},
  3169  			want: result{path: "/path"},
  3170  		},
  3171  
  3172  		// Opaque without the leading "//":
  3173  		4: {
  3174  			req: &http.Request{
  3175  				Method: "GET",
  3176  				URL: &url.URL{
  3177  					Opaque: "/path",
  3178  					Host:   "foo.com",
  3179  					Path:   "/ignored",
  3180  				},
  3181  			},
  3182  			want: result{path: "/path"},
  3183  		},
  3184  
  3185  		// Opaque we can't handle:
  3186  		5: {
  3187  			req: &http.Request{
  3188  				Method: "GET",
  3189  				URL: &url.URL{
  3190  					Scheme: "https",
  3191  					Opaque: "//unknown_host/path",
  3192  					Host:   "foo.com",
  3193  					Path:   "/ignored",
  3194  				},
  3195  			},
  3196  			want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`},
  3197  		},
  3198  
  3199  		// A CONNECT request:
  3200  		6: {
  3201  			req: &http.Request{
  3202  				Method: "CONNECT",
  3203  				URL: &url.URL{
  3204  					Host: "foo.com",
  3205  				},
  3206  			},
  3207  			want: result{},
  3208  		},
  3209  	}
  3210  	for i, tt := range tests {
  3211  		cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
  3212  		cc.henc = hpack.NewEncoder(&cc.hbuf)
  3213  		cc.mu.Lock()
  3214  		hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
  3215  		cc.mu.Unlock()
  3216  		var got result
  3217  		hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
  3218  			if f.Name == ":path" {
  3219  				got.path = f.Value
  3220  			}
  3221  		})
  3222  		if err != nil {
  3223  			got.err = err.Error()
  3224  		} else if len(hdrs) > 0 {
  3225  			if _, err := hpackDec.Write(hdrs); err != nil {
  3226  				t.Errorf("%d. bogus hpack: %v", i, err)
  3227  				continue
  3228  			}
  3229  		}
  3230  		if got != tt.want {
  3231  			t.Errorf("%d. got %+v; want %+v", i, got, tt.want)
  3232  		}
  3233  
  3234  	}
  3235  
  3236  }
  3237  
  3238  // golang.org/issue/17071 -- don't sniff the first byte of the request body
  3239  // before we've determined that the ClientConn is usable.
  3240  func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
  3241  	const body = "foo"
  3242  	req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body)))
  3243  	cc := &ClientConn{
  3244  		closed: true,
  3245  	}
  3246  	_, err := cc.RoundTrip(req)
  3247  	if err != errClientConnUnusable {
  3248  		t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err)
  3249  	}
  3250  	slurp, err := ioutil.ReadAll(req.Body)
  3251  	if err != nil {
  3252  		t.Errorf("ReadAll = %v", err)
  3253  	}
  3254  	if string(slurp) != body {
  3255  		t.Errorf("Body = %q; want %q", slurp, body)
  3256  	}
  3257  }
  3258  
  3259  func TestClientConnPing(t *testing.T) {
  3260  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
  3261  	defer st.Close()
  3262  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  3263  	defer tr.CloseIdleConnections()
  3264  	cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
  3265  	if err != nil {
  3266  		t.Fatal(err)
  3267  	}
  3268  	if err = cc.Ping(context.Background()); err != nil {
  3269  		t.Fatal(err)
  3270  	}
  3271  }
  3272  
  3273  // Issue 16974: if the server sent a DATA frame after the user
  3274  // canceled the Transport's Request, the Transport previously wrote to a
  3275  // closed pipe, got an error, and ended up closing the whole TCP
  3276  // connection.
  3277  func TestTransportCancelDataResponseRace(t *testing.T) {
  3278  	cancel := make(chan struct{})
  3279  	clientGotError := make(chan bool, 1)
  3280  
  3281  	const msg = "Hello."
  3282  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  3283  		if strings.Contains(r.URL.Path, "/hello") {
  3284  			time.Sleep(50 * time.Millisecond)
  3285  			io.WriteString(w, msg)
  3286  			return
  3287  		}
  3288  		for i := 0; i < 50; i++ {
  3289  			io.WriteString(w, "Some data.")
  3290  			w.(http.Flusher).Flush()
  3291  			if i == 2 {
  3292  				close(cancel)
  3293  				<-clientGotError
  3294  			}
  3295  			time.Sleep(10 * time.Millisecond)
  3296  		}
  3297  	}, optOnlyServer)
  3298  	defer st.Close()
  3299  
  3300  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  3301  	defer tr.CloseIdleConnections()
  3302  
  3303  	c := &http.Client{Transport: tr}
  3304  	req, _ := http.NewRequest("GET", st.ts.URL, nil)
  3305  	req.Cancel = cancel
  3306  	res, err := c.Do(req)
  3307  	if err != nil {
  3308  		t.Fatal(err)
  3309  	}
  3310  	if _, err = io.Copy(ioutil.Discard, res.Body); err == nil {
  3311  		t.Fatal("unexpected success")
  3312  	}
  3313  	clientGotError <- true
  3314  
  3315  	res, err = c.Get(st.ts.URL + "/hello")
  3316  	if err != nil {
  3317  		t.Fatal(err)
  3318  	}
  3319  	slurp, err := ioutil.ReadAll(res.Body)
  3320  	if err != nil {
  3321  		t.Fatal(err)
  3322  	}
  3323  	if string(slurp) != msg {
  3324  		t.Errorf("Got = %q; want %q", slurp, msg)
  3325  	}
  3326  }
  3327  
  3328  // Issue 21316: It should be safe to reuse an http.Request after the
  3329  // request has completed.
  3330  func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
  3331  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  3332  		w.WriteHeader(200)
  3333  		io.WriteString(w, "body")
  3334  	}, optOnlyServer)
  3335  	defer st.Close()
  3336  
  3337  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  3338  	defer tr.CloseIdleConnections()
  3339  
  3340  	req, _ := http.NewRequest("GET", st.ts.URL, nil)
  3341  	resp, err := tr.RoundTrip(req)
  3342  	if err != nil {
  3343  		t.Fatal(err)
  3344  	}
  3345  	if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil {
  3346  		t.Fatalf("error reading response body: %v", err)
  3347  	}
  3348  	if err := resp.Body.Close(); err != nil {
  3349  		t.Fatalf("error closing response body: %v", err)
  3350  	}
  3351  
  3352  	// This access of req.Header should not race with code in the transport.
  3353  	req.Header = http.Header{}
  3354  }
  3355  
  3356  func TestTransportCloseAfterLostPing(t *testing.T) {
  3357  	clientDone := make(chan struct{})
  3358  	ct := newClientTester(t)
  3359  	ct.tr.PingTimeout = 1 * time.Second
  3360  	ct.tr.ReadIdleTimeout = 1 * time.Second
  3361  	ct.client = func() error {
  3362  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3363  		defer close(clientDone)
  3364  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3365  		_, err := ct.tr.RoundTrip(req)
  3366  		if err == nil || !strings.Contains(err.Error(), "client connection lost") {
  3367  			return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
  3368  		}
  3369  		return nil
  3370  	}
  3371  	ct.server = func() error {
  3372  		ct.greet()
  3373  		<-clientDone
  3374  		return nil
  3375  	}
  3376  	ct.run()
  3377  }
  3378  
  3379  func TestTransportPingWhenReading(t *testing.T) {
  3380  	testCases := []struct {
  3381  		name                   string
  3382  		readIdleTimeout        time.Duration
  3383  		serverResponseInterval time.Duration
  3384  		expectedPingCount      int
  3385  	}{
  3386  		{
  3387  			name:                   "two pings in each serverResponseInterval",
  3388  			readIdleTimeout:        400 * time.Millisecond,
  3389  			serverResponseInterval: 1000 * time.Millisecond,
  3390  			expectedPingCount:      4,
  3391  		},
  3392  		{
  3393  			name:                   "one ping in each serverResponseInterval",
  3394  			readIdleTimeout:        700 * time.Millisecond,
  3395  			serverResponseInterval: 1000 * time.Millisecond,
  3396  			expectedPingCount:      2,
  3397  		},
  3398  		{
  3399  			name:                   "zero ping in each serverResponseInterval",
  3400  			readIdleTimeout:        1000 * time.Millisecond,
  3401  			serverResponseInterval: 500 * time.Millisecond,
  3402  			expectedPingCount:      0,
  3403  		},
  3404  		{
  3405  			name:                   "0 readIdleTimeout means no ping",
  3406  			readIdleTimeout:        0 * time.Millisecond,
  3407  			serverResponseInterval: 500 * time.Millisecond,
  3408  			expectedPingCount:      0,
  3409  		},
  3410  	}
  3411  
  3412  	for _, tc := range testCases {
  3413  		tc := tc // capture range variable
  3414  		t.Run(tc.name, func(t *testing.T) {
  3415  			t.Parallel()
  3416  			testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
  3417  		})
  3418  	}
  3419  }
  3420  
  3421  func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
  3422  	var pingCount int
  3423  	clientDone := make(chan struct{})
  3424  	ct := newClientTester(t)
  3425  	ct.tr.PingTimeout = 10 * time.Millisecond
  3426  	ct.tr.ReadIdleTimeout = readIdleTimeout
  3427  	// guards the ct.fr.Write
  3428  	var wmu sync.Mutex
  3429  
  3430  	ct.client = func() error {
  3431  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3432  		if runtime.GOOS == "plan9" {
  3433  			// CloseWrite not supported on Plan 9; Issue 17906
  3434  			defer ct.cc.(*net.TCPConn).Close()
  3435  		}
  3436  		defer close(clientDone)
  3437  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3438  		res, err := ct.tr.RoundTrip(req)
  3439  		if err != nil {
  3440  			return fmt.Errorf("RoundTrip: %v", err)
  3441  		}
  3442  		defer res.Body.Close()
  3443  		if res.StatusCode != 200 {
  3444  			return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
  3445  		}
  3446  		_, err = ioutil.ReadAll(res.Body)
  3447  		return err
  3448  	}
  3449  
  3450  	ct.server = func() error {
  3451  		ct.greet()
  3452  		var buf bytes.Buffer
  3453  		enc := hpack.NewEncoder(&buf)
  3454  		var wg sync.WaitGroup
  3455  		defer wg.Wait()
  3456  		for {
  3457  			f, err := ct.fr.ReadFrame()
  3458  			if err != nil {
  3459  				select {
  3460  				case <-clientDone:
  3461  					// If the client's done, it
  3462  					// will have reported any
  3463  					// errors on its side.
  3464  					return nil
  3465  				default:
  3466  					return err
  3467  				}
  3468  			}
  3469  			switch f := f.(type) {
  3470  			case *WindowUpdateFrame, *SettingsFrame:
  3471  			case *HeadersFrame:
  3472  				if !f.HeadersEnded() {
  3473  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  3474  				}
  3475  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
  3476  				ct.fr.WriteHeaders(HeadersFrameParam{
  3477  					StreamID:      f.StreamID,
  3478  					EndHeaders:    true,
  3479  					EndStream:     false,
  3480  					BlockFragment: buf.Bytes(),
  3481  				})
  3482  
  3483  				wg.Add(1)
  3484  				go func() {
  3485  					defer wg.Done()
  3486  					for i := 0; i < 2; i++ {
  3487  						wmu.Lock()
  3488  						if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
  3489  							wmu.Unlock()
  3490  							t.Error(err)
  3491  							return
  3492  						}
  3493  						wmu.Unlock()
  3494  						time.Sleep(serverResponseInterval)
  3495  					}
  3496  					wmu.Lock()
  3497  					if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
  3498  						wmu.Unlock()
  3499  						t.Error(err)
  3500  						return
  3501  					}
  3502  					wmu.Unlock()
  3503  				}()
  3504  			case *PingFrame:
  3505  				pingCount++
  3506  				wmu.Lock()
  3507  				if err := ct.fr.WritePing(true, f.Data); err != nil {
  3508  					wmu.Unlock()
  3509  					return err
  3510  				}
  3511  				wmu.Unlock()
  3512  			default:
  3513  				return fmt.Errorf("Unexpected client frame %v", f)
  3514  			}
  3515  		}
  3516  	}
  3517  	ct.run()
  3518  	if e, a := expectedPingCount, pingCount; e != a {
  3519  		t.Errorf("expected receiving %d pings, got %d pings", e, a)
  3520  
  3521  	}
  3522  }
  3523  
  3524  func TestTransportRetryAfterGOAWAY(t *testing.T) {
  3525  	var dialer struct {
  3526  		sync.Mutex
  3527  		count int
  3528  	}
  3529  	ct1 := make(chan *clientTester)
  3530  	ct2 := make(chan *clientTester)
  3531  
  3532  	ln := newLocalListener(t)
  3533  	defer ln.Close()
  3534  
  3535  	tr := &Transport{
  3536  		TLSClientConfig: tlsConfigInsecure,
  3537  	}
  3538  	tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  3539  		dialer.Lock()
  3540  		defer dialer.Unlock()
  3541  		dialer.count++
  3542  		if dialer.count == 3 {
  3543  			return nil, errors.New("unexpected number of dials")
  3544  		}
  3545  		cc, err := net.Dial("tcp", ln.Addr().String())
  3546  		if err != nil {
  3547  			return nil, fmt.Errorf("dial error: %v", err)
  3548  		}
  3549  		sc, err := ln.Accept()
  3550  		if err != nil {
  3551  			return nil, fmt.Errorf("accept error: %v", err)
  3552  		}
  3553  		ct := &clientTester{
  3554  			t:  t,
  3555  			tr: tr,
  3556  			cc: cc,
  3557  			sc: sc,
  3558  			fr: NewFramer(sc, sc),
  3559  		}
  3560  		switch dialer.count {
  3561  		case 1:
  3562  			ct1 <- ct
  3563  		case 2:
  3564  			ct2 <- ct
  3565  		}
  3566  		return cc, nil
  3567  	}
  3568  
  3569  	errs := make(chan error, 3)
  3570  
  3571  	// Client.
  3572  	go func() {
  3573  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3574  		res, err := tr.RoundTrip(req)
  3575  		if res != nil {
  3576  			res.Body.Close()
  3577  			if got := res.Header.Get("Foo"); got != "bar" {
  3578  				err = fmt.Errorf("foo header = %q; want bar", got)
  3579  			}
  3580  		}
  3581  		if err != nil {
  3582  			err = fmt.Errorf("RoundTrip: %v", err)
  3583  		}
  3584  		errs <- err
  3585  	}()
  3586  
  3587  	connToClose := make(chan io.Closer, 2)
  3588  
  3589  	// Server for the first request.
  3590  	go func() {
  3591  		ct := <-ct1
  3592  
  3593  		connToClose <- ct.cc
  3594  		ct.greet()
  3595  		hf, err := ct.firstHeaders()
  3596  		if err != nil {
  3597  			errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
  3598  			return
  3599  		}
  3600  		t.Logf("server1 got %v", hf)
  3601  		if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
  3602  			errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
  3603  			return
  3604  		}
  3605  		errs <- nil
  3606  	}()
  3607  
  3608  	// Server for the second request.
  3609  	go func() {
  3610  		ct := <-ct2
  3611  
  3612  		connToClose <- ct.cc
  3613  		ct.greet()
  3614  		hf, err := ct.firstHeaders()
  3615  		if err != nil {
  3616  			errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
  3617  			return
  3618  		}
  3619  		t.Logf("server2 got %v", hf)
  3620  
  3621  		var buf bytes.Buffer
  3622  		enc := hpack.NewEncoder(&buf)
  3623  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  3624  		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
  3625  		err = ct.fr.WriteHeaders(HeadersFrameParam{
  3626  			StreamID:      hf.StreamID,
  3627  			EndHeaders:    true,
  3628  			EndStream:     false,
  3629  			BlockFragment: buf.Bytes(),
  3630  		})
  3631  		if err != nil {
  3632  			errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
  3633  		} else {
  3634  			errs <- nil
  3635  		}
  3636  	}()
  3637  
  3638  	for k := 0; k < 3; k++ {
  3639  		err := <-errs
  3640  		if err != nil {
  3641  			t.Error(err)
  3642  		}
  3643  	}
  3644  
  3645  	close(connToClose)
  3646  	for c := range connToClose {
  3647  		c.Close()
  3648  	}
  3649  }
  3650  
  3651  func TestTransportRetryAfterRefusedStream(t *testing.T) {
  3652  	clientDone := make(chan struct{})
  3653  	ct := newClientTester(t)
  3654  	ct.client = func() error {
  3655  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3656  		if runtime.GOOS == "plan9" {
  3657  			// CloseWrite not supported on Plan 9; Issue 17906
  3658  			defer ct.cc.(*net.TCPConn).Close()
  3659  		}
  3660  		defer close(clientDone)
  3661  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3662  		resp, err := ct.tr.RoundTrip(req)
  3663  		if err != nil {
  3664  			return fmt.Errorf("RoundTrip: %v", err)
  3665  		}
  3666  		resp.Body.Close()
  3667  		if resp.StatusCode != 204 {
  3668  			return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
  3669  		}
  3670  		return nil
  3671  	}
  3672  	ct.server = func() error {
  3673  		ct.greet()
  3674  		var buf bytes.Buffer
  3675  		enc := hpack.NewEncoder(&buf)
  3676  		nreq := 0
  3677  
  3678  		for {
  3679  			f, err := ct.fr.ReadFrame()
  3680  			if err != nil {
  3681  				select {
  3682  				case <-clientDone:
  3683  					// If the client's done, it
  3684  					// will have reported any
  3685  					// errors on its side.
  3686  					return nil
  3687  				default:
  3688  					return err
  3689  				}
  3690  			}
  3691  			switch f := f.(type) {
  3692  			case *WindowUpdateFrame, *SettingsFrame:
  3693  			case *HeadersFrame:
  3694  				if !f.HeadersEnded() {
  3695  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  3696  				}
  3697  				nreq++
  3698  				if nreq == 1 {
  3699  					ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
  3700  				} else {
  3701  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
  3702  					ct.fr.WriteHeaders(HeadersFrameParam{
  3703  						StreamID:      f.StreamID,
  3704  						EndHeaders:    true,
  3705  						EndStream:     true,
  3706  						BlockFragment: buf.Bytes(),
  3707  					})
  3708  				}
  3709  			default:
  3710  				return fmt.Errorf("Unexpected client frame %v", f)
  3711  			}
  3712  		}
  3713  	}
  3714  	ct.run()
  3715  }
  3716  
  3717  func TestTransportRetryHasLimit(t *testing.T) {
  3718  	// Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s.
  3719  	if testing.Short() {
  3720  		t.Skip("skipping long test in short mode")
  3721  	}
  3722  	clientDone := make(chan struct{})
  3723  	ct := newClientTester(t)
  3724  	ct.client = func() error {
  3725  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3726  		if runtime.GOOS == "plan9" {
  3727  			// CloseWrite not supported on Plan 9; Issue 17906
  3728  			defer ct.cc.(*net.TCPConn).Close()
  3729  		}
  3730  		defer close(clientDone)
  3731  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  3732  		resp, err := ct.tr.RoundTrip(req)
  3733  		if err == nil {
  3734  			return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
  3735  		}
  3736  		t.Logf("expected error, got: %v", err)
  3737  		return nil
  3738  	}
  3739  	ct.server = func() error {
  3740  		ct.greet()
  3741  		for {
  3742  			f, err := ct.fr.ReadFrame()
  3743  			if err != nil {
  3744  				select {
  3745  				case <-clientDone:
  3746  					// If the client's done, it
  3747  					// will have reported any
  3748  					// errors on its side.
  3749  					return nil
  3750  				default:
  3751  					return err
  3752  				}
  3753  			}
  3754  			switch f := f.(type) {
  3755  			case *WindowUpdateFrame, *SettingsFrame:
  3756  			case *HeadersFrame:
  3757  				if !f.HeadersEnded() {
  3758  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  3759  				}
  3760  				ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
  3761  			default:
  3762  				return fmt.Errorf("Unexpected client frame %v", f)
  3763  			}
  3764  		}
  3765  	}
  3766  	ct.run()
  3767  }
  3768  
  3769  func TestTransportResponseDataBeforeHeaders(t *testing.T) {
  3770  	// This test use not valid response format.
  3771  	// Discarding logger output to not spam tests output.
  3772  	log.SetOutput(ioutil.Discard)
  3773  	defer log.SetOutput(os.Stderr)
  3774  
  3775  	ct := newClientTester(t)
  3776  	ct.client = func() error {
  3777  		defer ct.cc.(*net.TCPConn).CloseWrite()
  3778  		if runtime.GOOS == "plan9" {
  3779  			// CloseWrite not supported on Plan 9; Issue 17906
  3780  			defer ct.cc.(*net.TCPConn).Close()
  3781  		}
  3782  		req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
  3783  		// First request is normal to ensure the check is per stream and not per connection.
  3784  		_, err := ct.tr.RoundTrip(req)
  3785  		if err != nil {
  3786  			return fmt.Errorf("RoundTrip expected no error, got: %v", err)
  3787  		}
  3788  		// Second request returns a DATA frame with no HEADERS.
  3789  		resp, err := ct.tr.RoundTrip(req)
  3790  		if err == nil {
  3791  			return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
  3792  		}
  3793  		if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
  3794  			return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err)
  3795  		}
  3796  		return nil
  3797  	}
  3798  	ct.server = func() error {
  3799  		ct.greet()
  3800  		for {
  3801  			f, err := ct.fr.ReadFrame()
  3802  			if err == io.EOF {
  3803  				return nil
  3804  			} else if err != nil {
  3805  				return err
  3806  			}
  3807  			switch f := f.(type) {
  3808  			case *WindowUpdateFrame, *SettingsFrame:
  3809  			case *HeadersFrame:
  3810  				switch f.StreamID {
  3811  				case 1:
  3812  					// Send a valid response to first request.
  3813  					var buf bytes.Buffer
  3814  					enc := hpack.NewEncoder(&buf)
  3815  					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  3816  					ct.fr.WriteHeaders(HeadersFrameParam{
  3817  						StreamID:      f.StreamID,
  3818  						EndHeaders:    true,
  3819  						EndStream:     true,
  3820  						BlockFragment: buf.Bytes(),
  3821  					})
  3822  				case 3:
  3823  					ct.fr.WriteData(f.StreamID, true, []byte("payload"))
  3824  				}
  3825  			default:
  3826  				return fmt.Errorf("Unexpected client frame %v", f)
  3827  			}
  3828  		}
  3829  	}
  3830  	ct.run()
  3831  }
  3832  
  3833  // tests Transport.StrictMaxConcurrentStreams
  3834  func TestTransportRequestsStallAtServerLimit(t *testing.T) {
  3835  	const maxConcurrent = 2
  3836  
  3837  	greet := make(chan struct{})      // server sends initial SETTINGS frame
  3838  	gotRequest := make(chan struct{}) // server received a request
  3839  	clientDone := make(chan struct{})
  3840  
  3841  	// Collect errors from goroutines.
  3842  	var wg sync.WaitGroup
  3843  	errs := make(chan error, 100)
  3844  	defer func() {
  3845  		wg.Wait()
  3846  		close(errs)
  3847  		for err := range errs {
  3848  			t.Error(err)
  3849  		}
  3850  	}()
  3851  
  3852  	// We will send maxConcurrent+2 requests. This checker goroutine waits for the
  3853  	// following stages:
  3854  	//   1. The first maxConcurrent requests are received by the server.
  3855  	//   2. The client will cancel the next request
  3856  	//   3. The server is unblocked so it can service the first maxConcurrent requests
  3857  	//   4. The client will send the final request
  3858  	wg.Add(1)
  3859  	unblockClient := make(chan struct{})
  3860  	clientRequestCancelled := make(chan struct{})
  3861  	unblockServer := make(chan struct{})
  3862  	go func() {
  3863  		defer wg.Done()
  3864  		// Stage 1.
  3865  		for k := 0; k < maxConcurrent; k++ {
  3866  			<-gotRequest
  3867  		}
  3868  		// Stage 2.
  3869  		close(unblockClient)
  3870  		<-clientRequestCancelled
  3871  		// Stage 3: give some time for the final RoundTrip call to be scheduled and
  3872  		// verify that the final request is not sent.
  3873  		time.Sleep(50 * time.Millisecond)
  3874  		select {
  3875  		case <-gotRequest:
  3876  			errs <- errors.New("last request did not stall")
  3877  			close(unblockServer)
  3878  			return
  3879  		default:
  3880  		}
  3881  		close(unblockServer)
  3882  		// Stage 4.
  3883  		<-gotRequest
  3884  	}()
  3885  
  3886  	ct := newClientTester(t)
  3887  	ct.tr.StrictMaxConcurrentStreams = true
  3888  	ct.client = func() error {
  3889  		var wg sync.WaitGroup
  3890  		defer func() {
  3891  			wg.Wait()
  3892  			close(clientDone)
  3893  			ct.cc.(*net.TCPConn).CloseWrite()
  3894  			if runtime.GOOS == "plan9" {
  3895  				// CloseWrite not supported on Plan 9; Issue 17906
  3896  				ct.cc.(*net.TCPConn).Close()
  3897  			}
  3898  		}()
  3899  		for k := 0; k < maxConcurrent+2; k++ {
  3900  			wg.Add(1)
  3901  			go func(k int) {
  3902  				defer wg.Done()
  3903  				// Don't send the second request until after receiving SETTINGS from the server
  3904  				// to avoid a race where we use the default SettingMaxConcurrentStreams, which
  3905  				// is much larger than maxConcurrent. We have to send the first request before
  3906  				// waiting because the first request triggers the dial and greet.
  3907  				if k > 0 {
  3908  					<-greet
  3909  				}
  3910  				// Block until maxConcurrent requests are sent before sending any more.
  3911  				if k >= maxConcurrent {
  3912  					<-unblockClient
  3913  				}
  3914  				req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
  3915  				if k == maxConcurrent {
  3916  					// This request will be canceled.
  3917  					cancel := make(chan struct{})
  3918  					req.Cancel = cancel
  3919  					close(cancel)
  3920  					_, err := ct.tr.RoundTrip(req)
  3921  					close(clientRequestCancelled)
  3922  					if err == nil {
  3923  						errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
  3924  						return
  3925  					}
  3926  				} else {
  3927  					resp, err := ct.tr.RoundTrip(req)
  3928  					if err != nil {
  3929  						errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
  3930  						return
  3931  					}
  3932  					ioutil.ReadAll(resp.Body)
  3933  					resp.Body.Close()
  3934  					if resp.StatusCode != 204 {
  3935  						errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
  3936  						return
  3937  					}
  3938  				}
  3939  			}(k)
  3940  		}
  3941  		return nil
  3942  	}
  3943  
  3944  	ct.server = func() error {
  3945  		var wg sync.WaitGroup
  3946  		defer wg.Wait()
  3947  
  3948  		ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
  3949  
  3950  		// Server write loop.
  3951  		var buf bytes.Buffer
  3952  		enc := hpack.NewEncoder(&buf)
  3953  		writeResp := make(chan uint32, maxConcurrent+1)
  3954  
  3955  		wg.Add(1)
  3956  		go func() {
  3957  			defer wg.Done()
  3958  			<-unblockServer
  3959  			for id := range writeResp {
  3960  				buf.Reset()
  3961  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
  3962  				ct.fr.WriteHeaders(HeadersFrameParam{
  3963  					StreamID:      id,
  3964  					EndHeaders:    true,
  3965  					EndStream:     true,
  3966  					BlockFragment: buf.Bytes(),
  3967  				})
  3968  			}
  3969  		}()
  3970  
  3971  		// Server read loop.
  3972  		var nreq int
  3973  		for {
  3974  			f, err := ct.fr.ReadFrame()
  3975  			if err != nil {
  3976  				select {
  3977  				case <-clientDone:
  3978  					// If the client's done, it will have reported any errors on its side.
  3979  					return nil
  3980  				default:
  3981  					return err
  3982  				}
  3983  			}
  3984  			switch f := f.(type) {
  3985  			case *WindowUpdateFrame:
  3986  			case *SettingsFrame:
  3987  				// Wait for the client SETTINGS ack until ending the greet.
  3988  				close(greet)
  3989  			case *HeadersFrame:
  3990  				if !f.HeadersEnded() {
  3991  					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
  3992  				}
  3993  				gotRequest <- struct{}{}
  3994  				nreq++
  3995  				writeResp <- f.StreamID
  3996  				if nreq == maxConcurrent+1 {
  3997  					close(writeResp)
  3998  				}
  3999  			default:
  4000  				return fmt.Errorf("Unexpected client frame %v", f)
  4001  			}
  4002  		}
  4003  	}
  4004  
  4005  	ct.run()
  4006  }
  4007  
  4008  func TestAuthorityAddr(t *testing.T) {
  4009  	tests := []struct {
  4010  		scheme, authority string
  4011  		want              string
  4012  	}{
  4013  		{"http", "foo.com", "foo.com:80"},
  4014  		{"https", "foo.com", "foo.com:443"},
  4015  		{"https", "foo.com:1234", "foo.com:1234"},
  4016  		{"https", "1.2.3.4:1234", "1.2.3.4:1234"},
  4017  		{"https", "1.2.3.4", "1.2.3.4:443"},
  4018  		{"https", "[::1]:1234", "[::1]:1234"},
  4019  		{"https", "[::1]", "[::1]:443"},
  4020  	}
  4021  	for _, tt := range tests {
  4022  		got := authorityAddr(tt.scheme, tt.authority)
  4023  		if got != tt.want {
  4024  			t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want)
  4025  		}
  4026  	}
  4027  }
  4028  
  4029  // Issue 20448: stop allocating for DATA frames' payload after
  4030  // Response.Body.Close is called.
  4031  func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
  4032  	megabyteZero := make([]byte, 1<<20)
  4033  
  4034  	writeErr := make(chan error, 1)
  4035  
  4036  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  4037  		w.(http.Flusher).Flush()
  4038  		var sum int64
  4039  		for i := 0; i < 100; i++ {
  4040  			n, err := w.Write(megabyteZero)
  4041  			sum += int64(n)
  4042  			if err != nil {
  4043  				writeErr <- err
  4044  				return
  4045  			}
  4046  		}
  4047  		t.Logf("wrote all %d bytes", sum)
  4048  		writeErr <- nil
  4049  	}, optOnlyServer)
  4050  	defer st.Close()
  4051  
  4052  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  4053  	defer tr.CloseIdleConnections()
  4054  	c := &http.Client{Transport: tr}
  4055  	res, err := c.Get(st.ts.URL)
  4056  	if err != nil {
  4057  		t.Fatal(err)
  4058  	}
  4059  	var buf [1]byte
  4060  	if _, err := res.Body.Read(buf[:]); err != nil {
  4061  		t.Error(err)
  4062  	}
  4063  	if err := res.Body.Close(); err != nil {
  4064  		t.Error(err)
  4065  	}
  4066  
  4067  	trb, ok := res.Body.(transportResponseBody)
  4068  	if !ok {
  4069  		t.Fatalf("res.Body = %T; want transportResponseBody", res.Body)
  4070  	}
  4071  	if trb.cs.bufPipe.b != nil {
  4072  		t.Errorf("response body pipe is still open")
  4073  	}
  4074  
  4075  	gotErr := <-writeErr
  4076  	if gotErr == nil {
  4077  		t.Errorf("Handler unexpectedly managed to write its entire response without getting an error")
  4078  	} else if gotErr != errStreamClosed {
  4079  		t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr)
  4080  	}
  4081  }
  4082  
  4083  // Issue 18891: make sure Request.Body == NoBody means no DATA frame
  4084  // is ever sent, even if empty.
  4085  func TestTransportNoBodyMeansNoDATA(t *testing.T) {
  4086  	ct := newClientTester(t)
  4087  
  4088  	unblockClient := make(chan bool)
  4089  
  4090  	ct.client = func() error {
  4091  		req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
  4092  		ct.tr.RoundTrip(req)
  4093  		<-unblockClient
  4094  		return nil
  4095  	}
  4096  	ct.server = func() error {
  4097  		defer close(unblockClient)
  4098  		defer ct.cc.(*net.TCPConn).Close()
  4099  		ct.greet()
  4100  
  4101  		for {
  4102  			f, err := ct.fr.ReadFrame()
  4103  			if err != nil {
  4104  				return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
  4105  			}
  4106  			switch f := f.(type) {
  4107  			default:
  4108  				return fmt.Errorf("Got %T; want HeadersFrame", f)
  4109  			case *WindowUpdateFrame, *SettingsFrame:
  4110  				continue
  4111  			case *HeadersFrame:
  4112  				if !f.StreamEnded() {
  4113  					return fmt.Errorf("got headers frame without END_STREAM")
  4114  				}
  4115  				return nil
  4116  			}
  4117  		}
  4118  	}
  4119  	ct.run()
  4120  }
  4121  
  4122  func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
  4123  	defer disableGoroutineTracking()()
  4124  	b.ReportAllocs()
  4125  	st := newServerTester(b,
  4126  		func(w http.ResponseWriter, r *http.Request) {
  4127  			for i := 0; i < nResHeader; i++ {
  4128  				name := fmt.Sprint("A-", i)
  4129  				w.Header().Set(name, "*")
  4130  			}
  4131  		},
  4132  		optOnlyServer,
  4133  		optQuiet,
  4134  	)
  4135  	defer st.Close()
  4136  
  4137  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  4138  	defer tr.CloseIdleConnections()
  4139  
  4140  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  4141  	if err != nil {
  4142  		b.Fatal(err)
  4143  	}
  4144  
  4145  	for i := 0; i < nReqHeaders; i++ {
  4146  		name := fmt.Sprint("A-", i)
  4147  		req.Header.Set(name, "*")
  4148  	}
  4149  
  4150  	b.ResetTimer()
  4151  
  4152  	for i := 0; i < b.N; i++ {
  4153  		res, err := tr.RoundTrip(req)
  4154  		if err != nil {
  4155  			if res != nil {
  4156  				res.Body.Close()
  4157  			}
  4158  			b.Fatalf("RoundTrip err = %v; want nil", err)
  4159  		}
  4160  		res.Body.Close()
  4161  		if res.StatusCode != http.StatusOK {
  4162  			b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
  4163  		}
  4164  	}
  4165  }
  4166  
  4167  type infiniteReader struct{}
  4168  
  4169  func (r infiniteReader) Read(b []byte) (int, error) {
  4170  	return len(b), nil
  4171  }
  4172  
  4173  // Issue 20521: it is not an error to receive a response and end stream
  4174  // from the server without the body being consumed.
  4175  func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
  4176  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  4177  		w.WriteHeader(http.StatusOK)
  4178  	}, optOnlyServer)
  4179  	defer st.Close()
  4180  
  4181  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  4182  	defer tr.CloseIdleConnections()
  4183  
  4184  	// The request body needs to be big enough to trigger flow control.
  4185  	req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{})
  4186  	res, err := tr.RoundTrip(req)
  4187  	if err != nil {
  4188  		t.Fatal(err)
  4189  	}
  4190  	if res.StatusCode != http.StatusOK {
  4191  		t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
  4192  	}
  4193  }
  4194  
  4195  // Verify transport doesn't crash when receiving bogus response lacking a :status header.
  4196  // Issue 22880.
  4197  func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
  4198  	ct := newClientTester(t)
  4199  	ct.client = func() error {
  4200  		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
  4201  		_, err := ct.tr.RoundTrip(req)
  4202  		const substr = "malformed response from server: missing status pseudo header"
  4203  		if !strings.Contains(fmt.Sprint(err), substr) {
  4204  			return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr)
  4205  		}
  4206  		return nil
  4207  	}
  4208  	ct.server = func() error {
  4209  		ct.greet()
  4210  		var buf bytes.Buffer
  4211  		enc := hpack.NewEncoder(&buf)
  4212  
  4213  		for {
  4214  			f, err := ct.fr.ReadFrame()
  4215  			if err != nil {
  4216  				return err
  4217  			}
  4218  			switch f := f.(type) {
  4219  			case *HeadersFrame:
  4220  				enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header
  4221  				ct.fr.WriteHeaders(HeadersFrameParam{
  4222  					StreamID:      f.StreamID,
  4223  					EndHeaders:    true,
  4224  					EndStream:     false, // we'll send some DATA to try to crash the transport
  4225  					BlockFragment: buf.Bytes(),
  4226  				})
  4227  				ct.fr.WriteData(f.StreamID, true, []byte("payload"))
  4228  				return nil
  4229  			}
  4230  		}
  4231  	}
  4232  	ct.run()
  4233  }
  4234  
  4235  func newTestPushHandlerReadResponse() *testPushHandlerReadResponse {
  4236  	return &testPushHandlerReadResponse{
  4237  		done: make(chan struct{}),
  4238  	}
  4239  }
  4240  
  4241  type testPushHandlerReadResponse struct {
  4242  	promise       *http.Request
  4243  	origReqURL    *url.URL
  4244  	origReqHeader http.Header
  4245  	push          *http.Response
  4246  	pushErr       error
  4247  	done          chan struct{}
  4248  }
  4249  
  4250  func (ph *testPushHandlerReadResponse) HandlePush(r *PushedRequest) {
  4251  	ph.promise = r.Promise
  4252  	ph.origReqURL = r.OriginalRequestURL
  4253  	ph.origReqHeader = r.OriginalRequestHeader
  4254  	ph.push, ph.pushErr = r.ReadResponse(r.Promise.Context())
  4255  	close(ph.done)
  4256  }
  4257  
  4258  func testTransportHandlePushPromise(t *testing.T,
  4259  	configTransport func(t *testing.T, tr *Transport),
  4260  	useHTTP bool) {
  4261  	const (
  4262  		initiatingResponseText           = "response text"
  4263  		promisePath                      = "/getmestuff"
  4264  		promiseHeaderKey                 = "headkey"
  4265  		pushText                         = "push text"
  4266  		pushTrailerKey, pushTrailerValue = "trailkey", "end val"
  4267  	)
  4268  	scheme := "https"
  4269  	if useHTTP {
  4270  		scheme = "http"
  4271  	}
  4272  	promiseHeaderValue := strings.Repeat("a", 2*initialMaxFrameSize) // test PUSH_PROMISE+CONTINUATION
  4273  	checkResp := func(t *testing.T, res *http.Response, text string) error {
  4274  		defer res.Body.Close()
  4275  		if res.StatusCode != 200 {
  4276  			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
  4277  		}
  4278  		if slurp, err := ioutil.ReadAll(res.Body); string(slurp) != text || err != nil {
  4279  			return fmt.Errorf("res.Body ReadAll = %q, %v; want %q, %v", slurp, err, text, nil)
  4280  		}
  4281  		return nil
  4282  	}
  4283  	ct := newClientTester(t)
  4284  	if configTransport != nil {
  4285  		configTransport(t, ct.tr)
  4286  	}
  4287  	ct.client = func() error {
  4288  		pushHandler := newTestPushHandlerReadResponse()
  4289  		ct.tr.PushHandler = pushHandler
  4290  		req := httptest.NewRequest("GET", scheme+"://dummy.tld/", nil)
  4291  		req.Header.Set("foo", "bar")
  4292  		res, err := ct.tr.RoundTrip(req)
  4293  		if err != nil {
  4294  			return fmt.Errorf("RoundTrip: %v", err)
  4295  		}
  4296  		if err = checkResp(t, res, initiatingResponseText); err != nil {
  4297  			return err
  4298  		}
  4299  		select {
  4300  		case <-pushHandler.done:
  4301  		case <-time.After(5 * time.Second):
  4302  			return errors.New("timed out waiting for push to be handled")
  4303  		}
  4304  		if pushHandler.origReqURL != req.URL {
  4305  			return fmt.Errorf("expected original request %q, got %q",
  4306  				req.URL.String(), pushHandler.origReqURL.String())
  4307  		}
  4308  		if pushHandler.origReqHeader.Get("foo") != "bar" {
  4309  			return fmt.Errorf("expected original request header %q's value to be %q, got %q",
  4310  				"foo", "bar", pushHandler.origReqHeader.Get("foo"))
  4311  		}
  4312  		if pushHandler.promise == nil {
  4313  			return fmt.Errorf("promise not received")
  4314  		}
  4315  		if pushHandler.promise.URL.Path != promisePath {
  4316  			return fmt.Errorf("promise path = %q, want %q", pushHandler.promise.URL.Path, promisePath)
  4317  		}
  4318  		if pushHandler.promise.Header.Get(promiseHeaderKey) != promiseHeaderValue {
  4319  			return fmt.Errorf("promise value for key %q = %q, want %q", promiseHeaderKey,
  4320  				pushHandler.promise.Header.Get(promiseHeaderKey), promiseHeaderValue)
  4321  		}
  4322  		if pushHandler.pushErr != nil {
  4323  			return fmt.Errorf("push error = %v; want %v", pushHandler.pushErr, nil)
  4324  		}
  4325  		if pushHandler.push == nil {
  4326  			return fmt.Errorf("push not received")
  4327  		}
  4328  		if err = checkResp(t, pushHandler.push, pushText); err != nil {
  4329  			return err
  4330  		}
  4331  		if pushHandler.push.Trailer.Get(pushTrailerKey) != pushTrailerValue {
  4332  			return fmt.Errorf("promise value for key %q = %q, want %q", pushTrailerKey,
  4333  				pushHandler.push.Trailer.Get(pushTrailerKey), pushTrailerValue)
  4334  		}
  4335  		return nil
  4336  	}
  4337  	ct.server = func() error {
  4338  		ct.greet()
  4339  		hf, _ := ct.firstHeaders()
  4340  		var buf bytes.Buffer
  4341  		enc := hpack.NewEncoder(&buf)
  4342  		// Promise
  4343  		const promiseID = 2
  4344  		enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4345  		enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: scheme})
  4346  		enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld"})
  4347  		enc.WriteField(hpack.HeaderField{Name: ":path", Value: promisePath})
  4348  		enc.WriteField(hpack.HeaderField{Name: promiseHeaderKey, Value: promiseHeaderValue})
  4349  		ct.fr.WritePushPromise(PushPromiseParam{
  4350  			StreamID:      hf.StreamID,
  4351  			PromiseID:     promiseID,
  4352  			BlockFragment: buf.Bytes(),
  4353  			EndHeaders:    true,
  4354  		})
  4355  		// Push
  4356  		buf.Reset()
  4357  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  4358  		ct.fr.WriteHeaders(HeadersFrameParam{
  4359  			StreamID:      promiseID,
  4360  			EndHeaders:    true,
  4361  			EndStream:     false,
  4362  			BlockFragment: buf.Bytes(),
  4363  		})
  4364  		ct.fr.WriteData(promiseID, false, []byte(pushText))
  4365  		// add trailer
  4366  		buf.Reset()
  4367  		enc.WriteField(hpack.HeaderField{Name: pushTrailerKey, Value: pushTrailerValue})
  4368  		ct.fr.WriteHeaders(HeadersFrameParam{
  4369  			StreamID:      promiseID,
  4370  			EndHeaders:    true,
  4371  			EndStream:     true,
  4372  			BlockFragment: buf.Bytes(),
  4373  		})
  4374  		// Respond to initiating request
  4375  		buf.Reset()
  4376  		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  4377  		ct.fr.WriteHeaders(HeadersFrameParam{
  4378  			StreamID:      hf.StreamID,
  4379  			EndHeaders:    true,
  4380  			EndStream:     false,
  4381  			BlockFragment: buf.Bytes(),
  4382  		})
  4383  		ct.fr.WriteData(hf.StreamID, true, []byte(initiatingResponseText))
  4384  		return nil
  4385  	}
  4386  	ct.run()
  4387  }
  4388  
  4389  func setMockCert(t *testing.T, tr *Transport) {
  4390  	// Self-signed certificate to `dummy.tld` using SHA256 and RSA,
  4391  	// valid from 9/22/2019 to 8/29/2119. Hopefully this piece of code
  4392  	// will retire before the mock certificate.
  4393  	certPem := `
  4394  -----BEGIN CERTIFICATE-----
  4395  MIIEADCCAuigAwIBAgIJALtnD0hvKA2fMA0GCSqGSIb3DQEBCwUAMIGTMQswCQYD
  4396  VQQGEwJDTjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzESMBAG
  4397  A1UECgwJVGlhbmppIFd1MRMwEQYDVQQLDApnb2xhbmctZGV2MRIwEAYDVQQDDAlk
  4398  dW1teS50bGQxIzAhBgkqhkiG9w0BCQEWFGdvbGFuZy1kZXZAd3V0ai5pbmZvMCAX
  4399  DTE5MDkyMjE0MjUwNVoYDzIxMTkwODI5MTQyNTA1WjCBkzELMAkGA1UEBhMCQ04x
  4400  EDAOBgNVBAgMB0JlaWppbmcxEDAOBgNVBAcMB0JlaWppbmcxEjAQBgNVBAoMCVRp
  4401  YW5qaSBXdTETMBEGA1UECwwKZ29sYW5nLWRldjESMBAGA1UEAwwJZHVtbXkudGxk
  4402  MSMwIQYJKoZIhvcNAQkBFhRnb2xhbmctZGV2QHd1dGouaW5mbzCCASIwDQYJKoZI
  4403  hvcNAQEBBQADggEPADCCAQoCggEBALqfRTzoEDHZN2a1uBmU2NFvxKGY0zAI07bB
  4404  +0kGOOuqlixj2+Dvd2/eJXoDh8GugRaihSmmvx+XiuoA7MVOhUbE/tkhPmJ7L/sv
  4405  JRY7YNNq7hTSj0DXoP8iteKF5uTyCuBB1zQUFYfPcs4Nl5hF5iuhPmEPG9vn9b8Z
  4406  XFcwITakUPeXGLkJb8D1vXmXFew3e1hyROZ+klbJ96yXnGXoYQ4WDrBsVA3rHBuW
  4407  ouHNT0qA3dGPqkniOIGBuMUNaeEGoPhi1o4B9vQBmEHwULKpcOJbr+sj5YopCs5p
  4408  9wQuFNI6VsbaqLyiQE+BSGtoCX3FQyl4lpoPj/9k5kmty5K0v4kCAwEAAaNTMFEw
  4409  HQYDVR0OBBYEFIl5xOTIsFwQ68QhyY7kPy9NNUx4MB8GA1UdIwQYMBaAFIl5xOTI
  4410  sFwQ68QhyY7kPy9NNUx4MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQAD
  4411  ggEBAC6+U++9uTaM9JI5BGLATP8TSHdgJdC4nJWGao4CBEOWN1RO8LhSxwNHz729
  4412  GtJWoah9giIx2mjmYJfJtzZH30rRSRt1MXAOJX1NSJ4iWDH/jT0tC7oPhfrvr7FM
  4413  ZVZplk7k59+bQlo6q0+u0ax9Hrarwgx6j+K9v+5/dhc+qGk3pdoFa/Sa3gzQ3Gqo
  4414  V08pmKjQSDuM4Cvgd7UXg9PhrYtlCQsWOnuhJcPl+gkUlTvRStlpQurAt8FfpieV
  4415  Mff+u1o/ompjfVA7Knr38ZdaoXkzoLoypAWX1veeTzhFCITSMnnHVq0/OhpmbtcF
  4416  sIzU/RQ4lLHuDXjzQVWxCwUp0oY=
  4417  -----END CERTIFICATE-----
  4418  `
  4419  	pemBlock, _ := pem.Decode([]byte(certPem))
  4420  	cert, err := x509.ParseCertificate(pemBlock.Bytes)
  4421  	if err != nil {
  4422  		t.Fatalf("failed to parse certificate: %s", err)
  4423  	}
  4424  	req := httptest.NewRequest("GET", "https://dummy.tld:443/", nil)
  4425  	cc, err := tr.connPool().GetClientConn(req, "dummy.tld:443")
  4426  	if err != nil {
  4427  		t.Fatal(err)
  4428  	}
  4429  	cc.tlsState = &tls.ConnectionState{}
  4430  	cert.DNSNames = []string{"dummy.tld"}
  4431  	cc.tlsState.VerifiedChains = [][]*x509.Certificate{{cert}}
  4432  	cc.tlsState.PeerCertificates = []*x509.Certificate{cert}
  4433  	tr.TLSClientConfig = &tls.Config{}
  4434  }
  4435  
  4436  func TestTransportHandlePushPromise_TLS_SkipVerify(t *testing.T) {
  4437  	testTransportHandlePushPromise(t, nil, false)
  4438  }
  4439  
  4440  func TestTransportHandlePushPromise_TLS(t *testing.T) {
  4441  	testTransportHandlePushPromise(t, setMockCert, false)
  4442  }
  4443  
  4444  func TestTransportHandlePushPromise_NonTLS(t *testing.T) {
  4445  	allowHTTP := func(t *testing.T, tr *Transport) {
  4446  		tr.AllowHTTP = true
  4447  		tr.TLSClientConfig = &tls.Config{}
  4448  	}
  4449  	testTransportHandlePushPromise(t, allowHTTP, true)
  4450  }
  4451  
  4452  func testTransport_Push_Reject(t *testing.T,
  4453  	h PushHandler,
  4454  	getPush func(streamID uint32) PushPromiseParam,
  4455  	getExpectedErr func(streamID uint32) error) {
  4456  	ct := newClientTester(t)
  4457  	ct.client = func() error {
  4458  		ct.tr.PushHandler = h
  4459  		req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
  4460  		_, gotErr := ct.tr.RoundTrip(req)
  4461  		var streamID uint32
  4462  		if se, ok := gotErr.(StreamError); ok {
  4463  			gotErr = streamError(se.StreamID, se.Code)
  4464  			streamID = se.StreamID
  4465  		}
  4466  		wantErr := getExpectedErr(streamID)
  4467  		if !reflect.DeepEqual(wantErr, gotErr) {
  4468  			return fmt.Errorf("expected %v, but got %v", wantErr, gotErr)
  4469  		}
  4470  		return nil
  4471  	}
  4472  	ct.server = func() error {
  4473  		ct.greet()
  4474  		hf, _ := ct.firstHeaders()
  4475  		ct.fr.WritePushPromise(getPush(hf.StreamID))
  4476  		return nil
  4477  	}
  4478  	ct.run()
  4479  }
  4480  
  4481  func TestTransport_Push_RejectIfDisabled(t *testing.T) {
  4482  	testTransport_Push_Reject(t,
  4483  		nil,
  4484  		func(streamID uint32) PushPromiseParam {
  4485  			var buf bytes.Buffer
  4486  			enc := hpack.NewEncoder(&buf)
  4487  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4488  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4489  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld"})
  4490  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4491  			return PushPromiseParam{streamID, 2, buf.Bytes(), true, 0}
  4492  		},
  4493  		func(uint32) error {
  4494  			return ConnectionError(ErrCodeProtocol)
  4495  		},
  4496  	)
  4497  }
  4498  
  4499  func TestTransport_Push_RejectRecursivePush(t *testing.T) {
  4500  	testTransport_Push_Reject(t,
  4501  		newTestPushHandlerReadResponse(),
  4502  		func(streamID uint32) PushPromiseParam {
  4503  			var buf bytes.Buffer
  4504  			enc := hpack.NewEncoder(&buf)
  4505  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4506  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4507  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld"})
  4508  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4509  			return PushPromiseParam{2, 2, buf.Bytes(), true, 0}
  4510  		},
  4511  		func(uint32) error {
  4512  			return ConnectionError(ErrCodeProtocol)
  4513  		},
  4514  	)
  4515  }
  4516  
  4517  func TestTransport_Push_RejectInvalidPromiseId(t *testing.T) {
  4518  	testTransport_Push_Reject(t,
  4519  		newTestPushHandlerReadResponse(),
  4520  		func(streamID uint32) PushPromiseParam {
  4521  			var buf bytes.Buffer
  4522  			enc := hpack.NewEncoder(&buf)
  4523  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4524  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4525  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld"})
  4526  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4527  			return PushPromiseParam{streamID, 3, buf.Bytes(), true, 0}
  4528  		},
  4529  		func(uint32) error {
  4530  			return ConnectionError(ErrCodeProtocol)
  4531  		},
  4532  	)
  4533  }
  4534  
  4535  func TestTransport_Push_RejectInitiatingStream_NonExistent(t *testing.T) {
  4536  	testTransport_Push_Reject(t,
  4537  		newTestPushHandlerReadResponse(),
  4538  		func(streamID uint32) PushPromiseParam {
  4539  			var buf bytes.Buffer
  4540  			enc := hpack.NewEncoder(&buf)
  4541  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4542  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4543  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld"})
  4544  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4545  			return PushPromiseParam{7, 2, buf.Bytes(), true, 0}
  4546  		},
  4547  		func(uint32) error {
  4548  			return ConnectionError(ErrCodeProtocol)
  4549  		},
  4550  	)
  4551  }
  4552  
  4553  func TestTransport_Push_RejectMissingAuthority(t *testing.T) {
  4554  	testTransport_Push_Reject(t,
  4555  		newTestPushHandlerReadResponse(),
  4556  		func(streamID uint32) PushPromiseParam {
  4557  			var buf bytes.Buffer
  4558  			enc := hpack.NewEncoder(&buf)
  4559  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4560  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4561  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4562  			return PushPromiseParam{streamID, 2, buf.Bytes(), true, 0}
  4563  		},
  4564  		func(streamID uint32) error {
  4565  			return streamError(streamID, ErrCodeProtocol)
  4566  		},
  4567  	)
  4568  }
  4569  
  4570  func TestTransport_Push_RejectHeader_BodyRelated(t *testing.T) {
  4571  	testTransport_Push_Reject(t, newTestPushHandlerReadResponse(),
  4572  		func(streamID uint32) PushPromiseParam {
  4573  			var buf bytes.Buffer
  4574  			enc := hpack.NewEncoder(&buf)
  4575  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4576  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4577  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld"})
  4578  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4579  			enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "12"})
  4580  			return PushPromiseParam{streamID, 2, buf.Bytes(), true, 0}
  4581  		},
  4582  		func(streamID uint32) error {
  4583  			return streamError(streamID, ErrCodeProtocol)
  4584  		},
  4585  	)
  4586  }
  4587  
  4588  func TestTransport_Push_RejectHeader_ConnRelated(t *testing.T) {
  4589  	testTransport_Push_Reject(t, newTestPushHandlerReadResponse(),
  4590  		func(streamID uint32) PushPromiseParam {
  4591  			var buf bytes.Buffer
  4592  			enc := hpack.NewEncoder(&buf)
  4593  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4594  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4595  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld"})
  4596  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4597  			enc.WriteField(hpack.HeaderField{Name: "connection", Value: "close"})
  4598  			return PushPromiseParam{streamID, 2, buf.Bytes(), true, 0}
  4599  		},
  4600  		func(streamID uint32) error {
  4601  			return streamError(streamID, ErrCodeProtocol)
  4602  		},
  4603  	)
  4604  }
  4605  
  4606  func testTransport_Push_RejectAuthError(t *testing.T,
  4607  	h PushHandler,
  4608  	getPush func(streamId uint32) PushPromiseParam,
  4609  	useHTTP bool) {
  4610  	ct := newClientTester(t)
  4611  	scheme := "https"
  4612  	if useHTTP {
  4613  		scheme = "http"
  4614  		ct.tr.AllowHTTP = true
  4615  		ct.tr.TLSClientConfig = &tls.Config{}
  4616  	} else {
  4617  		setMockCert(t, ct.tr)
  4618  	}
  4619  	req := httptest.NewRequest("GET", scheme+"://dummy.tld:443/", nil)
  4620  	ct.client = func() error {
  4621  		ct.tr.PushHandler = h
  4622  		_, err := ct.tr.RoundTrip(req)
  4623  		if err != nil {
  4624  			if _, ok := err.(StreamError); !ok {
  4625  				return fmt.Errorf("expected stream error, but got %q", err)
  4626  			}
  4627  		} else {
  4628  			return fmt.Errorf("expected stream error, but got no error")
  4629  		}
  4630  		return nil
  4631  	}
  4632  	ct.server = func() error {
  4633  		ct.greet()
  4634  		hf, _ := ct.firstHeaders()
  4635  		ct.fr.WritePushPromise(getPush(hf.StreamID))
  4636  		return nil
  4637  	}
  4638  	ct.run()
  4639  }
  4640  
  4641  func TestTransport_Push_RejectAuthError_NonAuthoritativeHostname_TLS(t *testing.T) {
  4642  	testTransport_Push_RejectAuthError(t,
  4643  		newTestPushHandlerReadResponse(),
  4644  		func(streamID uint32) PushPromiseParam {
  4645  			var buf bytes.Buffer
  4646  			enc := hpack.NewEncoder(&buf)
  4647  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4648  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4649  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "sub.foo.net"})
  4650  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4651  			return PushPromiseParam{streamID, 2, buf.Bytes(), true, 0}
  4652  		},
  4653  		false,
  4654  	)
  4655  }
  4656  
  4657  func TestTransport_Push_RejectAuthError_NonAuthoritativeHostname_NonTLS(t *testing.T) {
  4658  	testTransport_Push_RejectAuthError(t,
  4659  		newTestPushHandlerReadResponse(),
  4660  		func(streamID uint32) PushPromiseParam {
  4661  			var buf bytes.Buffer
  4662  			enc := hpack.NewEncoder(&buf)
  4663  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4664  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "http"})
  4665  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "sub.foo.net"})
  4666  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4667  			return PushPromiseParam{streamID, 2, buf.Bytes(), true, 0}
  4668  		},
  4669  		true,
  4670  	)
  4671  }
  4672  
  4673  func TestTransport_Push_RejectAuthError_DifferentScheme(t *testing.T) {
  4674  	testTransport_Push_RejectAuthError(t,
  4675  		newTestPushHandlerReadResponse(),
  4676  		func(streamID uint32) PushPromiseParam {
  4677  			var buf bytes.Buffer
  4678  			enc := hpack.NewEncoder(&buf)
  4679  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4680  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "http"})
  4681  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld"})
  4682  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4683  			return PushPromiseParam{streamID, 2, buf.Bytes(), true, 0}
  4684  		},
  4685  		false,
  4686  	)
  4687  }
  4688  
  4689  func TestTransport_Push_RejectAuthError_DifferentPort(t *testing.T) {
  4690  	testTransport_Push_RejectAuthError(t,
  4691  		newTestPushHandlerReadResponse(),
  4692  		func(streamID uint32) PushPromiseParam {
  4693  			var buf bytes.Buffer
  4694  			enc := hpack.NewEncoder(&buf)
  4695  			enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
  4696  			enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
  4697  			enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "dummy.tld:1234"})
  4698  			enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/hello"})
  4699  			return PushPromiseParam{streamID, 2, buf.Bytes(), true, 0}
  4700  		},
  4701  		false,
  4702  	)
  4703  }
  4704  
  4705  func BenchmarkClientRequestHeaders(b *testing.B) {
  4706  	b.Run("   0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
  4707  	b.Run("  10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) })
  4708  	b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) })
  4709  	b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) })
  4710  }
  4711  
  4712  func BenchmarkClientResponseHeaders(b *testing.B) {
  4713  	b.Run("   0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
  4714  	b.Run("  10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) })
  4715  	b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) })
  4716  	b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) })
  4717  }
  4718  
  4719  func activeStreams(cc *ClientConn) int {
  4720  	cc.mu.Lock()
  4721  	defer cc.mu.Unlock()
  4722  	return len(cc.streams)
  4723  }
  4724  
  4725  type closeMode int
  4726  
  4727  const (
  4728  	closeAtHeaders closeMode = iota
  4729  	closeAtBody
  4730  	shutdown
  4731  	shutdownCancel
  4732  )
  4733  
  4734  // See golang.org/issue/17292
  4735  func testClientConnClose(t *testing.T, closeMode closeMode) {
  4736  	clientDone := make(chan struct{})
  4737  	defer close(clientDone)
  4738  	handlerDone := make(chan struct{})
  4739  	closeDone := make(chan struct{})
  4740  	beforeHeader := func() {}
  4741  	bodyWrite := func(w http.ResponseWriter) {}
  4742  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  4743  		defer close(handlerDone)
  4744  		beforeHeader()
  4745  		w.WriteHeader(http.StatusOK)
  4746  		w.(http.Flusher).Flush()
  4747  		bodyWrite(w)
  4748  		select {
  4749  		case <-w.(http.CloseNotifier).CloseNotify():
  4750  			// client closed connection before completion
  4751  			if closeMode == shutdown || closeMode == shutdownCancel {
  4752  				t.Error("expected request to complete")
  4753  			}
  4754  		case <-clientDone:
  4755  			if closeMode == closeAtHeaders || closeMode == closeAtBody {
  4756  				t.Error("expected connection closed by client")
  4757  			}
  4758  		}
  4759  	}, optOnlyServer)
  4760  	defer st.Close()
  4761  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  4762  	defer tr.CloseIdleConnections()
  4763  	cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
  4764  	req, err := http.NewRequest("GET", st.ts.URL, nil)
  4765  	if err != nil {
  4766  		t.Fatal(err)
  4767  	}
  4768  	if closeMode == closeAtHeaders {
  4769  		beforeHeader = func() {
  4770  			if err := cc.Close(); err != nil {
  4771  				t.Error(err)
  4772  			}
  4773  			close(closeDone)
  4774  		}
  4775  	}
  4776  	var sendBody chan struct{}
  4777  	if closeMode == closeAtBody {
  4778  		sendBody = make(chan struct{})
  4779  		bodyWrite = func(w http.ResponseWriter) {
  4780  			<-sendBody
  4781  			b := make([]byte, 32)
  4782  			w.Write(b)
  4783  			w.(http.Flusher).Flush()
  4784  			if err := cc.Close(); err != nil {
  4785  				t.Errorf("unexpected ClientConn close error: %v", err)
  4786  			}
  4787  			close(closeDone)
  4788  			w.Write(b)
  4789  			w.(http.Flusher).Flush()
  4790  		}
  4791  	}
  4792  	res, err := cc.RoundTrip(req)
  4793  	if res != nil {
  4794  		defer res.Body.Close()
  4795  	}
  4796  	if closeMode == closeAtHeaders {
  4797  		got := fmt.Sprint(err)
  4798  		want := "http2: client connection force closed via ClientConn.Close"
  4799  		if got != want {
  4800  			t.Fatalf("RoundTrip error = %v, want %v", got, want)
  4801  		}
  4802  	} else {
  4803  		if err != nil {
  4804  			t.Fatalf("RoundTrip: %v", err)
  4805  		}
  4806  		if got, want := activeStreams(cc), 1; got != want {
  4807  			t.Errorf("got %d active streams, want %d", got, want)
  4808  		}
  4809  	}
  4810  	switch closeMode {
  4811  	case shutdownCancel:
  4812  		if err = cc.Shutdown(canceledCtx); err != context.Canceled {
  4813  			t.Errorf("got %v, want %v", err, context.Canceled)
  4814  		}
  4815  		if cc.closing == false {
  4816  			t.Error("expected closing to be true")
  4817  		}
  4818  		if cc.CanTakeNewRequest() == true {
  4819  			t.Error("CanTakeNewRequest to return false")
  4820  		}
  4821  		if v, want := len(cc.streams), 1; v != want {
  4822  			t.Errorf("expected %d active streams, got %d", want, v)
  4823  		}
  4824  		clientDone <- struct{}{}
  4825  		<-handlerDone
  4826  	case shutdown:
  4827  		wait := make(chan struct{})
  4828  		shutdownEnterWaitStateHook = func() {
  4829  			close(wait)
  4830  			shutdownEnterWaitStateHook = func() {}
  4831  		}
  4832  		defer func() { shutdownEnterWaitStateHook = func() {} }()
  4833  		shutdown := make(chan struct{}, 1)
  4834  		go func() {
  4835  			if err = cc.Shutdown(context.Background()); err != nil {
  4836  				t.Error(err)
  4837  			}
  4838  			close(shutdown)
  4839  		}()
  4840  		// Let the shutdown to enter wait state
  4841  		<-wait
  4842  		cc.mu.Lock()
  4843  		if cc.closing == false {
  4844  			t.Error("expected closing to be true")
  4845  		}
  4846  		cc.mu.Unlock()
  4847  		if cc.CanTakeNewRequest() == true {
  4848  			t.Error("CanTakeNewRequest to return false")
  4849  		}
  4850  		if got, want := activeStreams(cc), 1; got != want {
  4851  			t.Errorf("got %d active streams, want %d", got, want)
  4852  		}
  4853  		// Let the active request finish
  4854  		clientDone <- struct{}{}
  4855  		// Wait for the shutdown to end
  4856  		select {
  4857  		case <-shutdown:
  4858  		case <-time.After(2 * time.Second):
  4859  			t.Fatal("expected server connection to close")
  4860  		}
  4861  	case closeAtHeaders, closeAtBody:
  4862  		if closeMode == closeAtBody {
  4863  			go close(sendBody)
  4864  			if _, err := io.Copy(ioutil.Discard, res.Body); err == nil {
  4865  				t.Error("expected a Copy error, got nil")
  4866  			}
  4867  		}
  4868  		<-closeDone
  4869  		if got, want := activeStreams(cc), 0; got != want {
  4870  			t.Errorf("got %d active streams, want %d", got, want)
  4871  		}
  4872  		// wait for server to get the connection close notice
  4873  		select {
  4874  		case <-handlerDone:
  4875  		case <-time.After(2 * time.Second):
  4876  			t.Fatal("expected server connection to close")
  4877  		}
  4878  	}
  4879  }
  4880  
  4881  // The client closes the connection just after the server got the client's HEADERS
  4882  // frame, but before the server sends its HEADERS response back. The expected
  4883  // result is an error on RoundTrip explaining the client closed the connection.
  4884  func TestClientConnCloseAtHeaders(t *testing.T) {
  4885  	testClientConnClose(t, closeAtHeaders)
  4886  }
  4887  
  4888  // The client closes the connection between two server's response DATA frames.
  4889  // The expected behavior is a response body io read error on the client.
  4890  func TestClientConnCloseAtBody(t *testing.T) {
  4891  	testClientConnClose(t, closeAtBody)
  4892  }
  4893  
  4894  // The client sends a GOAWAY frame before the server finished processing a request.
  4895  // We expect the connection not to close until the request is completed.
  4896  func TestClientConnShutdown(t *testing.T) {
  4897  	testClientConnClose(t, shutdown)
  4898  }
  4899  
  4900  // The client sends a GOAWAY frame before the server finishes processing a request,
  4901  // but cancels the passed context before the request is completed. The expected
  4902  // behavior is the client closing the connection after the context is canceled.
  4903  func TestClientConnShutdownCancel(t *testing.T) {
  4904  	testClientConnClose(t, shutdownCancel)
  4905  }
  4906  
  4907  // Issue 25009: use Request.GetBody if present, even if it seems like
  4908  // we might not need it. Apparently something else can still read from
  4909  // the original request body. Data race? In any case, rewinding
  4910  // unconditionally on retry is a nicer model anyway and should
  4911  // simplify code in the future (after the Go 1.11 freeze)
  4912  func TestTransportUsesGetBodyWhenPresent(t *testing.T) {
  4913  	calls := 0
  4914  	someBody := func() io.ReadCloser {
  4915  		return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))}
  4916  	}
  4917  	req := &http.Request{
  4918  		Body: someBody(),
  4919  		GetBody: func() (io.ReadCloser, error) {
  4920  			calls++
  4921  			return someBody(), nil
  4922  		},
  4923  	}
  4924  
  4925  	afterBodyWrite := false // pretend we haven't read+written the body yet
  4926  	req2, err := shouldRetryRequest(req, errClientConnUnusable, afterBodyWrite)
  4927  	if err != nil {
  4928  		t.Fatal(err)
  4929  	}
  4930  	if calls != 1 {
  4931  		t.Errorf("Calls = %d; want 1", calls)
  4932  	}
  4933  	if req2 == req {
  4934  		t.Error("req2 changed")
  4935  	}
  4936  	if req2 == nil {
  4937  		t.Fatal("req2 is nil")
  4938  	}
  4939  	if req2.Body == nil {
  4940  		t.Fatal("req2.Body is nil")
  4941  	}
  4942  	if req2.GetBody == nil {
  4943  		t.Fatal("req2.GetBody is nil")
  4944  	}
  4945  	if req2.Body == req.Body {
  4946  		t.Error("req2.Body unchanged")
  4947  	}
  4948  }
  4949  
  4950  // Issue 22891: verify that the "https" altproto we register with net/http
  4951  // is a certain type: a struct with one field with our *http2.Transport in it.
  4952  func TestNoDialH2RoundTripperType(t *testing.T) {
  4953  	t1 := new(http.Transport)
  4954  	t2 := new(Transport)
  4955  	rt := noDialH2RoundTripper{t2}
  4956  	if err := registerHTTPSProtocol(t1, rt); err != nil {
  4957  		t.Fatal(err)
  4958  	}
  4959  	rv := reflect.ValueOf(rt)
  4960  	if rv.Type().Kind() != reflect.Struct {
  4961  		t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind())
  4962  	}
  4963  	if n := rv.Type().NumField(); n != 1 {
  4964  		t.Fatalf("fields = %d; net/http expects 1", n)
  4965  	}
  4966  	v := rv.Field(0)
  4967  	if _, ok := v.Interface().(*Transport); !ok {
  4968  		t.Fatalf("wrong kind %T; want *Transport", v.Interface())
  4969  	}
  4970  }
  4971  
  4972  type errReader struct {
  4973  	body []byte
  4974  	err  error
  4975  }
  4976  
  4977  func (r *errReader) Read(p []byte) (int, error) {
  4978  	if len(r.body) > 0 {
  4979  		n := copy(p, r.body)
  4980  		r.body = r.body[n:]
  4981  		return n, nil
  4982  	}
  4983  	return 0, r.err
  4984  }
  4985  
  4986  func testTransportBodyReadError(t *testing.T, body []byte) {
  4987  	if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
  4988  		// So far we've only seen this be flaky on Windows and Plan 9,
  4989  		// perhaps due to TCP behavior on shutdowns while
  4990  		// unread data is in flight. This test should be
  4991  		// fixed, but a skip is better than annoying people
  4992  		// for now.
  4993  		t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS)
  4994  	}
  4995  	clientDone := make(chan struct{})
  4996  	ct := newClientTester(t)
  4997  	ct.client = func() error {
  4998  		defer ct.cc.(*net.TCPConn).CloseWrite()
  4999  		if runtime.GOOS == "plan9" {
  5000  			// CloseWrite not supported on Plan 9; Issue 17906
  5001  			defer ct.cc.(*net.TCPConn).Close()
  5002  		}
  5003  		defer close(clientDone)
  5004  
  5005  		checkNoStreams := func() error {
  5006  			cp, ok := ct.tr.connPool().(*clientConnPool)
  5007  			if !ok {
  5008  				return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
  5009  			}
  5010  			cp.mu.Lock()
  5011  			defer cp.mu.Unlock()
  5012  			conns, ok := cp.conns["dummy.tld:443"]
  5013  			if !ok {
  5014  				return fmt.Errorf("missing connection")
  5015  			}
  5016  			if len(conns) != 1 {
  5017  				return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
  5018  			}
  5019  			if activeStreams(conns[0]) != 0 {
  5020  				return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
  5021  			}
  5022  			return nil
  5023  		}
  5024  		bodyReadError := errors.New("body read error")
  5025  		body := &errReader{body, bodyReadError}
  5026  		req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
  5027  		if err != nil {
  5028  			return err
  5029  		}
  5030  		_, err = ct.tr.RoundTrip(req)
  5031  		if err != bodyReadError {
  5032  			return fmt.Errorf("err = %v; want %v", err, bodyReadError)
  5033  		}
  5034  		if err = checkNoStreams(); err != nil {
  5035  			return err
  5036  		}
  5037  		return nil
  5038  	}
  5039  	ct.server = func() error {
  5040  		ct.greet()
  5041  		var receivedBody []byte
  5042  		var resetCount int
  5043  		for {
  5044  			f, err := ct.fr.ReadFrame()
  5045  			t.Logf("server: ReadFrame = %v, %v", f, err)
  5046  			if err != nil {
  5047  				select {
  5048  				case <-clientDone:
  5049  					// If the client's done, it
  5050  					// will have reported any
  5051  					// errors on its side.
  5052  					if bytes.Compare(receivedBody, body) != 0 {
  5053  						return fmt.Errorf("body: %q; expected %q", receivedBody, body)
  5054  					}
  5055  					if resetCount != 1 {
  5056  						return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
  5057  					}
  5058  					return nil
  5059  				default:
  5060  					return err
  5061  				}
  5062  			}
  5063  			switch f := f.(type) {
  5064  			case *WindowUpdateFrame, *SettingsFrame:
  5065  			case *HeadersFrame:
  5066  			case *DataFrame:
  5067  				receivedBody = append(receivedBody, f.Data()...)
  5068  			case *RSTStreamFrame:
  5069  				resetCount++
  5070  			default:
  5071  				return fmt.Errorf("Unexpected client frame %v", f)
  5072  			}
  5073  		}
  5074  	}
  5075  	ct.run()
  5076  }
  5077  
  5078  func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
  5079  func TestTransportBodyReadError_Some(t *testing.T)        { testTransportBodyReadError(t, []byte("123")) }
  5080  
  5081  // Issue 32254: verify that the client sends END_STREAM flag eagerly with the last
  5082  // (or in this test-case the only one) request body data frame, and does not send
  5083  // extra zero-len data frames.
  5084  func TestTransportBodyEagerEndStream(t *testing.T) {
  5085  	const reqBody = "some request body"
  5086  	const resBody = "some response body"
  5087  
  5088  	ct := newClientTester(t)
  5089  	ct.client = func() error {
  5090  		defer ct.cc.(*net.TCPConn).CloseWrite()
  5091  		if runtime.GOOS == "plan9" {
  5092  			// CloseWrite not supported on Plan 9; Issue 17906
  5093  			defer ct.cc.(*net.TCPConn).Close()
  5094  		}
  5095  		body := strings.NewReader(reqBody)
  5096  		req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
  5097  		if err != nil {
  5098  			return err
  5099  		}
  5100  		_, err = ct.tr.RoundTrip(req)
  5101  		if err != nil {
  5102  			return err
  5103  		}
  5104  		return nil
  5105  	}
  5106  	ct.server = func() error {
  5107  		ct.greet()
  5108  
  5109  		for {
  5110  			f, err := ct.fr.ReadFrame()
  5111  			if err != nil {
  5112  				return err
  5113  			}
  5114  
  5115  			switch f := f.(type) {
  5116  			case *WindowUpdateFrame, *SettingsFrame:
  5117  			case *HeadersFrame:
  5118  			case *DataFrame:
  5119  				if !f.StreamEnded() {
  5120  					ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
  5121  					return fmt.Errorf("data frame without END_STREAM %v", f)
  5122  				}
  5123  				var buf bytes.Buffer
  5124  				enc := hpack.NewEncoder(&buf)
  5125  				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
  5126  				ct.fr.WriteHeaders(HeadersFrameParam{
  5127  					StreamID:      f.Header().StreamID,
  5128  					EndHeaders:    true,
  5129  					EndStream:     false,
  5130  					BlockFragment: buf.Bytes(),
  5131  				})
  5132  				ct.fr.WriteData(f.StreamID, true, []byte(resBody))
  5133  				return nil
  5134  			case *RSTStreamFrame:
  5135  			default:
  5136  				return fmt.Errorf("Unexpected client frame %v", f)
  5137  			}
  5138  		}
  5139  	}
  5140  	ct.run()
  5141  }
  5142  
  5143  type chunkReader struct {
  5144  	chunks [][]byte
  5145  }
  5146  
  5147  func (r *chunkReader) Read(p []byte) (int, error) {
  5148  	if len(r.chunks) > 0 {
  5149  		n := copy(p, r.chunks[0])
  5150  		r.chunks = r.chunks[1:]
  5151  		return n, nil
  5152  	}
  5153  	panic("shouldn't read this many times")
  5154  }
  5155  
  5156  // Issue 32254: if the request body is larger than the specified
  5157  // content length, the client should refuse to send the extra part
  5158  // and abort the stream.
  5159  //
  5160  // In _len3 case, the first Read() matches the expected content length
  5161  // but the second read returns more data.
  5162  //
  5163  // In _len2 case, the first Read() exceeds the expected content length.
  5164  func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) {
  5165  	body := &chunkReader{[][]byte{
  5166  		[]byte("123"),
  5167  		[]byte("456"),
  5168  	}}
  5169  	testTransportBodyLargerThanSpecifiedContentLength(t, body, 3)
  5170  }
  5171  
  5172  func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) {
  5173  	body := &chunkReader{[][]byte{
  5174  		[]byte("123"),
  5175  	}}
  5176  	testTransportBodyLargerThanSpecifiedContentLength(t, body, 2)
  5177  }
  5178  
  5179  func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) {
  5180  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5181  		r.Body.Read(make([]byte, 6))
  5182  	}, optOnlyServer)
  5183  	defer st.Close()
  5184  
  5185  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  5186  	defer tr.CloseIdleConnections()
  5187  
  5188  	req, _ := http.NewRequest("POST", st.ts.URL, body)
  5189  	req.ContentLength = contentLen
  5190  	_, err := tr.RoundTrip(req)
  5191  	if err != errReqBodyTooLong {
  5192  		t.Fatalf("expected %v, got %v", errReqBodyTooLong, err)
  5193  	}
  5194  }
  5195  
  5196  func TestClientConnTooIdle(t *testing.T) {
  5197  	tests := []struct {
  5198  		cc   func() *ClientConn
  5199  		want bool
  5200  	}{
  5201  		{
  5202  			func() *ClientConn {
  5203  				return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)}
  5204  			},
  5205  			true,
  5206  		},
  5207  		{
  5208  			func() *ClientConn {
  5209  				return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}}
  5210  			},
  5211  			false,
  5212  		},
  5213  		{
  5214  			func() *ClientConn {
  5215  				return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)}
  5216  			},
  5217  			false,
  5218  		},
  5219  		{
  5220  			func() *ClientConn {
  5221  				return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)}
  5222  			},
  5223  			false,
  5224  		},
  5225  	}
  5226  	for i, tt := range tests {
  5227  		got := tt.cc().tooIdleLocked()
  5228  		if got != tt.want {
  5229  			t.Errorf("%d. got %v; want %v", i, got, tt.want)
  5230  		}
  5231  	}
  5232  }
  5233  
  5234  type fakeConnErr struct {
  5235  	net.Conn
  5236  	writeErr error
  5237  	closed   bool
  5238  }
  5239  
  5240  func (fce *fakeConnErr) Write(b []byte) (n int, err error) {
  5241  	return 0, fce.writeErr
  5242  }
  5243  
  5244  func (fce *fakeConnErr) Close() error {
  5245  	fce.closed = true
  5246  	return nil
  5247  }
  5248  
  5249  // issue 39337: close the connection on a failed write
  5250  func TestTransportNewClientConnCloseOnWriteError(t *testing.T) {
  5251  	tr := &Transport{}
  5252  	writeErr := errors.New("write error")
  5253  	fakeConn := &fakeConnErr{writeErr: writeErr}
  5254  	_, err := tr.NewClientConn(fakeConn)
  5255  	if err != writeErr {
  5256  		t.Fatalf("expected %v, got %v", writeErr, err)
  5257  	}
  5258  	if !fakeConn.closed {
  5259  		t.Error("expected closed conn")
  5260  	}
  5261  }
  5262  
  5263  func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
  5264  	req, err := http.NewRequest("GET", "https://dummy.tld/", nil)
  5265  	if err != nil {
  5266  		t.Fatal(err)
  5267  	}
  5268  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
  5269  	defer st.Close()
  5270  
  5271  	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
  5272  	defer tr.CloseIdleConnections()
  5273  	cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
  5274  	if err != nil {
  5275  		t.Fatal(err)
  5276  	}
  5277  
  5278  	writeErr := errors.New("write error")
  5279  	cc.wmu.Lock()
  5280  	cc.werr = writeErr
  5281  	cc.wmu.Unlock()
  5282  
  5283  	_, err = cc.RoundTrip(req)
  5284  	if err != writeErr {
  5285  		t.Fatalf("expected %v, got %v", writeErr, err)
  5286  	}
  5287  
  5288  	cc.mu.Lock()
  5289  	closed := cc.closed
  5290  	cc.mu.Unlock()
  5291  	if !closed {
  5292  		t.Fatal("expected closed")
  5293  	}
  5294  }
  5295  
  5296  // Issue 31192: A failed request may be retried if the body has not been read
  5297  // already. If the request body has started to be sent, one must wait until it
  5298  // is completed.
  5299  func TestTransportBodyRewindRace(t *testing.T) {
  5300  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5301  		w.Header().Set("Connection", "close")
  5302  		w.WriteHeader(http.StatusOK)
  5303  		return
  5304  	}, optOnlyServer)
  5305  	defer st.Close()
  5306  
  5307  	tr := &http.Transport{
  5308  		TLSClientConfig: tlsConfigInsecure,
  5309  		MaxConnsPerHost: 1,
  5310  	}
  5311  	err := ConfigureTransport(tr)
  5312  	if err != nil {
  5313  		t.Fatal(err)
  5314  	}
  5315  	client := &http.Client{
  5316  		Transport: tr,
  5317  	}
  5318  
  5319  	const clients = 50
  5320  
  5321  	var wg sync.WaitGroup
  5322  	wg.Add(clients)
  5323  	for i := 0; i < clients; i++ {
  5324  		req, err := http.NewRequest("POST", st.ts.URL, bytes.NewBufferString("abcdef"))
  5325  		if err != nil {
  5326  			t.Fatalf("unexpect new request error: %v", err)
  5327  		}
  5328  
  5329  		go func() {
  5330  			defer wg.Done()
  5331  			res, err := client.Do(req)
  5332  			if err == nil {
  5333  				res.Body.Close()
  5334  			}
  5335  		}()
  5336  	}
  5337  
  5338  	wg.Wait()
  5339  }
  5340  
  5341  // Issue 42498: A request with a body will never be sent if the stream is
  5342  // reset prior to sending any data.
  5343  func TestTransportServerResetStreamAtHeaders(t *testing.T) {
  5344  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
  5345  		w.WriteHeader(http.StatusUnauthorized)
  5346  		return
  5347  	}, optOnlyServer)
  5348  	defer st.Close()
  5349  
  5350  	tr := &http.Transport{
  5351  		TLSClientConfig:       tlsConfigInsecure,
  5352  		MaxConnsPerHost:       1,
  5353  		ExpectContinueTimeout: 10 * time.Second,
  5354  	}
  5355  
  5356  	err := ConfigureTransport(tr)
  5357  	if err != nil {
  5358  		t.Fatal(err)
  5359  	}
  5360  	client := &http.Client{
  5361  		Transport: tr,
  5362  	}
  5363  
  5364  	req, err := http.NewRequest("POST", st.ts.URL, errorReader{io.EOF})
  5365  	if err != nil {
  5366  		t.Fatalf("unexpect new request error: %v", err)
  5367  	}
  5368  	req.ContentLength = 0 // so transport is tempted to sniff it
  5369  	req.Header.Set("Expect", "100-continue")
  5370  	res, err := client.Do(req)
  5371  	if err != nil {
  5372  		t.Fatal(err)
  5373  	}
  5374  	res.Body.Close()
  5375  }