github.com/Andyfoo/golang/x/net@v0.0.0-20190901054642-57c1bf301704/http2/server_push_test.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http2
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"net/http"
    13  	"reflect"
    14  	"strconv"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  func TestServer_Push_Success(t *testing.T) {
    21  	const (
    22  		mainBody   = "<html>index page</html>"
    23  		pushedBody = "<html>pushed page</html>"
    24  		userAgent  = "testagent"
    25  		cookie     = "testcookie"
    26  	)
    27  
    28  	var stURL string
    29  	checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
    30  		if got, want := r.Method, wantMethod; got != want {
    31  			return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
    32  		}
    33  		if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
    34  			return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
    35  		}
    36  		if got, want := "https://"+r.Host, stURL; got != want {
    37  			return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
    38  		}
    39  		if r.Body == nil {
    40  			return fmt.Errorf("nil Body")
    41  		}
    42  		if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
    43  			return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
    44  		}
    45  		return nil
    46  	}
    47  
    48  	errc := make(chan error, 3)
    49  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
    50  		switch r.URL.RequestURI() {
    51  		case "/":
    52  			// Push "/pushed?get" as a GET request, using an absolute URL.
    53  			opt := &http.PushOptions{
    54  				Header: http.Header{
    55  					"User-Agent": {userAgent},
    56  				},
    57  			}
    58  			if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
    59  				errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
    60  				return
    61  			}
    62  			// Push "/pushed?head" as a HEAD request, using a path.
    63  			opt = &http.PushOptions{
    64  				Method: "HEAD",
    65  				Header: http.Header{
    66  					"User-Agent": {userAgent},
    67  					"Cookie":     {cookie},
    68  				},
    69  			}
    70  			if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
    71  				errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
    72  				return
    73  			}
    74  			w.Header().Set("Content-Type", "text/html")
    75  			w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
    76  			w.WriteHeader(200)
    77  			io.WriteString(w, mainBody)
    78  			errc <- nil
    79  
    80  		case "/pushed?get":
    81  			wantH := http.Header{}
    82  			wantH.Set("User-Agent", userAgent)
    83  			if err := checkPromisedReq(r, "GET", wantH); err != nil {
    84  				errc <- fmt.Errorf("/pushed?get: %v", err)
    85  				return
    86  			}
    87  			w.Header().Set("Content-Type", "text/html")
    88  			w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
    89  			w.WriteHeader(200)
    90  			io.WriteString(w, pushedBody)
    91  			errc <- nil
    92  
    93  		case "/pushed?head":
    94  			wantH := http.Header{}
    95  			wantH.Set("User-Agent", userAgent)
    96  			wantH.Set("Cookie", cookie)
    97  			if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
    98  				errc <- fmt.Errorf("/pushed?head: %v", err)
    99  				return
   100  			}
   101  			w.WriteHeader(204)
   102  			errc <- nil
   103  
   104  		default:
   105  			errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
   106  		}
   107  	})
   108  	stURL = st.ts.URL
   109  
   110  	// Send one request, which should push two responses.
   111  	st.greet()
   112  	getSlash(st)
   113  	for k := 0; k < 3; k++ {
   114  		select {
   115  		case <-time.After(2 * time.Second):
   116  			t.Errorf("timeout waiting for handler %d to finish", k)
   117  		case err := <-errc:
   118  			if err != nil {
   119  				t.Fatal(err)
   120  			}
   121  		}
   122  	}
   123  
   124  	checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
   125  		pp, ok := f.(*PushPromiseFrame)
   126  		if !ok {
   127  			return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
   128  		}
   129  		if !pp.HeadersEnded() {
   130  			return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
   131  		}
   132  		if got, want := pp.PromiseID, promiseID; got != want {
   133  			return fmt.Errorf("got PromiseID %v; want %v", got, want)
   134  		}
   135  		gotH := st.decodeHeader(pp.HeaderBlockFragment())
   136  		if !reflect.DeepEqual(gotH, wantH) {
   137  			return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
   138  		}
   139  		return nil
   140  	}
   141  	checkHeaders := func(f Frame, wantH [][2]string) error {
   142  		hf, ok := f.(*HeadersFrame)
   143  		if !ok {
   144  			return fmt.Errorf("got a %T; want *HeadersFrame", f)
   145  		}
   146  		gotH := st.decodeHeader(hf.HeaderBlockFragment())
   147  		if !reflect.DeepEqual(gotH, wantH) {
   148  			return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
   149  		}
   150  		return nil
   151  	}
   152  	checkData := func(f Frame, wantData string) error {
   153  		df, ok := f.(*DataFrame)
   154  		if !ok {
   155  			return fmt.Errorf("got a %T; want *DataFrame", f)
   156  		}
   157  		if gotData := string(df.Data()); gotData != wantData {
   158  			return fmt.Errorf("got response data %q; want %q", gotData, wantData)
   159  		}
   160  		return nil
   161  	}
   162  
   163  	// Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
   164  	// Stream 2 has HEADERS + DATA
   165  	// Stream 4 has HEADERS
   166  	expected := map[uint32][]func(Frame) error{
   167  		1: {
   168  			func(f Frame) error {
   169  				return checkPushPromise(f, 2, [][2]string{
   170  					{":method", "GET"},
   171  					{":scheme", "https"},
   172  					{":authority", st.ts.Listener.Addr().String()},
   173  					{":path", "/pushed?get"},
   174  					{"user-agent", userAgent},
   175  				})
   176  			},
   177  			func(f Frame) error {
   178  				return checkPushPromise(f, 4, [][2]string{
   179  					{":method", "HEAD"},
   180  					{":scheme", "https"},
   181  					{":authority", st.ts.Listener.Addr().String()},
   182  					{":path", "/pushed?head"},
   183  					{"cookie", cookie},
   184  					{"user-agent", userAgent},
   185  				})
   186  			},
   187  			func(f Frame) error {
   188  				return checkHeaders(f, [][2]string{
   189  					{":status", "200"},
   190  					{"content-type", "text/html"},
   191  					{"content-length", strconv.Itoa(len(mainBody))},
   192  				})
   193  			},
   194  			func(f Frame) error {
   195  				return checkData(f, mainBody)
   196  			},
   197  		},
   198  		2: {
   199  			func(f Frame) error {
   200  				return checkHeaders(f, [][2]string{
   201  					{":status", "200"},
   202  					{"content-type", "text/html"},
   203  					{"content-length", strconv.Itoa(len(pushedBody))},
   204  				})
   205  			},
   206  			func(f Frame) error {
   207  				return checkData(f, pushedBody)
   208  			},
   209  		},
   210  		4: {
   211  			func(f Frame) error {
   212  				return checkHeaders(f, [][2]string{
   213  					{":status", "204"},
   214  				})
   215  			},
   216  		},
   217  	}
   218  
   219  	consumed := map[uint32]int{}
   220  	for k := 0; len(expected) > 0; k++ {
   221  		f, err := st.readFrame()
   222  		if err != nil {
   223  			for id, left := range expected {
   224  				t.Errorf("stream %d: missing %d frames", id, len(left))
   225  			}
   226  			t.Fatalf("readFrame %d: %v", k, err)
   227  		}
   228  		id := f.Header().StreamID
   229  		label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
   230  		if len(expected[id]) == 0 {
   231  			t.Fatalf("%s: unexpected frame %#+v", label, f)
   232  		}
   233  		check := expected[id][0]
   234  		expected[id] = expected[id][1:]
   235  		if len(expected[id]) == 0 {
   236  			delete(expected, id)
   237  		}
   238  		if err := check(f); err != nil {
   239  			t.Fatalf("%s: %v", label, err)
   240  		}
   241  		consumed[id]++
   242  	}
   243  }
   244  
   245  func TestServer_Push_SuccessNoRace(t *testing.T) {
   246  	// Regression test for issue #18326. Ensure the request handler can mutate
   247  	// pushed request headers without racing with the PUSH_PROMISE write.
   248  	errc := make(chan error, 2)
   249  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   250  		switch r.URL.RequestURI() {
   251  		case "/":
   252  			opt := &http.PushOptions{
   253  				Header: http.Header{"User-Agent": {"testagent"}},
   254  			}
   255  			if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
   256  				errc <- fmt.Errorf("error pushing: %v", err)
   257  				return
   258  			}
   259  			w.WriteHeader(200)
   260  			errc <- nil
   261  
   262  		case "/pushed":
   263  			// Update request header, ensure there is no race.
   264  			r.Header.Set("User-Agent", "newagent")
   265  			r.Header.Set("Cookie", "cookie")
   266  			w.WriteHeader(200)
   267  			errc <- nil
   268  
   269  		default:
   270  			errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
   271  		}
   272  	})
   273  
   274  	// Send one request, which should push one response.
   275  	st.greet()
   276  	getSlash(st)
   277  	for k := 0; k < 2; k++ {
   278  		select {
   279  		case <-time.After(2 * time.Second):
   280  			t.Errorf("timeout waiting for handler %d to finish", k)
   281  		case err := <-errc:
   282  			if err != nil {
   283  				t.Fatal(err)
   284  			}
   285  		}
   286  	}
   287  }
   288  
   289  func TestServer_Push_RejectRecursivePush(t *testing.T) {
   290  	// Expect two requests, but might get three if there's a bug and the second push succeeds.
   291  	errc := make(chan error, 3)
   292  	handler := func(w http.ResponseWriter, r *http.Request) error {
   293  		baseURL := "https://" + r.Host
   294  		switch r.URL.Path {
   295  		case "/":
   296  			if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
   297  				return fmt.Errorf("first Push()=%v, want nil", err)
   298  			}
   299  			return nil
   300  
   301  		case "/push1":
   302  			if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
   303  				return fmt.Errorf("Push()=%v, want %v", got, want)
   304  			}
   305  			return nil
   306  
   307  		default:
   308  			return fmt.Errorf("unexpected path: %q", r.URL.Path)
   309  		}
   310  	}
   311  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   312  		errc <- handler(w, r)
   313  	})
   314  	defer st.Close()
   315  	st.greet()
   316  	getSlash(st)
   317  	if err := <-errc; err != nil {
   318  		t.Errorf("First request failed: %v", err)
   319  	}
   320  	if err := <-errc; err != nil {
   321  		t.Errorf("Second request failed: %v", err)
   322  	}
   323  }
   324  
   325  func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
   326  	// Expect one request, but might get two if there's a bug and the push succeeds.
   327  	errc := make(chan error, 2)
   328  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   329  		errc <- doPush(w.(http.Pusher), r)
   330  	})
   331  	defer st.Close()
   332  	st.greet()
   333  	if err := st.fr.WriteSettings(settings...); err != nil {
   334  		st.t.Fatalf("WriteSettings: %v", err)
   335  	}
   336  	st.wantSettingsAck()
   337  	getSlash(st)
   338  	if err := <-errc; err != nil {
   339  		t.Error(err)
   340  	}
   341  	// Should not get a PUSH_PROMISE frame.
   342  	hf := st.wantHeaders()
   343  	if !hf.StreamEnded() {
   344  		t.Error("stream should end after headers")
   345  	}
   346  }
   347  
   348  func TestServer_Push_RejectIfDisabled(t *testing.T) {
   349  	testServer_Push_RejectSingleRequest(t,
   350  		func(p http.Pusher, r *http.Request) error {
   351  			if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
   352  				return fmt.Errorf("Push()=%v, want %v", got, want)
   353  			}
   354  			return nil
   355  		},
   356  		Setting{SettingEnablePush, 0})
   357  }
   358  
   359  func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
   360  	testServer_Push_RejectSingleRequest(t,
   361  		func(p http.Pusher, r *http.Request) error {
   362  			if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
   363  				return fmt.Errorf("Push()=%v, want %v", got, want)
   364  			}
   365  			return nil
   366  		},
   367  		Setting{SettingMaxConcurrentStreams, 0})
   368  }
   369  
   370  func TestServer_Push_RejectWrongScheme(t *testing.T) {
   371  	testServer_Push_RejectSingleRequest(t,
   372  		func(p http.Pusher, r *http.Request) error {
   373  			if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
   374  				return errors.New("Push() should have failed (push target URL is http)")
   375  			}
   376  			return nil
   377  		})
   378  }
   379  
   380  func TestServer_Push_RejectMissingHost(t *testing.T) {
   381  	testServer_Push_RejectSingleRequest(t,
   382  		func(p http.Pusher, r *http.Request) error {
   383  			if err := p.Push("https:pushed", nil); err == nil {
   384  				return errors.New("Push() should have failed (push target URL missing host)")
   385  			}
   386  			return nil
   387  		})
   388  }
   389  
   390  func TestServer_Push_RejectRelativePath(t *testing.T) {
   391  	testServer_Push_RejectSingleRequest(t,
   392  		func(p http.Pusher, r *http.Request) error {
   393  			if err := p.Push("../test", nil); err == nil {
   394  				return errors.New("Push() should have failed (push target is a relative path)")
   395  			}
   396  			return nil
   397  		})
   398  }
   399  
   400  func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
   401  	testServer_Push_RejectSingleRequest(t,
   402  		func(p http.Pusher, r *http.Request) error {
   403  			if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
   404  				return errors.New("Push() should have failed (cannot promise a POST)")
   405  			}
   406  			return nil
   407  		})
   408  }
   409  
   410  func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
   411  	testServer_Push_RejectSingleRequest(t,
   412  		func(p http.Pusher, r *http.Request) error {
   413  			header := http.Header{
   414  				"Content-Length":   {"10"},
   415  				"Content-Encoding": {"gzip"},
   416  				"Trailer":          {"Foo"},
   417  				"Te":               {"trailers"},
   418  				"Host":             {"test.com"},
   419  				":authority":       {"test.com"},
   420  			}
   421  			if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
   422  				return errors.New("Push() should have failed (forbidden headers)")
   423  			}
   424  			return nil
   425  		})
   426  }
   427  
   428  func TestServer_Push_StateTransitions(t *testing.T) {
   429  	const body = "foo"
   430  
   431  	gotPromise := make(chan bool)
   432  	finishedPush := make(chan bool)
   433  
   434  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   435  		switch r.URL.RequestURI() {
   436  		case "/":
   437  			if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
   438  				t.Errorf("Push error: %v", err)
   439  			}
   440  			// Don't finish this request until the push finishes so we don't
   441  			// nondeterministically interleave output frames with the push.
   442  			<-finishedPush
   443  		case "/pushed":
   444  			<-gotPromise
   445  		}
   446  		w.Header().Set("Content-Type", "text/html")
   447  		w.Header().Set("Content-Length", strconv.Itoa(len(body)))
   448  		w.WriteHeader(200)
   449  		io.WriteString(w, body)
   450  	})
   451  	defer st.Close()
   452  
   453  	st.greet()
   454  	if st.stream(2) != nil {
   455  		t.Fatal("stream 2 should be empty")
   456  	}
   457  	if got, want := st.streamState(2), stateIdle; got != want {
   458  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   459  	}
   460  	getSlash(st)
   461  	// After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
   462  	st.wantPushPromise()
   463  	if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
   464  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   465  	}
   466  	// We stall the HTTP handler for "/pushed" until the above check. If we don't
   467  	// stall the handler, then the handler might write HEADERS and DATA and finish
   468  	// the stream before we check st.streamState(2) -- should that happen, we'll
   469  	// see stateClosed and fail the above check.
   470  	close(gotPromise)
   471  	st.wantHeaders()
   472  	if df := st.wantData(); !df.StreamEnded() {
   473  		t.Fatal("expected END_STREAM flag on DATA")
   474  	}
   475  	if got, want := st.streamState(2), stateClosed; got != want {
   476  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   477  	}
   478  	close(finishedPush)
   479  }
   480  
   481  func TestServer_Push_RejectAfterGoAway(t *testing.T) {
   482  	var readyOnce sync.Once
   483  	ready := make(chan struct{})
   484  	errc := make(chan error, 2)
   485  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   486  		select {
   487  		case <-ready:
   488  		case <-time.After(5 * time.Second):
   489  			errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed")
   490  		}
   491  		if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
   492  			errc <- fmt.Errorf("Push()=%v, want %v", got, want)
   493  		}
   494  		errc <- nil
   495  	})
   496  	defer st.Close()
   497  	st.greet()
   498  	getSlash(st)
   499  
   500  	// Send GOAWAY and wait for it to be processed.
   501  	st.fr.WriteGoAway(1, ErrCodeNo, nil)
   502  	go func() {
   503  		for {
   504  			select {
   505  			case <-ready:
   506  				return
   507  			default:
   508  			}
   509  			st.sc.serveMsgCh <- func(loopNum int) {
   510  				if !st.sc.pushEnabled {
   511  					readyOnce.Do(func() { close(ready) })
   512  				}
   513  			}
   514  		}
   515  	}()
   516  	if err := <-errc; err != nil {
   517  		t.Error(err)
   518  	}
   519  }