google.golang.org/grpc@v1.72.2/internal/transport/transport_test.go (about) 1 /* 2 * 3 * Copyright 2014 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package transport 20 21 import ( 22 "bytes" 23 "context" 24 "encoding/binary" 25 "errors" 26 "fmt" 27 "io" 28 "math" 29 "net" 30 "os" 31 "runtime" 32 "strconv" 33 "strings" 34 "sync" 35 "sync/atomic" 36 "testing" 37 "time" 38 39 "github.com/google/go-cmp/cmp" 40 "golang.org/x/net/http2" 41 "golang.org/x/net/http2/hpack" 42 "google.golang.org/grpc/attributes" 43 "google.golang.org/grpc/codes" 44 "google.golang.org/grpc/credentials" 45 "google.golang.org/grpc/internal/channelz" 46 "google.golang.org/grpc/internal/grpctest" 47 "google.golang.org/grpc/internal/leakcheck" 48 "google.golang.org/grpc/internal/testutils" 49 "google.golang.org/grpc/mem" 50 "google.golang.org/grpc/metadata" 51 "google.golang.org/grpc/resolver" 52 "google.golang.org/grpc/status" 53 ) 54 55 type s struct { 56 grpctest.Tester 57 } 58 59 func Test(t *testing.T) { 60 grpctest.RunSubTests(t, s{}) 61 } 62 63 var ( 64 expectedRequest = []byte("ping") 65 expectedResponse = []byte("pong") 66 expectedRequestLarge = make([]byte, initialWindowSize*2) 67 expectedResponseLarge = make([]byte, initialWindowSize*2) 68 expectedInvalidHeaderField = "invalid/content-type" 69 ) 70 71 func init() { 72 expectedRequestLarge[0] = 'g' 73 expectedRequestLarge[len(expectedRequestLarge)-1] = 'r' 74 expectedResponseLarge[0] = 'p' 75 expectedResponseLarge[len(expectedResponseLarge)-1] = 'c' 76 } 77 78 func newBufferSlice(b []byte) mem.BufferSlice { 79 return mem.BufferSlice{mem.SliceBuffer(b)} 80 } 81 82 func (s *Stream) readTo(p []byte) (int, error) { 83 data, err := s.read(len(p)) 84 defer data.Free() 85 86 if err != nil { 87 return 0, err 88 } 89 90 if data.Len() != len(p) { 91 if err == nil { 92 err = io.ErrUnexpectedEOF 93 } 94 return 0, err 95 } 96 97 data.CopyTo(p) 98 return len(p), nil 99 } 100 101 type testStreamHandler struct { 102 t *http2Server 103 notify chan struct{} 104 getNotified chan struct{} 105 } 106 107 type hType int 108 109 const ( 110 normal hType = iota 111 suspended 112 notifyCall 113 misbehaved 114 encodingRequiredStatus 115 invalidHeaderField 116 delayRead 117 pingpong 118 ) 119 120 func (h *testStreamHandler) handleStreamAndNotify(*ServerStream) { 121 if h.notify == nil { 122 return 123 } 124 go func() { 125 select { 126 case <-h.notify: 127 default: 128 close(h.notify) 129 } 130 }() 131 } 132 133 func (h *testStreamHandler) handleStream(t *testing.T, s *ServerStream) { 134 req := expectedRequest 135 resp := expectedResponse 136 if s.Method() == "foo.Large" { 137 req = expectedRequestLarge 138 resp = expectedResponseLarge 139 } 140 p := make([]byte, len(req)) 141 _, err := s.readTo(p) 142 if err != nil { 143 return 144 } 145 if !bytes.Equal(p, req) { 146 t.Errorf("handleStream got %v, want %v", p, req) 147 s.WriteStatus(status.New(codes.Internal, "panic")) 148 return 149 } 150 // send a response back to the client. 151 s.Write(nil, newBufferSlice(resp), &WriteOptions{}) 152 // send the trailer to end the stream. 153 s.WriteStatus(status.New(codes.OK, "")) 154 } 155 156 func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *ServerStream) { 157 header := make([]byte, 5) 158 for { 159 if _, err := s.readTo(header); err != nil { 160 if err == io.EOF { 161 s.WriteStatus(status.New(codes.OK, "")) 162 return 163 } 164 t.Errorf("Error on server while reading data header: %v", err) 165 s.WriteStatus(status.New(codes.Internal, "panic")) 166 return 167 } 168 sz := binary.BigEndian.Uint32(header[1:]) 169 msg := make([]byte, int(sz)) 170 if _, err := s.readTo(msg); err != nil { 171 t.Errorf("Error on server while reading message: %v", err) 172 s.WriteStatus(status.New(codes.Internal, "panic")) 173 return 174 } 175 buf := make([]byte, sz+5) 176 buf[0] = byte(0) 177 binary.BigEndian.PutUint32(buf[1:], uint32(sz)) 178 copy(buf[5:], msg) 179 s.Write(nil, newBufferSlice(buf), &WriteOptions{}) 180 } 181 } 182 183 func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *ServerStream) { 184 conn, ok := s.st.(*http2Server) 185 if !ok { 186 t.Errorf("Failed to convert %v to *http2Server", s.st) 187 s.WriteStatus(status.New(codes.Internal, "")) 188 return 189 } 190 var sent int 191 p := make([]byte, http2MaxFrameLen) 192 for sent < initialWindowSize { 193 n := initialWindowSize - sent 194 // The last message may be smaller than http2MaxFrameLen 195 if n <= http2MaxFrameLen { 196 if s.Method() == "foo.Connection" { 197 // Violate connection level flow control window of client but do not 198 // violate any stream level windows. 199 p = make([]byte, n) 200 } else { 201 // Violate stream level flow control window of client. 202 p = make([]byte, n+1) 203 } 204 } 205 data := newBufferSlice(p) 206 conn.controlBuf.put(&dataFrame{ 207 streamID: s.id, 208 h: nil, 209 reader: data.Reader(), 210 onEachWrite: func() {}, 211 }) 212 sent += len(p) 213 } 214 } 215 216 func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *ServerStream) { 217 // raw newline is not accepted by http2 framer so it must be encoded. 218 s.WriteStatus(encodingTestStatus) 219 // Drain any remaining buffers from the stream since it was closed early. 220 s.Read(math.MaxInt) 221 } 222 223 func (h *testStreamHandler) handleStreamInvalidHeaderField(s *ServerStream) { 224 headerFields := []hpack.HeaderField{} 225 headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) 226 h.t.controlBuf.put(&headerFrame{ 227 streamID: s.id, 228 hf: headerFields, 229 endStream: false, 230 }) 231 } 232 233 // handleStreamDelayRead delays reads so that the other side has to halt on 234 // stream-level flow control. 235 // This handler assumes dynamic flow control is turned off and assumes window 236 // sizes to be set to defaultWindowSize. 237 func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream) { 238 req := expectedRequest 239 resp := expectedResponse 240 if s.Method() == "foo.Large" { 241 req = expectedRequestLarge 242 resp = expectedResponseLarge 243 } 244 var ( 245 mu sync.Mutex 246 total int 247 ) 248 s.wq.replenish = func(n int) { 249 mu.Lock() 250 total += n 251 mu.Unlock() 252 s.wq.realReplenish(n) 253 } 254 getTotal := func() int { 255 mu.Lock() 256 defer mu.Unlock() 257 return total 258 } 259 done := make(chan struct{}) 260 defer close(done) 261 go func() { 262 for { 263 select { 264 // Prevent goroutine from leaking. 265 case <-done: 266 return 267 default: 268 } 269 if getTotal() == defaultWindowSize { 270 // Signal the client to start reading and 271 // thereby send window update. 272 close(h.notify) 273 return 274 } 275 runtime.Gosched() 276 } 277 }() 278 p := make([]byte, len(req)) 279 280 // Let the other side run out of stream-level window before 281 // starting to read and thereby sending a window update. 282 timer := time.NewTimer(time.Second * 10) 283 select { 284 case <-h.getNotified: 285 timer.Stop() 286 case <-timer.C: 287 t.Errorf("Server timed-out.") 288 return 289 } 290 _, err := s.readTo(p) 291 if err != nil { 292 t.Errorf("s.Read(_) = _, %v, want _, <nil>", err) 293 return 294 } 295 296 if !bytes.Equal(p, req) { 297 t.Errorf("handleStream got %v, want %v", p, req) 298 return 299 } 300 // This write will cause server to run out of stream level, 301 // flow control and the other side won't send a window update 302 // until that happens. 303 if err := s.Write(nil, newBufferSlice(resp), &WriteOptions{}); err != nil { 304 t.Errorf("server Write got %v, want <nil>", err) 305 return 306 } 307 // Read one more time to ensure that everything remains fine and 308 // that the goroutine, that we launched earlier to signal client 309 // to read, gets enough time to process. 310 _, err = s.readTo(p) 311 if err != nil { 312 t.Errorf("s.Read(_) = _, %v, want _, nil", err) 313 return 314 } 315 // send the trailer to end the stream. 316 if err := s.WriteStatus(status.New(codes.OK, "")); err != nil { 317 t.Errorf("server WriteStatus got %v, want <nil>", err) 318 return 319 } 320 } 321 322 type server struct { 323 lis net.Listener 324 port string 325 startedErr chan error // error (or nil) with server start value 326 mu sync.Mutex 327 conns map[ServerTransport]net.Conn 328 h *testStreamHandler 329 ready chan struct{} 330 channelz *channelz.Server 331 } 332 333 func newTestServer() *server { 334 return &server{ 335 startedErr: make(chan error, 1), 336 ready: make(chan struct{}), 337 channelz: channelz.RegisterServer("test server"), 338 } 339 } 340 341 // start starts server. Other goroutines should block on s.readyChan for further operations. 342 func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) { 343 var err error 344 if port == 0 { 345 s.lis, err = net.Listen("tcp", "localhost:0") 346 } else { 347 s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port)) 348 } 349 if err != nil { 350 s.startedErr <- fmt.Errorf("failed to listen: %v", err) 351 return 352 } 353 _, p, err := net.SplitHostPort(s.lis.Addr().String()) 354 if err != nil { 355 s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err) 356 return 357 } 358 s.port = p 359 s.conns = make(map[ServerTransport]net.Conn) 360 s.startedErr <- nil 361 for { 362 conn, err := s.lis.Accept() 363 if err != nil { 364 return 365 } 366 rawConn := conn 367 if serverConfig.MaxStreams == 0 { 368 serverConfig.MaxStreams = math.MaxUint32 369 } 370 transport, err := NewServerTransport(conn, serverConfig) 371 if err != nil { 372 return 373 } 374 s.mu.Lock() 375 if s.conns == nil { 376 s.mu.Unlock() 377 transport.Close(errors.New("s.conns is nil")) 378 return 379 } 380 s.conns[transport] = rawConn 381 h := &testStreamHandler{t: transport.(*http2Server)} 382 s.h = h 383 s.mu.Unlock() 384 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 385 defer cancel() 386 switch ht { 387 case notifyCall: 388 go transport.HandleStreams(ctx, h.handleStreamAndNotify) 389 case suspended: 390 go transport.HandleStreams(ctx, func(*ServerStream) {}) 391 case misbehaved: 392 go transport.HandleStreams(ctx, func(s *ServerStream) { 393 go h.handleStreamMisbehave(t, s) 394 }) 395 case encodingRequiredStatus: 396 go transport.HandleStreams(ctx, func(s *ServerStream) { 397 go h.handleStreamEncodingRequiredStatus(s) 398 }) 399 case invalidHeaderField: 400 go transport.HandleStreams(ctx, func(s *ServerStream) { 401 go h.handleStreamInvalidHeaderField(s) 402 }) 403 case delayRead: 404 h.notify = make(chan struct{}) 405 h.getNotified = make(chan struct{}) 406 s.mu.Lock() 407 close(s.ready) 408 s.mu.Unlock() 409 go transport.HandleStreams(ctx, func(s *ServerStream) { 410 go h.handleStreamDelayRead(t, s) 411 }) 412 case pingpong: 413 go transport.HandleStreams(ctx, func(s *ServerStream) { 414 go h.handleStreamPingPong(t, s) 415 }) 416 default: 417 go transport.HandleStreams(ctx, func(s *ServerStream) { 418 go h.handleStream(t, s) 419 }) 420 } 421 } 422 } 423 424 func (s *server) wait(t *testing.T, timeout time.Duration) { 425 select { 426 case err := <-s.startedErr: 427 if err != nil { 428 t.Fatal(err) 429 } 430 case <-time.After(timeout): 431 t.Fatalf("Timed out after %v waiting for server to be ready", timeout) 432 } 433 } 434 435 func (s *server) stop() { 436 s.lis.Close() 437 s.mu.Lock() 438 for c := range s.conns { 439 c.Close(errors.New("server Stop called")) 440 } 441 s.conns = nil 442 s.mu.Unlock() 443 } 444 445 func (s *server) addr() string { 446 if s.lis == nil { 447 return "" 448 } 449 return s.lis.Addr().String() 450 } 451 452 func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server { 453 server := newTestServer() 454 sc.ChannelzParent = server.channelz 455 go server.start(t, port, sc, ht) 456 server.wait(t, 2*time.Second) 457 return server 458 } 459 460 func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) { 461 return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{}) 462 } 463 464 func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) { 465 server := setUpServerOnly(t, port, sc, ht) 466 addr := resolver.Address{Addr: "localhost:" + server.port} 467 copts.ChannelzParent = channelzSubChannel(t) 468 469 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 470 t.Cleanup(cancel) 471 connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second) 472 ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {}) 473 if connErr != nil { 474 cCancel() // Do not cancel in success path. 475 t.Fatalf("failed to create transport: %v", connErr) 476 } 477 return server, ct.(*http2Client), cCancel 478 } 479 480 func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) { 481 lis, err := net.Listen("tcp", "localhost:0") 482 if err != nil { 483 t.Fatalf("Failed to listen: %v", err) 484 } 485 // Launch a non responsive server. 486 go func() { 487 defer lis.Close() 488 conn, err := lis.Accept() 489 if err != nil { 490 t.Errorf("Error at server-side while accepting: %v", err) 491 close(connCh) 492 return 493 } 494 framer := http2.NewFramer(conn, conn) 495 if err := framer.WriteSettings(); err != nil { 496 t.Errorf("Error at server-side while writing settings: %v", err) 497 close(connCh) 498 return 499 } 500 connCh <- conn 501 }() 502 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 503 t.Cleanup(cancel) 504 connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second) 505 tr, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) 506 if err != nil { 507 cCancel() // Do not cancel in success path. 508 // Server clean-up. 509 lis.Close() 510 if conn, ok := <-connCh; ok { 511 conn.Close() 512 } 513 t.Fatalf("Failed to dial: %v", err) 514 } 515 return tr.(*http2Client), cCancel 516 } 517 518 // TestInflightStreamClosing ensures that closing in-flight stream 519 // sends status error to concurrent stream reader. 520 func (s) TestInflightStreamClosing(t *testing.T) { 521 serverConfig := &ServerConfig{} 522 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 523 defer cancel() 524 defer server.stop() 525 defer client.Close(fmt.Errorf("closed manually by test")) 526 527 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 528 defer cancel() 529 stream, err := client.NewStream(ctx, &CallHdr{}) 530 if err != nil { 531 t.Fatalf("Client failed to create RPC request: %v", err) 532 } 533 534 donec := make(chan struct{}) 535 serr := status.Error(codes.Internal, "client connection is closing") 536 go func() { 537 defer close(donec) 538 if _, err := stream.readTo(make([]byte, defaultWindowSize)); err != serr { 539 t.Errorf("unexpected Stream error %v, expected %v", err, serr) 540 } 541 }() 542 543 // should unblock concurrent stream.Read 544 stream.Close(serr) 545 546 // wait for stream.Read error 547 timeout := time.NewTimer(5 * time.Second) 548 select { 549 case <-donec: 550 if !timeout.Stop() { 551 <-timeout.C 552 } 553 case <-timeout.C: 554 t.Fatalf("Test timed out, expected a status error.") 555 } 556 } 557 558 // Tests that when streamID > MaxStreamId, the current client transport drains. 559 func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { 560 server, ct, cancel := setUp(t, 0, normal) 561 defer cancel() 562 defer server.stop() 563 callHdr := &CallHdr{ 564 Host: "localhost", 565 Method: "foo.Small", 566 } 567 568 originalMaxStreamID := MaxStreamID 569 MaxStreamID = 3 570 defer func() { 571 MaxStreamID = originalMaxStreamID 572 }() 573 574 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 575 defer ctxCancel() 576 577 s, err := ct.NewStream(ctx, callHdr) 578 if err != nil { 579 t.Fatalf("ct.NewStream() = %v", err) 580 } 581 if s.id != 1 { 582 t.Fatalf("Stream id: %d, want: 1", s.id) 583 } 584 585 if got, want := ct.stateForTesting(), reachable; got != want { 586 t.Fatalf("Client transport state %v, want %v", got, want) 587 } 588 589 // The expected stream ID here is 3 since stream IDs are incremented by 2. 590 s, err = ct.NewStream(ctx, callHdr) 591 if err != nil { 592 t.Fatalf("ct.NewStream() = %v", err) 593 } 594 if s.id != 3 { 595 t.Fatalf("Stream id: %d, want: 3", s.id) 596 } 597 598 // Verifying that ct.state is draining when next stream ID > MaxStreamId. 599 if got, want := ct.stateForTesting(), draining; got != want { 600 t.Fatalf("Client transport state %v, want %v", got, want) 601 } 602 } 603 604 func (s) TestClientSendAndReceive(t *testing.T) { 605 server, ct, cancel := setUp(t, 0, normal) 606 defer cancel() 607 callHdr := &CallHdr{ 608 Host: "localhost", 609 Method: "foo.Small", 610 } 611 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 612 defer ctxCancel() 613 s1, err1 := ct.NewStream(ctx, callHdr) 614 if err1 != nil { 615 t.Fatalf("failed to open stream: %v", err1) 616 } 617 if s1.id != 1 { 618 t.Fatalf("wrong stream id: %d", s1.id) 619 } 620 s2, err2 := ct.NewStream(ctx, callHdr) 621 if err2 != nil { 622 t.Fatalf("failed to open stream: %v", err2) 623 } 624 if s2.id != 3 { 625 t.Fatalf("wrong stream id: %d", s2.id) 626 } 627 opts := WriteOptions{Last: true} 628 if err := s1.Write(nil, newBufferSlice(expectedRequest), &opts); err != nil && err != io.EOF { 629 t.Fatalf("failed to send data: %v", err) 630 } 631 p := make([]byte, len(expectedResponse)) 632 _, recvErr := s1.readTo(p) 633 if recvErr != nil || !bytes.Equal(p, expectedResponse) { 634 t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse) 635 } 636 _, recvErr = s1.readTo(p) 637 if recvErr != io.EOF { 638 t.Fatalf("Error: %v; want <EOF>", recvErr) 639 } 640 ct.Close(fmt.Errorf("closed manually by test")) 641 server.stop() 642 } 643 644 func (s) TestClientErrorNotify(t *testing.T) { 645 server, ct, cancel := setUp(t, 0, normal) 646 defer cancel() 647 go server.stop() 648 // ct.reader should detect the error and activate ct.Error(). 649 <-ct.Error() 650 ct.Close(fmt.Errorf("closed manually by test")) 651 } 652 653 func performOneRPC(ct ClientTransport) { 654 callHdr := &CallHdr{ 655 Host: "localhost", 656 Method: "foo.Small", 657 } 658 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 659 defer cancel() 660 s, err := ct.NewStream(ctx, callHdr) 661 if err != nil { 662 return 663 } 664 opts := WriteOptions{Last: true} 665 if err := s.Write([]byte{}, newBufferSlice(expectedRequest), &opts); err == nil || err == io.EOF { 666 time.Sleep(5 * time.Millisecond) 667 // The following s.Recv()'s could error out because the 668 // underlying transport is gone. 669 // 670 // Read response 671 p := make([]byte, len(expectedResponse)) 672 s.readTo(p) 673 // Read io.EOF 674 s.readTo(p) 675 } 676 } 677 678 func (s) TestClientMix(t *testing.T) { 679 s, ct, cancel := setUp(t, 0, normal) 680 defer cancel() 681 time.AfterFunc(time.Second, s.stop) 682 go func(ct ClientTransport) { 683 <-ct.Error() 684 ct.Close(fmt.Errorf("closed manually by test")) 685 }(ct) 686 for i := 0; i < 750; i++ { 687 time.Sleep(2 * time.Millisecond) 688 go performOneRPC(ct) 689 } 690 } 691 692 func (s) TestLargeMessage(t *testing.T) { 693 server, ct, cancel := setUp(t, 0, normal) 694 defer cancel() 695 callHdr := &CallHdr{ 696 Host: "localhost", 697 Method: "foo.Large", 698 } 699 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 700 defer ctxCancel() 701 var wg sync.WaitGroup 702 for i := 0; i < 2; i++ { 703 wg.Add(1) 704 go func() { 705 defer wg.Done() 706 s, err := ct.NewStream(ctx, callHdr) 707 if err != nil { 708 t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) 709 } 710 if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{Last: true}); err != nil && err != io.EOF { 711 t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) 712 } 713 p := make([]byte, len(expectedResponseLarge)) 714 if _, err := s.readTo(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { 715 t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse) 716 } 717 if _, err = s.readTo(p); err != io.EOF { 718 t.Errorf("Failed to complete the stream %v; want <EOF>", err) 719 } 720 }() 721 } 722 wg.Wait() 723 ct.Close(fmt.Errorf("closed manually by test")) 724 server.stop() 725 } 726 727 func (s) TestLargeMessageWithDelayRead(t *testing.T) { 728 // Disable dynamic flow control. 729 sc := &ServerConfig{ 730 InitialWindowSize: defaultWindowSize, 731 InitialConnWindowSize: defaultWindowSize, 732 } 733 co := ConnectOptions{ 734 InitialWindowSize: defaultWindowSize, 735 InitialConnWindowSize: defaultWindowSize, 736 } 737 server, ct, cancel := setUpWithOptions(t, 0, sc, delayRead, co) 738 defer cancel() 739 defer server.stop() 740 defer ct.Close(fmt.Errorf("closed manually by test")) 741 server.mu.Lock() 742 ready := server.ready 743 server.mu.Unlock() 744 callHdr := &CallHdr{ 745 Host: "localhost", 746 Method: "foo.Large", 747 } 748 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 749 defer cancel() 750 s, err := ct.NewStream(ctx, callHdr) 751 if err != nil { 752 t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) 753 return 754 } 755 // Wait for server's handler to be initialized 756 select { 757 case <-ready: 758 case <-ctx.Done(): 759 t.Fatalf("Client timed out waiting for server handler to be initialized.") 760 } 761 server.mu.Lock() 762 serviceHandler := server.h 763 server.mu.Unlock() 764 var ( 765 mu sync.Mutex 766 total int 767 ) 768 s.wq.replenish = func(n int) { 769 mu.Lock() 770 total += n 771 mu.Unlock() 772 s.wq.realReplenish(n) 773 } 774 getTotal := func() int { 775 mu.Lock() 776 defer mu.Unlock() 777 return total 778 } 779 done := make(chan struct{}) 780 defer close(done) 781 go func() { 782 for { 783 select { 784 // Prevent goroutine from leaking in case of error. 785 case <-done: 786 return 787 default: 788 } 789 if getTotal() == defaultWindowSize { 790 // unblock server to be able to read and 791 // thereby send stream level window update. 792 close(serviceHandler.getNotified) 793 return 794 } 795 runtime.Gosched() 796 } 797 }() 798 // This write will cause client to run out of stream level, 799 // flow control and the other side won't send a window update 800 // until that happens. 801 if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{}); err != nil { 802 t.Fatalf("write(_, _, _) = %v, want <nil>", err) 803 } 804 p := make([]byte, len(expectedResponseLarge)) 805 806 // Wait for the other side to run out of stream level flow control before 807 // reading and thereby sending a window update. 808 select { 809 case <-serviceHandler.notify: 810 case <-ctx.Done(): 811 t.Fatalf("Client timed out") 812 } 813 if _, err := s.readTo(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { 814 t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err) 815 } 816 if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{Last: true}); err != nil { 817 t.Fatalf("Write(_, _, _) = %v, want <nil>", err) 818 } 819 if _, err = s.readTo(p); err != io.EOF { 820 t.Fatalf("Failed to complete the stream %v; want <EOF>", err) 821 } 822 } 823 824 // TestGracefulClose ensures that GracefulClose allows in-flight streams to 825 // proceed until they complete naturally, while not allowing creation of new 826 // streams during this window. 827 func (s) TestGracefulClose(t *testing.T) { 828 leakcheck.SetTrackingBufferPool(t) 829 server, ct, cancel := setUp(t, 0, pingpong) 830 defer cancel() 831 defer func() { 832 // Stop the server's listener to make the server's goroutines terminate 833 // (after the last active stream is done). 834 server.lis.Close() 835 // Check for goroutine leaks (i.e. GracefulClose with an active stream 836 // doesn't eventually close the connection when that stream completes). 837 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 838 defer cancel() 839 leakcheck.CheckGoroutines(ctx, t) 840 leakcheck.CheckTrackingBufferPool() 841 // Correctly clean up the server 842 server.stop() 843 }() 844 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 845 defer cancel() 846 847 // Create a stream that will exist for this whole test and confirm basic 848 // functionality. 849 s, err := ct.NewStream(ctx, &CallHdr{}) 850 if err != nil { 851 t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err) 852 } 853 msg := make([]byte, 1024) 854 outgoingHeader := make([]byte, 5) 855 outgoingHeader[0] = byte(0) 856 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg))) 857 incomingHeader := make([]byte, 5) 858 if err := s.Write(outgoingHeader, newBufferSlice(msg), &WriteOptions{}); err != nil { 859 t.Fatalf("Error while writing: %v", err) 860 } 861 if _, err := s.readTo(incomingHeader); err != nil { 862 t.Fatalf("Error while reading: %v", err) 863 } 864 sz := binary.BigEndian.Uint32(incomingHeader[1:]) 865 recvMsg := make([]byte, int(sz)) 866 if _, err := s.readTo(recvMsg); err != nil { 867 t.Fatalf("Error while reading: %v", err) 868 } 869 870 // Gracefully close the transport, which should not affect the existing 871 // stream. 872 ct.GracefulClose() 873 874 var wg sync.WaitGroup 875 // Expect errors creating new streams because the client transport has been 876 // gracefully closed. 877 for i := 0; i < 200; i++ { 878 wg.Add(1) 879 go func() { 880 defer wg.Done() 881 _, err := ct.NewStream(ctx, &CallHdr{}) 882 if err != nil && err.(*NewStreamError).Err == ErrConnClosing && err.(*NewStreamError).AllowTransparentRetry { 883 return 884 } 885 t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) 886 }() 887 } 888 889 // Confirm the existing stream still functions as expected. 890 s.Write(nil, nil, &WriteOptions{Last: true}) 891 if _, err := s.readTo(incomingHeader); err != io.EOF { 892 t.Fatalf("Client expected EOF from the server. Got: %v", err) 893 } 894 wg.Wait() 895 } 896 897 func (s) TestLargeMessageSuspension(t *testing.T) { 898 server, ct, cancel := setUp(t, 0, suspended) 899 defer cancel() 900 defer ct.Close(fmt.Errorf("closed manually by test")) 901 defer server.stop() 902 callHdr := &CallHdr{ 903 Host: "localhost", 904 Method: "foo.Large", 905 } 906 // Set a long enough timeout for writing a large message out. 907 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 908 defer cancel() 909 s, err := ct.NewStream(ctx, callHdr) 910 if err != nil { 911 t.Fatalf("failed to open stream: %v", err) 912 } 913 // Write should not be done successfully due to flow control. 914 msg := make([]byte, initialWindowSize*8) 915 s.Write(nil, newBufferSlice(msg), &WriteOptions{}) 916 err = s.Write(nil, newBufferSlice(msg), &WriteOptions{Last: true}) 917 if err != errStreamDone { 918 t.Fatalf("Write got %v, want io.EOF", err) 919 } 920 // The server will send an RST stream frame on observing the deadline 921 // expiration making the client stream fail with a DeadlineExceeded status. 922 _, err = s.readTo(make([]byte, 8)) 923 if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded { 924 t.Fatalf("Read got unexpected error: %v, want status with code %v", err, codes.DeadlineExceeded) 925 } 926 if got, want := s.Status().Code(), codes.DeadlineExceeded; got != want { 927 t.Fatalf("Read got status %v with code %v, want %v", s.Status(), got, want) 928 } 929 } 930 931 func (s) TestMaxStreams(t *testing.T) { 932 serverConfig := &ServerConfig{ 933 MaxStreams: 1, 934 } 935 server, ct, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 936 defer cancel() 937 defer ct.Close(fmt.Errorf("closed manually by test")) 938 defer server.stop() 939 callHdr := &CallHdr{ 940 Host: "localhost", 941 Method: "foo.Large", 942 } 943 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 944 defer cancel() 945 s, err := ct.NewStream(ctx, callHdr) 946 if err != nil { 947 t.Fatalf("Failed to open stream: %v", err) 948 } 949 // Keep creating streams until one fails with deadline exceeded, marking the application 950 // of server settings on client. 951 slist := []*ClientStream{} 952 pctx, cancel := context.WithCancel(context.Background()) 953 defer cancel() 954 timer := time.NewTimer(time.Second * 10) 955 expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) 956 for { 957 select { 958 case <-timer.C: 959 t.Fatalf("Test timeout: client didn't receive server settings.") 960 default: 961 } 962 ctx, cancel := context.WithDeadline(pctx, time.Now().Add(time.Second)) 963 // This is only to get rid of govet. All these context are based on a base 964 // context which is canceled at the end of the test. 965 defer cancel() 966 if str, err := ct.NewStream(ctx, callHdr); err == nil { 967 slist = append(slist, str) 968 continue 969 } else if err.Error() != expectedErr.Error() { 970 t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr) 971 } 972 timer.Stop() 973 break 974 } 975 done := make(chan struct{}) 976 // Try and create a new stream. 977 go func() { 978 defer close(done) 979 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 980 defer cancel() 981 if _, err := ct.NewStream(ctx, callHdr); err != nil { 982 t.Errorf("Failed to open stream: %v", err) 983 } 984 }() 985 // Close all the extra streams created and make sure the new stream is not created. 986 for _, str := range slist { 987 str.Close(nil) 988 } 989 select { 990 case <-done: 991 t.Fatalf("Test failed: didn't expect new stream to be created just yet.") 992 default: 993 } 994 // Close the first stream created so that the new stream can finally be created. 995 s.Close(nil) 996 <-done 997 ct.Close(fmt.Errorf("closed manually by test")) 998 <-ct.writerDone 999 if ct.maxConcurrentStreams != 1 { 1000 t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams) 1001 } 1002 } 1003 1004 func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { 1005 server, ct, cancel := setUp(t, 0, suspended) 1006 defer cancel() 1007 callHdr := &CallHdr{ 1008 Host: "localhost", 1009 Method: "foo", 1010 } 1011 var sc *http2Server 1012 // Wait until the server transport is setup. 1013 for { 1014 server.mu.Lock() 1015 if len(server.conns) == 0 { 1016 server.mu.Unlock() 1017 time.Sleep(time.Millisecond) 1018 continue 1019 } 1020 for k := range server.conns { 1021 var ok bool 1022 sc, ok = k.(*http2Server) 1023 if !ok { 1024 t.Fatalf("Failed to convert %v to *http2Server", k) 1025 } 1026 } 1027 server.mu.Unlock() 1028 break 1029 } 1030 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1031 defer cancel() 1032 s, err := ct.NewStream(ctx, callHdr) 1033 if err != nil { 1034 t.Fatalf("Failed to open stream: %v", err) 1035 } 1036 d := newBufferSlice(make([]byte, http2MaxFrameLen)) 1037 ct.controlBuf.put(&dataFrame{ 1038 streamID: s.id, 1039 endStream: false, 1040 h: nil, 1041 reader: d.Reader(), 1042 onEachWrite: func() {}, 1043 }) 1044 // Loop until the server side stream is created. 1045 var ss *ServerStream 1046 for { 1047 time.Sleep(time.Second) 1048 sc.mu.Lock() 1049 if len(sc.activeStreams) == 0 { 1050 sc.mu.Unlock() 1051 continue 1052 } 1053 ss = sc.activeStreams[s.id] 1054 sc.mu.Unlock() 1055 break 1056 } 1057 ct.Close(fmt.Errorf("closed manually by test")) 1058 select { 1059 case <-ss.Context().Done(): 1060 if ss.Context().Err() != context.Canceled { 1061 t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled) 1062 } 1063 case <-time.After(5 * time.Second): 1064 t.Fatalf("Failed to cancel the context of the sever side stream.") 1065 } 1066 server.stop() 1067 } 1068 1069 func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) { 1070 connectOptions := ConnectOptions{ 1071 InitialWindowSize: defaultWindowSize, 1072 InitialConnWindowSize: defaultWindowSize, 1073 } 1074 server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) 1075 defer cancel() 1076 defer server.stop() 1077 defer client.Close(fmt.Errorf("closed manually by test")) 1078 1079 waitWhileTrue(t, func() (bool, error) { 1080 server.mu.Lock() 1081 defer server.mu.Unlock() 1082 1083 if len(server.conns) == 0 { 1084 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 1085 } 1086 return false, nil 1087 }) 1088 1089 var st *http2Server 1090 server.mu.Lock() 1091 for k := range server.conns { 1092 st = k.(*http2Server) 1093 } 1094 notifyChan := make(chan struct{}) 1095 server.h.notify = notifyChan 1096 server.mu.Unlock() 1097 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1098 defer cancel() 1099 cstream1, err := client.NewStream(ctx, &CallHdr{}) 1100 if err != nil { 1101 t.Fatalf("Client failed to create first stream. Err: %v", err) 1102 } 1103 1104 <-notifyChan 1105 var sstream1 *ServerStream 1106 // Access stream on the server. 1107 st.mu.Lock() 1108 for _, v := range st.activeStreams { 1109 if v.id == cstream1.id { 1110 sstream1 = v 1111 } 1112 } 1113 st.mu.Unlock() 1114 if sstream1 == nil { 1115 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) 1116 } 1117 // Exhaust client's connection window. 1118 if err := sstream1.Write([]byte{}, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil { 1119 t.Fatalf("Server failed to write data. Err: %v", err) 1120 } 1121 notifyChan = make(chan struct{}) 1122 server.mu.Lock() 1123 server.h.notify = notifyChan 1124 server.mu.Unlock() 1125 // Create another stream on client. 1126 cstream2, err := client.NewStream(ctx, &CallHdr{}) 1127 if err != nil { 1128 t.Fatalf("Client failed to create second stream. Err: %v", err) 1129 } 1130 <-notifyChan 1131 var sstream2 *ServerStream 1132 st.mu.Lock() 1133 for _, v := range st.activeStreams { 1134 if v.id == cstream2.id { 1135 sstream2 = v 1136 } 1137 } 1138 st.mu.Unlock() 1139 if sstream2 == nil { 1140 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) 1141 } 1142 // Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream. 1143 if err := sstream2.Write([]byte{}, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil { 1144 t.Fatalf("Server failed to write data. Err: %v", err) 1145 } 1146 1147 // Client should be able to read data on second stream. 1148 if _, err := cstream2.readTo(make([]byte, defaultWindowSize)); err != nil { 1149 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) 1150 } 1151 1152 // Client should be able to read data on first stream. 1153 if _, err := cstream1.readTo(make([]byte, defaultWindowSize)); err != nil { 1154 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) 1155 } 1156 } 1157 1158 func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) { 1159 serverConfig := &ServerConfig{ 1160 InitialWindowSize: defaultWindowSize, 1161 InitialConnWindowSize: defaultWindowSize, 1162 } 1163 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 1164 defer cancel() 1165 defer server.stop() 1166 defer client.Close(fmt.Errorf("closed manually by test")) 1167 waitWhileTrue(t, func() (bool, error) { 1168 server.mu.Lock() 1169 defer server.mu.Unlock() 1170 1171 if len(server.conns) == 0 { 1172 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 1173 } 1174 return false, nil 1175 }) 1176 var st *http2Server 1177 server.mu.Lock() 1178 for k := range server.conns { 1179 st = k.(*http2Server) 1180 } 1181 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1182 defer cancel() 1183 server.mu.Unlock() 1184 cstream1, err := client.NewStream(ctx, &CallHdr{}) 1185 if err != nil { 1186 t.Fatalf("Failed to create 1st stream. Err: %v", err) 1187 } 1188 // Exhaust server's connection window. 1189 if err := cstream1.Write(nil, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{Last: true}); err != nil { 1190 t.Fatalf("Client failed to write data. Err: %v", err) 1191 } 1192 // Client should be able to create another stream and send data on it. 1193 cstream2, err := client.NewStream(ctx, &CallHdr{}) 1194 if err != nil { 1195 t.Fatalf("Failed to create 2nd stream. Err: %v", err) 1196 } 1197 if err := cstream2.Write(nil, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil { 1198 t.Fatalf("Client failed to write data. Err: %v", err) 1199 } 1200 // Get the streams on server. 1201 waitWhileTrue(t, func() (bool, error) { 1202 st.mu.Lock() 1203 defer st.mu.Unlock() 1204 1205 if len(st.activeStreams) != 2 { 1206 return true, fmt.Errorf("timed-out while waiting for server to have created the streams") 1207 } 1208 return false, nil 1209 }) 1210 var sstream1 *ServerStream 1211 st.mu.Lock() 1212 for _, v := range st.activeStreams { 1213 if v.id == 1 { 1214 sstream1 = v 1215 } 1216 } 1217 st.mu.Unlock() 1218 // Reading from the stream on server should succeed. 1219 if _, err := sstream1.readTo(make([]byte, defaultWindowSize)); err != nil { 1220 t.Fatalf("_.Read(_) = %v, want <nil>", err) 1221 } 1222 1223 if _, err := sstream1.readTo(make([]byte, 1)); err != io.EOF { 1224 t.Fatalf("_.Read(_) = %v, want io.EOF", err) 1225 } 1226 1227 } 1228 1229 func (s) TestServerWithMisbehavedClient(t *testing.T) { 1230 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1231 defer server.stop() 1232 // Create a client that can override server stream quota. 1233 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1234 if err != nil { 1235 t.Fatalf("Clent failed to dial:%v", err) 1236 } 1237 defer mconn.Close() 1238 if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil { 1239 t.Fatalf("Failed to set write deadline: %v", err) 1240 } 1241 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1242 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1243 } 1244 // success chan indicates that reader received a RSTStream from server. 1245 success := make(chan struct{}) 1246 var mu sync.Mutex 1247 framer := http2.NewFramer(mconn, mconn) 1248 if err := framer.WriteSettings(); err != nil { 1249 t.Fatalf("Error while writing settings: %v", err) 1250 } 1251 go func() { // Launch a reader for this misbehaving client. 1252 for { 1253 frame, err := framer.ReadFrame() 1254 if err != nil { 1255 return 1256 } 1257 switch frame := frame.(type) { 1258 case *http2.PingFrame: 1259 // Write ping ack back so that server's BDP estimation works right. 1260 mu.Lock() 1261 framer.WritePing(true, frame.Data) 1262 mu.Unlock() 1263 case *http2.RSTStreamFrame: 1264 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl { 1265 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) 1266 } 1267 close(success) 1268 return 1269 default: 1270 // Do nothing. 1271 } 1272 1273 } 1274 }() 1275 // Create a stream. 1276 var buf bytes.Buffer 1277 henc := hpack.NewEncoder(&buf) 1278 // TODO(mmukhi): Remove unnecessary fields. 1279 if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil { 1280 t.Fatalf("Error while encoding header: %v", err) 1281 } 1282 if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { 1283 t.Fatalf("Error while encoding header: %v", err) 1284 } 1285 if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { 1286 t.Fatalf("Error while encoding header: %v", err) 1287 } 1288 if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { 1289 t.Fatalf("Error while encoding header: %v", err) 1290 } 1291 mu.Lock() 1292 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1293 mu.Unlock() 1294 t.Fatalf("Error while writing headers: %v", err) 1295 } 1296 mu.Unlock() 1297 1298 // Test server behavior for violation of stream flow control window size restriction. 1299 timer := time.NewTimer(time.Second * 5) 1300 dbuf := make([]byte, http2MaxFrameLen) 1301 for { 1302 select { 1303 case <-timer.C: 1304 t.Fatalf("Test timed out.") 1305 case <-success: 1306 return 1307 default: 1308 } 1309 mu.Lock() 1310 if err := framer.WriteData(1, false, dbuf); err != nil { 1311 mu.Unlock() 1312 // Error here means the server could have closed the connection due to flow control 1313 // violation. Make sure that is the case by waiting for success chan to be closed. 1314 select { 1315 case <-timer.C: 1316 t.Fatalf("Error while writing data: %v", err) 1317 case <-success: 1318 return 1319 } 1320 } 1321 mu.Unlock() 1322 // This for loop is capable of hogging the CPU and cause starvation 1323 // in Go versions prior to 1.9, 1324 // in single CPU environment. Explicitly relinquish processor. 1325 runtime.Gosched() 1326 } 1327 } 1328 1329 func (s) TestClientHonorsConnectContext(t *testing.T) { 1330 // Create a server that will not send a preface. 1331 lis, err := net.Listen("tcp", "localhost:0") 1332 if err != nil { 1333 t.Fatalf("Error while listening: %v", err) 1334 } 1335 defer lis.Close() 1336 go func() { // Launch the misbehaving server. 1337 sconn, err := lis.Accept() 1338 if err != nil { 1339 t.Errorf("Error while accepting: %v", err) 1340 return 1341 } 1342 defer sconn.Close() 1343 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { 1344 t.Errorf("Error while reading client preface: %v", err) 1345 return 1346 } 1347 sfr := http2.NewFramer(sconn, sconn) 1348 // Do not write a settings frame, but read from the conn forever. 1349 for { 1350 if _, err := sfr.ReadFrame(); err != nil { 1351 return 1352 } 1353 } 1354 }() 1355 1356 // Test context cancellation. 1357 timeBefore := time.Now() 1358 connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1359 time.AfterFunc(100*time.Millisecond, cancel) 1360 1361 parent := channelzSubChannel(t) 1362 copts := ConnectOptions{ChannelzParent: parent} 1363 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1364 defer cancel() 1365 _, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) 1366 if err == nil { 1367 t.Fatalf("NewHTTP2Client() returned successfully; wanted error") 1368 } 1369 t.Logf("NewHTTP2Client() = _, %v", err) 1370 if time.Since(timeBefore) > 3*time.Second { 1371 t.Fatalf("NewHTTP2Client returned > 2.9s after context cancellation") 1372 } 1373 1374 // Test context deadline. 1375 connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) 1376 defer cancel() 1377 _, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) 1378 if err == nil { 1379 t.Fatalf("NewHTTP2Client() returned successfully; wanted error") 1380 } 1381 t.Logf("NewHTTP2Client() = _, %v", err) 1382 } 1383 1384 func (s) TestClientWithMisbehavedServer(t *testing.T) { 1385 // Create a misbehaving server. 1386 lis, err := net.Listen("tcp", "localhost:0") 1387 if err != nil { 1388 t.Fatalf("Error while listening: %v", err) 1389 } 1390 defer lis.Close() 1391 // success chan indicates that the server received 1392 // RSTStream from the client. 1393 success := make(chan struct{}) 1394 go func() { // Launch the misbehaving server. 1395 sconn, err := lis.Accept() 1396 if err != nil { 1397 t.Errorf("Error while accepting: %v", err) 1398 return 1399 } 1400 defer sconn.Close() 1401 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { 1402 t.Errorf("Error while reading client preface: %v", err) 1403 return 1404 } 1405 sfr := http2.NewFramer(sconn, sconn) 1406 if err := sfr.WriteSettings(); err != nil { 1407 t.Errorf("Error while writing settings: %v", err) 1408 return 1409 } 1410 if err := sfr.WriteSettingsAck(); err != nil { 1411 t.Errorf("Error while writing settings: %v", err) 1412 return 1413 } 1414 var mu sync.Mutex 1415 for { 1416 frame, err := sfr.ReadFrame() 1417 if err != nil { 1418 return 1419 } 1420 switch frame := frame.(type) { 1421 case *http2.HeadersFrame: 1422 // When the client creates a stream, violate the stream flow control. 1423 go func() { 1424 buf := make([]byte, http2MaxFrameLen) 1425 for { 1426 mu.Lock() 1427 if err := sfr.WriteData(1, false, buf); err != nil { 1428 mu.Unlock() 1429 return 1430 } 1431 mu.Unlock() 1432 // This for loop is capable of hogging the CPU and cause starvation 1433 // in Go versions prior to 1.9, 1434 // in single CPU environment. Explicitly relinquish processor. 1435 runtime.Gosched() 1436 } 1437 }() 1438 case *http2.RSTStreamFrame: 1439 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl { 1440 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) 1441 } 1442 close(success) 1443 return 1444 case *http2.PingFrame: 1445 mu.Lock() 1446 sfr.WritePing(true, frame.Data) 1447 mu.Unlock() 1448 default: 1449 } 1450 } 1451 }() 1452 connectCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 1453 defer cancel() 1454 1455 parent := channelzSubChannel(t) 1456 copts := ConnectOptions{ChannelzParent: parent} 1457 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1458 defer cancel() 1459 ct, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) 1460 if err != nil { 1461 t.Fatalf("Error while creating client transport: %v", err) 1462 } 1463 defer ct.Close(fmt.Errorf("closed manually by test")) 1464 1465 str, err := ct.NewStream(connectCtx, &CallHdr{}) 1466 if err != nil { 1467 t.Fatalf("Error while creating stream: %v", err) 1468 } 1469 timer := time.NewTimer(time.Second * 5) 1470 go func() { // This go routine mimics the one in stream.go to call CloseStream. 1471 <-str.Done() 1472 str.Close(nil) 1473 }() 1474 select { 1475 case <-timer.C: 1476 t.Fatalf("Test timed-out.") 1477 case <-success: 1478 } 1479 // Drain the remaining buffers in the stream by reading until an error is 1480 // encountered. 1481 str.Read(math.MaxInt) 1482 } 1483 1484 var encodingTestStatus = status.New(codes.Internal, "\n") 1485 1486 func (s) TestEncodingRequiredStatus(t *testing.T) { 1487 server, ct, cancel := setUp(t, 0, encodingRequiredStatus) 1488 defer cancel() 1489 callHdr := &CallHdr{ 1490 Host: "localhost", 1491 Method: "foo", 1492 } 1493 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1494 defer cancel() 1495 s, err := ct.NewStream(ctx, callHdr) 1496 if err != nil { 1497 return 1498 } 1499 opts := WriteOptions{Last: true} 1500 if err := s.Write(nil, newBufferSlice(expectedRequest), &opts); err != nil && err != errStreamDone { 1501 t.Fatalf("Failed to write the request: %v", err) 1502 } 1503 p := make([]byte, http2MaxFrameLen) 1504 if _, err := s.readTo(p); err != io.EOF { 1505 t.Fatalf("Read got error %v, want %v", err, io.EOF) 1506 } 1507 if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) { 1508 t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) 1509 } 1510 ct.Close(fmt.Errorf("closed manually by test")) 1511 server.stop() 1512 // Drain any remaining buffers from the stream since it was closed early. 1513 s.Read(math.MaxInt) 1514 } 1515 1516 func (s) TestInvalidHeaderField(t *testing.T) { 1517 server, ct, cancel := setUp(t, 0, invalidHeaderField) 1518 defer cancel() 1519 callHdr := &CallHdr{ 1520 Host: "localhost", 1521 Method: "foo", 1522 } 1523 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1524 defer cancel() 1525 s, err := ct.NewStream(ctx, callHdr) 1526 if err != nil { 1527 return 1528 } 1529 p := make([]byte, http2MaxFrameLen) 1530 _, err = s.readTo(p) 1531 if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { 1532 t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField) 1533 } 1534 ct.Close(fmt.Errorf("closed manually by test")) 1535 server.stop() 1536 } 1537 1538 func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { 1539 server, ct, cancel := setUp(t, 0, invalidHeaderField) 1540 defer cancel() 1541 defer server.stop() 1542 defer ct.Close(fmt.Errorf("closed manually by test")) 1543 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1544 defer cancel() 1545 s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) 1546 if err != nil { 1547 t.Fatalf("failed to create the stream") 1548 } 1549 timer := time.NewTimer(time.Second) 1550 defer timer.Stop() 1551 select { 1552 case <-s.headerChan: 1553 case <-timer.C: 1554 t.Errorf("s.headerChan: got open, want closed") 1555 } 1556 } 1557 1558 func (s) TestIsReservedHeader(t *testing.T) { 1559 tests := []struct { 1560 h string 1561 want bool 1562 }{ 1563 {"", false}, // but should be rejected earlier 1564 {"foo", false}, 1565 {"content-type", true}, 1566 {"user-agent", true}, 1567 {":anything", true}, 1568 {"grpc-message-type", true}, 1569 {"grpc-encoding", true}, 1570 {"grpc-message", true}, 1571 {"grpc-status", true}, 1572 {"grpc-timeout", true}, 1573 {"te", true}, 1574 } 1575 for _, tt := range tests { 1576 got := isReservedHeader(tt.h) 1577 if got != tt.want { 1578 t.Errorf("isReservedHeader(%q) = %v; want %v", tt.h, got, tt.want) 1579 } 1580 } 1581 } 1582 1583 func (s) TestContextErr(t *testing.T) { 1584 for _, test := range []struct { 1585 // input 1586 errIn error 1587 // outputs 1588 errOut error 1589 }{ 1590 {context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, 1591 {context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())}, 1592 } { 1593 err := ContextErr(test.errIn) 1594 if err.Error() != test.errOut.Error() { 1595 t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) 1596 } 1597 } 1598 } 1599 1600 type windowSizeConfig struct { 1601 serverStream int32 1602 serverConn int32 1603 clientStream int32 1604 clientConn int32 1605 } 1606 1607 func (s) TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) { 1608 wc := windowSizeConfig{ 1609 serverStream: 10 * 1024 * 1024, 1610 serverConn: 12 * 1024 * 1024, 1611 clientStream: 6 * 1024 * 1024, 1612 clientConn: 8 * 1024 * 1024, 1613 } 1614 testFlowControlAccountCheck(t, 1024*1024, wc) 1615 } 1616 1617 func (s) TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) { 1618 // These settings disable dynamic window sizes based on BDP estimation; 1619 // must be at least defaultWindowSize or the setting is ignored. 1620 wc := windowSizeConfig{ 1621 serverStream: defaultWindowSize, 1622 serverConn: defaultWindowSize, 1623 clientStream: defaultWindowSize, 1624 clientConn: defaultWindowSize, 1625 } 1626 testFlowControlAccountCheck(t, 1024*1024, wc) 1627 } 1628 1629 func (s) TestAccountCheckDynamicWindowSmallMessage(t *testing.T) { 1630 testFlowControlAccountCheck(t, 1024, windowSizeConfig{}) 1631 } 1632 1633 func (s) TestAccountCheckDynamicWindowLargeMessage(t *testing.T) { 1634 testFlowControlAccountCheck(t, 1024*1024, windowSizeConfig{}) 1635 } 1636 1637 func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) { 1638 sc := &ServerConfig{ 1639 InitialWindowSize: wc.serverStream, 1640 InitialConnWindowSize: wc.serverConn, 1641 } 1642 co := ConnectOptions{ 1643 InitialWindowSize: wc.clientStream, 1644 InitialConnWindowSize: wc.clientConn, 1645 } 1646 server, client, cancel := setUpWithOptions(t, 0, sc, pingpong, co) 1647 defer cancel() 1648 defer server.stop() 1649 defer client.Close(fmt.Errorf("closed manually by test")) 1650 waitWhileTrue(t, func() (bool, error) { 1651 server.mu.Lock() 1652 defer server.mu.Unlock() 1653 if len(server.conns) == 0 { 1654 return true, fmt.Errorf("timed out while waiting for server transport to be created") 1655 } 1656 return false, nil 1657 }) 1658 var st *http2Server 1659 server.mu.Lock() 1660 for k := range server.conns { 1661 st = k.(*http2Server) 1662 } 1663 server.mu.Unlock() 1664 1665 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1666 defer cancel() 1667 const numStreams = 5 1668 clientStreams := make([]*ClientStream, numStreams) 1669 for i := 0; i < numStreams; i++ { 1670 var err error 1671 clientStreams[i], err = client.NewStream(ctx, &CallHdr{}) 1672 if err != nil { 1673 t.Fatalf("Failed to create stream. Err: %v", err) 1674 } 1675 } 1676 var wg sync.WaitGroup 1677 // For each stream send pingpong messages to the server. 1678 for _, stream := range clientStreams { 1679 wg.Add(1) 1680 go func(stream *ClientStream) { 1681 defer wg.Done() 1682 buf := make([]byte, msgSize+5) 1683 buf[0] = byte(0) 1684 binary.BigEndian.PutUint32(buf[1:], uint32(msgSize)) 1685 opts := WriteOptions{} 1686 header := make([]byte, 5) 1687 for i := 1; i <= 5; i++ { 1688 if err := stream.Write(nil, newBufferSlice(buf), &opts); err != nil { 1689 t.Errorf("Error on client while writing message %v on stream %v: %v", i, stream.id, err) 1690 return 1691 } 1692 if _, err := stream.readTo(header); err != nil { 1693 t.Errorf("Error on client while reading data frame header %v on stream %v: %v", i, stream.id, err) 1694 return 1695 } 1696 sz := binary.BigEndian.Uint32(header[1:]) 1697 recvMsg := make([]byte, int(sz)) 1698 if _, err := stream.readTo(recvMsg); err != nil { 1699 t.Errorf("Error on client while reading data %v on stream %v: %v", i, stream.id, err) 1700 return 1701 } 1702 if len(recvMsg) != msgSize { 1703 t.Errorf("Length of message %v received by client on stream %v: %v, want: %v", i, stream.id, len(recvMsg), msgSize) 1704 return 1705 } 1706 } 1707 t.Logf("stream %v done with pingpongs", stream.id) 1708 }(stream) 1709 } 1710 wg.Wait() 1711 serverStreams := map[uint32]*ServerStream{} 1712 loopyClientStreams := map[uint32]*outStream{} 1713 loopyServerStreams := map[uint32]*outStream{} 1714 // Get all the streams from server reader and writer and client writer. 1715 st.mu.Lock() 1716 client.mu.Lock() 1717 for _, stream := range clientStreams { 1718 id := stream.id 1719 serverStreams[id] = st.activeStreams[id] 1720 loopyServerStreams[id] = st.loopy.estdStreams[id] 1721 loopyClientStreams[id] = client.loopy.estdStreams[id] 1722 1723 } 1724 client.mu.Unlock() 1725 st.mu.Unlock() 1726 // Close all streams 1727 for _, stream := range clientStreams { 1728 stream.Write(nil, nil, &WriteOptions{Last: true}) 1729 if _, err := stream.readTo(make([]byte, 5)); err != io.EOF { 1730 t.Fatalf("Client expected an EOF from the server. Got: %v", err) 1731 } 1732 } 1733 // Close down both server and client so that their internals can be read without data 1734 // races. 1735 client.Close(errors.New("closed manually by test")) 1736 st.Close(errors.New("closed manually by test")) 1737 <-st.readerDone 1738 <-st.loopyWriterDone 1739 <-client.readerDone 1740 <-client.writerDone 1741 for _, cstream := range clientStreams { 1742 id := cstream.id 1743 sstream := serverStreams[id] 1744 loopyServerStream := loopyServerStreams[id] 1745 loopyClientStream := loopyClientStreams[id] 1746 if loopyServerStream == nil { 1747 t.Fatalf("Unexpected nil loopyServerStream") 1748 } 1749 // Check stream flow control. 1750 if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { 1751 t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding) 1752 } 1753 if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(client.loopy.oiws)-loopyClientStream.bytesOutStanding { 1754 t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, client.loopy.oiws, loopyClientStream.bytesOutStanding) 1755 } 1756 } 1757 // Check transport flow control. 1758 if client.fc.limit != client.fc.unacked+st.loopy.sendQuota { 1759 t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", client.fc.limit, client.fc.unacked, st.loopy.sendQuota) 1760 } 1761 if st.fc.limit != st.fc.unacked+client.loopy.sendQuota { 1762 t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, client.loopy.sendQuota) 1763 } 1764 } 1765 1766 func waitWhileTrue(t *testing.T, condition func() (bool, error)) { 1767 var ( 1768 wait bool 1769 err error 1770 ) 1771 timer := time.NewTimer(time.Second * 5) 1772 for { 1773 wait, err = condition() 1774 if wait { 1775 select { 1776 case <-timer.C: 1777 t.Fatal(err) 1778 default: 1779 time.Sleep(50 * time.Millisecond) 1780 continue 1781 } 1782 } 1783 if !timer.Stop() { 1784 <-timer.C 1785 } 1786 break 1787 } 1788 } 1789 1790 // If any error occurs on a call to Stream.Read, future calls 1791 // should continue to return that same error. 1792 func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { 1793 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1794 defer cancel() 1795 testRecvBuffer := newRecvBuffer() 1796 s := &Stream{ 1797 ctx: ctx, 1798 buf: testRecvBuffer, 1799 requestRead: func(int) {}, 1800 } 1801 s.trReader = &transportReader{ 1802 reader: &recvBufferReader{ 1803 ctx: s.ctx, 1804 ctxDone: s.ctx.Done(), 1805 recv: s.buf, 1806 }, 1807 windowHandler: func(int) {}, 1808 } 1809 testData := make([]byte, 1) 1810 testData[0] = 5 1811 testErr := errors.New("test error") 1812 s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: testErr}) 1813 1814 inBuf := make([]byte, 1) 1815 actualCount, actualErr := s.readTo(inBuf) 1816 if actualCount != 0 { 1817 t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount) 1818 } 1819 if actualErr.Error() != testErr.Error() { 1820 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) 1821 } 1822 1823 s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: nil}) 1824 s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: errors.New("different error from first")}) 1825 1826 for i := 0; i < 2; i++ { 1827 inBuf := make([]byte, 1) 1828 actualCount, actualErr := s.readTo(inBuf) 1829 if actualCount != 0 { 1830 t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount) 1831 } 1832 if actualErr.Error() != testErr.Error() { 1833 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) 1834 } 1835 } 1836 } 1837 1838 // TestHeadersCausingStreamError tests headers that should cause a stream protocol 1839 // error, which would end up with a RST_STREAM being sent to the client and also 1840 // the server closing the stream. 1841 func (s) TestHeadersCausingStreamError(t *testing.T) { 1842 tests := []struct { 1843 name string 1844 headers []struct { 1845 name string 1846 values []string 1847 } 1848 }{ 1849 // "Transports must consider requests containing the Connection header 1850 // as malformed" - A41 Malformed requests map to a stream error of type 1851 // PROTOCOL_ERROR. 1852 { 1853 name: "Connection header present", 1854 headers: []struct { 1855 name string 1856 values []string 1857 }{ 1858 {name: ":method", values: []string{"POST"}}, 1859 {name: ":path", values: []string{"foo"}}, 1860 {name: ":authority", values: []string{"localhost"}}, 1861 {name: "content-type", values: []string{"application/grpc"}}, 1862 {name: "connection", values: []string{"not-supported"}}, 1863 }, 1864 }, 1865 // multiple :authority or multiple Host headers would make the eventual 1866 // :authority ambiguous as per A41. Since these headers won't have a 1867 // content-type that corresponds to a grpc-client, the server should 1868 // simply write a RST_STREAM to the wire. 1869 { 1870 // Note: multiple authority headers are handled by the framer 1871 // itself, which will cause a stream error. Thus, it will never get 1872 // to operateHeaders with the check in operateHeaders for stream 1873 // error, but the server transport will still send a stream error. 1874 name: "Multiple authority headers", 1875 headers: []struct { 1876 name string 1877 values []string 1878 }{ 1879 {name: ":method", values: []string{"POST"}}, 1880 {name: ":path", values: []string{"foo"}}, 1881 {name: ":authority", values: []string{"localhost", "localhost2"}}, 1882 {name: "host", values: []string{"localhost"}}, 1883 }, 1884 }, 1885 } 1886 for _, test := range tests { 1887 t.Run(test.name, func(t *testing.T) { 1888 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1889 defer server.stop() 1890 // Create a client directly to not tie what you can send to API of 1891 // http2_client.go (i.e. control headers being sent). 1892 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1893 if err != nil { 1894 t.Fatalf("Client failed to dial: %v", err) 1895 } 1896 defer mconn.Close() 1897 1898 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1899 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1900 } 1901 1902 framer := http2.NewFramer(mconn, mconn) 1903 if err := framer.WriteSettings(); err != nil { 1904 t.Fatalf("Error while writing settings: %v", err) 1905 } 1906 1907 // result chan indicates that reader received a RSTStream from server. 1908 // An error will be passed on it if any other frame is received. 1909 result := testutils.NewChannel() 1910 1911 // Launch a reader goroutine. 1912 go func() { 1913 for { 1914 frame, err := framer.ReadFrame() 1915 if err != nil { 1916 return 1917 } 1918 switch frame := frame.(type) { 1919 case *http2.SettingsFrame: 1920 // Do nothing. A settings frame is expected from server preface. 1921 case *http2.RSTStreamFrame: 1922 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol { 1923 // Client only created a single stream, so RST Stream should be for that single stream. 1924 result.Send(fmt.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))) 1925 } 1926 // Records that client successfully received RST Stream frame. 1927 result.Send(nil) 1928 return 1929 default: 1930 // The server should send nothing but a single RST Stream frame. 1931 result.Send(errors.New("the client received a frame other than RST Stream")) 1932 } 1933 } 1934 }() 1935 1936 var buf bytes.Buffer 1937 henc := hpack.NewEncoder(&buf) 1938 1939 // Needs to build headers deterministically to conform to gRPC over 1940 // HTTP/2 spec. 1941 for _, header := range test.headers { 1942 for _, value := range header.values { 1943 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { 1944 t.Fatalf("Error while encoding header: %v", err) 1945 } 1946 } 1947 } 1948 1949 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1950 t.Fatalf("Error while writing headers: %v", err) 1951 } 1952 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1953 defer cancel() 1954 r, err := result.Receive(ctx) 1955 if err != nil { 1956 t.Fatalf("Error receiving from channel: %v", err) 1957 } 1958 if r != nil { 1959 t.Fatalf("want nil, got %v", r) 1960 } 1961 }) 1962 } 1963 } 1964 1965 // TestHeadersHTTPStatusGRPCStatus tests requests with certain headers get a 1966 // certain HTTP and gRPC status back. 1967 func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) { 1968 tests := []struct { 1969 name string 1970 headers []struct { 1971 name string 1972 values []string 1973 } 1974 httpStatusWant string 1975 grpcStatusWant string 1976 grpcMessageWant string 1977 }{ 1978 // Note: multiple authority headers are handled by the framer itself, 1979 // which will cause a stream error. Thus, it will never get to 1980 // operateHeaders with the check in operateHeaders for possible 1981 // grpc-status sent back. 1982 1983 // multiple :authority or multiple Host headers would make the eventual 1984 // :authority ambiguous as per A41. This takes precedence even over the 1985 // fact a request is non grpc. All of these requests should be rejected 1986 // with grpc-status Internal. Thus, requests with multiple hosts should 1987 // get rejected with HTTP Status 400 and gRPC status Internal, 1988 // regardless of whether the client is speaking gRPC or not. 1989 { 1990 name: "Multiple host headers non grpc", 1991 headers: []struct { 1992 name string 1993 values []string 1994 }{ 1995 {name: ":method", values: []string{"POST"}}, 1996 {name: ":path", values: []string{"foo"}}, 1997 {name: ":authority", values: []string{"localhost"}}, 1998 {name: "host", values: []string{"localhost", "localhost2"}}, 1999 }, 2000 httpStatusWant: "400", 2001 grpcStatusWant: "13", 2002 grpcMessageWant: "both must only have 1 value as per HTTP/2 spec", 2003 }, 2004 { 2005 name: "Multiple host headers grpc", 2006 headers: []struct { 2007 name string 2008 values []string 2009 }{ 2010 {name: ":method", values: []string{"POST"}}, 2011 {name: ":path", values: []string{"foo"}}, 2012 {name: ":authority", values: []string{"localhost"}}, 2013 {name: "content-type", values: []string{"application/grpc"}}, 2014 {name: "host", values: []string{"localhost", "localhost2"}}, 2015 }, 2016 httpStatusWant: "400", 2017 grpcStatusWant: "13", 2018 grpcMessageWant: "both must only have 1 value as per HTTP/2 spec", 2019 }, 2020 // If the client sends an HTTP/2 request with a :method header with a 2021 // value other than POST, as specified in the gRPC over HTTP/2 2022 // specification, the server should fail the RPC. 2023 { 2024 name: "Client Sending Wrong Method", 2025 headers: []struct { 2026 name string 2027 values []string 2028 }{ 2029 {name: ":method", values: []string{"PUT"}}, 2030 {name: ":path", values: []string{"foo"}}, 2031 {name: ":authority", values: []string{"localhost"}}, 2032 {name: "content-type", values: []string{"application/grpc"}}, 2033 }, 2034 httpStatusWant: "405", 2035 grpcStatusWant: "13", 2036 grpcMessageWant: "which should be POST", 2037 }, 2038 { 2039 name: "Client Sending Wrong Content-Type", 2040 headers: []struct { 2041 name string 2042 values []string 2043 }{ 2044 {name: ":method", values: []string{"POST"}}, 2045 {name: ":path", values: []string{"foo"}}, 2046 {name: ":authority", values: []string{"localhost"}}, 2047 {name: "content-type", values: []string{"application/json"}}, 2048 }, 2049 httpStatusWant: "415", 2050 grpcStatusWant: "3", 2051 grpcMessageWant: `invalid gRPC request content-type "application/json"`, 2052 }, 2053 { 2054 name: "Client Sending Bad Timeout", 2055 headers: []struct { 2056 name string 2057 values []string 2058 }{ 2059 {name: ":method", values: []string{"POST"}}, 2060 {name: ":path", values: []string{"foo"}}, 2061 {name: ":authority", values: []string{"localhost"}}, 2062 {name: "content-type", values: []string{"application/grpc"}}, 2063 {name: "grpc-timeout", values: []string{"18f6n"}}, 2064 }, 2065 httpStatusWant: "400", 2066 grpcStatusWant: "13", 2067 grpcMessageWant: "malformed grpc-timeout", 2068 }, 2069 { 2070 name: "Client Sending Bad Binary Header", 2071 headers: []struct { 2072 name string 2073 values []string 2074 }{ 2075 {name: ":method", values: []string{"POST"}}, 2076 {name: ":path", values: []string{"foo"}}, 2077 {name: ":authority", values: []string{"localhost"}}, 2078 {name: "content-type", values: []string{"application/grpc"}}, 2079 {name: "foobar-bin", values: []string{"X()3e@#$-"}}, 2080 }, 2081 httpStatusWant: "400", 2082 grpcStatusWant: "13", 2083 grpcMessageWant: `header "foobar-bin": illegal base64 data`, 2084 }, 2085 } 2086 for _, test := range tests { 2087 t.Run(test.name, func(t *testing.T) { 2088 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 2089 defer server.stop() 2090 // Create a client directly to not tie what you can send to API of 2091 // http2_client.go (i.e. control headers being sent). 2092 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 2093 if err != nil { 2094 t.Fatalf("Client failed to dial: %v", err) 2095 } 2096 defer mconn.Close() 2097 2098 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 2099 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 2100 } 2101 2102 framer := http2.NewFramer(mconn, mconn) 2103 framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil) 2104 if err := framer.WriteSettings(); err != nil { 2105 t.Fatalf("Error while writing settings: %v", err) 2106 } 2107 2108 // result chan indicates that reader received a Headers Frame with 2109 // desired grpc status and message from server. An error will be passed 2110 // on it if any other frame is received. 2111 result := testutils.NewChannel() 2112 2113 // Launch a reader goroutine. 2114 go func() { 2115 for { 2116 frame, err := framer.ReadFrame() 2117 if err != nil { 2118 return 2119 } 2120 switch frame := frame.(type) { 2121 case *http2.SettingsFrame: 2122 // Do nothing. A settings frame is expected from server preface. 2123 case *http2.MetaHeadersFrame: 2124 var httpStatus, grpcStatus, grpcMessage string 2125 for _, header := range frame.Fields { 2126 if header.Name == ":status" { 2127 httpStatus = header.Value 2128 } 2129 if header.Name == "grpc-status" { 2130 grpcStatus = header.Value 2131 } 2132 if header.Name == "grpc-message" { 2133 grpcMessage = header.Value 2134 } 2135 } 2136 if httpStatus != test.httpStatusWant { 2137 result.Send(fmt.Errorf("incorrect HTTP Status got %v, want %v", httpStatus, test.httpStatusWant)) 2138 return 2139 } 2140 if grpcStatus != test.grpcStatusWant { // grpc status code internal 2141 result.Send(fmt.Errorf("incorrect gRPC Status got %v, want %v", grpcStatus, test.grpcStatusWant)) 2142 return 2143 } 2144 if !strings.Contains(grpcMessage, test.grpcMessageWant) { 2145 result.Send(fmt.Errorf("incorrect gRPC message, want %q got %q", test.grpcMessageWant, grpcMessage)) 2146 return 2147 } 2148 2149 // Records that client successfully received a HeadersFrame 2150 // with expected Trailers-Only response. 2151 result.Send(nil) 2152 return 2153 default: 2154 // The server should send nothing but a single Settings and Headers frame. 2155 result.Send(errors.New("the client received a frame other than Settings or Headers")) 2156 } 2157 } 2158 }() 2159 2160 var buf bytes.Buffer 2161 henc := hpack.NewEncoder(&buf) 2162 2163 // Needs to build headers deterministically to conform to gRPC over 2164 // HTTP/2 spec. 2165 for _, header := range test.headers { 2166 for _, value := range header.values { 2167 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { 2168 t.Fatalf("Error while encoding header: %v", err) 2169 } 2170 } 2171 } 2172 2173 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 2174 t.Fatalf("Error while writing headers: %v", err) 2175 } 2176 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2177 defer cancel() 2178 r, err := result.Receive(ctx) 2179 if err != nil { 2180 t.Fatalf("Error receiving from channel: %v", err) 2181 } 2182 if r != nil { 2183 t.Fatalf("want nil, got %v", r) 2184 } 2185 }) 2186 } 2187 } 2188 2189 func (s) TestWriteHeaderConnectionError(t *testing.T) { 2190 server, client, cancel := setUp(t, 0, notifyCall) 2191 defer cancel() 2192 defer server.stop() 2193 2194 waitWhileTrue(t, func() (bool, error) { 2195 server.mu.Lock() 2196 defer server.mu.Unlock() 2197 2198 if len(server.conns) == 0 { 2199 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 2200 } 2201 return false, nil 2202 }) 2203 2204 server.mu.Lock() 2205 2206 if len(server.conns) != 1 { 2207 t.Fatalf("Server has %d connections from the client, want 1", len(server.conns)) 2208 } 2209 2210 // Get the server transport for the connection to the client. 2211 var serverTransport *http2Server 2212 for k := range server.conns { 2213 serverTransport = k.(*http2Server) 2214 } 2215 notifyChan := make(chan struct{}) 2216 server.h.notify = notifyChan 2217 server.mu.Unlock() 2218 2219 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2220 defer cancel() 2221 cstream, err := client.NewStream(ctx, &CallHdr{}) 2222 if err != nil { 2223 t.Fatalf("Client failed to create first stream. Err: %v", err) 2224 } 2225 2226 <-notifyChan // Wait for server stream to be established. 2227 var sstream *ServerStream 2228 // Access stream on the server. 2229 serverTransport.mu.Lock() 2230 for _, v := range serverTransport.activeStreams { 2231 if v.id == cstream.id { 2232 sstream = v 2233 } 2234 } 2235 serverTransport.mu.Unlock() 2236 if sstream == nil { 2237 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream.id) 2238 } 2239 2240 client.Close(fmt.Errorf("closed manually by test")) 2241 2242 // Wait for server transport to be closed. 2243 <-serverTransport.done 2244 2245 // Write header on a closed server transport. 2246 err = sstream.SendHeader(metadata.MD{}) 2247 st := status.Convert(err) 2248 if st.Code() != codes.Unavailable { 2249 t.Fatalf("WriteHeader() failed with status code %s, want %s", st.Code(), codes.Unavailable) 2250 } 2251 } 2252 2253 func (s) TestPingPong1B(t *testing.T) { 2254 runPingPongTest(t, 1) 2255 } 2256 2257 func TestPingPong1KB(t *testing.T) { 2258 runPingPongTest(t, 1024) 2259 } 2260 2261 func TestPingPong64KB(t *testing.T) { 2262 runPingPongTest(t, 65536) 2263 } 2264 2265 func (s) TestPingPong1MB(t *testing.T) { 2266 runPingPongTest(t, 1048576) 2267 } 2268 2269 // This is a stress-test of flow control logic. 2270 func runPingPongTest(t *testing.T, msgSize int) { 2271 server, client, cancel := setUp(t, 0, pingpong) 2272 defer cancel() 2273 defer server.stop() 2274 defer client.Close(fmt.Errorf("closed manually by test")) 2275 waitWhileTrue(t, func() (bool, error) { 2276 server.mu.Lock() 2277 defer server.mu.Unlock() 2278 if len(server.conns) == 0 { 2279 return true, fmt.Errorf("timed out while waiting for server transport to be created") 2280 } 2281 return false, nil 2282 }) 2283 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2284 defer cancel() 2285 stream, err := client.NewStream(ctx, &CallHdr{}) 2286 if err != nil { 2287 t.Fatalf("Failed to create stream. Err: %v", err) 2288 } 2289 msg := make([]byte, msgSize) 2290 outgoingHeader := make([]byte, 5) 2291 outgoingHeader[0] = byte(0) 2292 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize)) 2293 opts := &WriteOptions{} 2294 incomingHeader := make([]byte, 5) 2295 2296 ctx, cancel = context.WithTimeout(ctx, 10*time.Millisecond) 2297 defer cancel() 2298 for ctx.Err() == nil { 2299 if err := stream.Write(outgoingHeader, newBufferSlice(msg), opts); err != nil { 2300 t.Fatalf("Error on client while writing message. Err: %v", err) 2301 } 2302 if _, err := stream.readTo(incomingHeader); err != nil { 2303 t.Fatalf("Error on client while reading data header. Err: %v", err) 2304 } 2305 sz := binary.BigEndian.Uint32(incomingHeader[1:]) 2306 recvMsg := make([]byte, int(sz)) 2307 if _, err := stream.readTo(recvMsg); err != nil { 2308 t.Fatalf("Error on client while reading data. Err: %v", err) 2309 } 2310 } 2311 2312 stream.Write(nil, nil, &WriteOptions{Last: true}) 2313 if _, err := stream.readTo(incomingHeader); err != io.EOF { 2314 t.Fatalf("Client expected EOF from the server. Got: %v", err) 2315 } 2316 } 2317 2318 type tableSizeLimit struct { 2319 mu sync.Mutex 2320 limits []uint32 2321 } 2322 2323 func (t *tableSizeLimit) add(limit uint32) { 2324 t.mu.Lock() 2325 t.limits = append(t.limits, limit) 2326 t.mu.Unlock() 2327 } 2328 2329 func (t *tableSizeLimit) getLen() int { 2330 t.mu.Lock() 2331 defer t.mu.Unlock() 2332 return len(t.limits) 2333 } 2334 2335 func (t *tableSizeLimit) getIndex(i int) uint32 { 2336 t.mu.Lock() 2337 defer t.mu.Unlock() 2338 return t.limits[i] 2339 } 2340 2341 func (s) TestHeaderTblSize(t *testing.T) { 2342 limits := &tableSizeLimit{} 2343 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { 2344 e.SetMaxDynamicTableSizeLimit(v) 2345 limits.add(v) 2346 } 2347 defer func() { 2348 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { 2349 e.SetMaxDynamicTableSizeLimit(v) 2350 } 2351 }() 2352 2353 server, ct, cancel := setUp(t, 0, normal) 2354 defer cancel() 2355 defer ct.Close(fmt.Errorf("closed manually by test")) 2356 defer server.stop() 2357 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2358 defer ctxCancel() 2359 _, err := ct.NewStream(ctx, &CallHdr{}) 2360 if err != nil { 2361 t.Fatalf("failed to open stream: %v", err) 2362 } 2363 2364 var svrTransport ServerTransport 2365 var i int 2366 for i = 0; i < 1000; i++ { 2367 server.mu.Lock() 2368 if len(server.conns) != 0 { 2369 server.mu.Unlock() 2370 break 2371 } 2372 server.mu.Unlock() 2373 time.Sleep(10 * time.Millisecond) 2374 continue 2375 } 2376 if i == 1000 { 2377 t.Fatalf("unable to create any server transport after 10s") 2378 } 2379 2380 for st := range server.conns { 2381 svrTransport = st 2382 break 2383 } 2384 svrTransport.(*http2Server).controlBuf.put(&outgoingSettings{ 2385 ss: []http2.Setting{ 2386 { 2387 ID: http2.SettingHeaderTableSize, 2388 Val: uint32(100), 2389 }, 2390 }, 2391 }) 2392 2393 for i = 0; i < 1000; i++ { 2394 if limits.getLen() != 1 { 2395 time.Sleep(10 * time.Millisecond) 2396 continue 2397 } 2398 if val := limits.getIndex(0); val != uint32(100) { 2399 t.Fatalf("expected limits[0] = 100, got %d", val) 2400 } 2401 break 2402 } 2403 if i == 1000 { 2404 t.Fatalf("expected len(limits) = 1 within 10s, got != 1") 2405 } 2406 2407 ct.controlBuf.put(&outgoingSettings{ 2408 ss: []http2.Setting{ 2409 { 2410 ID: http2.SettingHeaderTableSize, 2411 Val: uint32(200), 2412 }, 2413 }, 2414 }) 2415 2416 for i := 0; i < 1000; i++ { 2417 if limits.getLen() != 2 { 2418 time.Sleep(10 * time.Millisecond) 2419 continue 2420 } 2421 if val := limits.getIndex(1); val != uint32(200) { 2422 t.Fatalf("expected limits[1] = 200, got %d", val) 2423 } 2424 break 2425 } 2426 if i == 1000 { 2427 t.Fatalf("expected len(limits) = 2 within 10s, got != 2") 2428 } 2429 } 2430 2431 // attrTransportCreds is a transport credential implementation which stores 2432 // Attributes from the ClientHandshakeInfo struct passed in the context locally 2433 // for the test to inspect. 2434 type attrTransportCreds struct { 2435 credentials.TransportCredentials 2436 attr *attributes.Attributes 2437 } 2438 2439 func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 2440 ai := credentials.ClientHandshakeInfoFromContext(ctx) 2441 ac.attr = ai.Attributes 2442 return rawConn, nil, nil 2443 } 2444 func (ac *attrTransportCreds) Info() credentials.ProtocolInfo { 2445 return credentials.ProtocolInfo{} 2446 } 2447 func (ac *attrTransportCreds) Clone() credentials.TransportCredentials { 2448 return nil 2449 } 2450 2451 // TestClientHandshakeInfo adds attributes to the resolver.Address passes to 2452 // NewHTTP2Client and verifies that these attributes are received by the 2453 // transport credential handshaker. 2454 func (s) TestClientHandshakeInfo(t *testing.T) { 2455 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong) 2456 defer server.stop() 2457 2458 const ( 2459 testAttrKey = "foo" 2460 testAttrVal = "bar" 2461 ) 2462 addr := resolver.Address{ 2463 Addr: "localhost:" + server.port, 2464 Attributes: attributes.New(testAttrKey, testAttrVal), 2465 } 2466 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 2467 defer cancel() 2468 creds := &attrTransportCreds{} 2469 2470 copts := ConnectOptions{ 2471 TransportCredentials: creds, 2472 ChannelzParent: channelzSubChannel(t), 2473 } 2474 tr, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {}) 2475 if err != nil { 2476 t.Fatalf("NewHTTP2Client(): %v", err) 2477 } 2478 defer tr.Close(fmt.Errorf("closed manually by test")) 2479 2480 wantAttr := attributes.New(testAttrKey, testAttrVal) 2481 if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 2482 t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr) 2483 } 2484 } 2485 2486 // TestClientHandshakeInfoDialer adds attributes to the resolver.Address passes to 2487 // NewHTTP2Client and verifies that these attributes are received by a custom 2488 // dialer. 2489 func (s) TestClientHandshakeInfoDialer(t *testing.T) { 2490 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong) 2491 defer server.stop() 2492 2493 const ( 2494 testAttrKey = "foo" 2495 testAttrVal = "bar" 2496 ) 2497 addr := resolver.Address{ 2498 Addr: "localhost:" + server.port, 2499 Attributes: attributes.New(testAttrKey, testAttrVal), 2500 } 2501 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 2502 defer cancel() 2503 2504 var attr *attributes.Attributes 2505 dialer := func(ctx context.Context, addr string) (net.Conn, error) { 2506 ai := credentials.ClientHandshakeInfoFromContext(ctx) 2507 attr = ai.Attributes 2508 return (&net.Dialer{}).DialContext(ctx, "tcp", addr) 2509 } 2510 2511 copts := ConnectOptions{ 2512 Dialer: dialer, 2513 ChannelzParent: channelzSubChannel(t), 2514 } 2515 tr, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {}) 2516 if err != nil { 2517 t.Fatalf("NewHTTP2Client(): %v", err) 2518 } 2519 defer tr.Close(fmt.Errorf("closed manually by test")) 2520 2521 wantAttr := attributes.New(testAttrKey, testAttrVal) 2522 if gotAttr := attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 2523 t.Errorf("Received attributes %v in custom dialer, want %v", gotAttr, wantAttr) 2524 } 2525 } 2526 2527 func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { 2528 testStream := func() *ClientStream { 2529 return &ClientStream{ 2530 Stream: &Stream{ 2531 buf: &recvBuffer{ 2532 c: make(chan recvMsg), 2533 mu: sync.Mutex{}, 2534 }, 2535 }, 2536 done: make(chan struct{}), 2537 headerChan: make(chan struct{}), 2538 } 2539 } 2540 2541 testClient := func(ts *ClientStream) *http2Client { 2542 return &http2Client{ 2543 mu: sync.Mutex{}, 2544 activeStreams: map[uint32]*ClientStream{ 2545 0: ts, 2546 }, 2547 controlBuf: newControlBuffer(make(<-chan struct{})), 2548 } 2549 } 2550 2551 for _, test := range []struct { 2552 name string 2553 // input 2554 metaHeaderFrame *http2.MetaHeadersFrame 2555 // output 2556 wantStatus *status.Status 2557 }{ 2558 { 2559 name: "valid header", 2560 metaHeaderFrame: &http2.MetaHeadersFrame{ 2561 Fields: []hpack.HeaderField{ 2562 {Name: "content-type", Value: "application/grpc"}, 2563 {Name: "grpc-status", Value: "0"}, 2564 {Name: ":status", Value: "200"}, 2565 }, 2566 }, 2567 // no error 2568 wantStatus: status.New(codes.OK, ""), 2569 }, 2570 { 2571 name: "missing content-type header", 2572 metaHeaderFrame: &http2.MetaHeadersFrame{ 2573 Fields: []hpack.HeaderField{ 2574 {Name: "grpc-status", Value: "0"}, 2575 {Name: ":status", Value: "200"}, 2576 }, 2577 }, 2578 wantStatus: status.New( 2579 codes.Unknown, 2580 "malformed header: missing HTTP content-type", 2581 ), 2582 }, 2583 { 2584 name: "invalid grpc status header field", 2585 metaHeaderFrame: &http2.MetaHeadersFrame{ 2586 Fields: []hpack.HeaderField{ 2587 {Name: "content-type", Value: "application/grpc"}, 2588 {Name: "grpc-status", Value: "xxxx"}, 2589 {Name: ":status", Value: "200"}, 2590 }, 2591 }, 2592 wantStatus: status.New( 2593 codes.Internal, 2594 "transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", 2595 ), 2596 }, 2597 { 2598 name: "invalid http content type", 2599 metaHeaderFrame: &http2.MetaHeadersFrame{ 2600 Fields: []hpack.HeaderField{ 2601 {Name: "content-type", Value: "application/json"}, 2602 }, 2603 }, 2604 wantStatus: status.New( 2605 codes.Internal, 2606 "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"", 2607 ), 2608 }, 2609 { 2610 name: "http fallback and invalid http status", 2611 metaHeaderFrame: &http2.MetaHeadersFrame{ 2612 Fields: []hpack.HeaderField{ 2613 // No content type provided then fallback into handling http error. 2614 {Name: ":status", Value: "xxxx"}, 2615 }, 2616 }, 2617 wantStatus: status.New( 2618 codes.Internal, 2619 "transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", 2620 ), 2621 }, 2622 { 2623 name: "http2 frame size exceeds", 2624 metaHeaderFrame: &http2.MetaHeadersFrame{ 2625 Fields: nil, 2626 Truncated: true, 2627 }, 2628 wantStatus: status.New( 2629 codes.Internal, 2630 "peer header list size exceeded limit", 2631 ), 2632 }, 2633 { 2634 name: "bad status in grpc mode", 2635 metaHeaderFrame: &http2.MetaHeadersFrame{ 2636 Fields: []hpack.HeaderField{ 2637 {Name: "content-type", Value: "application/grpc"}, 2638 {Name: "grpc-status", Value: "0"}, 2639 {Name: ":status", Value: "504"}, 2640 }, 2641 }, 2642 wantStatus: status.New( 2643 codes.Unavailable, 2644 "unexpected HTTP status code received from server: 504 (Gateway Timeout)", 2645 ), 2646 }, 2647 { 2648 name: "missing http status", 2649 metaHeaderFrame: &http2.MetaHeadersFrame{ 2650 Fields: []hpack.HeaderField{ 2651 {Name: "content-type", Value: "application/grpc"}, 2652 }, 2653 }, 2654 wantStatus: status.New( 2655 codes.Internal, 2656 "malformed header: missing HTTP status", 2657 ), 2658 }, 2659 } { 2660 2661 t.Run(test.name, func(t *testing.T) { 2662 ts := testStream() 2663 s := testClient(ts) 2664 2665 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ 2666 FrameHeader: http2.FrameHeader{ 2667 StreamID: 0, 2668 }, 2669 } 2670 2671 s.operateHeaders(test.metaHeaderFrame) 2672 2673 got := ts.status 2674 want := test.wantStatus 2675 if got.Code() != want.Code() || got.Message() != want.Message() { 2676 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) 2677 } 2678 }) 2679 t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) { 2680 ts := testStream() 2681 s := testClient(ts) 2682 2683 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ 2684 FrameHeader: http2.FrameHeader{ 2685 StreamID: 0, 2686 Flags: http2.FlagHeadersEndStream, 2687 }, 2688 } 2689 2690 s.operateHeaders(test.metaHeaderFrame) 2691 2692 got := ts.status 2693 want := test.wantStatus 2694 if got.Code() != want.Code() || got.Message() != want.Message() { 2695 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) 2696 } 2697 }) 2698 } 2699 } 2700 2701 func TestConnectionError_Unwrap(t *testing.T) { 2702 err := connectionErrorf(false, os.ErrNotExist, "unwrap me") 2703 if !errors.Is(err, os.ErrNotExist) { 2704 t.Error("ConnectionError does not unwrap") 2705 } 2706 } 2707 2708 // Test that in the event of a graceful client transport shutdown, i.e., 2709 // clientTransport.Close(), client sends a goaway to the server with the correct 2710 // error code and debug data. 2711 func (s) TestClientSendsAGoAwayFrame(t *testing.T) { 2712 // Create a server. 2713 lis, err := net.Listen("tcp", "localhost:0") 2714 if err != nil { 2715 t.Fatalf("Error while listening: %v", err) 2716 } 2717 defer lis.Close() 2718 // greetDone is used to notify when server is done greeting the client. 2719 greetDone := make(chan struct{}) 2720 // errorCh verifies that desired GOAWAY not received by server 2721 errorCh := make(chan error) 2722 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2723 defer cancel() 2724 // Launch the server. 2725 go func() { 2726 sconn, err := lis.Accept() 2727 if err != nil { 2728 t.Errorf("Error while accepting: %v", err) 2729 } 2730 defer sconn.Close() 2731 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { 2732 t.Errorf("Error while writing settings ack: %v", err) 2733 return 2734 } 2735 sfr := http2.NewFramer(sconn, sconn) 2736 if err := sfr.WriteSettings(); err != nil { 2737 t.Errorf("Error while writing settings %v", err) 2738 return 2739 } 2740 fr, _ := sfr.ReadFrame() 2741 if _, ok := fr.(*http2.SettingsFrame); !ok { 2742 t.Errorf("Expected settings frame, got %v", fr) 2743 } 2744 fr, _ = sfr.ReadFrame() 2745 if fr, ok := fr.(*http2.SettingsFrame); !ok || !fr.IsAck() { 2746 t.Errorf("Expected settings ACK frame, got %v", fr) 2747 } 2748 fr, _ = sfr.ReadFrame() 2749 if fr, ok := fr.(*http2.HeadersFrame); !ok || !fr.Flags.Has(http2.FlagHeadersEndHeaders) { 2750 t.Errorf("Expected Headers frame with END_HEADERS frame, got %v", fr) 2751 } 2752 close(greetDone) 2753 2754 frame, err := sfr.ReadFrame() 2755 if err != nil { 2756 return 2757 } 2758 switch fr := frame.(type) { 2759 case *http2.GoAwayFrame: 2760 // Records that the server successfully received a GOAWAY frame. 2761 goAwayFrame := fr 2762 if goAwayFrame.ErrCode == http2.ErrCodeNo { 2763 t.Logf("Received goAway frame from client") 2764 close(errorCh) 2765 } else { 2766 errorCh <- fmt.Errorf("received unexpected goAway frame: %v", err) 2767 close(errorCh) 2768 } 2769 return 2770 default: 2771 errorCh <- fmt.Errorf("server received a frame other than GOAWAY: %v", err) 2772 close(errorCh) 2773 return 2774 } 2775 }() 2776 2777 ct, err := NewHTTP2Client(ctx, ctx, resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {}) 2778 if err != nil { 2779 t.Fatalf("Error while creating client transport: %v", err) 2780 } 2781 _, err = ct.NewStream(ctx, &CallHdr{}) 2782 if err != nil { 2783 t.Fatalf("failed to open stream: %v", err) 2784 } 2785 // Wait until server receives the headers and settings frame as part of greet. 2786 <-greetDone 2787 ct.Close(errors.New("manually closed by client")) 2788 t.Logf("Closed the client connection") 2789 select { 2790 case err := <-errorCh: 2791 if err != nil { 2792 t.Errorf("Error receiving the GOAWAY frame: %v", err) 2793 } 2794 case <-ctx.Done(): 2795 t.Errorf("Context timed out") 2796 } 2797 } 2798 2799 // readHangingConn is a wrapper around net.Conn that makes the Read() hang when 2800 // Close() is called. 2801 type readHangingConn struct { 2802 net.Conn 2803 readHangConn chan struct{} // Read() hangs until this channel is closed by Close(). 2804 closed *atomic.Bool // Set to true when Close() is called. 2805 } 2806 2807 func (hc *readHangingConn) Read(b []byte) (n int, err error) { 2808 n, err = hc.Conn.Read(b) 2809 if hc.closed.Load() { 2810 <-hc.readHangConn // hang the read till we want 2811 } 2812 return n, err 2813 } 2814 2815 func (hc *readHangingConn) Close() error { 2816 hc.closed.Store(true) 2817 return hc.Conn.Close() 2818 } 2819 2820 // Tests that closing a client transport does not return until the reader 2821 // goroutine exits. 2822 func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) { 2823 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2824 defer cancel() 2825 2826 server := setUpServerOnly(t, 0, &ServerConfig{}, normal) 2827 defer server.stop() 2828 addr := resolver.Address{Addr: "localhost:" + server.port} 2829 2830 isReaderHanging := &atomic.Bool{} 2831 readHangConn := make(chan struct{}) 2832 copts := ConnectOptions{ 2833 Dialer: func(_ context.Context, addr string) (net.Conn, error) { 2834 conn, err := net.Dial("tcp", addr) 2835 if err != nil { 2836 return nil, err 2837 } 2838 return &readHangingConn{Conn: conn, readHangConn: readHangConn, closed: isReaderHanging}, nil 2839 }, 2840 ChannelzParent: channelzSubChannel(t), 2841 } 2842 2843 // Create a client transport with a custom dialer that hangs the Read() 2844 // after Close(). 2845 ct, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {}) 2846 if err != nil { 2847 t.Fatalf("Failed to create transport: %v", err) 2848 } 2849 2850 if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { 2851 t.Fatalf("Failed to open stream: %v", err) 2852 } 2853 2854 // Closing the client transport will result in the underlying net.Conn being 2855 // closed, which will result in readHangingConn.Read() to hang. This will 2856 // stall the exit of the reader goroutine, and will stall client 2857 // transport's Close from returning. 2858 transportClosed := make(chan struct{}) 2859 go func() { 2860 ct.Close(errors.New("manually closed by client")) 2861 close(transportClosed) 2862 }() 2863 2864 // Wait for a short duration and ensure that the client transport's Close() 2865 // does not return. 2866 select { 2867 case <-transportClosed: 2868 t.Fatal("Transport closed before reader completed") 2869 case <-time.After(defaultTestShortTimeout): 2870 } 2871 2872 // Closing the channel will unblock the reader goroutine and will ensure 2873 // that the client transport's Close() returns. 2874 close(readHangConn) 2875 select { 2876 case <-transportClosed: 2877 case <-time.After(defaultTestTimeout): 2878 t.Fatal("Timeout when waiting for transport to close") 2879 } 2880 } 2881 2882 // hangingConn is a net.Conn wrapper for testing, simulating hanging connections 2883 // after a GOAWAY frame is sent, of which Write operations pause until explicitly 2884 // signaled or a timeout occurs. 2885 type hangingConn struct { 2886 net.Conn 2887 hangConn chan struct{} 2888 startHanging *atomic.Bool 2889 } 2890 2891 func (hc *hangingConn) Write(b []byte) (n int, err error) { 2892 n, err = hc.Conn.Write(b) 2893 if hc.startHanging.Load() { 2894 <-hc.hangConn 2895 } 2896 return n, err 2897 } 2898 2899 // Tests the scenario where a client transport is closed and writing of the 2900 // GOAWAY frame as part of the close does not complete because of a network 2901 // hang. The test verifies that the client transport is closed without waiting 2902 // for too long. 2903 func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) { 2904 // Override timer for writing GOAWAY to 0 so that the connection write 2905 // always times out. It is equivalent of real network hang when conn 2906 // write for goaway doesn't finish in specified deadline 2907 origGoAwayLoopyTimeout := goAwayLoopyWriterTimeout 2908 goAwayLoopyWriterTimeout = time.Millisecond 2909 defer func() { 2910 goAwayLoopyWriterTimeout = origGoAwayLoopyTimeout 2911 }() 2912 2913 // Create the server set up. 2914 connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2915 defer cancel() 2916 server := setUpServerOnly(t, 0, &ServerConfig{}, normal) 2917 defer server.stop() 2918 addr := resolver.Address{Addr: "localhost:" + server.port} 2919 isGreetingDone := &atomic.Bool{} 2920 hangConn := make(chan struct{}) 2921 defer close(hangConn) 2922 dialer := func(_ context.Context, addr string) (net.Conn, error) { 2923 conn, err := net.Dial("tcp", addr) 2924 if err != nil { 2925 return nil, err 2926 } 2927 return &hangingConn{Conn: conn, hangConn: hangConn, startHanging: isGreetingDone}, nil 2928 } 2929 copts := ConnectOptions{Dialer: dialer} 2930 copts.ChannelzParent = channelzSubChannel(t) 2931 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2932 defer cancel() 2933 // Create client transport with custom dialer 2934 ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {}) 2935 if connErr != nil { 2936 t.Fatalf("failed to create transport: %v", connErr) 2937 } 2938 2939 if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { 2940 t.Fatalf("Failed to open stream: %v", err) 2941 } 2942 2943 isGreetingDone.Store(true) 2944 ct.Close(errors.New("manually closed by client")) 2945 } 2946 2947 // TestReadHeaderMultipleBuffers tests the stream when the gRPC headers are 2948 // split across multiple buffers. It verifies that the reporting of the 2949 // number of bytes read for flow control is correct. 2950 func (s) TestReadMessageHeaderMultipleBuffers(t *testing.T) { 2951 headerLen := 5 2952 recvBuffer := newRecvBuffer() 2953 recvBuffer.put(recvMsg{buffer: make(mem.SliceBuffer, 3)}) 2954 recvBuffer.put(recvMsg{buffer: make(mem.SliceBuffer, headerLen-3)}) 2955 bytesRead := 0 2956 s := Stream{ 2957 requestRead: func(int) {}, 2958 trReader: &transportReader{ 2959 reader: &recvBufferReader{ 2960 recv: recvBuffer, 2961 }, 2962 windowHandler: func(i int) { 2963 bytesRead += i 2964 }, 2965 }, 2966 } 2967 2968 header := make([]byte, headerLen) 2969 err := s.ReadMessageHeader(header) 2970 if err != nil { 2971 t.Fatalf("ReadHeader(%v) = %v", header, err) 2972 } 2973 if bytesRead != headerLen { 2974 t.Errorf("bytesRead = %d, want = %d", bytesRead, headerLen) 2975 } 2976 } 2977 2978 // Tests a scenario when the client doesn't send an RST frame when the 2979 // configured deadline is reached. The test verifies that the server sends an 2980 // RST stream only after the deadline is reached. 2981 func (s) TestServerSendsRSTAfterDeadlineToMisbehavedClient(t *testing.T) { 2982 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 2983 defer server.stop() 2984 // Create a client that can override server stream quota. 2985 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 2986 if err != nil { 2987 t.Fatalf("Clent failed to dial:%v", err) 2988 } 2989 defer mconn.Close() 2990 if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil { 2991 t.Fatalf("Failed to set write deadline: %v", err) 2992 } 2993 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 2994 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 2995 } 2996 // rstTimeChan chan indicates that reader received a RSTStream from server. 2997 rstTimeChan := make(chan time.Time, 1) 2998 var mu sync.Mutex 2999 framer := http2.NewFramer(mconn, mconn) 3000 if err := framer.WriteSettings(); err != nil { 3001 t.Fatalf("Error while writing settings: %v", err) 3002 } 3003 go func() { // Launch a reader for this misbehaving client. 3004 for { 3005 frame, err := framer.ReadFrame() 3006 if err != nil { 3007 return 3008 } 3009 switch frame := frame.(type) { 3010 case *http2.PingFrame: 3011 // Write ping ack back so that server's BDP estimation works right. 3012 mu.Lock() 3013 framer.WritePing(true, frame.Data) 3014 mu.Unlock() 3015 case *http2.RSTStreamFrame: 3016 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeCancel { 3017 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeCancel", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) 3018 } 3019 rstTimeChan <- time.Now() 3020 return 3021 default: 3022 // Do nothing. 3023 } 3024 } 3025 }() 3026 // Create a stream. 3027 var buf bytes.Buffer 3028 henc := hpack.NewEncoder(&buf) 3029 if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil { 3030 t.Fatalf("Error while encoding header: %v", err) 3031 } 3032 if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { 3033 t.Fatalf("Error while encoding header: %v", err) 3034 } 3035 if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { 3036 t.Fatalf("Error while encoding header: %v", err) 3037 } 3038 if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { 3039 t.Fatalf("Error while encoding header: %v", err) 3040 } 3041 if err := henc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: "10m"}); err != nil { 3042 t.Fatalf("Error while encoding header: %v", err) 3043 } 3044 mu.Lock() 3045 startTime := time.Now() 3046 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 3047 mu.Unlock() 3048 t.Fatalf("Error while writing headers: %v", err) 3049 } 3050 mu.Unlock() 3051 3052 // Test server behavior for deadline expiration. 3053 var rstTime time.Time 3054 select { 3055 case <-time.After(5 * time.Second): 3056 t.Fatalf("Test timed out.") 3057 case rstTime = <-rstTimeChan: 3058 } 3059 3060 if got, want := rstTime.Sub(startTime), 10*time.Millisecond; got < want { 3061 t.Fatalf("RST frame received earlier than expected by duration: %v", want-got) 3062 } 3063 }