golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/http2/server_push_test.go (about) 1 // Copyright 2016 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package http2 6 7 import ( 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net/http" 13 "reflect" 14 "runtime" 15 "strconv" 16 "sync" 17 "testing" 18 "time" 19 ) 20 21 func TestServer_Push_Success(t *testing.T) { 22 const ( 23 mainBody = "<html>index page</html>" 24 pushedBody = "<html>pushed page</html>" 25 userAgent = "testagent" 26 cookie = "testcookie" 27 ) 28 29 var stURL string 30 checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error { 31 if got, want := r.Method, wantMethod; got != want { 32 return fmt.Errorf("promised Req.Method=%q, want %q", got, want) 33 } 34 if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) { 35 return fmt.Errorf("promised Req.Header=%q, want %q", got, want) 36 } 37 if got, want := "https://"+r.Host, stURL; got != want { 38 return fmt.Errorf("promised Req.Host=%q, want %q", got, want) 39 } 40 if r.Body == nil { 41 return fmt.Errorf("nil Body") 42 } 43 if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 { 44 return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err) 45 } 46 return nil 47 } 48 49 errc := make(chan error, 3) 50 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 51 switch r.URL.RequestURI() { 52 case "/": 53 // Push "/pushed?get" as a GET request, using an absolute URL. 54 opt := &http.PushOptions{ 55 Header: http.Header{ 56 "User-Agent": {userAgent}, 57 }, 58 } 59 if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil { 60 errc <- fmt.Errorf("error pushing /pushed?get: %v", err) 61 return 62 } 63 // Push "/pushed?head" as a HEAD request, using a path. 64 opt = &http.PushOptions{ 65 Method: "HEAD", 66 Header: http.Header{ 67 "User-Agent": {userAgent}, 68 "Cookie": {cookie}, 69 }, 70 } 71 if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil { 72 errc <- fmt.Errorf("error pushing /pushed?head: %v", err) 73 return 74 } 75 w.Header().Set("Content-Type", "text/html") 76 w.Header().Set("Content-Length", strconv.Itoa(len(mainBody))) 77 w.WriteHeader(200) 78 io.WriteString(w, mainBody) 79 errc <- nil 80 81 case "/pushed?get": 82 wantH := http.Header{} 83 wantH.Set("User-Agent", userAgent) 84 if err := checkPromisedReq(r, "GET", wantH); err != nil { 85 errc <- fmt.Errorf("/pushed?get: %v", err) 86 return 87 } 88 w.Header().Set("Content-Type", "text/html") 89 w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody))) 90 w.WriteHeader(200) 91 io.WriteString(w, pushedBody) 92 errc <- nil 93 94 case "/pushed?head": 95 wantH := http.Header{} 96 wantH.Set("User-Agent", userAgent) 97 wantH.Set("Cookie", cookie) 98 if err := checkPromisedReq(r, "HEAD", wantH); err != nil { 99 errc <- fmt.Errorf("/pushed?head: %v", err) 100 return 101 } 102 w.WriteHeader(204) 103 errc <- nil 104 105 default: 106 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI()) 107 } 108 }) 109 stURL = st.ts.URL 110 111 // Send one request, which should push two responses. 112 st.greet() 113 getSlash(st) 114 for k := 0; k < 3; k++ { 115 select { 116 case <-time.After(2 * time.Second): 117 t.Errorf("timeout waiting for handler %d to finish", k) 118 case err := <-errc: 119 if err != nil { 120 t.Fatal(err) 121 } 122 } 123 } 124 125 checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error { 126 pp, ok := f.(*PushPromiseFrame) 127 if !ok { 128 return fmt.Errorf("got a %T; want *PushPromiseFrame", f) 129 } 130 if !pp.HeadersEnded() { 131 return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame") 132 } 133 if got, want := pp.PromiseID, promiseID; got != want { 134 return fmt.Errorf("got PromiseID %v; want %v", got, want) 135 } 136 gotH := st.decodeHeader(pp.HeaderBlockFragment()) 137 if !reflect.DeepEqual(gotH, wantH) { 138 return fmt.Errorf("got promised headers %v; want %v", gotH, wantH) 139 } 140 return nil 141 } 142 checkHeaders := func(f Frame, wantH [][2]string) error { 143 hf, ok := f.(*HeadersFrame) 144 if !ok { 145 return fmt.Errorf("got a %T; want *HeadersFrame", f) 146 } 147 gotH := st.decodeHeader(hf.HeaderBlockFragment()) 148 if !reflect.DeepEqual(gotH, wantH) { 149 return fmt.Errorf("got response headers %v; want %v", gotH, wantH) 150 } 151 return nil 152 } 153 checkData := func(f Frame, wantData string) error { 154 df, ok := f.(*DataFrame) 155 if !ok { 156 return fmt.Errorf("got a %T; want *DataFrame", f) 157 } 158 if gotData := string(df.Data()); gotData != wantData { 159 return fmt.Errorf("got response data %q; want %q", gotData, wantData) 160 } 161 return nil 162 } 163 164 // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA 165 // Stream 2 has HEADERS + DATA 166 // Stream 4 has HEADERS 167 expected := map[uint32][]func(Frame) error{ 168 1: { 169 func(f Frame) error { 170 return checkPushPromise(f, 2, [][2]string{ 171 {":method", "GET"}, 172 {":scheme", "https"}, 173 {":authority", st.ts.Listener.Addr().String()}, 174 {":path", "/pushed?get"}, 175 {"user-agent", userAgent}, 176 }) 177 }, 178 func(f Frame) error { 179 return checkPushPromise(f, 4, [][2]string{ 180 {":method", "HEAD"}, 181 {":scheme", "https"}, 182 {":authority", st.ts.Listener.Addr().String()}, 183 {":path", "/pushed?head"}, 184 {"cookie", cookie}, 185 {"user-agent", userAgent}, 186 }) 187 }, 188 func(f Frame) error { 189 return checkHeaders(f, [][2]string{ 190 {":status", "200"}, 191 {"content-type", "text/html"}, 192 {"content-length", strconv.Itoa(len(mainBody))}, 193 }) 194 }, 195 func(f Frame) error { 196 return checkData(f, mainBody) 197 }, 198 }, 199 2: { 200 func(f Frame) error { 201 return checkHeaders(f, [][2]string{ 202 {":status", "200"}, 203 {"content-type", "text/html"}, 204 {"content-length", strconv.Itoa(len(pushedBody))}, 205 }) 206 }, 207 func(f Frame) error { 208 return checkData(f, pushedBody) 209 }, 210 }, 211 4: { 212 func(f Frame) error { 213 return checkHeaders(f, [][2]string{ 214 {":status", "204"}, 215 }) 216 }, 217 }, 218 } 219 220 consumed := map[uint32]int{} 221 for k := 0; len(expected) > 0; k++ { 222 f, err := st.readFrame() 223 if err != nil { 224 for id, left := range expected { 225 t.Errorf("stream %d: missing %d frames", id, len(left)) 226 } 227 t.Fatalf("readFrame %d: %v", k, err) 228 } 229 id := f.Header().StreamID 230 label := fmt.Sprintf("stream %d, frame %d", id, consumed[id]) 231 if len(expected[id]) == 0 { 232 t.Fatalf("%s: unexpected frame %#+v", label, f) 233 } 234 check := expected[id][0] 235 expected[id] = expected[id][1:] 236 if len(expected[id]) == 0 { 237 delete(expected, id) 238 } 239 if err := check(f); err != nil { 240 t.Fatalf("%s: %v", label, err) 241 } 242 consumed[id]++ 243 } 244 } 245 246 func TestServer_Push_SuccessNoRace(t *testing.T) { 247 // Regression test for issue #18326. Ensure the request handler can mutate 248 // pushed request headers without racing with the PUSH_PROMISE write. 249 errc := make(chan error, 2) 250 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 251 switch r.URL.RequestURI() { 252 case "/": 253 opt := &http.PushOptions{ 254 Header: http.Header{"User-Agent": {"testagent"}}, 255 } 256 if err := w.(http.Pusher).Push("/pushed", opt); err != nil { 257 errc <- fmt.Errorf("error pushing: %v", err) 258 return 259 } 260 w.WriteHeader(200) 261 errc <- nil 262 263 case "/pushed": 264 // Update request header, ensure there is no race. 265 r.Header.Set("User-Agent", "newagent") 266 r.Header.Set("Cookie", "cookie") 267 w.WriteHeader(200) 268 errc <- nil 269 270 default: 271 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI()) 272 } 273 }) 274 275 // Send one request, which should push one response. 276 st.greet() 277 getSlash(st) 278 for k := 0; k < 2; k++ { 279 select { 280 case <-time.After(2 * time.Second): 281 t.Errorf("timeout waiting for handler %d to finish", k) 282 case err := <-errc: 283 if err != nil { 284 t.Fatal(err) 285 } 286 } 287 } 288 } 289 290 func TestServer_Push_RejectRecursivePush(t *testing.T) { 291 // Expect two requests, but might get three if there's a bug and the second push succeeds. 292 errc := make(chan error, 3) 293 handler := func(w http.ResponseWriter, r *http.Request) error { 294 baseURL := "https://" + r.Host 295 switch r.URL.Path { 296 case "/": 297 if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil { 298 return fmt.Errorf("first Push()=%v, want nil", err) 299 } 300 return nil 301 302 case "/push1": 303 if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want { 304 return fmt.Errorf("Push()=%v, want %v", got, want) 305 } 306 return nil 307 308 default: 309 return fmt.Errorf("unexpected path: %q", r.URL.Path) 310 } 311 } 312 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 313 errc <- handler(w, r) 314 }) 315 defer st.Close() 316 st.greet() 317 getSlash(st) 318 if err := <-errc; err != nil { 319 t.Errorf("First request failed: %v", err) 320 } 321 if err := <-errc; err != nil { 322 t.Errorf("Second request failed: %v", err) 323 } 324 } 325 326 func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) { 327 // Expect one request, but might get two if there's a bug and the push succeeds. 328 errc := make(chan error, 2) 329 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 330 errc <- doPush(w.(http.Pusher), r) 331 }) 332 defer st.Close() 333 st.greet() 334 if err := st.fr.WriteSettings(settings...); err != nil { 335 st.t.Fatalf("WriteSettings: %v", err) 336 } 337 st.wantSettingsAck() 338 getSlash(st) 339 if err := <-errc; err != nil { 340 t.Error(err) 341 } 342 // Should not get a PUSH_PROMISE frame. 343 hf := st.wantHeaders() 344 if !hf.StreamEnded() { 345 t.Error("stream should end after headers") 346 } 347 } 348 349 func TestServer_Push_RejectIfDisabled(t *testing.T) { 350 testServer_Push_RejectSingleRequest(t, 351 func(p http.Pusher, r *http.Request) error { 352 if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { 353 return fmt.Errorf("Push()=%v, want %v", got, want) 354 } 355 return nil 356 }, 357 Setting{SettingEnablePush, 0}) 358 } 359 360 func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) { 361 testServer_Push_RejectSingleRequest(t, 362 func(p http.Pusher, r *http.Request) error { 363 if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want { 364 return fmt.Errorf("Push()=%v, want %v", got, want) 365 } 366 return nil 367 }, 368 Setting{SettingMaxConcurrentStreams, 0}) 369 } 370 371 func TestServer_Push_RejectWrongScheme(t *testing.T) { 372 testServer_Push_RejectSingleRequest(t, 373 func(p http.Pusher, r *http.Request) error { 374 if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil { 375 return errors.New("Push() should have failed (push target URL is http)") 376 } 377 return nil 378 }) 379 } 380 381 func TestServer_Push_RejectMissingHost(t *testing.T) { 382 testServer_Push_RejectSingleRequest(t, 383 func(p http.Pusher, r *http.Request) error { 384 if err := p.Push("https:pushed", nil); err == nil { 385 return errors.New("Push() should have failed (push target URL missing host)") 386 } 387 return nil 388 }) 389 } 390 391 func TestServer_Push_RejectRelativePath(t *testing.T) { 392 testServer_Push_RejectSingleRequest(t, 393 func(p http.Pusher, r *http.Request) error { 394 if err := p.Push("../test", nil); err == nil { 395 return errors.New("Push() should have failed (push target is a relative path)") 396 } 397 return nil 398 }) 399 } 400 401 func TestServer_Push_RejectForbiddenMethod(t *testing.T) { 402 testServer_Push_RejectSingleRequest(t, 403 func(p http.Pusher, r *http.Request) error { 404 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil { 405 return errors.New("Push() should have failed (cannot promise a POST)") 406 } 407 return nil 408 }) 409 } 410 411 func TestServer_Push_RejectForbiddenHeader(t *testing.T) { 412 testServer_Push_RejectSingleRequest(t, 413 func(p http.Pusher, r *http.Request) error { 414 header := http.Header{ 415 "Content-Length": {"10"}, 416 "Content-Encoding": {"gzip"}, 417 "Trailer": {"Foo"}, 418 "Te": {"trailers"}, 419 "Host": {"test.com"}, 420 ":authority": {"test.com"}, 421 } 422 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil { 423 return errors.New("Push() should have failed (forbidden headers)") 424 } 425 return nil 426 }) 427 } 428 429 func TestServer_Push_StateTransitions(t *testing.T) { 430 const body = "foo" 431 432 gotPromise := make(chan bool) 433 finishedPush := make(chan bool) 434 435 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 436 switch r.URL.RequestURI() { 437 case "/": 438 if err := w.(http.Pusher).Push("/pushed", nil); err != nil { 439 t.Errorf("Push error: %v", err) 440 } 441 // Don't finish this request until the push finishes so we don't 442 // nondeterministically interleave output frames with the push. 443 <-finishedPush 444 case "/pushed": 445 <-gotPromise 446 } 447 w.Header().Set("Content-Type", "text/html") 448 w.Header().Set("Content-Length", strconv.Itoa(len(body))) 449 w.WriteHeader(200) 450 io.WriteString(w, body) 451 }) 452 defer st.Close() 453 454 st.greet() 455 if st.stream(2) != nil { 456 t.Fatal("stream 2 should be empty") 457 } 458 if got, want := st.streamState(2), stateIdle; got != want { 459 t.Fatalf("streamState(2)=%v, want %v", got, want) 460 } 461 getSlash(st) 462 // After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote. 463 st.wantPushPromise() 464 if got, want := st.streamState(2), stateHalfClosedRemote; got != want { 465 t.Fatalf("streamState(2)=%v, want %v", got, want) 466 } 467 // We stall the HTTP handler for "/pushed" until the above check. If we don't 468 // stall the handler, then the handler might write HEADERS and DATA and finish 469 // the stream before we check st.streamState(2) -- should that happen, we'll 470 // see stateClosed and fail the above check. 471 close(gotPromise) 472 st.wantHeaders() 473 if df := st.wantData(); !df.StreamEnded() { 474 t.Fatal("expected END_STREAM flag on DATA") 475 } 476 if got, want := st.streamState(2), stateClosed; got != want { 477 t.Fatalf("streamState(2)=%v, want %v", got, want) 478 } 479 close(finishedPush) 480 } 481 482 func TestServer_Push_RejectAfterGoAway(t *testing.T) { 483 var readyOnce sync.Once 484 ready := make(chan struct{}) 485 errc := make(chan error, 2) 486 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 487 <-ready 488 if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { 489 errc <- fmt.Errorf("Push()=%v, want %v", got, want) 490 } 491 errc <- nil 492 }) 493 defer st.Close() 494 st.greet() 495 getSlash(st) 496 497 // Send GOAWAY and wait for it to be processed. 498 st.fr.WriteGoAway(1, ErrCodeNo, nil) 499 go func() { 500 for { 501 select { 502 case <-ready: 503 return 504 default: 505 if runtime.GOARCH == "wasm" { 506 // Work around https://go.dev/issue/65178 to avoid goroutine starvation. 507 runtime.Gosched() 508 } 509 } 510 st.sc.serveMsgCh <- func(loopNum int) { 511 if !st.sc.pushEnabled { 512 readyOnce.Do(func() { close(ready) }) 513 } 514 } 515 } 516 }() 517 if err := <-errc; err != nil { 518 t.Error(err) 519 } 520 } 521 522 func TestServer_Push_Underflow(t *testing.T) { 523 // Test for #63511: Send several requests which generate PUSH_PROMISE responses, 524 // verify they all complete successfully. 525 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 526 switch r.URL.RequestURI() { 527 case "/": 528 opt := &http.PushOptions{ 529 Header: http.Header{"User-Agent": {"testagent"}}, 530 } 531 if err := w.(http.Pusher).Push("/pushed", opt); err != nil { 532 t.Errorf("error pushing: %v", err) 533 } 534 w.WriteHeader(200) 535 case "/pushed": 536 r.Header.Set("User-Agent", "newagent") 537 r.Header.Set("Cookie", "cookie") 538 w.WriteHeader(200) 539 default: 540 t.Errorf("unknown RequestURL %q", r.URL.RequestURI()) 541 } 542 }) 543 // Send several requests. 544 st.greet() 545 const numRequests = 4 546 for i := 0; i < numRequests; i++ { 547 st.writeHeaders(HeadersFrameParam{ 548 StreamID: uint32(1 + i*2), // clients send odd numbers 549 BlockFragment: st.encodeHeader(), 550 EndStream: true, 551 EndHeaders: true, 552 }) 553 } 554 // Each request should result in one PUSH_PROMISE and two responses. 555 numPushPromises := 0 556 numHeaders := 0 557 for numHeaders < numRequests*2 || numPushPromises < numRequests { 558 f, err := st.readFrame() 559 if err != nil { 560 st.t.Fatal(err) 561 } 562 switch f := f.(type) { 563 case *HeadersFrame: 564 if !f.Flags.Has(FlagHeadersEndStream) { 565 t.Fatalf("got HEADERS frame with no END_STREAM, expected END_STREAM: %v", f) 566 } 567 numHeaders++ 568 case *PushPromiseFrame: 569 numPushPromises++ 570 } 571 } 572 }