gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/internal/transport/transport_test.go (about) 1 /* 2 * 3 * Copyright 2014 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package transport 20 21 import ( 22 "bytes" 23 "context" 24 "encoding/binary" 25 "errors" 26 "fmt" 27 "io" 28 "math" 29 "net" 30 "runtime" 31 "strconv" 32 "strings" 33 "sync" 34 "testing" 35 "time" 36 37 "gitee.com/ks-custle/core-gm/grpc/attributes" 38 "gitee.com/ks-custle/core-gm/grpc/codes" 39 "gitee.com/ks-custle/core-gm/grpc/credentials" 40 "gitee.com/ks-custle/core-gm/grpc/internal/grpctest" 41 "gitee.com/ks-custle/core-gm/grpc/internal/leakcheck" 42 "gitee.com/ks-custle/core-gm/grpc/internal/testutils" 43 "gitee.com/ks-custle/core-gm/grpc/resolver" 44 "gitee.com/ks-custle/core-gm/grpc/status" 45 "gitee.com/ks-custle/core-gm/net/http2" 46 "gitee.com/ks-custle/core-gm/net/http2/hpack" 47 "github.com/google/go-cmp/cmp" 48 ) 49 50 type s struct { 51 grpctest.Tester 52 } 53 54 func Test(t *testing.T) { 55 grpctest.RunSubTests(t, s{}) 56 } 57 58 type server struct { 59 lis net.Listener 60 port string 61 startedErr chan error // error (or nil) with server start value 62 mu sync.Mutex 63 conns map[ServerTransport]bool 64 h *testStreamHandler 65 ready chan struct{} 66 } 67 68 var ( 69 expectedRequest = []byte("ping") 70 expectedResponse = []byte("pong") 71 expectedRequestLarge = make([]byte, initialWindowSize*2) 72 expectedResponseLarge = make([]byte, initialWindowSize*2) 73 expectedInvalidHeaderField = "invalid/content-type" 74 ) 75 76 func init() { 77 expectedRequestLarge[0] = 'g' 78 expectedRequestLarge[len(expectedRequestLarge)-1] = 'r' 79 expectedResponseLarge[0] = 'p' 80 expectedResponseLarge[len(expectedResponseLarge)-1] = 'c' 81 } 82 83 type testStreamHandler struct { 84 t *http2Server 85 notify chan struct{} 86 getNotified chan struct{} 87 } 88 89 type hType int 90 91 const ( 92 normal hType = iota 93 suspended 94 notifyCall 95 misbehaved 96 encodingRequiredStatus 97 invalidHeaderField 98 delayRead 99 pingpong 100 ) 101 102 func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { 103 if h.notify == nil { 104 return 105 } 106 go func() { 107 select { 108 case <-h.notify: 109 default: 110 close(h.notify) 111 } 112 }() 113 } 114 115 func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { 116 req := expectedRequest 117 resp := expectedResponse 118 if s.Method() == "foo.Large" { 119 req = expectedRequestLarge 120 resp = expectedResponseLarge 121 } 122 p := make([]byte, len(req)) 123 _, err := s.Read(p) 124 if err != nil { 125 return 126 } 127 if !bytes.Equal(p, req) { 128 t.Errorf("handleStream got %v, want %v", p, req) 129 h.t.WriteStatus(s, status.New(codes.Internal, "panic")) 130 return 131 } 132 // send a response back to the client. 133 h.t.Write(s, nil, resp, &Options{}) 134 // send the trailer to end the stream. 135 h.t.WriteStatus(s, status.New(codes.OK, "")) 136 } 137 138 func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { 139 header := make([]byte, 5) 140 for { 141 if _, err := s.Read(header); err != nil { 142 if err == io.EOF { 143 h.t.WriteStatus(s, status.New(codes.OK, "")) 144 return 145 } 146 t.Errorf("Error on server while reading data header: %v", err) 147 h.t.WriteStatus(s, status.New(codes.Internal, "panic")) 148 return 149 } 150 sz := binary.BigEndian.Uint32(header[1:]) 151 msg := make([]byte, int(sz)) 152 if _, err := s.Read(msg); err != nil { 153 t.Errorf("Error on server while reading message: %v", err) 154 h.t.WriteStatus(s, status.New(codes.Internal, "panic")) 155 return 156 } 157 buf := make([]byte, sz+5) 158 buf[0] = byte(0) 159 binary.BigEndian.PutUint32(buf[1:], uint32(sz)) 160 copy(buf[5:], msg) 161 h.t.Write(s, nil, buf, &Options{}) 162 } 163 } 164 165 func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { 166 conn, ok := s.st.(*http2Server) 167 if !ok { 168 t.Errorf("Failed to convert %v to *http2Server", s.st) 169 h.t.WriteStatus(s, status.New(codes.Internal, "")) 170 return 171 } 172 var sent int 173 p := make([]byte, http2MaxFrameLen) 174 for sent < initialWindowSize { 175 n := initialWindowSize - sent 176 // The last message may be smaller than http2MaxFrameLen 177 if n <= http2MaxFrameLen { 178 if s.Method() == "foo.Connection" { 179 // Violate connection level flow control window of client but do not 180 // violate any stream level windows. 181 p = make([]byte, n) 182 } else { 183 // Violate stream level flow control window of client. 184 p = make([]byte, n+1) 185 } 186 } 187 conn.controlBuf.put(&dataFrame{ 188 streamID: s.id, 189 h: nil, 190 d: p, 191 onEachWrite: func() {}, 192 }) 193 sent += len(p) 194 } 195 } 196 197 func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *Stream) { 198 // raw newline is not accepted by http2 framer so it must be encoded. 199 h.t.WriteStatus(s, encodingTestStatus) 200 } 201 202 func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) { 203 headerFields := []hpack.HeaderField{} 204 headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) 205 h.t.controlBuf.put(&headerFrame{ 206 streamID: s.id, 207 hf: headerFields, 208 endStream: false, 209 }) 210 } 211 212 // handleStreamDelayRead delays reads so that the other side has to halt on 213 // stream-level flow control. 214 // This handler assumes dynamic flow control is turned off and assumes window 215 // sizes to be set to defaultWindowSize. 216 func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { 217 req := expectedRequest 218 resp := expectedResponse 219 if s.Method() == "foo.Large" { 220 req = expectedRequestLarge 221 resp = expectedResponseLarge 222 } 223 var ( 224 mu sync.Mutex 225 total int 226 ) 227 s.wq.replenish = func(n int) { 228 mu.Lock() 229 total += n 230 mu.Unlock() 231 s.wq.realReplenish(n) 232 } 233 getTotal := func() int { 234 mu.Lock() 235 defer mu.Unlock() 236 return total 237 } 238 done := make(chan struct{}) 239 defer close(done) 240 go func() { 241 for { 242 select { 243 // Prevent goroutine from leaking. 244 case <-done: 245 return 246 default: 247 } 248 if getTotal() == defaultWindowSize { 249 // Signal the client to start reading and 250 // thereby send window update. 251 close(h.notify) 252 return 253 } 254 runtime.Gosched() 255 } 256 }() 257 p := make([]byte, len(req)) 258 259 // Let the other side run out of stream-level window before 260 // starting to read and thereby sending a window update. 261 timer := time.NewTimer(time.Second * 10) 262 select { 263 case <-h.getNotified: 264 timer.Stop() 265 case <-timer.C: 266 t.Errorf("Server timed-out.") 267 return 268 } 269 _, err := s.Read(p) 270 if err != nil { 271 t.Errorf("s.Read(_) = _, %v, want _, <nil>", err) 272 return 273 } 274 275 if !bytes.Equal(p, req) { 276 t.Errorf("handleStream got %v, want %v", p, req) 277 return 278 } 279 // This write will cause server to run out of stream level, 280 // flow control and the other side won't send a window update 281 // until that happens. 282 if err := h.t.Write(s, nil, resp, &Options{}); err != nil { 283 t.Errorf("server Write got %v, want <nil>", err) 284 return 285 } 286 // Read one more time to ensure that everything remains fine and 287 // that the goroutine, that we launched earlier to signal client 288 // to read, gets enough time to process. 289 _, err = s.Read(p) 290 if err != nil { 291 t.Errorf("s.Read(_) = _, %v, want _, nil", err) 292 return 293 } 294 // send the trailer to end the stream. 295 if err := h.t.WriteStatus(s, status.New(codes.OK, "")); err != nil { 296 t.Errorf("server WriteStatus got %v, want <nil>", err) 297 return 298 } 299 } 300 301 // start starts server. Other goroutines should block on s.readyChan for further operations. 302 func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) { 303 var err error 304 if port == 0 { 305 s.lis, err = net.Listen("tcp", "localhost:0") 306 } else { 307 s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port)) 308 } 309 if err != nil { 310 s.startedErr <- fmt.Errorf("failed to listen: %v", err) 311 return 312 } 313 _, p, err := net.SplitHostPort(s.lis.Addr().String()) 314 if err != nil { 315 s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err) 316 return 317 } 318 s.port = p 319 s.conns = make(map[ServerTransport]bool) 320 s.startedErr <- nil 321 for { 322 conn, err := s.lis.Accept() 323 if err != nil { 324 return 325 } 326 transport, err := NewServerTransport(conn, serverConfig) 327 if err != nil { 328 return 329 } 330 s.mu.Lock() 331 if s.conns == nil { 332 s.mu.Unlock() 333 transport.Close() 334 return 335 } 336 s.conns[transport] = true 337 h := &testStreamHandler{t: transport.(*http2Server)} 338 s.h = h 339 s.mu.Unlock() 340 switch ht { 341 case notifyCall: 342 go transport.HandleStreams(h.handleStreamAndNotify, 343 func(ctx context.Context, _ string) context.Context { 344 return ctx 345 }) 346 case suspended: 347 go transport.HandleStreams(func(*Stream) {}, // Do nothing to handle the stream. 348 func(ctx context.Context, method string) context.Context { 349 return ctx 350 }) 351 case misbehaved: 352 go transport.HandleStreams(func(s *Stream) { 353 go h.handleStreamMisbehave(t, s) 354 }, func(ctx context.Context, method string) context.Context { 355 return ctx 356 }) 357 case encodingRequiredStatus: 358 go transport.HandleStreams(func(s *Stream) { 359 go h.handleStreamEncodingRequiredStatus(s) 360 }, func(ctx context.Context, method string) context.Context { 361 return ctx 362 }) 363 case invalidHeaderField: 364 go transport.HandleStreams(func(s *Stream) { 365 go h.handleStreamInvalidHeaderField(s) 366 }, func(ctx context.Context, method string) context.Context { 367 return ctx 368 }) 369 case delayRead: 370 h.notify = make(chan struct{}) 371 h.getNotified = make(chan struct{}) 372 s.mu.Lock() 373 close(s.ready) 374 s.mu.Unlock() 375 go transport.HandleStreams(func(s *Stream) { 376 go h.handleStreamDelayRead(t, s) 377 }, func(ctx context.Context, method string) context.Context { 378 return ctx 379 }) 380 case pingpong: 381 go transport.HandleStreams(func(s *Stream) { 382 go h.handleStreamPingPong(t, s) 383 }, func(ctx context.Context, method string) context.Context { 384 return ctx 385 }) 386 default: 387 go transport.HandleStreams(func(s *Stream) { 388 go h.handleStream(t, s) 389 }, func(ctx context.Context, method string) context.Context { 390 return ctx 391 }) 392 } 393 } 394 } 395 396 func (s *server) wait(t *testing.T, timeout time.Duration) { 397 select { 398 case err := <-s.startedErr: 399 if err != nil { 400 t.Fatal(err) 401 } 402 case <-time.After(timeout): 403 t.Fatalf("Timed out after %v waiting for server to be ready", timeout) 404 } 405 } 406 407 func (s *server) stop() { 408 s.lis.Close() 409 s.mu.Lock() 410 for c := range s.conns { 411 c.Close() 412 } 413 s.conns = nil 414 s.mu.Unlock() 415 } 416 417 func (s *server) addr() string { 418 if s.lis == nil { 419 return "" 420 } 421 return s.lis.Addr().String() 422 } 423 424 func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hType) *server { 425 server := &server{startedErr: make(chan error, 1), ready: make(chan struct{})} 426 go server.start(t, port, serverConfig, ht) 427 server.wait(t, 2*time.Second) 428 return server 429 } 430 431 func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) { 432 return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{}) 433 } 434 435 func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) { 436 server := setUpServerOnly(t, port, serverConfig, ht) 437 addr := resolver.Address{Addr: "localhost:" + server.port} 438 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 439 ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func() {}, func(GoAwayReason) {}, func() {}) 440 if connErr != nil { 441 cancel() // Do not cancel in success path. 442 t.Fatalf("failed to create transport: %v", connErr) 443 } 444 return server, ct.(*http2Client), cancel 445 } 446 447 func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) { 448 lis, err := net.Listen("tcp", "localhost:0") 449 if err != nil { 450 t.Fatalf("Failed to listen: %v", err) 451 } 452 // Launch a non responsive server. 453 go func() { 454 defer lis.Close() 455 conn, err := lis.Accept() 456 if err != nil { 457 t.Errorf("Error at server-side while accepting: %v", err) 458 close(connCh) 459 return 460 } 461 connCh <- conn 462 }() 463 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 464 tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func() {}, func(GoAwayReason) {}, func() {}) 465 if err != nil { 466 cancel() // Do not cancel in success path. 467 // Server clean-up. 468 lis.Close() 469 if conn, ok := <-connCh; ok { 470 conn.Close() 471 } 472 t.Fatalf("Failed to dial: %v", err) 473 } 474 return tr.(*http2Client), cancel 475 } 476 477 // TestInflightStreamClosing ensures that closing in-flight stream 478 // sends status error to concurrent stream reader. 479 func (s) TestInflightStreamClosing(t *testing.T) { 480 serverConfig := &ServerConfig{} 481 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 482 defer cancel() 483 defer server.stop() 484 defer client.Close(fmt.Errorf("closed manually by test")) 485 486 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 487 defer cancel() 488 stream, err := client.NewStream(ctx, &CallHdr{}) 489 if err != nil { 490 t.Fatalf("Client failed to create RPC request: %v", err) 491 } 492 493 donec := make(chan struct{}) 494 serr := status.Error(codes.Internal, "client connection is closing") 495 go func() { 496 defer close(donec) 497 if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr { 498 t.Errorf("unexpected Stream error %v, expected %v", err, serr) 499 } 500 }() 501 502 // should unblock concurrent stream.Read 503 client.CloseStream(stream, serr) 504 505 // wait for stream.Read error 506 timeout := time.NewTimer(5 * time.Second) 507 select { 508 case <-donec: 509 if !timeout.Stop() { 510 <-timeout.C 511 } 512 case <-timeout.C: 513 t.Fatalf("Test timed out, expected a status error.") 514 } 515 } 516 517 func (s) TestClientSendAndReceive(t *testing.T) { 518 server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 519 defer cancel() 520 callHdr := &CallHdr{ 521 Host: "localhost", 522 Method: "foo.Small", 523 } 524 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 525 defer ctxCancel() 526 s1, err1 := ct.NewStream(ctx, callHdr) 527 if err1 != nil { 528 t.Fatalf("failed to open stream: %v", err1) 529 } 530 if s1.id != 1 { 531 t.Fatalf("wrong stream id: %d", s1.id) 532 } 533 s2, err2 := ct.NewStream(ctx, callHdr) 534 if err2 != nil { 535 t.Fatalf("failed to open stream: %v", err2) 536 } 537 if s2.id != 3 { 538 t.Fatalf("wrong stream id: %d", s2.id) 539 } 540 opts := Options{Last: true} 541 if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF { 542 t.Fatalf("failed to send data: %v", err) 543 } 544 p := make([]byte, len(expectedResponse)) 545 _, recvErr := s1.Read(p) 546 if recvErr != nil || !bytes.Equal(p, expectedResponse) { 547 t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse) 548 } 549 _, recvErr = s1.Read(p) 550 if recvErr != io.EOF { 551 t.Fatalf("Error: %v; want <EOF>", recvErr) 552 } 553 ct.Close(fmt.Errorf("closed manually by test")) 554 server.stop() 555 } 556 557 func (s) TestClientErrorNotify(t *testing.T) { 558 server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 559 defer cancel() 560 go server.stop() 561 // ct.reader should detect the error and activate ct.Error(). 562 <-ct.Error() 563 ct.Close(fmt.Errorf("closed manually by test")) 564 } 565 566 func performOneRPC(ct ClientTransport) { 567 callHdr := &CallHdr{ 568 Host: "localhost", 569 Method: "foo.Small", 570 } 571 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 572 defer cancel() 573 s, err := ct.NewStream(ctx, callHdr) 574 if err != nil { 575 return 576 } 577 opts := Options{Last: true} 578 if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF { 579 time.Sleep(5 * time.Millisecond) 580 // The following s.Recv()'s could error out because the 581 // underlying transport is gone. 582 // 583 // Read response 584 p := make([]byte, len(expectedResponse)) 585 s.Read(p) 586 // Read io.EOF 587 s.Read(p) 588 } 589 } 590 591 func (s) TestClientMix(t *testing.T) { 592 s, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 593 defer cancel() 594 go func(s *server) { 595 time.Sleep(5 * time.Second) 596 s.stop() 597 }(s) 598 go func(ct ClientTransport) { 599 <-ct.Error() 600 ct.Close(fmt.Errorf("closed manually by test")) 601 }(ct) 602 for i := 0; i < 1000; i++ { 603 time.Sleep(10 * time.Millisecond) 604 go performOneRPC(ct) 605 } 606 } 607 608 func (s) TestLargeMessage(t *testing.T) { 609 server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 610 defer cancel() 611 callHdr := &CallHdr{ 612 Host: "localhost", 613 Method: "foo.Large", 614 } 615 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 616 defer ctxCancel() 617 var wg sync.WaitGroup 618 for i := 0; i < 2; i++ { 619 wg.Add(1) 620 go func() { 621 defer wg.Done() 622 s, err := ct.NewStream(ctx, callHdr) 623 if err != nil { 624 t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) 625 } 626 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil && err != io.EOF { 627 t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) 628 } 629 p := make([]byte, len(expectedResponseLarge)) 630 if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { 631 t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse) 632 } 633 if _, err = s.Read(p); err != io.EOF { 634 t.Errorf("Failed to complete the stream %v; want <EOF>", err) 635 } 636 }() 637 } 638 wg.Wait() 639 ct.Close(fmt.Errorf("closed manually by test")) 640 server.stop() 641 } 642 643 func (s) TestLargeMessageWithDelayRead(t *testing.T) { 644 // Disable dynamic flow control. 645 sc := &ServerConfig{ 646 InitialWindowSize: defaultWindowSize, 647 InitialConnWindowSize: defaultWindowSize, 648 } 649 co := ConnectOptions{ 650 InitialWindowSize: defaultWindowSize, 651 InitialConnWindowSize: defaultWindowSize, 652 } 653 server, ct, cancel := setUpWithOptions(t, 0, sc, delayRead, co) 654 defer cancel() 655 defer server.stop() 656 defer ct.Close(fmt.Errorf("closed manually by test")) 657 server.mu.Lock() 658 ready := server.ready 659 server.mu.Unlock() 660 callHdr := &CallHdr{ 661 Host: "localhost", 662 Method: "foo.Large", 663 } 664 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) 665 defer cancel() 666 s, err := ct.NewStream(ctx, callHdr) 667 if err != nil { 668 t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) 669 return 670 } 671 // Wait for server's handerler to be initialized 672 select { 673 case <-ready: 674 case <-ctx.Done(): 675 t.Fatalf("Client timed out waiting for server handler to be initialized.") 676 } 677 server.mu.Lock() 678 serviceHandler := server.h 679 server.mu.Unlock() 680 var ( 681 mu sync.Mutex 682 total int 683 ) 684 s.wq.replenish = func(n int) { 685 mu.Lock() 686 total += n 687 mu.Unlock() 688 s.wq.realReplenish(n) 689 } 690 getTotal := func() int { 691 mu.Lock() 692 defer mu.Unlock() 693 return total 694 } 695 done := make(chan struct{}) 696 defer close(done) 697 go func() { 698 for { 699 select { 700 // Prevent goroutine from leaking in case of error. 701 case <-done: 702 return 703 default: 704 } 705 if getTotal() == defaultWindowSize { 706 // unblock server to be able to read and 707 // thereby send stream level window update. 708 close(serviceHandler.getNotified) 709 return 710 } 711 runtime.Gosched() 712 } 713 }() 714 // This write will cause client to run out of stream level, 715 // flow control and the other side won't send a window update 716 // until that happens. 717 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{}); err != nil { 718 t.Fatalf("write(_, _, _) = %v, want <nil>", err) 719 } 720 p := make([]byte, len(expectedResponseLarge)) 721 722 // Wait for the other side to run out of stream level flow control before 723 // reading and thereby sending a window update. 724 select { 725 case <-serviceHandler.notify: 726 case <-ctx.Done(): 727 t.Fatalf("Client timed out") 728 } 729 if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { 730 t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err) 731 } 732 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil { 733 t.Fatalf("Write(_, _, _) = %v, want <nil>", err) 734 } 735 if _, err = s.Read(p); err != io.EOF { 736 t.Fatalf("Failed to complete the stream %v; want <EOF>", err) 737 } 738 } 739 740 func (s) TestGracefulClose(t *testing.T) { 741 server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong) 742 defer cancel() 743 defer func() { 744 // Stop the server's listener to make the server's goroutines terminate 745 // (after the last active stream is done). 746 server.lis.Close() 747 // Check for goroutine leaks (i.e. GracefulClose with an active stream 748 // doesn't eventually close the connection when that stream completes). 749 leakcheck.Check(t) 750 // Correctly clean up the server 751 server.stop() 752 }() 753 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) 754 defer cancel() 755 s, err := ct.NewStream(ctx, &CallHdr{}) 756 if err != nil { 757 t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err) 758 } 759 msg := make([]byte, 1024) 760 outgoingHeader := make([]byte, 5) 761 outgoingHeader[0] = byte(0) 762 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg))) 763 incomingHeader := make([]byte, 5) 764 if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil { 765 t.Fatalf("Error while writing: %v", err) 766 } 767 if _, err := s.Read(incomingHeader); err != nil { 768 t.Fatalf("Error while reading: %v", err) 769 } 770 sz := binary.BigEndian.Uint32(incomingHeader[1:]) 771 recvMsg := make([]byte, int(sz)) 772 if _, err := s.Read(recvMsg); err != nil { 773 t.Fatalf("Error while reading: %v", err) 774 } 775 ct.GracefulClose() 776 var wg sync.WaitGroup 777 // Expect the failure for all the follow-up streams because ct has been closed gracefully. 778 for i := 0; i < 200; i++ { 779 wg.Add(1) 780 go func() { 781 defer wg.Done() 782 str, err := ct.NewStream(ctx, &CallHdr{}) 783 if err != nil && err.(*NewStreamError).Err == ErrConnClosing { 784 return 785 } else if err != nil { 786 t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) 787 return 788 } 789 ct.Write(str, nil, nil, &Options{Last: true}) 790 if _, err := str.Read(make([]byte, 8)); err != errStreamDrain && err != ErrConnClosing { 791 t.Errorf("_.Read(_) = _, %v, want _, %v or %v", err, errStreamDrain, ErrConnClosing) 792 } 793 }() 794 } 795 ct.Write(s, nil, nil, &Options{Last: true}) 796 if _, err := s.Read(incomingHeader); err != io.EOF { 797 t.Fatalf("Client expected EOF from the server. Got: %v", err) 798 } 799 // The stream which was created before graceful close can still proceed. 800 wg.Wait() 801 } 802 803 func (s) TestLargeMessageSuspension(t *testing.T) { 804 server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended) 805 defer cancel() 806 callHdr := &CallHdr{ 807 Host: "localhost", 808 Method: "foo.Large", 809 } 810 // Set a long enough timeout for writing a large message out. 811 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 812 defer cancel() 813 s, err := ct.NewStream(ctx, callHdr) 814 if err != nil { 815 t.Fatalf("failed to open stream: %v", err) 816 } 817 // Launch a goroutine simillar to the stream monitoring goroutine in 818 // stream.go to keep track of context timeout and call CloseStream. 819 go func() { 820 <-ctx.Done() 821 ct.CloseStream(s, ContextErr(ctx.Err())) 822 }() 823 // Write should not be done successfully due to flow control. 824 msg := make([]byte, initialWindowSize*8) 825 ct.Write(s, nil, msg, &Options{}) 826 err = ct.Write(s, nil, msg, &Options{Last: true}) 827 if err != errStreamDone { 828 t.Fatalf("Write got %v, want io.EOF", err) 829 } 830 expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) 831 if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { 832 t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) 833 } 834 ct.Close(fmt.Errorf("closed manually by test")) 835 server.stop() 836 } 837 838 func (s) TestMaxStreams(t *testing.T) { 839 serverConfig := &ServerConfig{ 840 MaxStreams: 1, 841 } 842 server, ct, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 843 defer cancel() 844 defer ct.Close(fmt.Errorf("closed manually by test")) 845 defer server.stop() 846 callHdr := &CallHdr{ 847 Host: "localhost", 848 Method: "foo.Large", 849 } 850 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 851 defer cancel() 852 s, err := ct.NewStream(ctx, callHdr) 853 if err != nil { 854 t.Fatalf("Failed to open stream: %v", err) 855 } 856 // Keep creating streams until one fails with deadline exceeded, marking the application 857 // of server settings on client. 858 slist := []*Stream{} 859 pctx, cancel := context.WithCancel(context.Background()) 860 defer cancel() 861 timer := time.NewTimer(time.Second * 10) 862 expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) 863 for { 864 select { 865 case <-timer.C: 866 t.Fatalf("Test timeout: client didn't receive server settings.") 867 default: 868 } 869 ctx, cancel := context.WithDeadline(pctx, time.Now().Add(time.Second)) 870 // This is only to get rid of govet. All these context are based on a base 871 // context which is canceled at the end of the test. 872 defer cancel() 873 if str, err := ct.NewStream(ctx, callHdr); err == nil { 874 slist = append(slist, str) 875 continue 876 } else if err.Error() != expectedErr.Error() { 877 t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr) 878 } 879 timer.Stop() 880 break 881 } 882 done := make(chan struct{}) 883 // Try and create a new stream. 884 go func() { 885 defer close(done) 886 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) 887 defer cancel() 888 if _, err := ct.NewStream(ctx, callHdr); err != nil { 889 t.Errorf("Failed to open stream: %v", err) 890 } 891 }() 892 // Close all the extra streams created and make sure the new stream is not created. 893 for _, str := range slist { 894 ct.CloseStream(str, nil) 895 } 896 select { 897 case <-done: 898 t.Fatalf("Test failed: didn't expect new stream to be created just yet.") 899 default: 900 } 901 // Close the first stream created so that the new stream can finally be created. 902 ct.CloseStream(s, nil) 903 <-done 904 ct.Close(fmt.Errorf("closed manually by test")) 905 <-ct.writerDone 906 if ct.maxConcurrentStreams != 1 { 907 t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams) 908 } 909 } 910 911 func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { 912 server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended) 913 defer cancel() 914 callHdr := &CallHdr{ 915 Host: "localhost", 916 Method: "foo", 917 } 918 var sc *http2Server 919 // Wait until the server transport is setup. 920 for { 921 server.mu.Lock() 922 if len(server.conns) == 0 { 923 server.mu.Unlock() 924 time.Sleep(time.Millisecond) 925 continue 926 } 927 for k := range server.conns { 928 var ok bool 929 sc, ok = k.(*http2Server) 930 if !ok { 931 t.Fatalf("Failed to convert %v to *http2Server", k) 932 } 933 } 934 server.mu.Unlock() 935 break 936 } 937 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 938 defer cancel() 939 s, err := ct.NewStream(ctx, callHdr) 940 if err != nil { 941 t.Fatalf("Failed to open stream: %v", err) 942 } 943 ct.controlBuf.put(&dataFrame{ 944 streamID: s.id, 945 endStream: false, 946 h: nil, 947 d: make([]byte, http2MaxFrameLen), 948 onEachWrite: func() {}, 949 }) 950 // Loop until the server side stream is created. 951 var ss *Stream 952 for { 953 time.Sleep(time.Second) 954 sc.mu.Lock() 955 if len(sc.activeStreams) == 0 { 956 sc.mu.Unlock() 957 continue 958 } 959 ss = sc.activeStreams[s.id] 960 sc.mu.Unlock() 961 break 962 } 963 ct.Close(fmt.Errorf("closed manually by test")) 964 select { 965 case <-ss.Context().Done(): 966 if ss.Context().Err() != context.Canceled { 967 t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled) 968 } 969 case <-time.After(5 * time.Second): 970 t.Fatalf("Failed to cancel the context of the sever side stream.") 971 } 972 server.stop() 973 } 974 975 func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) { 976 connectOptions := ConnectOptions{ 977 InitialWindowSize: defaultWindowSize, 978 InitialConnWindowSize: defaultWindowSize, 979 } 980 server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) 981 defer cancel() 982 defer server.stop() 983 defer client.Close(fmt.Errorf("closed manually by test")) 984 985 waitWhileTrue(t, func() (bool, error) { 986 server.mu.Lock() 987 defer server.mu.Unlock() 988 989 if len(server.conns) == 0 { 990 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 991 } 992 return false, nil 993 }) 994 995 var st *http2Server 996 server.mu.Lock() 997 for k := range server.conns { 998 st = k.(*http2Server) 999 } 1000 notifyChan := make(chan struct{}) 1001 server.h.notify = notifyChan 1002 server.mu.Unlock() 1003 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1004 defer cancel() 1005 cstream1, err := client.NewStream(ctx, &CallHdr{}) 1006 if err != nil { 1007 t.Fatalf("Client failed to create first stream. Err: %v", err) 1008 } 1009 1010 <-notifyChan 1011 var sstream1 *Stream 1012 // Access stream on the server. 1013 st.mu.Lock() 1014 for _, v := range st.activeStreams { 1015 if v.id == cstream1.id { 1016 sstream1 = v 1017 } 1018 } 1019 st.mu.Unlock() 1020 if sstream1 == nil { 1021 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) 1022 } 1023 // Exhaust client's connection window. 1024 if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { 1025 t.Fatalf("Server failed to write data. Err: %v", err) 1026 } 1027 notifyChan = make(chan struct{}) 1028 server.mu.Lock() 1029 server.h.notify = notifyChan 1030 server.mu.Unlock() 1031 // Create another stream on client. 1032 cstream2, err := client.NewStream(ctx, &CallHdr{}) 1033 if err != nil { 1034 t.Fatalf("Client failed to create second stream. Err: %v", err) 1035 } 1036 <-notifyChan 1037 var sstream2 *Stream 1038 st.mu.Lock() 1039 for _, v := range st.activeStreams { 1040 if v.id == cstream2.id { 1041 sstream2 = v 1042 } 1043 } 1044 st.mu.Unlock() 1045 if sstream2 == nil { 1046 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) 1047 } 1048 // Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream. 1049 if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { 1050 t.Fatalf("Server failed to write data. Err: %v", err) 1051 } 1052 1053 // Client should be able to read data on second stream. 1054 if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil { 1055 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) 1056 } 1057 1058 // Client should be able to read data on first stream. 1059 if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil { 1060 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) 1061 } 1062 } 1063 1064 func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) { 1065 serverConfig := &ServerConfig{ 1066 InitialWindowSize: defaultWindowSize, 1067 InitialConnWindowSize: defaultWindowSize, 1068 } 1069 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 1070 defer cancel() 1071 defer server.stop() 1072 defer client.Close(fmt.Errorf("closed manually by test")) 1073 waitWhileTrue(t, func() (bool, error) { 1074 server.mu.Lock() 1075 defer server.mu.Unlock() 1076 1077 if len(server.conns) == 0 { 1078 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 1079 } 1080 return false, nil 1081 }) 1082 var st *http2Server 1083 server.mu.Lock() 1084 for k := range server.conns { 1085 st = k.(*http2Server) 1086 } 1087 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1088 defer cancel() 1089 server.mu.Unlock() 1090 cstream1, err := client.NewStream(ctx, &CallHdr{}) 1091 if err != nil { 1092 t.Fatalf("Failed to create 1st stream. Err: %v", err) 1093 } 1094 // Exhaust server's connection window. 1095 if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { 1096 t.Fatalf("Client failed to write data. Err: %v", err) 1097 } 1098 //Client should be able to create another stream and send data on it. 1099 cstream2, err := client.NewStream(ctx, &CallHdr{}) 1100 if err != nil { 1101 t.Fatalf("Failed to create 2nd stream. Err: %v", err) 1102 } 1103 if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil { 1104 t.Fatalf("Client failed to write data. Err: %v", err) 1105 } 1106 // Get the streams on server. 1107 waitWhileTrue(t, func() (bool, error) { 1108 st.mu.Lock() 1109 defer st.mu.Unlock() 1110 1111 if len(st.activeStreams) != 2 { 1112 return true, fmt.Errorf("timed-out while waiting for server to have created the streams") 1113 } 1114 return false, nil 1115 }) 1116 var sstream1 *Stream 1117 st.mu.Lock() 1118 for _, v := range st.activeStreams { 1119 if v.id == 1 { 1120 sstream1 = v 1121 } 1122 } 1123 st.mu.Unlock() 1124 // Reading from the stream on server should succeed. 1125 if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil { 1126 t.Fatalf("_.Read(_) = %v, want <nil>", err) 1127 } 1128 1129 if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF { 1130 t.Fatalf("_.Read(_) = %v, want io.EOF", err) 1131 } 1132 1133 } 1134 1135 func (s) TestServerWithMisbehavedClient(t *testing.T) { 1136 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1137 defer server.stop() 1138 // Create a client that can override server stream quota. 1139 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1140 if err != nil { 1141 t.Fatalf("Clent failed to dial:%v", err) 1142 } 1143 defer mconn.Close() 1144 if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil { 1145 t.Fatalf("Failed to set write deadline: %v", err) 1146 } 1147 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1148 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1149 } 1150 // success chan indicates that reader received a RSTStream from server. 1151 success := make(chan struct{}) 1152 var mu sync.Mutex 1153 framer := http2.NewFramer(mconn, mconn) 1154 if err := framer.WriteSettings(); err != nil { 1155 t.Fatalf("Error while writing settings: %v", err) 1156 } 1157 go func() { // Launch a reader for this misbehaving client. 1158 for { 1159 frame, err := framer.ReadFrame() 1160 if err != nil { 1161 return 1162 } 1163 switch frame := frame.(type) { 1164 case *http2.PingFrame: 1165 // Write ping ack back so that server's BDP estimation works right. 1166 mu.Lock() 1167 framer.WritePing(true, frame.Data) 1168 mu.Unlock() 1169 case *http2.RSTStreamFrame: 1170 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl { 1171 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) 1172 } 1173 close(success) 1174 return 1175 default: 1176 // Do nothing. 1177 } 1178 1179 } 1180 }() 1181 // Create a stream. 1182 var buf bytes.Buffer 1183 henc := hpack.NewEncoder(&buf) 1184 // TODO(mmukhi): Remove unnecessary fields. 1185 if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil { 1186 t.Fatalf("Error while encoding header: %v", err) 1187 } 1188 if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { 1189 t.Fatalf("Error while encoding header: %v", err) 1190 } 1191 if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { 1192 t.Fatalf("Error while encoding header: %v", err) 1193 } 1194 if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { 1195 t.Fatalf("Error while encoding header: %v", err) 1196 } 1197 mu.Lock() 1198 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1199 mu.Unlock() 1200 t.Fatalf("Error while writing headers: %v", err) 1201 } 1202 mu.Unlock() 1203 1204 // Test server behavior for violation of stream flow control window size restriction. 1205 timer := time.NewTimer(time.Second * 5) 1206 dbuf := make([]byte, http2MaxFrameLen) 1207 for { 1208 select { 1209 case <-timer.C: 1210 t.Fatalf("Test timed out.") 1211 case <-success: 1212 return 1213 default: 1214 } 1215 mu.Lock() 1216 if err := framer.WriteData(1, false, dbuf); err != nil { 1217 mu.Unlock() 1218 // Error here means the server could have closed the connection due to flow control 1219 // violation. Make sure that is the case by waiting for success chan to be closed. 1220 select { 1221 case <-timer.C: 1222 t.Fatalf("Error while writing data: %v", err) 1223 case <-success: 1224 return 1225 } 1226 } 1227 mu.Unlock() 1228 // This for loop is capable of hogging the CPU and cause starvation 1229 // in Go versions prior to 1.9, 1230 // in single CPU environment. Explicitly relinquish processor. 1231 runtime.Gosched() 1232 } 1233 } 1234 1235 func (s) TestClientWithMisbehavedServer(t *testing.T) { 1236 // Create a misbehaving server. 1237 lis, err := net.Listen("tcp", "localhost:0") 1238 if err != nil { 1239 t.Fatalf("Error while listening: %v", err) 1240 } 1241 defer lis.Close() 1242 // success chan indicates that the server received 1243 // RSTStream from the client. 1244 success := make(chan struct{}) 1245 go func() { // Launch the misbehaving server. 1246 sconn, err := lis.Accept() 1247 if err != nil { 1248 t.Errorf("Error while accepting: %v", err) 1249 return 1250 } 1251 defer sconn.Close() 1252 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { 1253 t.Errorf("Error while reading clieng preface: %v", err) 1254 return 1255 } 1256 sfr := http2.NewFramer(sconn, sconn) 1257 if err := sfr.WriteSettingsAck(); err != nil { 1258 t.Errorf("Error while writing settings: %v", err) 1259 return 1260 } 1261 var mu sync.Mutex 1262 for { 1263 frame, err := sfr.ReadFrame() 1264 if err != nil { 1265 return 1266 } 1267 switch frame := frame.(type) { 1268 case *http2.HeadersFrame: 1269 // When the client creates a stream, violate the stream flow control. 1270 go func() { 1271 buf := make([]byte, http2MaxFrameLen) 1272 for { 1273 mu.Lock() 1274 if err := sfr.WriteData(1, false, buf); err != nil { 1275 mu.Unlock() 1276 return 1277 } 1278 mu.Unlock() 1279 // This for loop is capable of hogging the CPU and cause starvation 1280 // in Go versions prior to 1.9, 1281 // in single CPU environment. Explicitly relinquish processor. 1282 runtime.Gosched() 1283 } 1284 }() 1285 case *http2.RSTStreamFrame: 1286 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl { 1287 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) 1288 } 1289 close(success) 1290 return 1291 case *http2.PingFrame: 1292 mu.Lock() 1293 sfr.WritePing(true, frame.Data) 1294 mu.Unlock() 1295 default: 1296 } 1297 } 1298 }() 1299 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 1300 defer cancel() 1301 ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func() {}, func(GoAwayReason) {}, func() {}) 1302 if err != nil { 1303 t.Fatalf("Error while creating client transport: %v", err) 1304 } 1305 defer ct.Close(fmt.Errorf("closed manually by test")) 1306 str, err := ct.NewStream(connectCtx, &CallHdr{}) 1307 if err != nil { 1308 t.Fatalf("Error while creating stream: %v", err) 1309 } 1310 timer := time.NewTimer(time.Second * 5) 1311 go func() { // This go routine mimics the one in stream.go to call CloseStream. 1312 <-str.Done() 1313 ct.CloseStream(str, nil) 1314 }() 1315 select { 1316 case <-timer.C: 1317 t.Fatalf("Test timed-out.") 1318 case <-success: 1319 } 1320 } 1321 1322 var encodingTestStatus = status.New(codes.Internal, "\n") 1323 1324 func (s) TestEncodingRequiredStatus(t *testing.T) { 1325 server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) 1326 defer cancel() 1327 callHdr := &CallHdr{ 1328 Host: "localhost", 1329 Method: "foo", 1330 } 1331 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1332 defer cancel() 1333 s, err := ct.NewStream(ctx, callHdr) 1334 if err != nil { 1335 return 1336 } 1337 opts := Options{Last: true} 1338 if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone { 1339 t.Fatalf("Failed to write the request: %v", err) 1340 } 1341 p := make([]byte, http2MaxFrameLen) 1342 if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF { 1343 t.Fatalf("Read got error %v, want %v", err, io.EOF) 1344 } 1345 if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) { 1346 t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) 1347 } 1348 ct.Close(fmt.Errorf("closed manually by test")) 1349 server.stop() 1350 } 1351 1352 func (s) TestInvalidHeaderField(t *testing.T) { 1353 server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) 1354 defer cancel() 1355 callHdr := &CallHdr{ 1356 Host: "localhost", 1357 Method: "foo", 1358 } 1359 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1360 defer cancel() 1361 s, err := ct.NewStream(ctx, callHdr) 1362 if err != nil { 1363 return 1364 } 1365 p := make([]byte, http2MaxFrameLen) 1366 _, err = s.trReader.(*transportReader).Read(p) 1367 if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { 1368 t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField) 1369 } 1370 ct.Close(fmt.Errorf("closed manually by test")) 1371 server.stop() 1372 } 1373 1374 func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { 1375 server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) 1376 defer cancel() 1377 defer server.stop() 1378 defer ct.Close(fmt.Errorf("closed manually by test")) 1379 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1380 defer cancel() 1381 s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) 1382 if err != nil { 1383 t.Fatalf("failed to create the stream") 1384 } 1385 timer := time.NewTimer(time.Second) 1386 defer timer.Stop() 1387 select { 1388 case <-s.headerChan: 1389 case <-timer.C: 1390 t.Errorf("s.headerChan: got open, want closed") 1391 } 1392 } 1393 1394 func (s) TestIsReservedHeader(t *testing.T) { 1395 tests := []struct { 1396 h string 1397 want bool 1398 }{ 1399 {"", false}, // but should be rejected earlier 1400 {"foo", false}, 1401 {"content-type", true}, 1402 {"user-agent", true}, 1403 {":anything", true}, 1404 {"grpc-message-type", true}, 1405 {"grpc-encoding", true}, 1406 {"grpc-message", true}, 1407 {"grpc-status", true}, 1408 {"grpc-timeout", true}, 1409 {"te", true}, 1410 } 1411 for _, tt := range tests { 1412 got := isReservedHeader(tt.h) 1413 if got != tt.want { 1414 t.Errorf("isReservedHeader(%q) = %v; want %v", tt.h, got, tt.want) 1415 } 1416 } 1417 } 1418 1419 func (s) TestContextErr(t *testing.T) { 1420 for _, test := range []struct { 1421 // input 1422 errIn error 1423 // outputs 1424 errOut error 1425 }{ 1426 {context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, 1427 {context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())}, 1428 } { 1429 err := ContextErr(test.errIn) 1430 if err.Error() != test.errOut.Error() { 1431 t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) 1432 } 1433 } 1434 } 1435 1436 type windowSizeConfig struct { 1437 serverStream int32 1438 serverConn int32 1439 clientStream int32 1440 clientConn int32 1441 } 1442 1443 func (s) TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) { 1444 wc := windowSizeConfig{ 1445 serverStream: 10 * 1024 * 1024, 1446 serverConn: 12 * 1024 * 1024, 1447 clientStream: 6 * 1024 * 1024, 1448 clientConn: 8 * 1024 * 1024, 1449 } 1450 testFlowControlAccountCheck(t, 1024*1024, wc) 1451 } 1452 1453 func (s) TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) { 1454 wc := windowSizeConfig{ 1455 serverStream: defaultWindowSize, 1456 // Note this is smaller than initialConnWindowSize which is the current default. 1457 serverConn: defaultWindowSize, 1458 clientStream: defaultWindowSize, 1459 clientConn: defaultWindowSize, 1460 } 1461 testFlowControlAccountCheck(t, 1024*1024, wc) 1462 } 1463 1464 func (s) TestAccountCheckDynamicWindowSmallMessage(t *testing.T) { 1465 testFlowControlAccountCheck(t, 1024, windowSizeConfig{}) 1466 } 1467 1468 func (s) TestAccountCheckDynamicWindowLargeMessage(t *testing.T) { 1469 testFlowControlAccountCheck(t, 1024*1024, windowSizeConfig{}) 1470 } 1471 1472 func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) { 1473 sc := &ServerConfig{ 1474 InitialWindowSize: wc.serverStream, 1475 InitialConnWindowSize: wc.serverConn, 1476 } 1477 co := ConnectOptions{ 1478 InitialWindowSize: wc.clientStream, 1479 InitialConnWindowSize: wc.clientConn, 1480 } 1481 server, client, cancel := setUpWithOptions(t, 0, sc, pingpong, co) 1482 defer cancel() 1483 defer server.stop() 1484 defer client.Close(fmt.Errorf("closed manually by test")) 1485 waitWhileTrue(t, func() (bool, error) { 1486 server.mu.Lock() 1487 defer server.mu.Unlock() 1488 if len(server.conns) == 0 { 1489 return true, fmt.Errorf("timed out while waiting for server transport to be created") 1490 } 1491 return false, nil 1492 }) 1493 var st *http2Server 1494 server.mu.Lock() 1495 for k := range server.conns { 1496 st = k.(*http2Server) 1497 } 1498 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1499 defer cancel() 1500 server.mu.Unlock() 1501 const numStreams = 10 1502 clientStreams := make([]*Stream, numStreams) 1503 for i := 0; i < numStreams; i++ { 1504 var err error 1505 clientStreams[i], err = client.NewStream(ctx, &CallHdr{}) 1506 if err != nil { 1507 t.Fatalf("Failed to create stream. Err: %v", err) 1508 } 1509 } 1510 var wg sync.WaitGroup 1511 // For each stream send pingpong messages to the server. 1512 for _, stream := range clientStreams { 1513 wg.Add(1) 1514 go func(stream *Stream) { 1515 defer wg.Done() 1516 buf := make([]byte, msgSize+5) 1517 buf[0] = byte(0) 1518 binary.BigEndian.PutUint32(buf[1:], uint32(msgSize)) 1519 opts := Options{} 1520 header := make([]byte, 5) 1521 for i := 1; i <= 10; i++ { 1522 if err := client.Write(stream, nil, buf, &opts); err != nil { 1523 t.Errorf("Error on client while writing message: %v", err) 1524 return 1525 } 1526 if _, err := stream.Read(header); err != nil { 1527 t.Errorf("Error on client while reading data frame header: %v", err) 1528 return 1529 } 1530 sz := binary.BigEndian.Uint32(header[1:]) 1531 recvMsg := make([]byte, int(sz)) 1532 if _, err := stream.Read(recvMsg); err != nil { 1533 t.Errorf("Error on client while reading data: %v", err) 1534 return 1535 } 1536 if len(recvMsg) != msgSize { 1537 t.Errorf("Length of message received by client: %v, want: %v", len(recvMsg), msgSize) 1538 return 1539 } 1540 } 1541 }(stream) 1542 } 1543 wg.Wait() 1544 serverStreams := map[uint32]*Stream{} 1545 loopyClientStreams := map[uint32]*outStream{} 1546 loopyServerStreams := map[uint32]*outStream{} 1547 // Get all the streams from server reader and writer and client writer. 1548 st.mu.Lock() 1549 for _, stream := range clientStreams { 1550 id := stream.id 1551 serverStreams[id] = st.activeStreams[id] 1552 loopyServerStreams[id] = st.loopy.estdStreams[id] 1553 loopyClientStreams[id] = client.loopy.estdStreams[id] 1554 1555 } 1556 st.mu.Unlock() 1557 // Close all streams 1558 for _, stream := range clientStreams { 1559 client.Write(stream, nil, nil, &Options{Last: true}) 1560 if _, err := stream.Read(make([]byte, 5)); err != io.EOF { 1561 t.Fatalf("Client expected an EOF from the server. Got: %v", err) 1562 } 1563 } 1564 // Close down both server and client so that their internals can be read without data 1565 // races. 1566 client.Close(fmt.Errorf("closed manually by test")) 1567 st.Close() 1568 <-st.readerDone 1569 <-st.writerDone 1570 <-client.readerDone 1571 <-client.writerDone 1572 for _, cstream := range clientStreams { 1573 id := cstream.id 1574 sstream := serverStreams[id] 1575 loopyServerStream := loopyServerStreams[id] 1576 loopyClientStream := loopyClientStreams[id] 1577 // Check stream flow control. 1578 if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { 1579 t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding) 1580 } 1581 if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(client.loopy.oiws)-loopyClientStream.bytesOutStanding { 1582 t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, client.loopy.oiws, loopyClientStream.bytesOutStanding) 1583 } 1584 } 1585 // Check transport flow control. 1586 if client.fc.limit != client.fc.unacked+st.loopy.sendQuota { 1587 t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", client.fc.limit, client.fc.unacked, st.loopy.sendQuota) 1588 } 1589 if st.fc.limit != st.fc.unacked+client.loopy.sendQuota { 1590 t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, client.loopy.sendQuota) 1591 } 1592 } 1593 1594 func waitWhileTrue(t *testing.T, condition func() (bool, error)) { 1595 var ( 1596 wait bool 1597 err error 1598 ) 1599 timer := time.NewTimer(time.Second * 5) 1600 for { 1601 wait, err = condition() 1602 if wait { 1603 select { 1604 case <-timer.C: 1605 t.Fatalf(err.Error()) 1606 default: 1607 time.Sleep(50 * time.Millisecond) 1608 continue 1609 } 1610 } 1611 if !timer.Stop() { 1612 <-timer.C 1613 } 1614 break 1615 } 1616 } 1617 1618 // If any error occurs on a call to Stream.Read, future calls 1619 // should continue to return that same error. 1620 func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { 1621 testRecvBuffer := newRecvBuffer() 1622 s := &Stream{ 1623 ctx: context.Background(), 1624 buf: testRecvBuffer, 1625 requestRead: func(int) {}, 1626 } 1627 s.trReader = &transportReader{ 1628 reader: &recvBufferReader{ 1629 ctx: s.ctx, 1630 ctxDone: s.ctx.Done(), 1631 recv: s.buf, 1632 freeBuffer: func(*bytes.Buffer) {}, 1633 }, 1634 windowHandler: func(int) {}, 1635 } 1636 testData := make([]byte, 1) 1637 testData[0] = 5 1638 testBuffer := bytes.NewBuffer(testData) 1639 testErr := errors.New("test error") 1640 s.write(recvMsg{buffer: testBuffer, err: testErr}) 1641 1642 inBuf := make([]byte, 1) 1643 actualCount, actualErr := s.Read(inBuf) 1644 if actualCount != 0 { 1645 t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount) 1646 } 1647 if actualErr.Error() != testErr.Error() { 1648 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) 1649 } 1650 1651 s.write(recvMsg{buffer: testBuffer, err: nil}) 1652 s.write(recvMsg{buffer: testBuffer, err: errors.New("different error from first")}) 1653 1654 for i := 0; i < 2; i++ { 1655 inBuf := make([]byte, 1) 1656 actualCount, actualErr := s.Read(inBuf) 1657 if actualCount != 0 { 1658 t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount) 1659 } 1660 if actualErr.Error() != testErr.Error() { 1661 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) 1662 } 1663 } 1664 } 1665 1666 // TestHeadersCausingStreamError tests headers that should cause a stream protocol 1667 // error, which would end up with a RST_STREAM being sent to the client and also 1668 // the server closing the stream. 1669 func (s) TestHeadersCausingStreamError(t *testing.T) { 1670 tests := []struct { 1671 name string 1672 headers []struct { 1673 name string 1674 values []string 1675 } 1676 }{ 1677 // If the client sends an HTTP/2 request with a :method header with a 1678 // value other than POST, as specified in the gRPC over HTTP/2 1679 // specification, the server should close the stream. 1680 { 1681 name: "Client Sending Wrong Method", 1682 headers: []struct { 1683 name string 1684 values []string 1685 }{ 1686 {name: ":method", values: []string{"PUT"}}, 1687 {name: ":path", values: []string{"foo"}}, 1688 {name: ":authority", values: []string{"localhost"}}, 1689 {name: "content-type", values: []string{"application/grpc"}}, 1690 }, 1691 }, 1692 // "Transports must consider requests containing the Connection header 1693 // as malformed" - A41 Malformed requests map to a stream error of type 1694 // PROTOCOL_ERROR. 1695 { 1696 name: "Connection header present", 1697 headers: []struct { 1698 name string 1699 values []string 1700 }{ 1701 {name: ":method", values: []string{"POST"}}, 1702 {name: ":path", values: []string{"foo"}}, 1703 {name: ":authority", values: []string{"localhost"}}, 1704 {name: "content-type", values: []string{"application/grpc"}}, 1705 {name: "connection", values: []string{"not-supported"}}, 1706 }, 1707 }, 1708 // multiple :authority or multiple Host headers would make the eventual 1709 // :authority ambiguous as per A41. Since these headers won't have a 1710 // content-type that corresponds to a grpc-client, the server should 1711 // simply write a RST_STREAM to the wire. 1712 { 1713 // Note: multiple authority headers are handled by the framer 1714 // itself, which will cause a stream error. Thus, it will never get 1715 // to operateHeaders with the check in operateHeaders for stream 1716 // error, but the server transport will still send a stream error. 1717 name: "Multiple authority headers", 1718 headers: []struct { 1719 name string 1720 values []string 1721 }{ 1722 {name: ":method", values: []string{"POST"}}, 1723 {name: ":path", values: []string{"foo"}}, 1724 {name: ":authority", values: []string{"localhost", "localhost2"}}, 1725 {name: "host", values: []string{"localhost"}}, 1726 }, 1727 }, 1728 } 1729 for _, test := range tests { 1730 t.Run(test.name, func(t *testing.T) { 1731 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1732 defer server.stop() 1733 // Create a client directly to not tie what you can send to API of 1734 // http2_client.go (i.e. control headers being sent). 1735 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1736 if err != nil { 1737 t.Fatalf("Client failed to dial: %v", err) 1738 } 1739 defer mconn.Close() 1740 1741 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1742 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1743 } 1744 1745 framer := http2.NewFramer(mconn, mconn) 1746 if err := framer.WriteSettings(); err != nil { 1747 t.Fatalf("Error while writing settings: %v", err) 1748 } 1749 1750 // result chan indicates that reader received a RSTStream from server. 1751 // An error will be passed on it if any other frame is received. 1752 result := testutils.NewChannel() 1753 1754 // Launch a reader goroutine. 1755 go func() { 1756 for { 1757 frame, err := framer.ReadFrame() 1758 if err != nil { 1759 return 1760 } 1761 switch frame := frame.(type) { 1762 case *http2.SettingsFrame: 1763 // Do nothing. A settings frame is expected from server preface. 1764 case *http2.RSTStreamFrame: 1765 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol { 1766 // Client only created a single stream, so RST Stream should be for that single stream. 1767 result.Send(fmt.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))) 1768 } 1769 // Records that client successfully received RST Stream frame. 1770 result.Send(nil) 1771 return 1772 default: 1773 // The server should send nothing but a single RST Stream frame. 1774 result.Send(errors.New("the client received a frame other than RST Stream")) 1775 } 1776 } 1777 }() 1778 1779 var buf bytes.Buffer 1780 henc := hpack.NewEncoder(&buf) 1781 1782 // Needs to build headers deterministically to conform to gRPC over 1783 // HTTP/2 spec. 1784 for _, header := range test.headers { 1785 for _, value := range header.values { 1786 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { 1787 t.Fatalf("Error while encoding header: %v", err) 1788 } 1789 } 1790 } 1791 1792 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1793 t.Fatalf("Error while writing headers: %v", err) 1794 } 1795 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1796 defer cancel() 1797 r, err := result.Receive(ctx) 1798 if err != nil { 1799 t.Fatalf("Error receiving from channel: %v", err) 1800 } 1801 if r != nil { 1802 t.Fatalf("want nil, got %v", r) 1803 } 1804 }) 1805 } 1806 } 1807 1808 // TestHeadersMultipleHosts tests that a request with multiple hosts gets 1809 // rejected with HTTP Status 400 and gRPC status Internal, regardless of whether 1810 // the client is speaking gRPC or not. 1811 func (s) TestHeadersMultipleHosts(t *testing.T) { 1812 tests := []struct { 1813 name string 1814 headers []struct { 1815 name string 1816 values []string 1817 } 1818 }{ 1819 // Note: multiple authority headers are handled by the framer itself, 1820 // which will cause a stream error. Thus, it will never get to 1821 // operateHeaders with the check in operateHeaders for possible grpc-status sent back. 1822 1823 // multiple :authority or multiple Host headers would make the eventual 1824 // :authority ambiguous as per A41. This takes precedence even over the 1825 // fact a request is non grpc. All of these requests should be rejected 1826 // with grpc-status Internal. 1827 { 1828 name: "Multiple host headers non grpc", 1829 headers: []struct { 1830 name string 1831 values []string 1832 }{ 1833 {name: ":method", values: []string{"POST"}}, 1834 {name: ":path", values: []string{"foo"}}, 1835 {name: ":authority", values: []string{"localhost"}}, 1836 {name: "host", values: []string{"localhost", "localhost2"}}, 1837 }, 1838 }, 1839 { 1840 name: "Multiple host headers grpc", 1841 headers: []struct { 1842 name string 1843 values []string 1844 }{ 1845 {name: ":method", values: []string{"POST"}}, 1846 {name: ":path", values: []string{"foo"}}, 1847 {name: ":authority", values: []string{"localhost"}}, 1848 {name: "content-type", values: []string{"application/grpc"}}, 1849 {name: "host", values: []string{"localhost", "localhost2"}}, 1850 }, 1851 }, 1852 } 1853 for _, test := range tests { 1854 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1855 defer server.stop() 1856 // Create a client directly to not tie what you can send to API of 1857 // http2_client.go (i.e. control headers being sent). 1858 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1859 if err != nil { 1860 t.Fatalf("Client failed to dial: %v", err) 1861 } 1862 defer mconn.Close() 1863 1864 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1865 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1866 } 1867 1868 framer := http2.NewFramer(mconn, mconn) 1869 framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil) 1870 if err := framer.WriteSettings(); err != nil { 1871 t.Fatalf("Error while writing settings: %v", err) 1872 } 1873 1874 // result chan indicates that reader received a Headers Frame with 1875 // desired grpc status and message from server. An error will be passed 1876 // on it if any other frame is received. 1877 result := testutils.NewChannel() 1878 1879 // Launch a reader goroutine. 1880 go func() { 1881 for { 1882 frame, err := framer.ReadFrame() 1883 if err != nil { 1884 return 1885 } 1886 switch frame := frame.(type) { 1887 case *http2.SettingsFrame: 1888 // Do nothing. A settings frame is expected from server preface. 1889 case *http2.MetaHeadersFrame: 1890 var status, grpcStatus, grpcMessage string 1891 for _, header := range frame.Fields { 1892 if header.Name == ":status" { 1893 status = header.Value 1894 } 1895 if header.Name == "grpc-status" { 1896 grpcStatus = header.Value 1897 } 1898 if header.Name == "grpc-message" { 1899 grpcMessage = header.Value 1900 } 1901 } 1902 if status != "400" { 1903 result.Send(fmt.Errorf("incorrect HTTP Status got %v, want 200", status)) 1904 return 1905 } 1906 if grpcStatus != "13" { // grpc status code internal 1907 result.Send(fmt.Errorf("incorrect gRPC Status got %v, want 13", grpcStatus)) 1908 return 1909 } 1910 if !strings.Contains(grpcMessage, "both must only have 1 value as per HTTP/2 spec") { 1911 result.Send(fmt.Errorf("incorrect gRPC message")) 1912 return 1913 } 1914 1915 // Records that client successfully received a HeadersFrame 1916 // with expected Trailers-Only response. 1917 result.Send(nil) 1918 return 1919 default: 1920 // The server should send nothing but a single Settings and Headers frame. 1921 result.Send(errors.New("the client received a frame other than Settings or Headers")) 1922 } 1923 } 1924 }() 1925 1926 var buf bytes.Buffer 1927 henc := hpack.NewEncoder(&buf) 1928 1929 // Needs to build headers deterministically to conform to gRPC over 1930 // HTTP/2 spec. 1931 for _, header := range test.headers { 1932 for _, value := range header.values { 1933 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { 1934 t.Fatalf("Error while encoding header: %v", err) 1935 } 1936 } 1937 } 1938 1939 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1940 t.Fatalf("Error while writing headers: %v", err) 1941 } 1942 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1943 defer cancel() 1944 r, err := result.Receive(ctx) 1945 if err != nil { 1946 t.Fatalf("Error receiving from channel: %v", err) 1947 } 1948 if r != nil { 1949 t.Fatalf("want nil, got %v", r) 1950 } 1951 } 1952 } 1953 1954 func (s) TestPingPong1B(t *testing.T) { 1955 runPingPongTest(t, 1) 1956 } 1957 1958 func (s) TestPingPong1KB(t *testing.T) { 1959 runPingPongTest(t, 1024) 1960 } 1961 1962 func (s) TestPingPong64KB(t *testing.T) { 1963 runPingPongTest(t, 65536) 1964 } 1965 1966 func (s) TestPingPong1MB(t *testing.T) { 1967 runPingPongTest(t, 1048576) 1968 } 1969 1970 // This is a stress-test of flow control logic. 1971 func runPingPongTest(t *testing.T, msgSize int) { 1972 server, client, cancel := setUp(t, 0, 0, pingpong) 1973 defer cancel() 1974 defer server.stop() 1975 defer client.Close(fmt.Errorf("closed manually by test")) 1976 waitWhileTrue(t, func() (bool, error) { 1977 server.mu.Lock() 1978 defer server.mu.Unlock() 1979 if len(server.conns) == 0 { 1980 return true, fmt.Errorf("timed out while waiting for server transport to be created") 1981 } 1982 return false, nil 1983 }) 1984 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1985 defer cancel() 1986 stream, err := client.NewStream(ctx, &CallHdr{}) 1987 if err != nil { 1988 t.Fatalf("Failed to create stream. Err: %v", err) 1989 } 1990 msg := make([]byte, msgSize) 1991 outgoingHeader := make([]byte, 5) 1992 outgoingHeader[0] = byte(0) 1993 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize)) 1994 opts := &Options{} 1995 incomingHeader := make([]byte, 5) 1996 done := make(chan struct{}) 1997 go func() { 1998 timer := time.NewTimer(time.Second * 5) 1999 <-timer.C 2000 close(done) 2001 }() 2002 for { 2003 select { 2004 case <-done: 2005 client.Write(stream, nil, nil, &Options{Last: true}) 2006 if _, err := stream.Read(incomingHeader); err != io.EOF { 2007 t.Fatalf("Client expected EOF from the server. Got: %v", err) 2008 } 2009 return 2010 default: 2011 if err := client.Write(stream, outgoingHeader, msg, opts); err != nil { 2012 t.Fatalf("Error on client while writing message. Err: %v", err) 2013 } 2014 if _, err := stream.Read(incomingHeader); err != nil { 2015 t.Fatalf("Error on client while reading data header. Err: %v", err) 2016 } 2017 sz := binary.BigEndian.Uint32(incomingHeader[1:]) 2018 recvMsg := make([]byte, int(sz)) 2019 if _, err := stream.Read(recvMsg); err != nil { 2020 t.Fatalf("Error on client while reading data. Err: %v", err) 2021 } 2022 } 2023 } 2024 } 2025 2026 type tableSizeLimit struct { 2027 mu sync.Mutex 2028 limits []uint32 2029 } 2030 2031 func (t *tableSizeLimit) add(limit uint32) { 2032 t.mu.Lock() 2033 t.limits = append(t.limits, limit) 2034 t.mu.Unlock() 2035 } 2036 2037 func (t *tableSizeLimit) getLen() int { 2038 t.mu.Lock() 2039 defer t.mu.Unlock() 2040 return len(t.limits) 2041 } 2042 2043 func (t *tableSizeLimit) getIndex(i int) uint32 { 2044 t.mu.Lock() 2045 defer t.mu.Unlock() 2046 return t.limits[i] 2047 } 2048 2049 func (s) TestHeaderTblSize(t *testing.T) { 2050 limits := &tableSizeLimit{} 2051 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { 2052 e.SetMaxDynamicTableSizeLimit(v) 2053 limits.add(v) 2054 } 2055 defer func() { 2056 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { 2057 e.SetMaxDynamicTableSizeLimit(v) 2058 } 2059 }() 2060 2061 server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 2062 defer cancel() 2063 defer ct.Close(fmt.Errorf("closed manually by test")) 2064 defer server.stop() 2065 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2066 defer ctxCancel() 2067 _, err := ct.NewStream(ctx, &CallHdr{}) 2068 if err != nil { 2069 t.Fatalf("failed to open stream: %v", err) 2070 } 2071 2072 var svrTransport ServerTransport 2073 var i int 2074 for i = 0; i < 1000; i++ { 2075 server.mu.Lock() 2076 if len(server.conns) != 0 { 2077 server.mu.Unlock() 2078 break 2079 } 2080 server.mu.Unlock() 2081 time.Sleep(10 * time.Millisecond) 2082 continue 2083 } 2084 if i == 1000 { 2085 t.Fatalf("unable to create any server transport after 10s") 2086 } 2087 2088 for st := range server.conns { 2089 svrTransport = st 2090 break 2091 } 2092 svrTransport.(*http2Server).controlBuf.put(&outgoingSettings{ 2093 ss: []http2.Setting{ 2094 { 2095 ID: http2.SettingHeaderTableSize, 2096 Val: uint32(100), 2097 }, 2098 }, 2099 }) 2100 2101 for i = 0; i < 1000; i++ { 2102 if limits.getLen() != 1 { 2103 time.Sleep(10 * time.Millisecond) 2104 continue 2105 } 2106 if val := limits.getIndex(0); val != uint32(100) { 2107 t.Fatalf("expected limits[0] = 100, got %d", val) 2108 } 2109 break 2110 } 2111 if i == 1000 { 2112 t.Fatalf("expected len(limits) = 1 within 10s, got != 1") 2113 } 2114 2115 ct.controlBuf.put(&outgoingSettings{ 2116 ss: []http2.Setting{ 2117 { 2118 ID: http2.SettingHeaderTableSize, 2119 Val: uint32(200), 2120 }, 2121 }, 2122 }) 2123 2124 for i := 0; i < 1000; i++ { 2125 if limits.getLen() != 2 { 2126 time.Sleep(10 * time.Millisecond) 2127 continue 2128 } 2129 if val := limits.getIndex(1); val != uint32(200) { 2130 t.Fatalf("expected limits[1] = 200, got %d", val) 2131 } 2132 break 2133 } 2134 if i == 1000 { 2135 t.Fatalf("expected len(limits) = 2 within 10s, got != 2") 2136 } 2137 } 2138 2139 // attrTransportCreds is a transport credential implementation which stores 2140 // Attributes from the ClientHandshakeInfo struct passed in the context locally 2141 // for the test to inspect. 2142 type attrTransportCreds struct { 2143 credentials.TransportCredentials 2144 attr *attributes.Attributes 2145 } 2146 2147 func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 2148 ai := credentials.ClientHandshakeInfoFromContext(ctx) 2149 ac.attr = ai.Attributes 2150 return rawConn, nil, nil 2151 } 2152 func (ac *attrTransportCreds) Info() credentials.ProtocolInfo { 2153 return credentials.ProtocolInfo{} 2154 } 2155 func (ac *attrTransportCreds) Clone() credentials.TransportCredentials { 2156 return nil 2157 } 2158 2159 // TestClientHandshakeInfo adds attributes to the resolver.Address passes to 2160 // NewClientTransport and verifies that these attributes are received by the 2161 // transport credential handshaker. 2162 func (s) TestClientHandshakeInfo(t *testing.T) { 2163 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong) 2164 defer server.stop() 2165 2166 const ( 2167 testAttrKey = "foo" 2168 testAttrVal = "bar" 2169 ) 2170 addr := resolver.Address{ 2171 Addr: "localhost:" + server.port, 2172 Attributes: attributes.New(testAttrKey, testAttrVal), 2173 } 2174 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 2175 defer cancel() 2176 creds := &attrTransportCreds{} 2177 2178 tr, err := NewClientTransport(ctx, context.Background(), addr, ConnectOptions{TransportCredentials: creds}, func() {}, func(GoAwayReason) {}, func() {}) 2179 if err != nil { 2180 t.Fatalf("NewClientTransport(): %v", err) 2181 } 2182 defer tr.Close(fmt.Errorf("closed manually by test")) 2183 2184 wantAttr := attributes.New(testAttrKey, testAttrVal) 2185 if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 2186 t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr) 2187 } 2188 } 2189 2190 // TestClientHandshakeInfoDialer adds attributes to the resolver.Address passes to 2191 // NewClientTransport and verifies that these attributes are received by a custom 2192 // dialer. 2193 func (s) TestClientHandshakeInfoDialer(t *testing.T) { 2194 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong) 2195 defer server.stop() 2196 2197 const ( 2198 testAttrKey = "foo" 2199 testAttrVal = "bar" 2200 ) 2201 addr := resolver.Address{ 2202 Addr: "localhost:" + server.port, 2203 Attributes: attributes.New(testAttrKey, testAttrVal), 2204 } 2205 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 2206 defer cancel() 2207 2208 var attr *attributes.Attributes 2209 dialer := func(ctx context.Context, addr string) (net.Conn, error) { 2210 ai := credentials.ClientHandshakeInfoFromContext(ctx) 2211 attr = ai.Attributes 2212 return (&net.Dialer{}).DialContext(ctx, "tcp", addr) 2213 } 2214 2215 tr, err := NewClientTransport(ctx, context.Background(), addr, ConnectOptions{Dialer: dialer}, func() {}, func(GoAwayReason) {}, func() {}) 2216 if err != nil { 2217 t.Fatalf("NewClientTransport(): %v", err) 2218 } 2219 defer tr.Close(fmt.Errorf("closed manually by test")) 2220 2221 wantAttr := attributes.New(testAttrKey, testAttrVal) 2222 if gotAttr := attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 2223 t.Errorf("Received attributes %v in custom dialer, want %v", gotAttr, wantAttr) 2224 } 2225 } 2226 2227 func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { 2228 testStream := func() *Stream { 2229 return &Stream{ 2230 done: make(chan struct{}), 2231 headerChan: make(chan struct{}), 2232 buf: &recvBuffer{ 2233 c: make(chan recvMsg), 2234 mu: sync.Mutex{}, 2235 }, 2236 } 2237 } 2238 2239 testClient := func(ts *Stream) *http2Client { 2240 return &http2Client{ 2241 mu: sync.Mutex{}, 2242 activeStreams: map[uint32]*Stream{ 2243 0: ts, 2244 }, 2245 controlBuf: &controlBuffer{ 2246 ch: make(chan struct{}), 2247 done: make(chan struct{}), 2248 list: &itemList{}, 2249 }, 2250 } 2251 } 2252 2253 for _, test := range []struct { 2254 name string 2255 // input 2256 metaHeaderFrame *http2.MetaHeadersFrame 2257 // output 2258 wantStatus *status.Status 2259 }{ 2260 { 2261 name: "valid header", 2262 metaHeaderFrame: &http2.MetaHeadersFrame{ 2263 Fields: []hpack.HeaderField{ 2264 {Name: "content-type", Value: "application/grpc"}, 2265 {Name: "grpc-status", Value: "0"}, 2266 {Name: ":status", Value: "200"}, 2267 }, 2268 }, 2269 // no error 2270 wantStatus: status.New(codes.OK, ""), 2271 }, 2272 { 2273 name: "missing content-type header", 2274 metaHeaderFrame: &http2.MetaHeadersFrame{ 2275 Fields: []hpack.HeaderField{ 2276 {Name: "grpc-status", Value: "0"}, 2277 {Name: ":status", Value: "200"}, 2278 }, 2279 }, 2280 wantStatus: status.New( 2281 codes.Unknown, 2282 "malformed header: missing HTTP content-type", 2283 ), 2284 }, 2285 { 2286 name: "invalid grpc status header field", 2287 metaHeaderFrame: &http2.MetaHeadersFrame{ 2288 Fields: []hpack.HeaderField{ 2289 {Name: "content-type", Value: "application/grpc"}, 2290 {Name: "grpc-status", Value: "xxxx"}, 2291 {Name: ":status", Value: "200"}, 2292 }, 2293 }, 2294 wantStatus: status.New( 2295 codes.Internal, 2296 "transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", 2297 ), 2298 }, 2299 { 2300 name: "invalid http content type", 2301 metaHeaderFrame: &http2.MetaHeadersFrame{ 2302 Fields: []hpack.HeaderField{ 2303 {Name: "content-type", Value: "application/json"}, 2304 }, 2305 }, 2306 wantStatus: status.New( 2307 codes.Internal, 2308 "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"", 2309 ), 2310 }, 2311 { 2312 name: "http fallback and invalid http status", 2313 metaHeaderFrame: &http2.MetaHeadersFrame{ 2314 Fields: []hpack.HeaderField{ 2315 // No content type provided then fallback into handling http error. 2316 {Name: ":status", Value: "xxxx"}, 2317 }, 2318 }, 2319 wantStatus: status.New( 2320 codes.Internal, 2321 "transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", 2322 ), 2323 }, 2324 { 2325 name: "http2 frame size exceeds", 2326 metaHeaderFrame: &http2.MetaHeadersFrame{ 2327 Fields: nil, 2328 Truncated: true, 2329 }, 2330 wantStatus: status.New( 2331 codes.Internal, 2332 "peer header list size exceeded limit", 2333 ), 2334 }, 2335 { 2336 name: "bad status in grpc mode", 2337 metaHeaderFrame: &http2.MetaHeadersFrame{ 2338 Fields: []hpack.HeaderField{ 2339 {Name: "content-type", Value: "application/grpc"}, 2340 {Name: "grpc-status", Value: "0"}, 2341 {Name: ":status", Value: "504"}, 2342 }, 2343 }, 2344 wantStatus: status.New( 2345 codes.Unavailable, 2346 "unexpected HTTP status code received from server: 504 (Gateway Timeout)", 2347 ), 2348 }, 2349 { 2350 name: "missing http status", 2351 metaHeaderFrame: &http2.MetaHeadersFrame{ 2352 Fields: []hpack.HeaderField{ 2353 {Name: "content-type", Value: "application/grpc"}, 2354 }, 2355 }, 2356 wantStatus: status.New( 2357 codes.Internal, 2358 "malformed header: missing HTTP status", 2359 ), 2360 }, 2361 } { 2362 2363 t.Run(test.name, func(t *testing.T) { 2364 ts := testStream() 2365 s := testClient(ts) 2366 2367 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ 2368 FrameHeader: http2.FrameHeader{ 2369 StreamID: 0, 2370 }, 2371 } 2372 2373 s.operateHeaders(test.metaHeaderFrame) 2374 2375 got := ts.status 2376 want := test.wantStatus 2377 if got.Code() != want.Code() || got.Message() != want.Message() { 2378 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) 2379 } 2380 }) 2381 t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) { 2382 ts := testStream() 2383 s := testClient(ts) 2384 2385 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ 2386 FrameHeader: http2.FrameHeader{ 2387 StreamID: 0, 2388 Flags: http2.FlagHeadersEndStream, 2389 }, 2390 } 2391 2392 s.operateHeaders(test.metaHeaderFrame) 2393 2394 got := ts.status 2395 want := test.wantStatus 2396 if got.Code() != want.Code() || got.Message() != want.Message() { 2397 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) 2398 } 2399 }) 2400 } 2401 }