github.com/Andyfoo/golang/x/net@v0.0.0-20190901054642-57c1bf301704/http2/server_push_test.go (about) 1 // Copyright 2016 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package http2 6 7 import ( 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net/http" 13 "reflect" 14 "strconv" 15 "sync" 16 "testing" 17 "time" 18 ) 19 20 func TestServer_Push_Success(t *testing.T) { 21 const ( 22 mainBody = "<html>index page</html>" 23 pushedBody = "<html>pushed page</html>" 24 userAgent = "testagent" 25 cookie = "testcookie" 26 ) 27 28 var stURL string 29 checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error { 30 if got, want := r.Method, wantMethod; got != want { 31 return fmt.Errorf("promised Req.Method=%q, want %q", got, want) 32 } 33 if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) { 34 return fmt.Errorf("promised Req.Header=%q, want %q", got, want) 35 } 36 if got, want := "https://"+r.Host, stURL; got != want { 37 return fmt.Errorf("promised Req.Host=%q, want %q", got, want) 38 } 39 if r.Body == nil { 40 return fmt.Errorf("nil Body") 41 } 42 if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 { 43 return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err) 44 } 45 return nil 46 } 47 48 errc := make(chan error, 3) 49 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 50 switch r.URL.RequestURI() { 51 case "/": 52 // Push "/pushed?get" as a GET request, using an absolute URL. 53 opt := &http.PushOptions{ 54 Header: http.Header{ 55 "User-Agent": {userAgent}, 56 }, 57 } 58 if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil { 59 errc <- fmt.Errorf("error pushing /pushed?get: %v", err) 60 return 61 } 62 // Push "/pushed?head" as a HEAD request, using a path. 63 opt = &http.PushOptions{ 64 Method: "HEAD", 65 Header: http.Header{ 66 "User-Agent": {userAgent}, 67 "Cookie": {cookie}, 68 }, 69 } 70 if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil { 71 errc <- fmt.Errorf("error pushing /pushed?head: %v", err) 72 return 73 } 74 w.Header().Set("Content-Type", "text/html") 75 w.Header().Set("Content-Length", strconv.Itoa(len(mainBody))) 76 w.WriteHeader(200) 77 io.WriteString(w, mainBody) 78 errc <- nil 79 80 case "/pushed?get": 81 wantH := http.Header{} 82 wantH.Set("User-Agent", userAgent) 83 if err := checkPromisedReq(r, "GET", wantH); err != nil { 84 errc <- fmt.Errorf("/pushed?get: %v", err) 85 return 86 } 87 w.Header().Set("Content-Type", "text/html") 88 w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody))) 89 w.WriteHeader(200) 90 io.WriteString(w, pushedBody) 91 errc <- nil 92 93 case "/pushed?head": 94 wantH := http.Header{} 95 wantH.Set("User-Agent", userAgent) 96 wantH.Set("Cookie", cookie) 97 if err := checkPromisedReq(r, "HEAD", wantH); err != nil { 98 errc <- fmt.Errorf("/pushed?head: %v", err) 99 return 100 } 101 w.WriteHeader(204) 102 errc <- nil 103 104 default: 105 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI()) 106 } 107 }) 108 stURL = st.ts.URL 109 110 // Send one request, which should push two responses. 111 st.greet() 112 getSlash(st) 113 for k := 0; k < 3; k++ { 114 select { 115 case <-time.After(2 * time.Second): 116 t.Errorf("timeout waiting for handler %d to finish", k) 117 case err := <-errc: 118 if err != nil { 119 t.Fatal(err) 120 } 121 } 122 } 123 124 checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error { 125 pp, ok := f.(*PushPromiseFrame) 126 if !ok { 127 return fmt.Errorf("got a %T; want *PushPromiseFrame", f) 128 } 129 if !pp.HeadersEnded() { 130 return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame") 131 } 132 if got, want := pp.PromiseID, promiseID; got != want { 133 return fmt.Errorf("got PromiseID %v; want %v", got, want) 134 } 135 gotH := st.decodeHeader(pp.HeaderBlockFragment()) 136 if !reflect.DeepEqual(gotH, wantH) { 137 return fmt.Errorf("got promised headers %v; want %v", gotH, wantH) 138 } 139 return nil 140 } 141 checkHeaders := func(f Frame, wantH [][2]string) error { 142 hf, ok := f.(*HeadersFrame) 143 if !ok { 144 return fmt.Errorf("got a %T; want *HeadersFrame", f) 145 } 146 gotH := st.decodeHeader(hf.HeaderBlockFragment()) 147 if !reflect.DeepEqual(gotH, wantH) { 148 return fmt.Errorf("got response headers %v; want %v", gotH, wantH) 149 } 150 return nil 151 } 152 checkData := func(f Frame, wantData string) error { 153 df, ok := f.(*DataFrame) 154 if !ok { 155 return fmt.Errorf("got a %T; want *DataFrame", f) 156 } 157 if gotData := string(df.Data()); gotData != wantData { 158 return fmt.Errorf("got response data %q; want %q", gotData, wantData) 159 } 160 return nil 161 } 162 163 // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA 164 // Stream 2 has HEADERS + DATA 165 // Stream 4 has HEADERS 166 expected := map[uint32][]func(Frame) error{ 167 1: { 168 func(f Frame) error { 169 return checkPushPromise(f, 2, [][2]string{ 170 {":method", "GET"}, 171 {":scheme", "https"}, 172 {":authority", st.ts.Listener.Addr().String()}, 173 {":path", "/pushed?get"}, 174 {"user-agent", userAgent}, 175 }) 176 }, 177 func(f Frame) error { 178 return checkPushPromise(f, 4, [][2]string{ 179 {":method", "HEAD"}, 180 {":scheme", "https"}, 181 {":authority", st.ts.Listener.Addr().String()}, 182 {":path", "/pushed?head"}, 183 {"cookie", cookie}, 184 {"user-agent", userAgent}, 185 }) 186 }, 187 func(f Frame) error { 188 return checkHeaders(f, [][2]string{ 189 {":status", "200"}, 190 {"content-type", "text/html"}, 191 {"content-length", strconv.Itoa(len(mainBody))}, 192 }) 193 }, 194 func(f Frame) error { 195 return checkData(f, mainBody) 196 }, 197 }, 198 2: { 199 func(f Frame) error { 200 return checkHeaders(f, [][2]string{ 201 {":status", "200"}, 202 {"content-type", "text/html"}, 203 {"content-length", strconv.Itoa(len(pushedBody))}, 204 }) 205 }, 206 func(f Frame) error { 207 return checkData(f, pushedBody) 208 }, 209 }, 210 4: { 211 func(f Frame) error { 212 return checkHeaders(f, [][2]string{ 213 {":status", "204"}, 214 }) 215 }, 216 }, 217 } 218 219 consumed := map[uint32]int{} 220 for k := 0; len(expected) > 0; k++ { 221 f, err := st.readFrame() 222 if err != nil { 223 for id, left := range expected { 224 t.Errorf("stream %d: missing %d frames", id, len(left)) 225 } 226 t.Fatalf("readFrame %d: %v", k, err) 227 } 228 id := f.Header().StreamID 229 label := fmt.Sprintf("stream %d, frame %d", id, consumed[id]) 230 if len(expected[id]) == 0 { 231 t.Fatalf("%s: unexpected frame %#+v", label, f) 232 } 233 check := expected[id][0] 234 expected[id] = expected[id][1:] 235 if len(expected[id]) == 0 { 236 delete(expected, id) 237 } 238 if err := check(f); err != nil { 239 t.Fatalf("%s: %v", label, err) 240 } 241 consumed[id]++ 242 } 243 } 244 245 func TestServer_Push_SuccessNoRace(t *testing.T) { 246 // Regression test for issue #18326. Ensure the request handler can mutate 247 // pushed request headers without racing with the PUSH_PROMISE write. 248 errc := make(chan error, 2) 249 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 250 switch r.URL.RequestURI() { 251 case "/": 252 opt := &http.PushOptions{ 253 Header: http.Header{"User-Agent": {"testagent"}}, 254 } 255 if err := w.(http.Pusher).Push("/pushed", opt); err != nil { 256 errc <- fmt.Errorf("error pushing: %v", err) 257 return 258 } 259 w.WriteHeader(200) 260 errc <- nil 261 262 case "/pushed": 263 // Update request header, ensure there is no race. 264 r.Header.Set("User-Agent", "newagent") 265 r.Header.Set("Cookie", "cookie") 266 w.WriteHeader(200) 267 errc <- nil 268 269 default: 270 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI()) 271 } 272 }) 273 274 // Send one request, which should push one response. 275 st.greet() 276 getSlash(st) 277 for k := 0; k < 2; k++ { 278 select { 279 case <-time.After(2 * time.Second): 280 t.Errorf("timeout waiting for handler %d to finish", k) 281 case err := <-errc: 282 if err != nil { 283 t.Fatal(err) 284 } 285 } 286 } 287 } 288 289 func TestServer_Push_RejectRecursivePush(t *testing.T) { 290 // Expect two requests, but might get three if there's a bug and the second push succeeds. 291 errc := make(chan error, 3) 292 handler := func(w http.ResponseWriter, r *http.Request) error { 293 baseURL := "https://" + r.Host 294 switch r.URL.Path { 295 case "/": 296 if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil { 297 return fmt.Errorf("first Push()=%v, want nil", err) 298 } 299 return nil 300 301 case "/push1": 302 if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want { 303 return fmt.Errorf("Push()=%v, want %v", got, want) 304 } 305 return nil 306 307 default: 308 return fmt.Errorf("unexpected path: %q", r.URL.Path) 309 } 310 } 311 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 312 errc <- handler(w, r) 313 }) 314 defer st.Close() 315 st.greet() 316 getSlash(st) 317 if err := <-errc; err != nil { 318 t.Errorf("First request failed: %v", err) 319 } 320 if err := <-errc; err != nil { 321 t.Errorf("Second request failed: %v", err) 322 } 323 } 324 325 func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) { 326 // Expect one request, but might get two if there's a bug and the push succeeds. 327 errc := make(chan error, 2) 328 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 329 errc <- doPush(w.(http.Pusher), r) 330 }) 331 defer st.Close() 332 st.greet() 333 if err := st.fr.WriteSettings(settings...); err != nil { 334 st.t.Fatalf("WriteSettings: %v", err) 335 } 336 st.wantSettingsAck() 337 getSlash(st) 338 if err := <-errc; err != nil { 339 t.Error(err) 340 } 341 // Should not get a PUSH_PROMISE frame. 342 hf := st.wantHeaders() 343 if !hf.StreamEnded() { 344 t.Error("stream should end after headers") 345 } 346 } 347 348 func TestServer_Push_RejectIfDisabled(t *testing.T) { 349 testServer_Push_RejectSingleRequest(t, 350 func(p http.Pusher, r *http.Request) error { 351 if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { 352 return fmt.Errorf("Push()=%v, want %v", got, want) 353 } 354 return nil 355 }, 356 Setting{SettingEnablePush, 0}) 357 } 358 359 func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) { 360 testServer_Push_RejectSingleRequest(t, 361 func(p http.Pusher, r *http.Request) error { 362 if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want { 363 return fmt.Errorf("Push()=%v, want %v", got, want) 364 } 365 return nil 366 }, 367 Setting{SettingMaxConcurrentStreams, 0}) 368 } 369 370 func TestServer_Push_RejectWrongScheme(t *testing.T) { 371 testServer_Push_RejectSingleRequest(t, 372 func(p http.Pusher, r *http.Request) error { 373 if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil { 374 return errors.New("Push() should have failed (push target URL is http)") 375 } 376 return nil 377 }) 378 } 379 380 func TestServer_Push_RejectMissingHost(t *testing.T) { 381 testServer_Push_RejectSingleRequest(t, 382 func(p http.Pusher, r *http.Request) error { 383 if err := p.Push("https:pushed", nil); err == nil { 384 return errors.New("Push() should have failed (push target URL missing host)") 385 } 386 return nil 387 }) 388 } 389 390 func TestServer_Push_RejectRelativePath(t *testing.T) { 391 testServer_Push_RejectSingleRequest(t, 392 func(p http.Pusher, r *http.Request) error { 393 if err := p.Push("../test", nil); err == nil { 394 return errors.New("Push() should have failed (push target is a relative path)") 395 } 396 return nil 397 }) 398 } 399 400 func TestServer_Push_RejectForbiddenMethod(t *testing.T) { 401 testServer_Push_RejectSingleRequest(t, 402 func(p http.Pusher, r *http.Request) error { 403 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil { 404 return errors.New("Push() should have failed (cannot promise a POST)") 405 } 406 return nil 407 }) 408 } 409 410 func TestServer_Push_RejectForbiddenHeader(t *testing.T) { 411 testServer_Push_RejectSingleRequest(t, 412 func(p http.Pusher, r *http.Request) error { 413 header := http.Header{ 414 "Content-Length": {"10"}, 415 "Content-Encoding": {"gzip"}, 416 "Trailer": {"Foo"}, 417 "Te": {"trailers"}, 418 "Host": {"test.com"}, 419 ":authority": {"test.com"}, 420 } 421 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil { 422 return errors.New("Push() should have failed (forbidden headers)") 423 } 424 return nil 425 }) 426 } 427 428 func TestServer_Push_StateTransitions(t *testing.T) { 429 const body = "foo" 430 431 gotPromise := make(chan bool) 432 finishedPush := make(chan bool) 433 434 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 435 switch r.URL.RequestURI() { 436 case "/": 437 if err := w.(http.Pusher).Push("/pushed", nil); err != nil { 438 t.Errorf("Push error: %v", err) 439 } 440 // Don't finish this request until the push finishes so we don't 441 // nondeterministically interleave output frames with the push. 442 <-finishedPush 443 case "/pushed": 444 <-gotPromise 445 } 446 w.Header().Set("Content-Type", "text/html") 447 w.Header().Set("Content-Length", strconv.Itoa(len(body))) 448 w.WriteHeader(200) 449 io.WriteString(w, body) 450 }) 451 defer st.Close() 452 453 st.greet() 454 if st.stream(2) != nil { 455 t.Fatal("stream 2 should be empty") 456 } 457 if got, want := st.streamState(2), stateIdle; got != want { 458 t.Fatalf("streamState(2)=%v, want %v", got, want) 459 } 460 getSlash(st) 461 // After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote. 462 st.wantPushPromise() 463 if got, want := st.streamState(2), stateHalfClosedRemote; got != want { 464 t.Fatalf("streamState(2)=%v, want %v", got, want) 465 } 466 // We stall the HTTP handler for "/pushed" until the above check. If we don't 467 // stall the handler, then the handler might write HEADERS and DATA and finish 468 // the stream before we check st.streamState(2) -- should that happen, we'll 469 // see stateClosed and fail the above check. 470 close(gotPromise) 471 st.wantHeaders() 472 if df := st.wantData(); !df.StreamEnded() { 473 t.Fatal("expected END_STREAM flag on DATA") 474 } 475 if got, want := st.streamState(2), stateClosed; got != want { 476 t.Fatalf("streamState(2)=%v, want %v", got, want) 477 } 478 close(finishedPush) 479 } 480 481 func TestServer_Push_RejectAfterGoAway(t *testing.T) { 482 var readyOnce sync.Once 483 ready := make(chan struct{}) 484 errc := make(chan error, 2) 485 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { 486 select { 487 case <-ready: 488 case <-time.After(5 * time.Second): 489 errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed") 490 } 491 if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { 492 errc <- fmt.Errorf("Push()=%v, want %v", got, want) 493 } 494 errc <- nil 495 }) 496 defer st.Close() 497 st.greet() 498 getSlash(st) 499 500 // Send GOAWAY and wait for it to be processed. 501 st.fr.WriteGoAway(1, ErrCodeNo, nil) 502 go func() { 503 for { 504 select { 505 case <-ready: 506 return 507 default: 508 } 509 st.sc.serveMsgCh <- func(loopNum int) { 510 if !st.sc.pushEnabled { 511 readyOnce.Do(func() { close(ready) }) 512 } 513 } 514 } 515 }() 516 if err := <-errc; err != nil { 517 t.Error(err) 518 } 519 }