golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/http2/server_push_test.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http2
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"net/http"
    13  	"reflect"
    14  	"runtime"
    15  	"strconv"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  func TestServer_Push_Success(t *testing.T) {
    22  	const (
    23  		mainBody   = "<html>index page</html>"
    24  		pushedBody = "<html>pushed page</html>"
    25  		userAgent  = "testagent"
    26  		cookie     = "testcookie"
    27  	)
    28  
    29  	var stURL string
    30  	checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
    31  		if got, want := r.Method, wantMethod; got != want {
    32  			return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
    33  		}
    34  		if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
    35  			return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
    36  		}
    37  		if got, want := "https://"+r.Host, stURL; got != want {
    38  			return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
    39  		}
    40  		if r.Body == nil {
    41  			return fmt.Errorf("nil Body")
    42  		}
    43  		if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
    44  			return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
    45  		}
    46  		return nil
    47  	}
    48  
    49  	errc := make(chan error, 3)
    50  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
    51  		switch r.URL.RequestURI() {
    52  		case "/":
    53  			// Push "/pushed?get" as a GET request, using an absolute URL.
    54  			opt := &http.PushOptions{
    55  				Header: http.Header{
    56  					"User-Agent": {userAgent},
    57  				},
    58  			}
    59  			if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
    60  				errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
    61  				return
    62  			}
    63  			// Push "/pushed?head" as a HEAD request, using a path.
    64  			opt = &http.PushOptions{
    65  				Method: "HEAD",
    66  				Header: http.Header{
    67  					"User-Agent": {userAgent},
    68  					"Cookie":     {cookie},
    69  				},
    70  			}
    71  			if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
    72  				errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
    73  				return
    74  			}
    75  			w.Header().Set("Content-Type", "text/html")
    76  			w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
    77  			w.WriteHeader(200)
    78  			io.WriteString(w, mainBody)
    79  			errc <- nil
    80  
    81  		case "/pushed?get":
    82  			wantH := http.Header{}
    83  			wantH.Set("User-Agent", userAgent)
    84  			if err := checkPromisedReq(r, "GET", wantH); err != nil {
    85  				errc <- fmt.Errorf("/pushed?get: %v", err)
    86  				return
    87  			}
    88  			w.Header().Set("Content-Type", "text/html")
    89  			w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
    90  			w.WriteHeader(200)
    91  			io.WriteString(w, pushedBody)
    92  			errc <- nil
    93  
    94  		case "/pushed?head":
    95  			wantH := http.Header{}
    96  			wantH.Set("User-Agent", userAgent)
    97  			wantH.Set("Cookie", cookie)
    98  			if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
    99  				errc <- fmt.Errorf("/pushed?head: %v", err)
   100  				return
   101  			}
   102  			w.WriteHeader(204)
   103  			errc <- nil
   104  
   105  		default:
   106  			errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
   107  		}
   108  	})
   109  	stURL = st.ts.URL
   110  
   111  	// Send one request, which should push two responses.
   112  	st.greet()
   113  	getSlash(st)
   114  	for k := 0; k < 3; k++ {
   115  		select {
   116  		case <-time.After(2 * time.Second):
   117  			t.Errorf("timeout waiting for handler %d to finish", k)
   118  		case err := <-errc:
   119  			if err != nil {
   120  				t.Fatal(err)
   121  			}
   122  		}
   123  	}
   124  
   125  	checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
   126  		pp, ok := f.(*PushPromiseFrame)
   127  		if !ok {
   128  			return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
   129  		}
   130  		if !pp.HeadersEnded() {
   131  			return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
   132  		}
   133  		if got, want := pp.PromiseID, promiseID; got != want {
   134  			return fmt.Errorf("got PromiseID %v; want %v", got, want)
   135  		}
   136  		gotH := st.decodeHeader(pp.HeaderBlockFragment())
   137  		if !reflect.DeepEqual(gotH, wantH) {
   138  			return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
   139  		}
   140  		return nil
   141  	}
   142  	checkHeaders := func(f Frame, wantH [][2]string) error {
   143  		hf, ok := f.(*HeadersFrame)
   144  		if !ok {
   145  			return fmt.Errorf("got a %T; want *HeadersFrame", f)
   146  		}
   147  		gotH := st.decodeHeader(hf.HeaderBlockFragment())
   148  		if !reflect.DeepEqual(gotH, wantH) {
   149  			return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
   150  		}
   151  		return nil
   152  	}
   153  	checkData := func(f Frame, wantData string) error {
   154  		df, ok := f.(*DataFrame)
   155  		if !ok {
   156  			return fmt.Errorf("got a %T; want *DataFrame", f)
   157  		}
   158  		if gotData := string(df.Data()); gotData != wantData {
   159  			return fmt.Errorf("got response data %q; want %q", gotData, wantData)
   160  		}
   161  		return nil
   162  	}
   163  
   164  	// Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
   165  	// Stream 2 has HEADERS + DATA
   166  	// Stream 4 has HEADERS
   167  	expected := map[uint32][]func(Frame) error{
   168  		1: {
   169  			func(f Frame) error {
   170  				return checkPushPromise(f, 2, [][2]string{
   171  					{":method", "GET"},
   172  					{":scheme", "https"},
   173  					{":authority", st.ts.Listener.Addr().String()},
   174  					{":path", "/pushed?get"},
   175  					{"user-agent", userAgent},
   176  				})
   177  			},
   178  			func(f Frame) error {
   179  				return checkPushPromise(f, 4, [][2]string{
   180  					{":method", "HEAD"},
   181  					{":scheme", "https"},
   182  					{":authority", st.ts.Listener.Addr().String()},
   183  					{":path", "/pushed?head"},
   184  					{"cookie", cookie},
   185  					{"user-agent", userAgent},
   186  				})
   187  			},
   188  			func(f Frame) error {
   189  				return checkHeaders(f, [][2]string{
   190  					{":status", "200"},
   191  					{"content-type", "text/html"},
   192  					{"content-length", strconv.Itoa(len(mainBody))},
   193  				})
   194  			},
   195  			func(f Frame) error {
   196  				return checkData(f, mainBody)
   197  			},
   198  		},
   199  		2: {
   200  			func(f Frame) error {
   201  				return checkHeaders(f, [][2]string{
   202  					{":status", "200"},
   203  					{"content-type", "text/html"},
   204  					{"content-length", strconv.Itoa(len(pushedBody))},
   205  				})
   206  			},
   207  			func(f Frame) error {
   208  				return checkData(f, pushedBody)
   209  			},
   210  		},
   211  		4: {
   212  			func(f Frame) error {
   213  				return checkHeaders(f, [][2]string{
   214  					{":status", "204"},
   215  				})
   216  			},
   217  		},
   218  	}
   219  
   220  	consumed := map[uint32]int{}
   221  	for k := 0; len(expected) > 0; k++ {
   222  		f, err := st.readFrame()
   223  		if err != nil {
   224  			for id, left := range expected {
   225  				t.Errorf("stream %d: missing %d frames", id, len(left))
   226  			}
   227  			t.Fatalf("readFrame %d: %v", k, err)
   228  		}
   229  		id := f.Header().StreamID
   230  		label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
   231  		if len(expected[id]) == 0 {
   232  			t.Fatalf("%s: unexpected frame %#+v", label, f)
   233  		}
   234  		check := expected[id][0]
   235  		expected[id] = expected[id][1:]
   236  		if len(expected[id]) == 0 {
   237  			delete(expected, id)
   238  		}
   239  		if err := check(f); err != nil {
   240  			t.Fatalf("%s: %v", label, err)
   241  		}
   242  		consumed[id]++
   243  	}
   244  }
   245  
   246  func TestServer_Push_SuccessNoRace(t *testing.T) {
   247  	// Regression test for issue #18326. Ensure the request handler can mutate
   248  	// pushed request headers without racing with the PUSH_PROMISE write.
   249  	errc := make(chan error, 2)
   250  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   251  		switch r.URL.RequestURI() {
   252  		case "/":
   253  			opt := &http.PushOptions{
   254  				Header: http.Header{"User-Agent": {"testagent"}},
   255  			}
   256  			if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
   257  				errc <- fmt.Errorf("error pushing: %v", err)
   258  				return
   259  			}
   260  			w.WriteHeader(200)
   261  			errc <- nil
   262  
   263  		case "/pushed":
   264  			// Update request header, ensure there is no race.
   265  			r.Header.Set("User-Agent", "newagent")
   266  			r.Header.Set("Cookie", "cookie")
   267  			w.WriteHeader(200)
   268  			errc <- nil
   269  
   270  		default:
   271  			errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
   272  		}
   273  	})
   274  
   275  	// Send one request, which should push one response.
   276  	st.greet()
   277  	getSlash(st)
   278  	for k := 0; k < 2; k++ {
   279  		select {
   280  		case <-time.After(2 * time.Second):
   281  			t.Errorf("timeout waiting for handler %d to finish", k)
   282  		case err := <-errc:
   283  			if err != nil {
   284  				t.Fatal(err)
   285  			}
   286  		}
   287  	}
   288  }
   289  
   290  func TestServer_Push_RejectRecursivePush(t *testing.T) {
   291  	// Expect two requests, but might get three if there's a bug and the second push succeeds.
   292  	errc := make(chan error, 3)
   293  	handler := func(w http.ResponseWriter, r *http.Request) error {
   294  		baseURL := "https://" + r.Host
   295  		switch r.URL.Path {
   296  		case "/":
   297  			if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
   298  				return fmt.Errorf("first Push()=%v, want nil", err)
   299  			}
   300  			return nil
   301  
   302  		case "/push1":
   303  			if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
   304  				return fmt.Errorf("Push()=%v, want %v", got, want)
   305  			}
   306  			return nil
   307  
   308  		default:
   309  			return fmt.Errorf("unexpected path: %q", r.URL.Path)
   310  		}
   311  	}
   312  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   313  		errc <- handler(w, r)
   314  	})
   315  	defer st.Close()
   316  	st.greet()
   317  	getSlash(st)
   318  	if err := <-errc; err != nil {
   319  		t.Errorf("First request failed: %v", err)
   320  	}
   321  	if err := <-errc; err != nil {
   322  		t.Errorf("Second request failed: %v", err)
   323  	}
   324  }
   325  
   326  func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
   327  	// Expect one request, but might get two if there's a bug and the push succeeds.
   328  	errc := make(chan error, 2)
   329  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   330  		errc <- doPush(w.(http.Pusher), r)
   331  	})
   332  	defer st.Close()
   333  	st.greet()
   334  	if err := st.fr.WriteSettings(settings...); err != nil {
   335  		st.t.Fatalf("WriteSettings: %v", err)
   336  	}
   337  	st.wantSettingsAck()
   338  	getSlash(st)
   339  	if err := <-errc; err != nil {
   340  		t.Error(err)
   341  	}
   342  	// Should not get a PUSH_PROMISE frame.
   343  	hf := st.wantHeaders()
   344  	if !hf.StreamEnded() {
   345  		t.Error("stream should end after headers")
   346  	}
   347  }
   348  
   349  func TestServer_Push_RejectIfDisabled(t *testing.T) {
   350  	testServer_Push_RejectSingleRequest(t,
   351  		func(p http.Pusher, r *http.Request) error {
   352  			if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
   353  				return fmt.Errorf("Push()=%v, want %v", got, want)
   354  			}
   355  			return nil
   356  		},
   357  		Setting{SettingEnablePush, 0})
   358  }
   359  
   360  func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
   361  	testServer_Push_RejectSingleRequest(t,
   362  		func(p http.Pusher, r *http.Request) error {
   363  			if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
   364  				return fmt.Errorf("Push()=%v, want %v", got, want)
   365  			}
   366  			return nil
   367  		},
   368  		Setting{SettingMaxConcurrentStreams, 0})
   369  }
   370  
   371  func TestServer_Push_RejectWrongScheme(t *testing.T) {
   372  	testServer_Push_RejectSingleRequest(t,
   373  		func(p http.Pusher, r *http.Request) error {
   374  			if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
   375  				return errors.New("Push() should have failed (push target URL is http)")
   376  			}
   377  			return nil
   378  		})
   379  }
   380  
   381  func TestServer_Push_RejectMissingHost(t *testing.T) {
   382  	testServer_Push_RejectSingleRequest(t,
   383  		func(p http.Pusher, r *http.Request) error {
   384  			if err := p.Push("https:pushed", nil); err == nil {
   385  				return errors.New("Push() should have failed (push target URL missing host)")
   386  			}
   387  			return nil
   388  		})
   389  }
   390  
   391  func TestServer_Push_RejectRelativePath(t *testing.T) {
   392  	testServer_Push_RejectSingleRequest(t,
   393  		func(p http.Pusher, r *http.Request) error {
   394  			if err := p.Push("../test", nil); err == nil {
   395  				return errors.New("Push() should have failed (push target is a relative path)")
   396  			}
   397  			return nil
   398  		})
   399  }
   400  
   401  func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
   402  	testServer_Push_RejectSingleRequest(t,
   403  		func(p http.Pusher, r *http.Request) error {
   404  			if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
   405  				return errors.New("Push() should have failed (cannot promise a POST)")
   406  			}
   407  			return nil
   408  		})
   409  }
   410  
   411  func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
   412  	testServer_Push_RejectSingleRequest(t,
   413  		func(p http.Pusher, r *http.Request) error {
   414  			header := http.Header{
   415  				"Content-Length":   {"10"},
   416  				"Content-Encoding": {"gzip"},
   417  				"Trailer":          {"Foo"},
   418  				"Te":               {"trailers"},
   419  				"Host":             {"test.com"},
   420  				":authority":       {"test.com"},
   421  			}
   422  			if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
   423  				return errors.New("Push() should have failed (forbidden headers)")
   424  			}
   425  			return nil
   426  		})
   427  }
   428  
   429  func TestServer_Push_StateTransitions(t *testing.T) {
   430  	const body = "foo"
   431  
   432  	gotPromise := make(chan bool)
   433  	finishedPush := make(chan bool)
   434  
   435  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   436  		switch r.URL.RequestURI() {
   437  		case "/":
   438  			if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
   439  				t.Errorf("Push error: %v", err)
   440  			}
   441  			// Don't finish this request until the push finishes so we don't
   442  			// nondeterministically interleave output frames with the push.
   443  			<-finishedPush
   444  		case "/pushed":
   445  			<-gotPromise
   446  		}
   447  		w.Header().Set("Content-Type", "text/html")
   448  		w.Header().Set("Content-Length", strconv.Itoa(len(body)))
   449  		w.WriteHeader(200)
   450  		io.WriteString(w, body)
   451  	})
   452  	defer st.Close()
   453  
   454  	st.greet()
   455  	if st.stream(2) != nil {
   456  		t.Fatal("stream 2 should be empty")
   457  	}
   458  	if got, want := st.streamState(2), stateIdle; got != want {
   459  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   460  	}
   461  	getSlash(st)
   462  	// After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
   463  	st.wantPushPromise()
   464  	if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
   465  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   466  	}
   467  	// We stall the HTTP handler for "/pushed" until the above check. If we don't
   468  	// stall the handler, then the handler might write HEADERS and DATA and finish
   469  	// the stream before we check st.streamState(2) -- should that happen, we'll
   470  	// see stateClosed and fail the above check.
   471  	close(gotPromise)
   472  	st.wantHeaders()
   473  	if df := st.wantData(); !df.StreamEnded() {
   474  		t.Fatal("expected END_STREAM flag on DATA")
   475  	}
   476  	if got, want := st.streamState(2), stateClosed; got != want {
   477  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   478  	}
   479  	close(finishedPush)
   480  }
   481  
   482  func TestServer_Push_RejectAfterGoAway(t *testing.T) {
   483  	var readyOnce sync.Once
   484  	ready := make(chan struct{})
   485  	errc := make(chan error, 2)
   486  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   487  		<-ready
   488  		if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
   489  			errc <- fmt.Errorf("Push()=%v, want %v", got, want)
   490  		}
   491  		errc <- nil
   492  	})
   493  	defer st.Close()
   494  	st.greet()
   495  	getSlash(st)
   496  
   497  	// Send GOAWAY and wait for it to be processed.
   498  	st.fr.WriteGoAway(1, ErrCodeNo, nil)
   499  	go func() {
   500  		for {
   501  			select {
   502  			case <-ready:
   503  				return
   504  			default:
   505  				if runtime.GOARCH == "wasm" {
   506  					// Work around https://go.dev/issue/65178 to avoid goroutine starvation.
   507  					runtime.Gosched()
   508  				}
   509  			}
   510  			st.sc.serveMsgCh <- func(loopNum int) {
   511  				if !st.sc.pushEnabled {
   512  					readyOnce.Do(func() { close(ready) })
   513  				}
   514  			}
   515  		}
   516  	}()
   517  	if err := <-errc; err != nil {
   518  		t.Error(err)
   519  	}
   520  }
   521  
   522  func TestServer_Push_Underflow(t *testing.T) {
   523  	// Test for #63511: Send several requests which generate PUSH_PROMISE responses,
   524  	// verify they all complete successfully.
   525  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   526  		switch r.URL.RequestURI() {
   527  		case "/":
   528  			opt := &http.PushOptions{
   529  				Header: http.Header{"User-Agent": {"testagent"}},
   530  			}
   531  			if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
   532  				t.Errorf("error pushing: %v", err)
   533  			}
   534  			w.WriteHeader(200)
   535  		case "/pushed":
   536  			r.Header.Set("User-Agent", "newagent")
   537  			r.Header.Set("Cookie", "cookie")
   538  			w.WriteHeader(200)
   539  		default:
   540  			t.Errorf("unknown RequestURL %q", r.URL.RequestURI())
   541  		}
   542  	})
   543  	// Send several requests.
   544  	st.greet()
   545  	const numRequests = 4
   546  	for i := 0; i < numRequests; i++ {
   547  		st.writeHeaders(HeadersFrameParam{
   548  			StreamID:      uint32(1 + i*2), // clients send odd numbers
   549  			BlockFragment: st.encodeHeader(),
   550  			EndStream:     true,
   551  			EndHeaders:    true,
   552  		})
   553  	}
   554  	// Each request should result in one PUSH_PROMISE and two responses.
   555  	numPushPromises := 0
   556  	numHeaders := 0
   557  	for numHeaders < numRequests*2 || numPushPromises < numRequests {
   558  		f, err := st.readFrame()
   559  		if err != nil {
   560  			st.t.Fatal(err)
   561  		}
   562  		switch f := f.(type) {
   563  		case *HeadersFrame:
   564  			if !f.Flags.Has(FlagHeadersEndStream) {
   565  				t.Fatalf("got HEADERS frame with no END_STREAM, expected END_STREAM: %v", f)
   566  			}
   567  			numHeaders++
   568  		case *PushPromiseFrame:
   569  			numPushPromises++
   570  		}
   571  	}
   572  }