google.golang.org/grpc@v1.74.2/internal/transport/transport_test.go (about) 1 /* 2 * 3 * Copyright 2014 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package transport 20 21 import ( 22 "bytes" 23 "context" 24 "encoding/binary" 25 "errors" 26 "fmt" 27 "io" 28 "math" 29 "net" 30 "os" 31 "runtime" 32 "strconv" 33 "strings" 34 "sync" 35 "sync/atomic" 36 "testing" 37 "time" 38 39 "github.com/google/go-cmp/cmp" 40 "golang.org/x/net/http2" 41 "golang.org/x/net/http2/hpack" 42 "google.golang.org/grpc/attributes" 43 "google.golang.org/grpc/codes" 44 "google.golang.org/grpc/credentials" 45 "google.golang.org/grpc/internal/channelz" 46 "google.golang.org/grpc/internal/grpctest" 47 "google.golang.org/grpc/internal/leakcheck" 48 "google.golang.org/grpc/internal/testutils" 49 "google.golang.org/grpc/mem" 50 "google.golang.org/grpc/metadata" 51 "google.golang.org/grpc/resolver" 52 "google.golang.org/grpc/status" 53 ) 54 55 type s struct { 56 grpctest.Tester 57 } 58 59 func Test(t *testing.T) { 60 grpctest.RunSubTests(t, s{}) 61 } 62 63 var ( 64 expectedRequest = []byte("ping") 65 expectedResponse = []byte("pong") 66 expectedRequestLarge = make([]byte, initialWindowSize*2) 67 expectedResponseLarge = make([]byte, initialWindowSize*2) 68 expectedInvalidHeaderField = "invalid/content-type" 69 ) 70 71 func init() { 72 expectedRequestLarge[0] = 'g' 73 expectedRequestLarge[len(expectedRequestLarge)-1] = 'r' 74 expectedResponseLarge[0] = 'p' 75 expectedResponseLarge[len(expectedResponseLarge)-1] = 'c' 76 } 77 78 func newBufferSlice(b []byte) mem.BufferSlice { 79 return mem.BufferSlice{mem.SliceBuffer(b)} 80 } 81 82 func (s *Stream) readTo(p []byte) (int, error) { 83 data, err := s.read(len(p)) 84 defer data.Free() 85 86 if err != nil { 87 return 0, err 88 } 89 90 if data.Len() != len(p) { 91 if err == nil { 92 err = io.ErrUnexpectedEOF 93 } 94 return 0, err 95 } 96 97 data.CopyTo(p) 98 return len(p), nil 99 } 100 101 type testStreamHandler struct { 102 t *http2Server 103 notify chan struct{} 104 getNotified chan struct{} 105 } 106 107 type hType int 108 109 const ( 110 normal hType = iota 111 suspended 112 notifyCall 113 misbehaved 114 encodingRequiredStatus 115 invalidHeaderField 116 delayRead 117 pingpong 118 ) 119 120 func (h *testStreamHandler) handleStreamAndNotify(*ServerStream) { 121 if h.notify == nil { 122 return 123 } 124 go func() { 125 select { 126 case <-h.notify: 127 default: 128 close(h.notify) 129 } 130 }() 131 } 132 133 func (h *testStreamHandler) handleStream(t *testing.T, s *ServerStream) { 134 req := expectedRequest 135 resp := expectedResponse 136 if s.Method() == "foo.Large" { 137 req = expectedRequestLarge 138 resp = expectedResponseLarge 139 } 140 p := make([]byte, len(req)) 141 _, err := s.readTo(p) 142 if err != nil { 143 return 144 } 145 if !bytes.Equal(p, req) { 146 t.Errorf("handleStream got %v, want %v", p, req) 147 s.WriteStatus(status.New(codes.Internal, "panic")) 148 return 149 } 150 // send a response back to the client. 151 s.Write(nil, newBufferSlice(resp), &WriteOptions{}) 152 // send the trailer to end the stream. 153 s.WriteStatus(status.New(codes.OK, "")) 154 } 155 156 func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *ServerStream) { 157 header := make([]byte, 5) 158 for { 159 if _, err := s.readTo(header); err != nil { 160 if err == io.EOF { 161 s.WriteStatus(status.New(codes.OK, "")) 162 return 163 } 164 t.Errorf("Error on server while reading data header: %v", err) 165 s.WriteStatus(status.New(codes.Internal, "panic")) 166 return 167 } 168 sz := binary.BigEndian.Uint32(header[1:]) 169 msg := make([]byte, int(sz)) 170 if _, err := s.readTo(msg); err != nil { 171 t.Errorf("Error on server while reading message: %v", err) 172 s.WriteStatus(status.New(codes.Internal, "panic")) 173 return 174 } 175 buf := make([]byte, sz+5) 176 buf[0] = byte(0) 177 binary.BigEndian.PutUint32(buf[1:], uint32(sz)) 178 copy(buf[5:], msg) 179 s.Write(nil, newBufferSlice(buf), &WriteOptions{}) 180 } 181 } 182 183 func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *ServerStream) { 184 conn, ok := s.st.(*http2Server) 185 if !ok { 186 t.Errorf("Failed to convert %v to *http2Server", s.st) 187 s.WriteStatus(status.New(codes.Internal, "")) 188 return 189 } 190 var sent int 191 p := make([]byte, http2MaxFrameLen) 192 for sent < initialWindowSize { 193 n := initialWindowSize - sent 194 // The last message may be smaller than http2MaxFrameLen 195 if n <= http2MaxFrameLen { 196 if s.Method() == "foo.Connection" { 197 // Violate connection level flow control window of client but do not 198 // violate any stream level windows. 199 p = make([]byte, n) 200 } else { 201 // Violate stream level flow control window of client. 202 p = make([]byte, n+1) 203 } 204 } 205 data := newBufferSlice(p) 206 data.Ref() 207 conn.controlBuf.put(&dataFrame{ 208 streamID: s.id, 209 h: nil, 210 data: data, 211 onEachWrite: func() {}, 212 }) 213 sent += len(p) 214 } 215 } 216 217 func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *ServerStream) { 218 // raw newline is not accepted by http2 framer so it must be encoded. 219 s.WriteStatus(encodingTestStatus) 220 // Drain any remaining buffers from the stream since it was closed early. 221 s.Read(math.MaxInt) 222 } 223 224 func (h *testStreamHandler) handleStreamInvalidHeaderField(s *ServerStream) { 225 headerFields := []hpack.HeaderField{} 226 headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) 227 h.t.controlBuf.put(&headerFrame{ 228 streamID: s.id, 229 hf: headerFields, 230 endStream: false, 231 }) 232 } 233 234 // handleStreamDelayRead delays reads so that the other side has to halt on 235 // stream-level flow control. 236 // This handler assumes dynamic flow control is turned off and assumes window 237 // sizes to be set to defaultWindowSize. 238 func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream) { 239 req := expectedRequest 240 resp := expectedResponse 241 if s.Method() == "foo.Large" { 242 req = expectedRequestLarge 243 resp = expectedResponseLarge 244 } 245 var ( 246 mu sync.Mutex 247 total int 248 ) 249 s.wq.replenish = func(n int) { 250 mu.Lock() 251 total += n 252 mu.Unlock() 253 s.wq.realReplenish(n) 254 } 255 getTotal := func() int { 256 mu.Lock() 257 defer mu.Unlock() 258 return total 259 } 260 done := make(chan struct{}) 261 defer close(done) 262 go func() { 263 for { 264 select { 265 // Prevent goroutine from leaking. 266 case <-done: 267 return 268 default: 269 } 270 if getTotal() == defaultWindowSize { 271 // Signal the client to start reading and 272 // thereby send window update. 273 close(h.notify) 274 return 275 } 276 runtime.Gosched() 277 } 278 }() 279 p := make([]byte, len(req)) 280 281 // Let the other side run out of stream-level window before 282 // starting to read and thereby sending a window update. 283 timer := time.NewTimer(time.Second * 10) 284 select { 285 case <-h.getNotified: 286 timer.Stop() 287 case <-timer.C: 288 t.Errorf("Server timed-out.") 289 return 290 } 291 _, err := s.readTo(p) 292 if err != nil { 293 t.Errorf("s.Read(_) = _, %v, want _, <nil>", err) 294 return 295 } 296 297 if !bytes.Equal(p, req) { 298 t.Errorf("handleStream got %v, want %v", p, req) 299 return 300 } 301 // This write will cause server to run out of stream level, 302 // flow control and the other side won't send a window update 303 // until that happens. 304 if err := s.Write(nil, newBufferSlice(resp), &WriteOptions{}); err != nil { 305 t.Errorf("server Write got %v, want <nil>", err) 306 return 307 } 308 // Read one more time to ensure that everything remains fine and 309 // that the goroutine, that we launched earlier to signal client 310 // to read, gets enough time to process. 311 _, err = s.readTo(p) 312 if err != nil { 313 t.Errorf("s.Read(_) = _, %v, want _, nil", err) 314 return 315 } 316 // send the trailer to end the stream. 317 if err := s.WriteStatus(status.New(codes.OK, "")); err != nil { 318 t.Errorf("server WriteStatus got %v, want <nil>", err) 319 return 320 } 321 } 322 323 type server struct { 324 lis net.Listener 325 port string 326 startedErr chan error // error (or nil) with server start value 327 mu sync.Mutex 328 conns map[ServerTransport]net.Conn 329 h *testStreamHandler 330 ready chan struct{} 331 channelz *channelz.Server 332 servingTasksDone chan struct{} 333 } 334 335 func newTestServer() *server { 336 return &server{ 337 startedErr: make(chan error, 1), 338 ready: make(chan struct{}), 339 servingTasksDone: make(chan struct{}), 340 channelz: channelz.RegisterServer("test server"), 341 } 342 } 343 344 // start starts server. Other goroutines should block on s.readyChan for further operations. 345 func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) { 346 var err error 347 if port == 0 { 348 s.lis, err = net.Listen("tcp", "localhost:0") 349 } else { 350 s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port)) 351 } 352 if err != nil { 353 s.startedErr <- fmt.Errorf("failed to listen: %v", err) 354 return 355 } 356 _, p, err := net.SplitHostPort(s.lis.Addr().String()) 357 if err != nil { 358 s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err) 359 return 360 } 361 s.port = p 362 s.conns = make(map[ServerTransport]net.Conn) 363 s.startedErr <- nil 364 wg := sync.WaitGroup{} 365 defer func() { 366 wg.Wait() 367 close(s.servingTasksDone) 368 }() 369 370 for { 371 conn, err := s.lis.Accept() 372 if err != nil { 373 return 374 } 375 rawConn := conn 376 if serverConfig.MaxStreams == 0 { 377 serverConfig.MaxStreams = math.MaxUint32 378 } 379 transport, err := NewServerTransport(conn, serverConfig) 380 if err != nil { 381 return 382 } 383 s.mu.Lock() 384 if s.conns == nil { 385 s.mu.Unlock() 386 transport.Close(errors.New("s.conns is nil")) 387 return 388 } 389 s.conns[transport] = rawConn 390 h := &testStreamHandler{t: transport.(*http2Server)} 391 s.h = h 392 s.mu.Unlock() 393 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 394 defer cancel() 395 wg.Add(1) 396 switch ht { 397 case notifyCall: 398 go func() { 399 transport.HandleStreams(ctx, h.handleStreamAndNotify) 400 wg.Done() 401 }() 402 case suspended: 403 go func() { 404 transport.HandleStreams(ctx, func(*ServerStream) {}) 405 wg.Done() 406 }() 407 case misbehaved: 408 go func() { 409 transport.HandleStreams(ctx, func(s *ServerStream) { 410 wg.Add(1) 411 go func() { 412 h.handleStreamMisbehave(t, s) 413 wg.Done() 414 }() 415 }) 416 wg.Done() 417 }() 418 case encodingRequiredStatus: 419 go func() { 420 transport.HandleStreams(ctx, func(s *ServerStream) { 421 wg.Add(1) 422 go func() { 423 h.handleStreamEncodingRequiredStatus(s) 424 wg.Done() 425 }() 426 }) 427 wg.Done() 428 }() 429 case invalidHeaderField: 430 go func() { 431 transport.HandleStreams(ctx, func(s *ServerStream) { 432 wg.Add(1) 433 go func() { 434 h.handleStreamInvalidHeaderField(s) 435 wg.Done() 436 }() 437 }) 438 wg.Done() 439 }() 440 case delayRead: 441 h.notify = make(chan struct{}) 442 h.getNotified = make(chan struct{}) 443 s.mu.Lock() 444 close(s.ready) 445 s.mu.Unlock() 446 go func() { 447 transport.HandleStreams(ctx, func(s *ServerStream) { 448 wg.Add(1) 449 go func() { 450 h.handleStreamDelayRead(t, s) 451 wg.Done() 452 }() 453 }) 454 wg.Done() 455 }() 456 case pingpong: 457 go func() { 458 transport.HandleStreams(ctx, func(s *ServerStream) { 459 wg.Add(1) 460 go func() { 461 h.handleStreamPingPong(t, s) 462 wg.Done() 463 }() 464 }) 465 wg.Done() 466 }() 467 default: 468 go func() { 469 transport.HandleStreams(ctx, func(s *ServerStream) { 470 wg.Add(1) 471 go func() { 472 h.handleStream(t, s) 473 wg.Done() 474 }() 475 }) 476 wg.Done() 477 }() 478 } 479 } 480 } 481 482 func (s *server) wait(t *testing.T, timeout time.Duration) { 483 select { 484 case err := <-s.startedErr: 485 if err != nil { 486 t.Fatal(err) 487 } 488 case <-time.After(timeout): 489 t.Fatalf("Timed out after %v waiting for server to be ready", timeout) 490 } 491 } 492 493 func (s *server) stop() { 494 s.lis.Close() 495 s.mu.Lock() 496 for c := range s.conns { 497 c.Close(errors.New("server Stop called")) 498 } 499 s.conns = nil 500 s.mu.Unlock() 501 <-s.servingTasksDone 502 } 503 504 func (s *server) addr() string { 505 if s.lis == nil { 506 return "" 507 } 508 return s.lis.Addr().String() 509 } 510 511 func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server { 512 server := newTestServer() 513 sc.ChannelzParent = server.channelz 514 go server.start(t, port, sc, ht) 515 server.wait(t, 2*time.Second) 516 return server 517 } 518 519 func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) { 520 return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{}) 521 } 522 523 func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) { 524 server := setUpServerOnly(t, port, sc, ht) 525 addr := resolver.Address{Addr: "localhost:" + server.port} 526 copts.ChannelzParent = channelzSubChannel(t) 527 528 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 529 t.Cleanup(cancel) 530 connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second) 531 ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {}) 532 if connErr != nil { 533 cCancel() // Do not cancel in success path. 534 t.Fatalf("failed to create transport: %v", connErr) 535 } 536 return server, ct.(*http2Client), cCancel 537 } 538 539 func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) { 540 lis, err := net.Listen("tcp", "localhost:0") 541 if err != nil { 542 t.Fatalf("Failed to listen: %v", err) 543 } 544 // Launch a non responsive server. 545 go func() { 546 defer lis.Close() 547 conn, err := lis.Accept() 548 if err != nil { 549 t.Errorf("Error at server-side while accepting: %v", err) 550 close(connCh) 551 return 552 } 553 framer := http2.NewFramer(conn, conn) 554 if err := framer.WriteSettings(); err != nil { 555 t.Errorf("Error at server-side while writing settings: %v", err) 556 close(connCh) 557 return 558 } 559 connCh <- conn 560 }() 561 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 562 t.Cleanup(cancel) 563 connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second) 564 tr, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) 565 if err != nil { 566 cCancel() // Do not cancel in success path. 567 // Server clean-up. 568 lis.Close() 569 if conn, ok := <-connCh; ok { 570 conn.Close() 571 } 572 t.Fatalf("Failed to dial: %v", err) 573 } 574 return tr.(*http2Client), cCancel 575 } 576 577 // TestInflightStreamClosing ensures that closing in-flight stream 578 // sends status error to concurrent stream reader. 579 func (s) TestInflightStreamClosing(t *testing.T) { 580 serverConfig := &ServerConfig{} 581 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 582 defer cancel() 583 defer server.stop() 584 defer client.Close(fmt.Errorf("closed manually by test")) 585 586 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 587 defer cancel() 588 stream, err := client.NewStream(ctx, &CallHdr{}) 589 if err != nil { 590 t.Fatalf("Client failed to create RPC request: %v", err) 591 } 592 593 donec := make(chan struct{}) 594 serr := status.Error(codes.Internal, "client connection is closing") 595 go func() { 596 defer close(donec) 597 if _, err := stream.readTo(make([]byte, defaultWindowSize)); err != serr { 598 t.Errorf("unexpected Stream error %v, expected %v", err, serr) 599 } 600 }() 601 602 // should unblock concurrent stream.Read 603 stream.Close(serr) 604 605 // wait for stream.Read error 606 timeout := time.NewTimer(5 * time.Second) 607 select { 608 case <-donec: 609 if !timeout.Stop() { 610 <-timeout.C 611 } 612 case <-timeout.C: 613 t.Fatalf("Test timed out, expected a status error.") 614 } 615 } 616 617 // Tests that when streamID > MaxStreamId, the current client transport drains. 618 func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { 619 server, ct, cancel := setUp(t, 0, normal) 620 defer cancel() 621 defer server.stop() 622 callHdr := &CallHdr{ 623 Host: "localhost", 624 Method: "foo.Small", 625 } 626 627 originalMaxStreamID := MaxStreamID 628 MaxStreamID = 3 629 defer func() { 630 MaxStreamID = originalMaxStreamID 631 }() 632 633 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 634 defer ctxCancel() 635 636 s, err := ct.NewStream(ctx, callHdr) 637 if err != nil { 638 t.Fatalf("ct.NewStream() = %v", err) 639 } 640 if s.id != 1 { 641 t.Fatalf("Stream id: %d, want: 1", s.id) 642 } 643 644 if got, want := ct.stateForTesting(), reachable; got != want { 645 t.Fatalf("Client transport state %v, want %v", got, want) 646 } 647 648 // The expected stream ID here is 3 since stream IDs are incremented by 2. 649 s, err = ct.NewStream(ctx, callHdr) 650 if err != nil { 651 t.Fatalf("ct.NewStream() = %v", err) 652 } 653 if s.id != 3 { 654 t.Fatalf("Stream id: %d, want: 3", s.id) 655 } 656 657 // Verifying that ct.state is draining when next stream ID > MaxStreamId. 658 if got, want := ct.stateForTesting(), draining; got != want { 659 t.Fatalf("Client transport state %v, want %v", got, want) 660 } 661 } 662 663 func (s) TestClientSendAndReceive(t *testing.T) { 664 server, ct, cancel := setUp(t, 0, normal) 665 defer cancel() 666 callHdr := &CallHdr{ 667 Host: "localhost", 668 Method: "foo.Small", 669 } 670 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 671 defer ctxCancel() 672 s1, err1 := ct.NewStream(ctx, callHdr) 673 if err1 != nil { 674 t.Fatalf("failed to open stream: %v", err1) 675 } 676 if s1.id != 1 { 677 t.Fatalf("wrong stream id: %d", s1.id) 678 } 679 s2, err2 := ct.NewStream(ctx, callHdr) 680 if err2 != nil { 681 t.Fatalf("failed to open stream: %v", err2) 682 } 683 if s2.id != 3 { 684 t.Fatalf("wrong stream id: %d", s2.id) 685 } 686 opts := WriteOptions{Last: true} 687 if err := s1.Write(nil, newBufferSlice(expectedRequest), &opts); err != nil && err != io.EOF { 688 t.Fatalf("failed to send data: %v", err) 689 } 690 p := make([]byte, len(expectedResponse)) 691 _, recvErr := s1.readTo(p) 692 if recvErr != nil || !bytes.Equal(p, expectedResponse) { 693 t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse) 694 } 695 _, recvErr = s1.readTo(p) 696 if recvErr != io.EOF { 697 t.Fatalf("Error: %v; want <EOF>", recvErr) 698 } 699 ct.Close(fmt.Errorf("closed manually by test")) 700 server.stop() 701 } 702 703 func (s) TestClientErrorNotify(t *testing.T) { 704 server, ct, cancel := setUp(t, 0, normal) 705 defer cancel() 706 go server.stop() 707 // ct.reader should detect the error and activate ct.Error(). 708 <-ct.Error() 709 ct.Close(fmt.Errorf("closed manually by test")) 710 } 711 712 func performOneRPC(ct ClientTransport) { 713 callHdr := &CallHdr{ 714 Host: "localhost", 715 Method: "foo.Small", 716 } 717 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 718 defer cancel() 719 s, err := ct.NewStream(ctx, callHdr) 720 if err != nil { 721 return 722 } 723 opts := WriteOptions{Last: true} 724 if err := s.Write([]byte{}, newBufferSlice(expectedRequest), &opts); err == nil || err == io.EOF { 725 time.Sleep(5 * time.Millisecond) 726 // The following s.Recv()'s could error out because the 727 // underlying transport is gone. 728 // 729 // Read response 730 p := make([]byte, len(expectedResponse)) 731 s.readTo(p) 732 // Read io.EOF 733 s.readTo(p) 734 } 735 } 736 737 func (s) TestClientMix(t *testing.T) { 738 s, ct, cancel := setUp(t, 0, normal) 739 defer cancel() 740 time.AfterFunc(time.Second, s.stop) 741 go func(ct ClientTransport) { 742 <-ct.Error() 743 ct.Close(fmt.Errorf("closed manually by test")) 744 }(ct) 745 for i := 0; i < 750; i++ { 746 time.Sleep(2 * time.Millisecond) 747 go performOneRPC(ct) 748 } 749 } 750 751 func (s) TestLargeMessage(t *testing.T) { 752 server, ct, cancel := setUp(t, 0, normal) 753 defer cancel() 754 callHdr := &CallHdr{ 755 Host: "localhost", 756 Method: "foo.Large", 757 } 758 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 759 defer ctxCancel() 760 var wg sync.WaitGroup 761 for i := 0; i < 2; i++ { 762 wg.Add(1) 763 go func() { 764 defer wg.Done() 765 s, err := ct.NewStream(ctx, callHdr) 766 if err != nil { 767 t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) 768 } 769 if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{Last: true}); err != nil && err != io.EOF { 770 t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) 771 } 772 p := make([]byte, len(expectedResponseLarge)) 773 if _, err := s.readTo(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { 774 t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse) 775 } 776 if _, err = s.readTo(p); err != io.EOF { 777 t.Errorf("Failed to complete the stream %v; want <EOF>", err) 778 } 779 }() 780 } 781 wg.Wait() 782 ct.Close(fmt.Errorf("closed manually by test")) 783 server.stop() 784 } 785 786 func (s) TestLargeMessageWithDelayRead(t *testing.T) { 787 // Disable dynamic flow control. 788 sc := &ServerConfig{ 789 InitialWindowSize: defaultWindowSize, 790 InitialConnWindowSize: defaultWindowSize, 791 StaticWindowSize: true, 792 } 793 co := ConnectOptions{ 794 InitialWindowSize: defaultWindowSize, 795 InitialConnWindowSize: defaultWindowSize, 796 StaticWindowSize: true, 797 } 798 server, ct, cancel := setUpWithOptions(t, 0, sc, delayRead, co) 799 defer cancel() 800 defer server.stop() 801 defer ct.Close(fmt.Errorf("closed manually by test")) 802 server.mu.Lock() 803 ready := server.ready 804 server.mu.Unlock() 805 callHdr := &CallHdr{ 806 Host: "localhost", 807 Method: "foo.Large", 808 } 809 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 810 defer cancel() 811 s, err := ct.NewStream(ctx, callHdr) 812 if err != nil { 813 t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) 814 return 815 } 816 // Wait for server's handler to be initialized 817 select { 818 case <-ready: 819 case <-ctx.Done(): 820 t.Fatalf("Client timed out waiting for server handler to be initialized.") 821 } 822 server.mu.Lock() 823 serviceHandler := server.h 824 server.mu.Unlock() 825 var ( 826 mu sync.Mutex 827 total int 828 ) 829 s.wq.replenish = func(n int) { 830 mu.Lock() 831 total += n 832 mu.Unlock() 833 s.wq.realReplenish(n) 834 } 835 getTotal := func() int { 836 mu.Lock() 837 defer mu.Unlock() 838 return total 839 } 840 done := make(chan struct{}) 841 defer close(done) 842 go func() { 843 for { 844 select { 845 // Prevent goroutine from leaking in case of error. 846 case <-done: 847 return 848 default: 849 } 850 if getTotal() == defaultWindowSize { 851 // unblock server to be able to read and 852 // thereby send stream level window update. 853 close(serviceHandler.getNotified) 854 return 855 } 856 runtime.Gosched() 857 } 858 }() 859 // This write will cause client to run out of stream level, 860 // flow control and the other side won't send a window update 861 // until that happens. 862 if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{}); err != nil { 863 t.Fatalf("write(_, _, _) = %v, want <nil>", err) 864 } 865 p := make([]byte, len(expectedResponseLarge)) 866 867 // Wait for the other side to run out of stream level flow control before 868 // reading and thereby sending a window update. 869 select { 870 case <-serviceHandler.notify: 871 case <-ctx.Done(): 872 t.Fatalf("Client timed out") 873 } 874 if _, err := s.readTo(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { 875 t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err) 876 } 877 if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{Last: true}); err != nil { 878 t.Fatalf("Write(_, _, _) = %v, want <nil>", err) 879 } 880 if _, err = s.readTo(p); err != io.EOF { 881 t.Fatalf("Failed to complete the stream %v; want <EOF>", err) 882 } 883 } 884 885 // TestGracefulClose ensures that GracefulClose allows in-flight streams to 886 // proceed until they complete naturally, while not allowing creation of new 887 // streams during this window. 888 func (s) TestGracefulClose(t *testing.T) { 889 leakcheck.SetTrackingBufferPool(t) 890 server, ct, cancel := setUp(t, 0, pingpong) 891 defer cancel() 892 defer func() { 893 // Stop the server's listener to make the server's goroutines terminate 894 // (after the last active stream is done). 895 server.lis.Close() 896 // Check for goroutine leaks (i.e. GracefulClose with an active stream 897 // doesn't eventually close the connection when that stream completes). 898 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 899 defer cancel() 900 leakcheck.CheckGoroutines(ctx, t) 901 leakcheck.CheckTrackingBufferPool() 902 // Correctly clean up the server 903 server.stop() 904 }() 905 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 906 defer cancel() 907 908 // Create a stream that will exist for this whole test and confirm basic 909 // functionality. 910 s, err := ct.NewStream(ctx, &CallHdr{}) 911 if err != nil { 912 t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err) 913 } 914 msg := make([]byte, 1024) 915 outgoingHeader := make([]byte, 5) 916 outgoingHeader[0] = byte(0) 917 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg))) 918 incomingHeader := make([]byte, 5) 919 if err := s.Write(outgoingHeader, newBufferSlice(msg), &WriteOptions{}); err != nil { 920 t.Fatalf("Error while writing: %v", err) 921 } 922 if _, err := s.readTo(incomingHeader); err != nil { 923 t.Fatalf("Error while reading: %v", err) 924 } 925 sz := binary.BigEndian.Uint32(incomingHeader[1:]) 926 recvMsg := make([]byte, int(sz)) 927 if _, err := s.readTo(recvMsg); err != nil { 928 t.Fatalf("Error while reading: %v", err) 929 } 930 931 // Gracefully close the transport, which should not affect the existing 932 // stream. 933 ct.GracefulClose() 934 935 var wg sync.WaitGroup 936 // Expect errors creating new streams because the client transport has been 937 // gracefully closed. 938 for i := 0; i < 200; i++ { 939 wg.Add(1) 940 go func() { 941 defer wg.Done() 942 _, err := ct.NewStream(ctx, &CallHdr{}) 943 if err != nil && err.(*NewStreamError).Err == ErrConnClosing && err.(*NewStreamError).AllowTransparentRetry { 944 return 945 } 946 t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) 947 }() 948 } 949 950 // Confirm the existing stream still functions as expected. 951 s.Write(nil, nil, &WriteOptions{Last: true}) 952 if _, err := s.readTo(incomingHeader); err != io.EOF { 953 t.Fatalf("Client expected EOF from the server. Got: %v", err) 954 } 955 wg.Wait() 956 } 957 958 func (s) TestLargeMessageSuspension(t *testing.T) { 959 server, ct, cancel := setUp(t, 0, suspended) 960 defer cancel() 961 defer ct.Close(fmt.Errorf("closed manually by test")) 962 defer server.stop() 963 callHdr := &CallHdr{ 964 Host: "localhost", 965 Method: "foo.Large", 966 } 967 // Set a long enough timeout for writing a large message out. 968 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 969 defer cancel() 970 s, err := ct.NewStream(ctx, callHdr) 971 if err != nil { 972 t.Fatalf("failed to open stream: %v", err) 973 } 974 // Write should not be done successfully due to flow control. 975 msg := make([]byte, initialWindowSize*8) 976 s.Write(nil, newBufferSlice(msg), &WriteOptions{}) 977 err = s.Write(nil, newBufferSlice(msg), &WriteOptions{Last: true}) 978 if err != errStreamDone { 979 t.Fatalf("Write got %v, want io.EOF", err) 980 } 981 // The server will send an RST stream frame on observing the deadline 982 // expiration making the client stream fail with a DeadlineExceeded status. 983 _, err = s.readTo(make([]byte, 8)) 984 if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded { 985 t.Fatalf("Read got unexpected error: %v, want status with code %v", err, codes.DeadlineExceeded) 986 } 987 if got, want := s.Status().Code(), codes.DeadlineExceeded; got != want { 988 t.Fatalf("Read got status %v with code %v, want %v", s.Status(), got, want) 989 } 990 } 991 992 func (s) TestMaxStreams(t *testing.T) { 993 serverConfig := &ServerConfig{ 994 MaxStreams: 1, 995 } 996 server, ct, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 997 defer cancel() 998 defer ct.Close(fmt.Errorf("closed manually by test")) 999 defer server.stop() 1000 callHdr := &CallHdr{ 1001 Host: "localhost", 1002 Method: "foo.Large", 1003 } 1004 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1005 defer cancel() 1006 s, err := ct.NewStream(ctx, callHdr) 1007 if err != nil { 1008 t.Fatalf("Failed to open stream: %v", err) 1009 } 1010 // Keep creating streams until one fails with deadline exceeded, marking the application 1011 // of server settings on client. 1012 slist := []*ClientStream{} 1013 pctx, cancel := context.WithCancel(context.Background()) 1014 defer cancel() 1015 timer := time.NewTimer(time.Second * 10) 1016 expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) 1017 for { 1018 select { 1019 case <-timer.C: 1020 t.Fatalf("Test timeout: client didn't receive server settings.") 1021 default: 1022 } 1023 ctx, cancel := context.WithDeadline(pctx, time.Now().Add(time.Second)) 1024 // This is only to get rid of govet. All these context are based on a base 1025 // context which is canceled at the end of the test. 1026 defer cancel() 1027 if str, err := ct.NewStream(ctx, callHdr); err == nil { 1028 slist = append(slist, str) 1029 continue 1030 } else if err.Error() != expectedErr.Error() { 1031 t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr) 1032 } 1033 timer.Stop() 1034 break 1035 } 1036 done := make(chan struct{}) 1037 // Try and create a new stream. 1038 go func() { 1039 defer close(done) 1040 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 1041 defer cancel() 1042 if _, err := ct.NewStream(ctx, callHdr); err != nil { 1043 t.Errorf("Failed to open stream: %v", err) 1044 } 1045 }() 1046 // Close all the extra streams created and make sure the new stream is not created. 1047 for _, str := range slist { 1048 str.Close(nil) 1049 } 1050 select { 1051 case <-done: 1052 t.Fatalf("Test failed: didn't expect new stream to be created just yet.") 1053 default: 1054 } 1055 // Close the first stream created so that the new stream can finally be created. 1056 s.Close(nil) 1057 <-done 1058 ct.Close(fmt.Errorf("closed manually by test")) 1059 <-ct.writerDone 1060 if ct.maxConcurrentStreams != 1 { 1061 t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams) 1062 } 1063 } 1064 1065 func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { 1066 server, ct, cancel := setUp(t, 0, suspended) 1067 defer cancel() 1068 callHdr := &CallHdr{ 1069 Host: "localhost", 1070 Method: "foo", 1071 } 1072 var sc *http2Server 1073 // Wait until the server transport is setup. 1074 for { 1075 server.mu.Lock() 1076 if len(server.conns) == 0 { 1077 server.mu.Unlock() 1078 time.Sleep(time.Millisecond) 1079 continue 1080 } 1081 for k := range server.conns { 1082 var ok bool 1083 sc, ok = k.(*http2Server) 1084 if !ok { 1085 t.Fatalf("Failed to convert %v to *http2Server", k) 1086 } 1087 } 1088 server.mu.Unlock() 1089 break 1090 } 1091 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1092 defer cancel() 1093 s, err := ct.NewStream(ctx, callHdr) 1094 if err != nil { 1095 t.Fatalf("Failed to open stream: %v", err) 1096 } 1097 d := newBufferSlice(make([]byte, http2MaxFrameLen)) 1098 d.Ref() 1099 ct.controlBuf.put(&dataFrame{ 1100 streamID: s.id, 1101 endStream: false, 1102 h: nil, 1103 data: d, 1104 onEachWrite: func() {}, 1105 }) 1106 // Loop until the server side stream is created. 1107 var ss *ServerStream 1108 for { 1109 time.Sleep(time.Second) 1110 sc.mu.Lock() 1111 if len(sc.activeStreams) == 0 { 1112 sc.mu.Unlock() 1113 continue 1114 } 1115 ss = sc.activeStreams[s.id] 1116 sc.mu.Unlock() 1117 break 1118 } 1119 ct.Close(fmt.Errorf("closed manually by test")) 1120 select { 1121 case <-ss.Context().Done(): 1122 if ss.Context().Err() != context.Canceled { 1123 t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled) 1124 } 1125 case <-time.After(5 * time.Second): 1126 t.Fatalf("Failed to cancel the context of the sever side stream.") 1127 } 1128 server.stop() 1129 } 1130 1131 func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) { 1132 connectOptions := ConnectOptions{ 1133 InitialWindowSize: defaultWindowSize, 1134 InitialConnWindowSize: defaultWindowSize, 1135 } 1136 server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) 1137 defer cancel() 1138 defer server.stop() 1139 defer client.Close(fmt.Errorf("closed manually by test")) 1140 1141 waitWhileTrue(t, func() (bool, error) { 1142 server.mu.Lock() 1143 defer server.mu.Unlock() 1144 1145 if len(server.conns) == 0 { 1146 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 1147 } 1148 return false, nil 1149 }) 1150 1151 var st *http2Server 1152 server.mu.Lock() 1153 for k := range server.conns { 1154 st = k.(*http2Server) 1155 } 1156 notifyChan := make(chan struct{}) 1157 server.h.notify = notifyChan 1158 server.mu.Unlock() 1159 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1160 defer cancel() 1161 cstream1, err := client.NewStream(ctx, &CallHdr{}) 1162 if err != nil { 1163 t.Fatalf("Client failed to create first stream. Err: %v", err) 1164 } 1165 1166 <-notifyChan 1167 var sstream1 *ServerStream 1168 // Access stream on the server. 1169 st.mu.Lock() 1170 for _, v := range st.activeStreams { 1171 if v.id == cstream1.id { 1172 sstream1 = v 1173 } 1174 } 1175 st.mu.Unlock() 1176 if sstream1 == nil { 1177 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) 1178 } 1179 // Exhaust client's connection window. 1180 if err := sstream1.Write([]byte{}, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil { 1181 t.Fatalf("Server failed to write data. Err: %v", err) 1182 } 1183 notifyChan = make(chan struct{}) 1184 server.mu.Lock() 1185 server.h.notify = notifyChan 1186 server.mu.Unlock() 1187 // Create another stream on client. 1188 cstream2, err := client.NewStream(ctx, &CallHdr{}) 1189 if err != nil { 1190 t.Fatalf("Client failed to create second stream. Err: %v", err) 1191 } 1192 <-notifyChan 1193 var sstream2 *ServerStream 1194 st.mu.Lock() 1195 for _, v := range st.activeStreams { 1196 if v.id == cstream2.id { 1197 sstream2 = v 1198 } 1199 } 1200 st.mu.Unlock() 1201 if sstream2 == nil { 1202 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) 1203 } 1204 // Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream. 1205 if err := sstream2.Write([]byte{}, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil { 1206 t.Fatalf("Server failed to write data. Err: %v", err) 1207 } 1208 1209 // Client should be able to read data on second stream. 1210 if _, err := cstream2.readTo(make([]byte, defaultWindowSize)); err != nil { 1211 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) 1212 } 1213 1214 // Client should be able to read data on first stream. 1215 if _, err := cstream1.readTo(make([]byte, defaultWindowSize)); err != nil { 1216 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) 1217 } 1218 } 1219 1220 func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) { 1221 serverConfig := &ServerConfig{ 1222 InitialWindowSize: defaultWindowSize, 1223 InitialConnWindowSize: defaultWindowSize, 1224 } 1225 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 1226 defer cancel() 1227 defer server.stop() 1228 defer client.Close(fmt.Errorf("closed manually by test")) 1229 waitWhileTrue(t, func() (bool, error) { 1230 server.mu.Lock() 1231 defer server.mu.Unlock() 1232 1233 if len(server.conns) == 0 { 1234 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 1235 } 1236 return false, nil 1237 }) 1238 var st *http2Server 1239 server.mu.Lock() 1240 for k := range server.conns { 1241 st = k.(*http2Server) 1242 } 1243 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1244 defer cancel() 1245 server.mu.Unlock() 1246 cstream1, err := client.NewStream(ctx, &CallHdr{}) 1247 if err != nil { 1248 t.Fatalf("Failed to create 1st stream. Err: %v", err) 1249 } 1250 // Exhaust server's connection window. 1251 if err := cstream1.Write(nil, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{Last: true}); err != nil { 1252 t.Fatalf("Client failed to write data. Err: %v", err) 1253 } 1254 // Client should be able to create another stream and send data on it. 1255 cstream2, err := client.NewStream(ctx, &CallHdr{}) 1256 if err != nil { 1257 t.Fatalf("Failed to create 2nd stream. Err: %v", err) 1258 } 1259 if err := cstream2.Write(nil, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil { 1260 t.Fatalf("Client failed to write data. Err: %v", err) 1261 } 1262 // Get the streams on server. 1263 waitWhileTrue(t, func() (bool, error) { 1264 st.mu.Lock() 1265 defer st.mu.Unlock() 1266 1267 if len(st.activeStreams) != 2 { 1268 return true, fmt.Errorf("timed-out while waiting for server to have created the streams") 1269 } 1270 return false, nil 1271 }) 1272 var sstream1 *ServerStream 1273 st.mu.Lock() 1274 for _, v := range st.activeStreams { 1275 if v.id == 1 { 1276 sstream1 = v 1277 } 1278 } 1279 st.mu.Unlock() 1280 // Reading from the stream on server should succeed. 1281 if _, err := sstream1.readTo(make([]byte, defaultWindowSize)); err != nil { 1282 t.Fatalf("_.Read(_) = %v, want <nil>", err) 1283 } 1284 1285 if _, err := sstream1.readTo(make([]byte, 1)); err != io.EOF { 1286 t.Fatalf("_.Read(_) = %v, want io.EOF", err) 1287 } 1288 1289 } 1290 1291 func (s) TestServerWithMisbehavedClient(t *testing.T) { 1292 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1293 defer server.stop() 1294 // Create a client that can override server stream quota. 1295 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1296 if err != nil { 1297 t.Fatalf("Clent failed to dial:%v", err) 1298 } 1299 defer mconn.Close() 1300 if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil { 1301 t.Fatalf("Failed to set write deadline: %v", err) 1302 } 1303 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1304 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1305 } 1306 // success chan indicates that reader received a RSTStream from server. 1307 success := make(chan struct{}) 1308 var mu sync.Mutex 1309 framer := http2.NewFramer(mconn, mconn) 1310 if err := framer.WriteSettings(); err != nil { 1311 t.Fatalf("Error while writing settings: %v", err) 1312 } 1313 go func() { // Launch a reader for this misbehaving client. 1314 for { 1315 frame, err := framer.ReadFrame() 1316 if err != nil { 1317 return 1318 } 1319 switch frame := frame.(type) { 1320 case *http2.PingFrame: 1321 // Write ping ack back so that server's BDP estimation works right. 1322 mu.Lock() 1323 framer.WritePing(true, frame.Data) 1324 mu.Unlock() 1325 case *http2.RSTStreamFrame: 1326 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl { 1327 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) 1328 } 1329 close(success) 1330 return 1331 default: 1332 // Do nothing. 1333 } 1334 1335 } 1336 }() 1337 // Create a stream. 1338 var buf bytes.Buffer 1339 henc := hpack.NewEncoder(&buf) 1340 // TODO(mmukhi): Remove unnecessary fields. 1341 if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil { 1342 t.Fatalf("Error while encoding header: %v", err) 1343 } 1344 if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { 1345 t.Fatalf("Error while encoding header: %v", err) 1346 } 1347 if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { 1348 t.Fatalf("Error while encoding header: %v", err) 1349 } 1350 if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { 1351 t.Fatalf("Error while encoding header: %v", err) 1352 } 1353 mu.Lock() 1354 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1355 mu.Unlock() 1356 t.Fatalf("Error while writing headers: %v", err) 1357 } 1358 mu.Unlock() 1359 1360 // Test server behavior for violation of stream flow control window size restriction. 1361 timer := time.NewTimer(time.Second * 5) 1362 dbuf := make([]byte, http2MaxFrameLen) 1363 for { 1364 select { 1365 case <-timer.C: 1366 t.Fatalf("Test timed out.") 1367 case <-success: 1368 return 1369 default: 1370 } 1371 mu.Lock() 1372 if err := framer.WriteData(1, false, dbuf); err != nil { 1373 mu.Unlock() 1374 // Error here means the server could have closed the connection due to flow control 1375 // violation. Make sure that is the case by waiting for success chan to be closed. 1376 select { 1377 case <-timer.C: 1378 t.Fatalf("Error while writing data: %v", err) 1379 case <-success: 1380 return 1381 } 1382 } 1383 mu.Unlock() 1384 // This for loop is capable of hogging the CPU and cause starvation 1385 // in Go versions prior to 1.9, 1386 // in single CPU environment. Explicitly relinquish processor. 1387 runtime.Gosched() 1388 } 1389 } 1390 1391 func (s) TestClientHonorsConnectContext(t *testing.T) { 1392 // Create a server that will not send a preface. 1393 lis, err := net.Listen("tcp", "localhost:0") 1394 if err != nil { 1395 t.Fatalf("Error while listening: %v", err) 1396 } 1397 defer lis.Close() 1398 go func() { // Launch the misbehaving server. 1399 sconn, err := lis.Accept() 1400 if err != nil { 1401 t.Errorf("Error while accepting: %v", err) 1402 return 1403 } 1404 defer sconn.Close() 1405 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { 1406 t.Errorf("Error while reading client preface: %v", err) 1407 return 1408 } 1409 sfr := http2.NewFramer(sconn, sconn) 1410 // Do not write a settings frame, but read from the conn forever. 1411 for { 1412 if _, err := sfr.ReadFrame(); err != nil { 1413 return 1414 } 1415 } 1416 }() 1417 1418 // Test context cancellation. 1419 timeBefore := time.Now() 1420 connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1421 time.AfterFunc(100*time.Millisecond, cancel) 1422 1423 parent := channelzSubChannel(t) 1424 copts := ConnectOptions{ChannelzParent: parent} 1425 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1426 defer cancel() 1427 _, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) 1428 if err == nil { 1429 t.Fatalf("NewHTTP2Client() returned successfully; wanted error") 1430 } 1431 t.Logf("NewHTTP2Client() = _, %v", err) 1432 if time.Since(timeBefore) > 3*time.Second { 1433 t.Fatalf("NewHTTP2Client returned > 2.9s after context cancellation") 1434 } 1435 1436 // Test context deadline. 1437 connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) 1438 defer cancel() 1439 _, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) 1440 if err == nil { 1441 t.Fatalf("NewHTTP2Client() returned successfully; wanted error") 1442 } 1443 t.Logf("NewHTTP2Client() = _, %v", err) 1444 } 1445 1446 func (s) TestClientWithMisbehavedServer(t *testing.T) { 1447 // Create a misbehaving server. 1448 lis, err := net.Listen("tcp", "localhost:0") 1449 if err != nil { 1450 t.Fatalf("Error while listening: %v", err) 1451 } 1452 defer lis.Close() 1453 // success chan indicates that the server received 1454 // RSTStream from the client. 1455 success := make(chan struct{}) 1456 go func() { // Launch the misbehaving server. 1457 sconn, err := lis.Accept() 1458 if err != nil { 1459 t.Errorf("Error while accepting: %v", err) 1460 return 1461 } 1462 defer sconn.Close() 1463 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { 1464 t.Errorf("Error while reading client preface: %v", err) 1465 return 1466 } 1467 sfr := http2.NewFramer(sconn, sconn) 1468 if err := sfr.WriteSettings(); err != nil { 1469 t.Errorf("Error while writing settings: %v", err) 1470 return 1471 } 1472 if err := sfr.WriteSettingsAck(); err != nil { 1473 t.Errorf("Error while writing settings: %v", err) 1474 return 1475 } 1476 var mu sync.Mutex 1477 for { 1478 frame, err := sfr.ReadFrame() 1479 if err != nil { 1480 return 1481 } 1482 switch frame := frame.(type) { 1483 case *http2.HeadersFrame: 1484 // When the client creates a stream, violate the stream flow control. 1485 go func() { 1486 buf := make([]byte, http2MaxFrameLen) 1487 for { 1488 mu.Lock() 1489 if err := sfr.WriteData(1, false, buf); err != nil { 1490 mu.Unlock() 1491 return 1492 } 1493 mu.Unlock() 1494 // This for loop is capable of hogging the CPU and cause starvation 1495 // in Go versions prior to 1.9, 1496 // in single CPU environment. Explicitly relinquish processor. 1497 runtime.Gosched() 1498 } 1499 }() 1500 case *http2.RSTStreamFrame: 1501 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl { 1502 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) 1503 } 1504 close(success) 1505 return 1506 case *http2.PingFrame: 1507 mu.Lock() 1508 sfr.WritePing(true, frame.Data) 1509 mu.Unlock() 1510 default: 1511 } 1512 } 1513 }() 1514 connectCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 1515 defer cancel() 1516 1517 parent := channelzSubChannel(t) 1518 copts := ConnectOptions{ChannelzParent: parent} 1519 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1520 defer cancel() 1521 ct, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) 1522 if err != nil { 1523 t.Fatalf("Error while creating client transport: %v", err) 1524 } 1525 defer ct.Close(fmt.Errorf("closed manually by test")) 1526 1527 str, err := ct.NewStream(connectCtx, &CallHdr{}) 1528 if err != nil { 1529 t.Fatalf("Error while creating stream: %v", err) 1530 } 1531 timer := time.NewTimer(time.Second * 5) 1532 go func() { // This go routine mimics the one in stream.go to call CloseStream. 1533 <-str.Done() 1534 str.Close(nil) 1535 }() 1536 select { 1537 case <-timer.C: 1538 t.Fatalf("Test timed-out.") 1539 case <-success: 1540 } 1541 // Drain the remaining buffers in the stream by reading until an error is 1542 // encountered. 1543 str.Read(math.MaxInt) 1544 } 1545 1546 var encodingTestStatus = status.New(codes.Internal, "\n") 1547 1548 func (s) TestEncodingRequiredStatus(t *testing.T) { 1549 server, ct, cancel := setUp(t, 0, encodingRequiredStatus) 1550 defer cancel() 1551 callHdr := &CallHdr{ 1552 Host: "localhost", 1553 Method: "foo", 1554 } 1555 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1556 defer cancel() 1557 s, err := ct.NewStream(ctx, callHdr) 1558 if err != nil { 1559 return 1560 } 1561 opts := WriteOptions{Last: true} 1562 if err := s.Write(nil, newBufferSlice(expectedRequest), &opts); err != nil && err != errStreamDone { 1563 t.Fatalf("Failed to write the request: %v", err) 1564 } 1565 p := make([]byte, http2MaxFrameLen) 1566 if _, err := s.readTo(p); err != io.EOF { 1567 t.Fatalf("Read got error %v, want %v", err, io.EOF) 1568 } 1569 if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) { 1570 t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) 1571 } 1572 ct.Close(fmt.Errorf("closed manually by test")) 1573 server.stop() 1574 // Drain any remaining buffers from the stream since it was closed early. 1575 s.Read(math.MaxInt) 1576 } 1577 1578 func (s) TestInvalidHeaderField(t *testing.T) { 1579 server, ct, cancel := setUp(t, 0, invalidHeaderField) 1580 defer cancel() 1581 callHdr := &CallHdr{ 1582 Host: "localhost", 1583 Method: "foo", 1584 } 1585 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1586 defer cancel() 1587 s, err := ct.NewStream(ctx, callHdr) 1588 if err != nil { 1589 return 1590 } 1591 p := make([]byte, http2MaxFrameLen) 1592 _, err = s.readTo(p) 1593 if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { 1594 t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField) 1595 } 1596 ct.Close(fmt.Errorf("closed manually by test")) 1597 server.stop() 1598 } 1599 1600 func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { 1601 server, ct, cancel := setUp(t, 0, invalidHeaderField) 1602 defer cancel() 1603 defer server.stop() 1604 defer ct.Close(fmt.Errorf("closed manually by test")) 1605 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1606 defer cancel() 1607 s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) 1608 if err != nil { 1609 t.Fatalf("failed to create the stream") 1610 } 1611 timer := time.NewTimer(time.Second) 1612 defer timer.Stop() 1613 select { 1614 case <-s.headerChan: 1615 case <-timer.C: 1616 t.Errorf("s.headerChan: got open, want closed") 1617 } 1618 } 1619 1620 func (s) TestIsReservedHeader(t *testing.T) { 1621 tests := []struct { 1622 h string 1623 want bool 1624 }{ 1625 {"", false}, // but should be rejected earlier 1626 {"foo", false}, 1627 {"content-type", true}, 1628 {"user-agent", true}, 1629 {":anything", true}, 1630 {"grpc-message-type", true}, 1631 {"grpc-encoding", true}, 1632 {"grpc-message", true}, 1633 {"grpc-status", true}, 1634 {"grpc-timeout", true}, 1635 {"te", true}, 1636 } 1637 for _, tt := range tests { 1638 got := isReservedHeader(tt.h) 1639 if got != tt.want { 1640 t.Errorf("isReservedHeader(%q) = %v; want %v", tt.h, got, tt.want) 1641 } 1642 } 1643 } 1644 1645 func (s) TestContextErr(t *testing.T) { 1646 for _, test := range []struct { 1647 // input 1648 errIn error 1649 // outputs 1650 errOut error 1651 }{ 1652 {context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, 1653 {context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())}, 1654 } { 1655 err := ContextErr(test.errIn) 1656 if err.Error() != test.errOut.Error() { 1657 t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) 1658 } 1659 } 1660 } 1661 1662 type windowSizeConfig struct { 1663 serverStream int32 1664 serverConn int32 1665 clientStream int32 1666 clientConn int32 1667 } 1668 1669 func (s) TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) { 1670 wc := windowSizeConfig{ 1671 serverStream: 10 * 1024 * 1024, 1672 serverConn: 12 * 1024 * 1024, 1673 clientStream: 6 * 1024 * 1024, 1674 clientConn: 8 * 1024 * 1024, 1675 } 1676 testFlowControlAccountCheck(t, 1024*1024, wc) 1677 } 1678 1679 func (s) TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) { 1680 // These settings disable dynamic window sizes based on BDP estimation; 1681 // must be at least defaultWindowSize or the setting is ignored. 1682 wc := windowSizeConfig{ 1683 serverStream: defaultWindowSize, 1684 serverConn: defaultWindowSize, 1685 clientStream: defaultWindowSize, 1686 clientConn: defaultWindowSize, 1687 } 1688 testFlowControlAccountCheck(t, 1024*1024, wc) 1689 } 1690 1691 func (s) TestAccountCheckDynamicWindowSmallMessage(t *testing.T) { 1692 testFlowControlAccountCheck(t, 1024, windowSizeConfig{}) 1693 } 1694 1695 func (s) TestAccountCheckDynamicWindowLargeMessage(t *testing.T) { 1696 testFlowControlAccountCheck(t, 1024*1024, windowSizeConfig{}) 1697 } 1698 1699 func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) { 1700 sc := &ServerConfig{ 1701 InitialWindowSize: wc.serverStream, 1702 InitialConnWindowSize: wc.serverConn, 1703 StaticWindowSize: true, 1704 } 1705 co := ConnectOptions{ 1706 InitialWindowSize: wc.clientStream, 1707 InitialConnWindowSize: wc.clientConn, 1708 StaticWindowSize: true, 1709 } 1710 server, client, cancel := setUpWithOptions(t, 0, sc, pingpong, co) 1711 defer cancel() 1712 defer server.stop() 1713 defer client.Close(fmt.Errorf("closed manually by test")) 1714 waitWhileTrue(t, func() (bool, error) { 1715 server.mu.Lock() 1716 defer server.mu.Unlock() 1717 if len(server.conns) == 0 { 1718 return true, fmt.Errorf("timed out while waiting for server transport to be created") 1719 } 1720 return false, nil 1721 }) 1722 var st *http2Server 1723 server.mu.Lock() 1724 for k := range server.conns { 1725 st = k.(*http2Server) 1726 } 1727 server.mu.Unlock() 1728 1729 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1730 defer cancel() 1731 const numStreams = 5 1732 clientStreams := make([]*ClientStream, numStreams) 1733 for i := 0; i < numStreams; i++ { 1734 var err error 1735 clientStreams[i], err = client.NewStream(ctx, &CallHdr{}) 1736 if err != nil { 1737 t.Fatalf("Failed to create stream. Err: %v", err) 1738 } 1739 } 1740 var wg sync.WaitGroup 1741 // For each stream send pingpong messages to the server. 1742 for _, stream := range clientStreams { 1743 wg.Add(1) 1744 go func(stream *ClientStream) { 1745 defer wg.Done() 1746 buf := make([]byte, msgSize+5) 1747 buf[0] = byte(0) 1748 binary.BigEndian.PutUint32(buf[1:], uint32(msgSize)) 1749 opts := WriteOptions{} 1750 header := make([]byte, 5) 1751 for i := 1; i <= 5; i++ { 1752 if err := stream.Write(nil, newBufferSlice(buf), &opts); err != nil { 1753 t.Errorf("Error on client while writing message %v on stream %v: %v", i, stream.id, err) 1754 return 1755 } 1756 if _, err := stream.readTo(header); err != nil { 1757 t.Errorf("Error on client while reading data frame header %v on stream %v: %v", i, stream.id, err) 1758 return 1759 } 1760 sz := binary.BigEndian.Uint32(header[1:]) 1761 recvMsg := make([]byte, int(sz)) 1762 if _, err := stream.readTo(recvMsg); err != nil { 1763 t.Errorf("Error on client while reading data %v on stream %v: %v", i, stream.id, err) 1764 return 1765 } 1766 if len(recvMsg) != msgSize { 1767 t.Errorf("Length of message %v received by client on stream %v: %v, want: %v", i, stream.id, len(recvMsg), msgSize) 1768 return 1769 } 1770 } 1771 t.Logf("stream %v done with pingpongs", stream.id) 1772 }(stream) 1773 } 1774 wg.Wait() 1775 serverStreams := map[uint32]*ServerStream{} 1776 loopyClientStreams := map[uint32]*outStream{} 1777 loopyServerStreams := map[uint32]*outStream{} 1778 // Get all the streams from server reader and writer and client writer. 1779 st.mu.Lock() 1780 client.mu.Lock() 1781 for _, stream := range clientStreams { 1782 id := stream.id 1783 serverStreams[id] = st.activeStreams[id] 1784 loopyServerStreams[id] = st.loopy.estdStreams[id] 1785 loopyClientStreams[id] = client.loopy.estdStreams[id] 1786 1787 } 1788 client.mu.Unlock() 1789 st.mu.Unlock() 1790 // Close all streams 1791 for _, stream := range clientStreams { 1792 stream.Write(nil, nil, &WriteOptions{Last: true}) 1793 if _, err := stream.readTo(make([]byte, 5)); err != io.EOF { 1794 t.Fatalf("Client expected an EOF from the server. Got: %v", err) 1795 } 1796 } 1797 // Close down both server and client so that their internals can be read without data 1798 // races. 1799 client.Close(errors.New("closed manually by test")) 1800 st.Close(errors.New("closed manually by test")) 1801 <-st.readerDone 1802 <-st.loopyWriterDone 1803 <-client.readerDone 1804 <-client.writerDone 1805 for _, cstream := range clientStreams { 1806 id := cstream.id 1807 sstream := serverStreams[id] 1808 loopyServerStream := loopyServerStreams[id] 1809 loopyClientStream := loopyClientStreams[id] 1810 if loopyServerStream == nil { 1811 t.Fatalf("Unexpected nil loopyServerStream") 1812 } 1813 // Check stream flow control. 1814 if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { 1815 t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding) 1816 } 1817 if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(client.loopy.oiws)-loopyClientStream.bytesOutStanding { 1818 t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, client.loopy.oiws, loopyClientStream.bytesOutStanding) 1819 } 1820 } 1821 // Check transport flow control. 1822 if client.fc.limit != client.fc.unacked+st.loopy.sendQuota { 1823 t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", client.fc.limit, client.fc.unacked, st.loopy.sendQuota) 1824 } 1825 if st.fc.limit != st.fc.unacked+client.loopy.sendQuota { 1826 t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, client.loopy.sendQuota) 1827 } 1828 } 1829 1830 func waitWhileTrue(t *testing.T, condition func() (bool, error)) { 1831 var ( 1832 wait bool 1833 err error 1834 ) 1835 timer := time.NewTimer(time.Second * 5) 1836 for { 1837 wait, err = condition() 1838 if wait { 1839 select { 1840 case <-timer.C: 1841 t.Fatal(err) 1842 default: 1843 time.Sleep(50 * time.Millisecond) 1844 continue 1845 } 1846 } 1847 if !timer.Stop() { 1848 <-timer.C 1849 } 1850 break 1851 } 1852 } 1853 1854 // If any error occurs on a call to Stream.Read, future calls 1855 // should continue to return that same error. 1856 func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { 1857 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1858 defer cancel() 1859 testRecvBuffer := newRecvBuffer() 1860 s := &Stream{ 1861 ctx: ctx, 1862 buf: testRecvBuffer, 1863 requestRead: func(int) {}, 1864 } 1865 s.trReader = &transportReader{ 1866 reader: &recvBufferReader{ 1867 ctx: s.ctx, 1868 ctxDone: s.ctx.Done(), 1869 recv: s.buf, 1870 }, 1871 windowHandler: func(int) {}, 1872 } 1873 testData := make([]byte, 1) 1874 testData[0] = 5 1875 testErr := errors.New("test error") 1876 s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: testErr}) 1877 1878 inBuf := make([]byte, 1) 1879 actualCount, actualErr := s.readTo(inBuf) 1880 if actualCount != 0 { 1881 t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount) 1882 } 1883 if actualErr.Error() != testErr.Error() { 1884 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) 1885 } 1886 1887 s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: nil}) 1888 s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: errors.New("different error from first")}) 1889 1890 for i := 0; i < 2; i++ { 1891 inBuf := make([]byte, 1) 1892 actualCount, actualErr := s.readTo(inBuf) 1893 if actualCount != 0 { 1894 t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount) 1895 } 1896 if actualErr.Error() != testErr.Error() { 1897 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) 1898 } 1899 } 1900 } 1901 1902 // TestHeadersCausingStreamError tests headers that should cause a stream protocol 1903 // error, which would end up with a RST_STREAM being sent to the client and also 1904 // the server closing the stream. 1905 func (s) TestHeadersCausingStreamError(t *testing.T) { 1906 tests := []struct { 1907 name string 1908 headers []struct { 1909 name string 1910 values []string 1911 } 1912 }{ 1913 // "Transports must consider requests containing the Connection header 1914 // as malformed" - A41 Malformed requests map to a stream error of type 1915 // PROTOCOL_ERROR. 1916 { 1917 name: "Connection header present", 1918 headers: []struct { 1919 name string 1920 values []string 1921 }{ 1922 {name: ":method", values: []string{"POST"}}, 1923 {name: ":path", values: []string{"foo"}}, 1924 {name: ":authority", values: []string{"localhost"}}, 1925 {name: "content-type", values: []string{"application/grpc"}}, 1926 {name: "connection", values: []string{"not-supported"}}, 1927 }, 1928 }, 1929 // multiple :authority or multiple Host headers would make the eventual 1930 // :authority ambiguous as per A41. Since these headers won't have a 1931 // content-type that corresponds to a grpc-client, the server should 1932 // simply write a RST_STREAM to the wire. 1933 { 1934 // Note: multiple authority headers are handled by the framer 1935 // itself, which will cause a stream error. Thus, it will never get 1936 // to operateHeaders with the check in operateHeaders for stream 1937 // error, but the server transport will still send a stream error. 1938 name: "Multiple authority headers", 1939 headers: []struct { 1940 name string 1941 values []string 1942 }{ 1943 {name: ":method", values: []string{"POST"}}, 1944 {name: ":path", values: []string{"foo"}}, 1945 {name: ":authority", values: []string{"localhost", "localhost2"}}, 1946 {name: "host", values: []string{"localhost"}}, 1947 }, 1948 }, 1949 } 1950 for _, test := range tests { 1951 t.Run(test.name, func(t *testing.T) { 1952 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1953 defer server.stop() 1954 // Create a client directly to not tie what you can send to API of 1955 // http2_client.go (i.e. control headers being sent). 1956 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1957 if err != nil { 1958 t.Fatalf("Client failed to dial: %v", err) 1959 } 1960 defer mconn.Close() 1961 1962 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1963 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1964 } 1965 1966 framer := http2.NewFramer(mconn, mconn) 1967 if err := framer.WriteSettings(); err != nil { 1968 t.Fatalf("Error while writing settings: %v", err) 1969 } 1970 1971 // result chan indicates that reader received a RSTStream from server. 1972 // An error will be passed on it if any other frame is received. 1973 result := testutils.NewChannel() 1974 1975 // Launch a reader goroutine. 1976 go func() { 1977 for { 1978 frame, err := framer.ReadFrame() 1979 if err != nil { 1980 return 1981 } 1982 switch frame := frame.(type) { 1983 case *http2.SettingsFrame: 1984 // Do nothing. A settings frame is expected from server preface. 1985 case *http2.RSTStreamFrame: 1986 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol { 1987 // Client only created a single stream, so RST Stream should be for that single stream. 1988 result.Send(fmt.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))) 1989 } 1990 // Records that client successfully received RST Stream frame. 1991 result.Send(nil) 1992 return 1993 default: 1994 // The server should send nothing but a single RST Stream frame. 1995 result.Send(errors.New("the client received a frame other than RST Stream")) 1996 } 1997 } 1998 }() 1999 2000 var buf bytes.Buffer 2001 henc := hpack.NewEncoder(&buf) 2002 2003 // Needs to build headers deterministically to conform to gRPC over 2004 // HTTP/2 spec. 2005 for _, header := range test.headers { 2006 for _, value := range header.values { 2007 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { 2008 t.Fatalf("Error while encoding header: %v", err) 2009 } 2010 } 2011 } 2012 2013 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 2014 t.Fatalf("Error while writing headers: %v", err) 2015 } 2016 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2017 defer cancel() 2018 r, err := result.Receive(ctx) 2019 if err != nil { 2020 t.Fatalf("Error receiving from channel: %v", err) 2021 } 2022 if r != nil { 2023 t.Fatalf("want nil, got %v", r) 2024 } 2025 }) 2026 } 2027 } 2028 2029 // TestHeadersHTTPStatusGRPCStatus tests requests with certain headers get a 2030 // certain HTTP and gRPC status back. 2031 func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) { 2032 tests := []struct { 2033 name string 2034 headers []struct { 2035 name string 2036 values []string 2037 } 2038 httpStatusWant string 2039 grpcStatusWant string 2040 grpcMessageWant string 2041 }{ 2042 // Note: multiple authority headers are handled by the framer itself, 2043 // which will cause a stream error. Thus, it will never get to 2044 // operateHeaders with the check in operateHeaders for possible 2045 // grpc-status sent back. 2046 2047 // multiple :authority or multiple Host headers would make the eventual 2048 // :authority ambiguous as per A41. This takes precedence even over the 2049 // fact a request is non grpc. All of these requests should be rejected 2050 // with grpc-status Internal. Thus, requests with multiple hosts should 2051 // get rejected with HTTP Status 400 and gRPC status Internal, 2052 // regardless of whether the client is speaking gRPC or not. 2053 { 2054 name: "Multiple host headers non grpc", 2055 headers: []struct { 2056 name string 2057 values []string 2058 }{ 2059 {name: ":method", values: []string{"POST"}}, 2060 {name: ":path", values: []string{"foo"}}, 2061 {name: ":authority", values: []string{"localhost"}}, 2062 {name: "host", values: []string{"localhost", "localhost2"}}, 2063 }, 2064 httpStatusWant: "400", 2065 grpcStatusWant: "13", 2066 grpcMessageWant: "both must only have 1 value as per HTTP/2 spec", 2067 }, 2068 { 2069 name: "Multiple host headers grpc", 2070 headers: []struct { 2071 name string 2072 values []string 2073 }{ 2074 {name: ":method", values: []string{"POST"}}, 2075 {name: ":path", values: []string{"foo"}}, 2076 {name: ":authority", values: []string{"localhost"}}, 2077 {name: "content-type", values: []string{"application/grpc"}}, 2078 {name: "host", values: []string{"localhost", "localhost2"}}, 2079 }, 2080 httpStatusWant: "400", 2081 grpcStatusWant: "13", 2082 grpcMessageWant: "both must only have 1 value as per HTTP/2 spec", 2083 }, 2084 // If the client sends an HTTP/2 request with a :method header with a 2085 // value other than POST, as specified in the gRPC over HTTP/2 2086 // specification, the server should fail the RPC. 2087 { 2088 name: "Client Sending Wrong Method", 2089 headers: []struct { 2090 name string 2091 values []string 2092 }{ 2093 {name: ":method", values: []string{"PUT"}}, 2094 {name: ":path", values: []string{"foo"}}, 2095 {name: ":authority", values: []string{"localhost"}}, 2096 {name: "content-type", values: []string{"application/grpc"}}, 2097 }, 2098 httpStatusWant: "405", 2099 grpcStatusWant: "13", 2100 grpcMessageWant: "which should be POST", 2101 }, 2102 { 2103 name: "Client Sending Wrong Content-Type", 2104 headers: []struct { 2105 name string 2106 values []string 2107 }{ 2108 {name: ":method", values: []string{"POST"}}, 2109 {name: ":path", values: []string{"foo"}}, 2110 {name: ":authority", values: []string{"localhost"}}, 2111 {name: "content-type", values: []string{"application/json"}}, 2112 }, 2113 httpStatusWant: "415", 2114 grpcStatusWant: "3", 2115 grpcMessageWant: `invalid gRPC request content-type "application/json"`, 2116 }, 2117 { 2118 name: "Client Sending Bad Timeout", 2119 headers: []struct { 2120 name string 2121 values []string 2122 }{ 2123 {name: ":method", values: []string{"POST"}}, 2124 {name: ":path", values: []string{"foo"}}, 2125 {name: ":authority", values: []string{"localhost"}}, 2126 {name: "content-type", values: []string{"application/grpc"}}, 2127 {name: "grpc-timeout", values: []string{"18f6n"}}, 2128 }, 2129 httpStatusWant: "400", 2130 grpcStatusWant: "13", 2131 grpcMessageWant: "malformed grpc-timeout", 2132 }, 2133 { 2134 name: "Client Sending Bad Binary Header", 2135 headers: []struct { 2136 name string 2137 values []string 2138 }{ 2139 {name: ":method", values: []string{"POST"}}, 2140 {name: ":path", values: []string{"foo"}}, 2141 {name: ":authority", values: []string{"localhost"}}, 2142 {name: "content-type", values: []string{"application/grpc"}}, 2143 {name: "foobar-bin", values: []string{"X()3e@#$-"}}, 2144 }, 2145 httpStatusWant: "400", 2146 grpcStatusWant: "13", 2147 grpcMessageWant: `header "foobar-bin": illegal base64 data`, 2148 }, 2149 } 2150 for _, test := range tests { 2151 t.Run(test.name, func(t *testing.T) { 2152 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 2153 defer server.stop() 2154 // Create a client directly to not tie what you can send to API of 2155 // http2_client.go (i.e. control headers being sent). 2156 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 2157 if err != nil { 2158 t.Fatalf("Client failed to dial: %v", err) 2159 } 2160 defer mconn.Close() 2161 2162 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 2163 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 2164 } 2165 2166 framer := http2.NewFramer(mconn, mconn) 2167 framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil) 2168 if err := framer.WriteSettings(); err != nil { 2169 t.Fatalf("Error while writing settings: %v", err) 2170 } 2171 2172 // result chan indicates that reader received a Headers Frame with 2173 // desired grpc status and message from server. An error will be passed 2174 // on it if any other frame is received. 2175 result := testutils.NewChannel() 2176 2177 // Launch a reader goroutine. 2178 go func() { 2179 for { 2180 frame, err := framer.ReadFrame() 2181 if err != nil { 2182 return 2183 } 2184 switch frame := frame.(type) { 2185 case *http2.SettingsFrame: 2186 // Do nothing. A settings frame is expected from server preface. 2187 case *http2.MetaHeadersFrame: 2188 var httpStatus, grpcStatus, grpcMessage string 2189 for _, header := range frame.Fields { 2190 if header.Name == ":status" { 2191 httpStatus = header.Value 2192 } 2193 if header.Name == "grpc-status" { 2194 grpcStatus = header.Value 2195 } 2196 if header.Name == "grpc-message" { 2197 grpcMessage = header.Value 2198 } 2199 } 2200 if httpStatus != test.httpStatusWant { 2201 result.Send(fmt.Errorf("incorrect HTTP Status got %v, want %v", httpStatus, test.httpStatusWant)) 2202 return 2203 } 2204 if grpcStatus != test.grpcStatusWant { // grpc status code internal 2205 result.Send(fmt.Errorf("incorrect gRPC Status got %v, want %v", grpcStatus, test.grpcStatusWant)) 2206 return 2207 } 2208 if !strings.Contains(grpcMessage, test.grpcMessageWant) { 2209 result.Send(fmt.Errorf("incorrect gRPC message, want %q got %q", test.grpcMessageWant, grpcMessage)) 2210 return 2211 } 2212 2213 // Records that client successfully received a HeadersFrame 2214 // with expected Trailers-Only response. 2215 result.Send(nil) 2216 return 2217 default: 2218 // The server should send nothing but a single Settings and Headers frame. 2219 result.Send(errors.New("the client received a frame other than Settings or Headers")) 2220 } 2221 } 2222 }() 2223 2224 var buf bytes.Buffer 2225 henc := hpack.NewEncoder(&buf) 2226 2227 // Needs to build headers deterministically to conform to gRPC over 2228 // HTTP/2 spec. 2229 for _, header := range test.headers { 2230 for _, value := range header.values { 2231 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { 2232 t.Fatalf("Error while encoding header: %v", err) 2233 } 2234 } 2235 } 2236 2237 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 2238 t.Fatalf("Error while writing headers: %v", err) 2239 } 2240 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2241 defer cancel() 2242 r, err := result.Receive(ctx) 2243 if err != nil { 2244 t.Fatalf("Error receiving from channel: %v", err) 2245 } 2246 if r != nil { 2247 t.Fatalf("want nil, got %v", r) 2248 } 2249 }) 2250 } 2251 } 2252 2253 func (s) TestWriteHeaderConnectionError(t *testing.T) { 2254 server, client, cancel := setUp(t, 0, notifyCall) 2255 defer cancel() 2256 defer server.stop() 2257 2258 waitWhileTrue(t, func() (bool, error) { 2259 server.mu.Lock() 2260 defer server.mu.Unlock() 2261 2262 if len(server.conns) == 0 { 2263 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 2264 } 2265 return false, nil 2266 }) 2267 2268 server.mu.Lock() 2269 2270 if len(server.conns) != 1 { 2271 t.Fatalf("Server has %d connections from the client, want 1", len(server.conns)) 2272 } 2273 2274 // Get the server transport for the connection to the client. 2275 var serverTransport *http2Server 2276 for k := range server.conns { 2277 serverTransport = k.(*http2Server) 2278 } 2279 notifyChan := make(chan struct{}) 2280 server.h.notify = notifyChan 2281 server.mu.Unlock() 2282 2283 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2284 defer cancel() 2285 cstream, err := client.NewStream(ctx, &CallHdr{}) 2286 if err != nil { 2287 t.Fatalf("Client failed to create first stream. Err: %v", err) 2288 } 2289 2290 <-notifyChan // Wait for server stream to be established. 2291 var sstream *ServerStream 2292 // Access stream on the server. 2293 serverTransport.mu.Lock() 2294 for _, v := range serverTransport.activeStreams { 2295 if v.id == cstream.id { 2296 sstream = v 2297 } 2298 } 2299 serverTransport.mu.Unlock() 2300 if sstream == nil { 2301 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream.id) 2302 } 2303 2304 client.Close(fmt.Errorf("closed manually by test")) 2305 2306 // Wait for server transport to be closed. 2307 <-serverTransport.done 2308 2309 // Write header on a closed server transport. 2310 err = sstream.SendHeader(metadata.MD{}) 2311 st := status.Convert(err) 2312 if st.Code() != codes.Unavailable { 2313 t.Fatalf("WriteHeader() failed with status code %s, want %s", st.Code(), codes.Unavailable) 2314 } 2315 } 2316 2317 func (s) TestPingPong1B(t *testing.T) { 2318 runPingPongTest(t, 1) 2319 } 2320 2321 func (s) TestPingPong1KB(t *testing.T) { 2322 runPingPongTest(t, 1024) 2323 } 2324 2325 func (s) TestPingPong64KB(t *testing.T) { 2326 runPingPongTest(t, 65536) 2327 } 2328 2329 func (s) TestPingPong1MB(t *testing.T) { 2330 runPingPongTest(t, 1048576) 2331 } 2332 2333 // This is a stress-test of flow control logic. 2334 func runPingPongTest(t *testing.T, msgSize int) { 2335 server, client, cancel := setUp(t, 0, pingpong) 2336 defer cancel() 2337 defer server.stop() 2338 defer client.Close(fmt.Errorf("closed manually by test")) 2339 waitWhileTrue(t, func() (bool, error) { 2340 server.mu.Lock() 2341 defer server.mu.Unlock() 2342 if len(server.conns) == 0 { 2343 return true, fmt.Errorf("timed out while waiting for server transport to be created") 2344 } 2345 return false, nil 2346 }) 2347 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2348 defer cancel() 2349 stream, err := client.NewStream(ctx, &CallHdr{}) 2350 if err != nil { 2351 t.Fatalf("Failed to create stream. Err: %v", err) 2352 } 2353 msg := make([]byte, msgSize) 2354 outgoingHeader := make([]byte, 5) 2355 outgoingHeader[0] = byte(0) 2356 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize)) 2357 opts := &WriteOptions{} 2358 incomingHeader := make([]byte, 5) 2359 2360 ctx, cancel = context.WithTimeout(ctx, 10*time.Millisecond) 2361 defer cancel() 2362 for ctx.Err() == nil { 2363 if err := stream.Write(outgoingHeader, newBufferSlice(msg), opts); err != nil { 2364 t.Fatalf("Error on client while writing message. Err: %v", err) 2365 } 2366 if _, err := stream.readTo(incomingHeader); err != nil { 2367 t.Fatalf("Error on client while reading data header. Err: %v", err) 2368 } 2369 sz := binary.BigEndian.Uint32(incomingHeader[1:]) 2370 recvMsg := make([]byte, int(sz)) 2371 if _, err := stream.readTo(recvMsg); err != nil { 2372 t.Fatalf("Error on client while reading data. Err: %v", err) 2373 } 2374 } 2375 2376 stream.Write(nil, nil, &WriteOptions{Last: true}) 2377 if _, err := stream.readTo(incomingHeader); err != io.EOF { 2378 t.Fatalf("Client expected EOF from the server. Got: %v", err) 2379 } 2380 } 2381 2382 type tableSizeLimit struct { 2383 mu sync.Mutex 2384 limits []uint32 2385 } 2386 2387 func (t *tableSizeLimit) add(limit uint32) { 2388 t.mu.Lock() 2389 t.limits = append(t.limits, limit) 2390 t.mu.Unlock() 2391 } 2392 2393 func (t *tableSizeLimit) getLen() int { 2394 t.mu.Lock() 2395 defer t.mu.Unlock() 2396 return len(t.limits) 2397 } 2398 2399 func (t *tableSizeLimit) getIndex(i int) uint32 { 2400 t.mu.Lock() 2401 defer t.mu.Unlock() 2402 return t.limits[i] 2403 } 2404 2405 func (s) TestHeaderTblSize(t *testing.T) { 2406 limits := &tableSizeLimit{} 2407 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { 2408 e.SetMaxDynamicTableSizeLimit(v) 2409 limits.add(v) 2410 } 2411 defer func() { 2412 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { 2413 e.SetMaxDynamicTableSizeLimit(v) 2414 } 2415 }() 2416 2417 server, ct, cancel := setUp(t, 0, normal) 2418 defer cancel() 2419 defer ct.Close(fmt.Errorf("closed manually by test")) 2420 defer server.stop() 2421 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2422 defer ctxCancel() 2423 _, err := ct.NewStream(ctx, &CallHdr{}) 2424 if err != nil { 2425 t.Fatalf("failed to open stream: %v", err) 2426 } 2427 2428 var svrTransport ServerTransport 2429 var i int 2430 for i = 0; i < 1000; i++ { 2431 server.mu.Lock() 2432 if len(server.conns) != 0 { 2433 server.mu.Unlock() 2434 break 2435 } 2436 server.mu.Unlock() 2437 time.Sleep(10 * time.Millisecond) 2438 continue 2439 } 2440 if i == 1000 { 2441 t.Fatalf("unable to create any server transport after 10s") 2442 } 2443 2444 for st := range server.conns { 2445 svrTransport = st 2446 break 2447 } 2448 svrTransport.(*http2Server).controlBuf.put(&outgoingSettings{ 2449 ss: []http2.Setting{ 2450 { 2451 ID: http2.SettingHeaderTableSize, 2452 Val: uint32(100), 2453 }, 2454 }, 2455 }) 2456 2457 for i = 0; i < 1000; i++ { 2458 if limits.getLen() != 1 { 2459 time.Sleep(10 * time.Millisecond) 2460 continue 2461 } 2462 if val := limits.getIndex(0); val != uint32(100) { 2463 t.Fatalf("expected limits[0] = 100, got %d", val) 2464 } 2465 break 2466 } 2467 if i == 1000 { 2468 t.Fatalf("expected len(limits) = 1 within 10s, got != 1") 2469 } 2470 2471 ct.controlBuf.put(&outgoingSettings{ 2472 ss: []http2.Setting{ 2473 { 2474 ID: http2.SettingHeaderTableSize, 2475 Val: uint32(200), 2476 }, 2477 }, 2478 }) 2479 2480 for i := 0; i < 1000; i++ { 2481 if limits.getLen() != 2 { 2482 time.Sleep(10 * time.Millisecond) 2483 continue 2484 } 2485 if val := limits.getIndex(1); val != uint32(200) { 2486 t.Fatalf("expected limits[1] = 200, got %d", val) 2487 } 2488 break 2489 } 2490 if i == 1000 { 2491 t.Fatalf("expected len(limits) = 2 within 10s, got != 2") 2492 } 2493 } 2494 2495 // attrTransportCreds is a transport credential implementation which stores 2496 // Attributes from the ClientHandshakeInfo struct passed in the context locally 2497 // for the test to inspect. 2498 type attrTransportCreds struct { 2499 credentials.TransportCredentials 2500 attr *attributes.Attributes 2501 } 2502 2503 func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 2504 ai := credentials.ClientHandshakeInfoFromContext(ctx) 2505 ac.attr = ai.Attributes 2506 return rawConn, nil, nil 2507 } 2508 func (ac *attrTransportCreds) Info() credentials.ProtocolInfo { 2509 return credentials.ProtocolInfo{} 2510 } 2511 func (ac *attrTransportCreds) Clone() credentials.TransportCredentials { 2512 return nil 2513 } 2514 2515 // TestClientHandshakeInfo adds attributes to the resolver.Address passes to 2516 // NewHTTP2Client and verifies that these attributes are received by the 2517 // transport credential handshaker. 2518 func (s) TestClientHandshakeInfo(t *testing.T) { 2519 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong) 2520 defer server.stop() 2521 2522 const ( 2523 testAttrKey = "foo" 2524 testAttrVal = "bar" 2525 ) 2526 addr := resolver.Address{ 2527 Addr: "localhost:" + server.port, 2528 Attributes: attributes.New(testAttrKey, testAttrVal), 2529 } 2530 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 2531 defer cancel() 2532 creds := &attrTransportCreds{} 2533 2534 copts := ConnectOptions{ 2535 TransportCredentials: creds, 2536 ChannelzParent: channelzSubChannel(t), 2537 } 2538 tr, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {}) 2539 if err != nil { 2540 t.Fatalf("NewHTTP2Client(): %v", err) 2541 } 2542 defer tr.Close(fmt.Errorf("closed manually by test")) 2543 2544 wantAttr := attributes.New(testAttrKey, testAttrVal) 2545 if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 2546 t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr) 2547 } 2548 } 2549 2550 // TestClientHandshakeInfoDialer adds attributes to the resolver.Address passes to 2551 // NewHTTP2Client and verifies that these attributes are received by a custom 2552 // dialer. 2553 func (s) TestClientHandshakeInfoDialer(t *testing.T) { 2554 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong) 2555 defer server.stop() 2556 2557 const ( 2558 testAttrKey = "foo" 2559 testAttrVal = "bar" 2560 ) 2561 addr := resolver.Address{ 2562 Addr: "localhost:" + server.port, 2563 Attributes: attributes.New(testAttrKey, testAttrVal), 2564 } 2565 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 2566 defer cancel() 2567 2568 var attr *attributes.Attributes 2569 dialer := func(ctx context.Context, addr string) (net.Conn, error) { 2570 ai := credentials.ClientHandshakeInfoFromContext(ctx) 2571 attr = ai.Attributes 2572 return (&net.Dialer{}).DialContext(ctx, "tcp", addr) 2573 } 2574 2575 copts := ConnectOptions{ 2576 Dialer: dialer, 2577 ChannelzParent: channelzSubChannel(t), 2578 } 2579 tr, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {}) 2580 if err != nil { 2581 t.Fatalf("NewHTTP2Client(): %v", err) 2582 } 2583 defer tr.Close(fmt.Errorf("closed manually by test")) 2584 2585 wantAttr := attributes.New(testAttrKey, testAttrVal) 2586 if gotAttr := attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 2587 t.Errorf("Received attributes %v in custom dialer, want %v", gotAttr, wantAttr) 2588 } 2589 } 2590 2591 func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { 2592 testStream := func() *ClientStream { 2593 return &ClientStream{ 2594 Stream: &Stream{ 2595 buf: &recvBuffer{ 2596 c: make(chan recvMsg), 2597 mu: sync.Mutex{}, 2598 }, 2599 }, 2600 done: make(chan struct{}), 2601 headerChan: make(chan struct{}), 2602 } 2603 } 2604 2605 testClient := func(ts *ClientStream) *http2Client { 2606 return &http2Client{ 2607 mu: sync.Mutex{}, 2608 activeStreams: map[uint32]*ClientStream{ 2609 0: ts, 2610 }, 2611 controlBuf: newControlBuffer(make(<-chan struct{})), 2612 } 2613 } 2614 2615 for _, test := range []struct { 2616 name string 2617 // input 2618 metaHeaderFrame *http2.MetaHeadersFrame 2619 // output 2620 wantStatus *status.Status 2621 }{ 2622 { 2623 name: "valid header", 2624 metaHeaderFrame: &http2.MetaHeadersFrame{ 2625 Fields: []hpack.HeaderField{ 2626 {Name: "content-type", Value: "application/grpc"}, 2627 {Name: "grpc-status", Value: "0"}, 2628 {Name: ":status", Value: "200"}, 2629 }, 2630 }, 2631 // no error 2632 wantStatus: status.New(codes.OK, ""), 2633 }, 2634 { 2635 name: "missing content-type header", 2636 metaHeaderFrame: &http2.MetaHeadersFrame{ 2637 Fields: []hpack.HeaderField{ 2638 {Name: "grpc-status", Value: "0"}, 2639 {Name: ":status", Value: "200"}, 2640 }, 2641 }, 2642 wantStatus: status.New( 2643 codes.Unknown, 2644 "malformed header: missing HTTP content-type", 2645 ), 2646 }, 2647 { 2648 name: "invalid grpc status header field", 2649 metaHeaderFrame: &http2.MetaHeadersFrame{ 2650 Fields: []hpack.HeaderField{ 2651 {Name: "content-type", Value: "application/grpc"}, 2652 {Name: "grpc-status", Value: "xxxx"}, 2653 {Name: ":status", Value: "200"}, 2654 }, 2655 }, 2656 wantStatus: status.New( 2657 codes.Internal, 2658 "transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", 2659 ), 2660 }, 2661 { 2662 name: "invalid http content type", 2663 metaHeaderFrame: &http2.MetaHeadersFrame{ 2664 Fields: []hpack.HeaderField{ 2665 {Name: "content-type", Value: "application/json"}, 2666 }, 2667 }, 2668 wantStatus: status.New( 2669 codes.Internal, 2670 "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"", 2671 ), 2672 }, 2673 { 2674 name: "http fallback and invalid http status", 2675 metaHeaderFrame: &http2.MetaHeadersFrame{ 2676 Fields: []hpack.HeaderField{ 2677 // No content type provided then fallback into handling http error. 2678 {Name: ":status", Value: "xxxx"}, 2679 }, 2680 }, 2681 wantStatus: status.New( 2682 codes.Internal, 2683 "transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", 2684 ), 2685 }, 2686 { 2687 name: "http2 frame size exceeds", 2688 metaHeaderFrame: &http2.MetaHeadersFrame{ 2689 Fields: nil, 2690 Truncated: true, 2691 }, 2692 wantStatus: status.New( 2693 codes.Internal, 2694 "peer header list size exceeded limit", 2695 ), 2696 }, 2697 { 2698 name: "bad status in grpc mode", 2699 metaHeaderFrame: &http2.MetaHeadersFrame{ 2700 Fields: []hpack.HeaderField{ 2701 {Name: "content-type", Value: "application/grpc"}, 2702 {Name: "grpc-status", Value: "0"}, 2703 {Name: ":status", Value: "504"}, 2704 }, 2705 }, 2706 wantStatus: status.New( 2707 codes.Unavailable, 2708 "unexpected HTTP status code received from server: 504 (Gateway Timeout)", 2709 ), 2710 }, 2711 { 2712 name: "missing http status", 2713 metaHeaderFrame: &http2.MetaHeadersFrame{ 2714 Fields: []hpack.HeaderField{ 2715 {Name: "content-type", Value: "application/grpc"}, 2716 }, 2717 }, 2718 wantStatus: status.New( 2719 codes.Internal, 2720 "malformed header: missing HTTP status", 2721 ), 2722 }, 2723 } { 2724 2725 t.Run(test.name, func(t *testing.T) { 2726 ts := testStream() 2727 s := testClient(ts) 2728 2729 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ 2730 FrameHeader: http2.FrameHeader{ 2731 StreamID: 0, 2732 }, 2733 } 2734 2735 s.operateHeaders(test.metaHeaderFrame) 2736 2737 got := ts.status 2738 want := test.wantStatus 2739 if got.Code() != want.Code() || got.Message() != want.Message() { 2740 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) 2741 } 2742 }) 2743 t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) { 2744 ts := testStream() 2745 s := testClient(ts) 2746 2747 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ 2748 FrameHeader: http2.FrameHeader{ 2749 StreamID: 0, 2750 Flags: http2.FlagHeadersEndStream, 2751 }, 2752 } 2753 2754 s.operateHeaders(test.metaHeaderFrame) 2755 2756 got := ts.status 2757 want := test.wantStatus 2758 if got.Code() != want.Code() || got.Message() != want.Message() { 2759 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) 2760 } 2761 }) 2762 } 2763 } 2764 2765 func TestConnectionError_Unwrap(t *testing.T) { 2766 err := connectionErrorf(false, os.ErrNotExist, "unwrap me") 2767 if !errors.Is(err, os.ErrNotExist) { 2768 t.Error("ConnectionError does not unwrap") 2769 } 2770 } 2771 2772 // Test that in the event of a graceful client transport shutdown, i.e., 2773 // clientTransport.Close(), client sends a goaway to the server with the correct 2774 // error code and debug data. 2775 func (s) TestClientSendsAGoAwayFrame(t *testing.T) { 2776 // Create a server. 2777 lis, err := net.Listen("tcp", "localhost:0") 2778 if err != nil { 2779 t.Fatalf("Error while listening: %v", err) 2780 } 2781 defer lis.Close() 2782 // greetDone is used to notify when server is done greeting the client. 2783 greetDone := make(chan struct{}) 2784 // errorCh verifies that desired GOAWAY not received by server 2785 errorCh := make(chan error) 2786 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2787 defer cancel() 2788 // Launch the server. 2789 go func() { 2790 sconn, err := lis.Accept() 2791 if err != nil { 2792 t.Errorf("Error while accepting: %v", err) 2793 } 2794 defer sconn.Close() 2795 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { 2796 t.Errorf("Error while writing settings ack: %v", err) 2797 return 2798 } 2799 sfr := http2.NewFramer(sconn, sconn) 2800 if err := sfr.WriteSettings(); err != nil { 2801 t.Errorf("Error while writing settings %v", err) 2802 return 2803 } 2804 fr, _ := sfr.ReadFrame() 2805 if _, ok := fr.(*http2.SettingsFrame); !ok { 2806 t.Errorf("Expected settings frame, got %v", fr) 2807 } 2808 fr, _ = sfr.ReadFrame() 2809 if fr, ok := fr.(*http2.SettingsFrame); !ok || !fr.IsAck() { 2810 t.Errorf("Expected settings ACK frame, got %v", fr) 2811 } 2812 fr, _ = sfr.ReadFrame() 2813 if fr, ok := fr.(*http2.HeadersFrame); !ok || !fr.Flags.Has(http2.FlagHeadersEndHeaders) { 2814 t.Errorf("Expected Headers frame with END_HEADERS frame, got %v", fr) 2815 } 2816 close(greetDone) 2817 2818 frame, err := sfr.ReadFrame() 2819 if err != nil { 2820 return 2821 } 2822 switch fr := frame.(type) { 2823 case *http2.GoAwayFrame: 2824 // Records that the server successfully received a GOAWAY frame. 2825 goAwayFrame := fr 2826 if goAwayFrame.ErrCode == http2.ErrCodeNo { 2827 t.Logf("Received goAway frame from client") 2828 close(errorCh) 2829 } else { 2830 errorCh <- fmt.Errorf("received unexpected goAway frame: %v", err) 2831 close(errorCh) 2832 } 2833 return 2834 default: 2835 errorCh <- fmt.Errorf("server received a frame other than GOAWAY: %v", err) 2836 close(errorCh) 2837 return 2838 } 2839 }() 2840 2841 ct, err := NewHTTP2Client(ctx, ctx, resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {}) 2842 if err != nil { 2843 t.Fatalf("Error while creating client transport: %v", err) 2844 } 2845 _, err = ct.NewStream(ctx, &CallHdr{}) 2846 if err != nil { 2847 t.Fatalf("failed to open stream: %v", err) 2848 } 2849 // Wait until server receives the headers and settings frame as part of greet. 2850 <-greetDone 2851 ct.Close(errors.New("manually closed by client")) 2852 t.Logf("Closed the client connection") 2853 select { 2854 case err := <-errorCh: 2855 if err != nil { 2856 t.Errorf("Error receiving the GOAWAY frame: %v", err) 2857 } 2858 case <-ctx.Done(): 2859 t.Errorf("Context timed out") 2860 } 2861 } 2862 2863 // readHangingConn is a wrapper around net.Conn that makes the Read() hang when 2864 // Close() is called. 2865 type readHangingConn struct { 2866 net.Conn 2867 readHangConn chan struct{} // Read() hangs until this channel is closed by Close(). 2868 closed *atomic.Bool // Set to true when Close() is called. 2869 } 2870 2871 func (hc *readHangingConn) Read(b []byte) (n int, err error) { 2872 n, err = hc.Conn.Read(b) 2873 if hc.closed.Load() { 2874 <-hc.readHangConn // hang the read till we want 2875 } 2876 return n, err 2877 } 2878 2879 func (hc *readHangingConn) Close() error { 2880 hc.closed.Store(true) 2881 return hc.Conn.Close() 2882 } 2883 2884 // Tests that closing a client transport does not return until the reader 2885 // goroutine exits. 2886 func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) { 2887 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2888 defer cancel() 2889 2890 server := setUpServerOnly(t, 0, &ServerConfig{}, normal) 2891 defer server.stop() 2892 addr := resolver.Address{Addr: "localhost:" + server.port} 2893 2894 isReaderHanging := &atomic.Bool{} 2895 readHangConn := make(chan struct{}) 2896 copts := ConnectOptions{ 2897 Dialer: func(_ context.Context, addr string) (net.Conn, error) { 2898 conn, err := net.Dial("tcp", addr) 2899 if err != nil { 2900 return nil, err 2901 } 2902 return &readHangingConn{Conn: conn, readHangConn: readHangConn, closed: isReaderHanging}, nil 2903 }, 2904 ChannelzParent: channelzSubChannel(t), 2905 } 2906 2907 // Create a client transport with a custom dialer that hangs the Read() 2908 // after Close(). 2909 ct, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {}) 2910 if err != nil { 2911 t.Fatalf("Failed to create transport: %v", err) 2912 } 2913 2914 if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { 2915 t.Fatalf("Failed to open stream: %v", err) 2916 } 2917 2918 // Closing the client transport will result in the underlying net.Conn being 2919 // closed, which will result in readHangingConn.Read() to hang. This will 2920 // stall the exit of the reader goroutine, and will stall client 2921 // transport's Close from returning. 2922 transportClosed := make(chan struct{}) 2923 go func() { 2924 ct.Close(errors.New("manually closed by client")) 2925 close(transportClosed) 2926 }() 2927 2928 // Wait for a short duration and ensure that the client transport's Close() 2929 // does not return. 2930 select { 2931 case <-transportClosed: 2932 t.Fatal("Transport closed before reader completed") 2933 case <-time.After(defaultTestShortTimeout): 2934 } 2935 2936 // Closing the channel will unblock the reader goroutine and will ensure 2937 // that the client transport's Close() returns. 2938 close(readHangConn) 2939 select { 2940 case <-transportClosed: 2941 case <-time.After(defaultTestTimeout): 2942 t.Fatal("Timeout when waiting for transport to close") 2943 } 2944 } 2945 2946 // hangingConn is a net.Conn wrapper for testing, simulating hanging connections 2947 // after a GOAWAY frame is sent, of which Write operations pause until explicitly 2948 // signaled or a timeout occurs. 2949 type hangingConn struct { 2950 net.Conn 2951 hangConn chan struct{} 2952 startHanging *atomic.Bool 2953 } 2954 2955 func (hc *hangingConn) Write(b []byte) (n int, err error) { 2956 n, err = hc.Conn.Write(b) 2957 if hc.startHanging.Load() { 2958 <-hc.hangConn 2959 } 2960 return n, err 2961 } 2962 2963 // Tests the scenario where a client transport is closed and writing of the 2964 // GOAWAY frame as part of the close does not complete because of a network 2965 // hang. The test verifies that the client transport is closed without waiting 2966 // for too long. 2967 func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) { 2968 // Override timer for writing GOAWAY to 0 so that the connection write 2969 // always times out. It is equivalent of real network hang when conn 2970 // write for goaway doesn't finish in specified deadline 2971 origGoAwayLoopyTimeout := goAwayLoopyWriterTimeout 2972 goAwayLoopyWriterTimeout = time.Millisecond 2973 defer func() { 2974 goAwayLoopyWriterTimeout = origGoAwayLoopyTimeout 2975 }() 2976 2977 // Create the server set up. 2978 connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2979 defer cancel() 2980 server := setUpServerOnly(t, 0, &ServerConfig{}, normal) 2981 defer server.stop() 2982 addr := resolver.Address{Addr: "localhost:" + server.port} 2983 isGreetingDone := &atomic.Bool{} 2984 hangConn := make(chan struct{}) 2985 defer close(hangConn) 2986 dialer := func(_ context.Context, addr string) (net.Conn, error) { 2987 conn, err := net.Dial("tcp", addr) 2988 if err != nil { 2989 return nil, err 2990 } 2991 return &hangingConn{Conn: conn, hangConn: hangConn, startHanging: isGreetingDone}, nil 2992 } 2993 copts := ConnectOptions{Dialer: dialer} 2994 copts.ChannelzParent = channelzSubChannel(t) 2995 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2996 defer cancel() 2997 // Create client transport with custom dialer 2998 ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {}) 2999 if connErr != nil { 3000 t.Fatalf("failed to create transport: %v", connErr) 3001 } 3002 3003 if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { 3004 t.Fatalf("Failed to open stream: %v", err) 3005 } 3006 3007 isGreetingDone.Store(true) 3008 ct.Close(errors.New("manually closed by client")) 3009 } 3010 3011 // TestReadHeaderMultipleBuffers tests the stream when the gRPC headers are 3012 // split across multiple buffers. It verifies that the reporting of the 3013 // number of bytes read for flow control is correct. 3014 func (s) TestReadMessageHeaderMultipleBuffers(t *testing.T) { 3015 headerLen := 5 3016 recvBuffer := newRecvBuffer() 3017 recvBuffer.put(recvMsg{buffer: make(mem.SliceBuffer, 3)}) 3018 recvBuffer.put(recvMsg{buffer: make(mem.SliceBuffer, headerLen-3)}) 3019 bytesRead := 0 3020 s := Stream{ 3021 requestRead: func(int) {}, 3022 trReader: &transportReader{ 3023 reader: &recvBufferReader{ 3024 recv: recvBuffer, 3025 }, 3026 windowHandler: func(i int) { 3027 bytesRead += i 3028 }, 3029 }, 3030 } 3031 3032 header := make([]byte, headerLen) 3033 err := s.ReadMessageHeader(header) 3034 if err != nil { 3035 t.Fatalf("ReadHeader(%v) = %v", header, err) 3036 } 3037 if bytesRead != headerLen { 3038 t.Errorf("bytesRead = %d, want = %d", bytesRead, headerLen) 3039 } 3040 } 3041 3042 // Tests a scenario when the client doesn't send an RST frame when the 3043 // configured deadline is reached. The test verifies that the server sends an 3044 // RST stream only after the deadline is reached. 3045 func (s) TestServerSendsRSTAfterDeadlineToMisbehavedClient(t *testing.T) { 3046 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 3047 defer server.stop() 3048 // Create a client that can override server stream quota. 3049 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 3050 if err != nil { 3051 t.Fatalf("Clent failed to dial:%v", err) 3052 } 3053 defer mconn.Close() 3054 if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil { 3055 t.Fatalf("Failed to set write deadline: %v", err) 3056 } 3057 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 3058 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 3059 } 3060 // rstTimeChan chan indicates that reader received a RSTStream from server. 3061 rstTimeChan := make(chan time.Time, 1) 3062 var mu sync.Mutex 3063 framer := http2.NewFramer(mconn, mconn) 3064 if err := framer.WriteSettings(); err != nil { 3065 t.Fatalf("Error while writing settings: %v", err) 3066 } 3067 go func() { // Launch a reader for this misbehaving client. 3068 for { 3069 frame, err := framer.ReadFrame() 3070 if err != nil { 3071 return 3072 } 3073 switch frame := frame.(type) { 3074 case *http2.PingFrame: 3075 // Write ping ack back so that server's BDP estimation works right. 3076 mu.Lock() 3077 framer.WritePing(true, frame.Data) 3078 mu.Unlock() 3079 case *http2.RSTStreamFrame: 3080 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeCancel { 3081 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeCancel", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) 3082 } 3083 rstTimeChan <- time.Now() 3084 return 3085 default: 3086 // Do nothing. 3087 } 3088 } 3089 }() 3090 // Create a stream. 3091 var buf bytes.Buffer 3092 henc := hpack.NewEncoder(&buf) 3093 if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil { 3094 t.Fatalf("Error while encoding header: %v", err) 3095 } 3096 if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { 3097 t.Fatalf("Error while encoding header: %v", err) 3098 } 3099 if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { 3100 t.Fatalf("Error while encoding header: %v", err) 3101 } 3102 if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { 3103 t.Fatalf("Error while encoding header: %v", err) 3104 } 3105 if err := henc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: "10m"}); err != nil { 3106 t.Fatalf("Error while encoding header: %v", err) 3107 } 3108 mu.Lock() 3109 startTime := time.Now() 3110 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 3111 mu.Unlock() 3112 t.Fatalf("Error while writing headers: %v", err) 3113 } 3114 mu.Unlock() 3115 3116 // Test server behavior for deadline expiration. 3117 var rstTime time.Time 3118 select { 3119 case <-time.After(5 * time.Second): 3120 t.Fatalf("Test timed out.") 3121 case rstTime = <-rstTimeChan: 3122 } 3123 3124 if got, want := rstTime.Sub(startTime), 10*time.Millisecond; got < want { 3125 t.Fatalf("RST frame received earlier than expected by duration: %v", want-got) 3126 } 3127 }