gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/grpc/internal/transport/transport_test.go (about) 1 /* 2 * 3 * Copyright 2014 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package transport 20 21 import ( 22 "bytes" 23 "context" 24 "encoding/binary" 25 "errors" 26 "fmt" 27 "io" 28 "math" 29 "net" 30 "runtime" 31 "strconv" 32 "strings" 33 "sync" 34 "testing" 35 "time" 36 37 "gitee.com/zhaochuninhefei/gmgo/grpc/attributes" 38 "gitee.com/zhaochuninhefei/gmgo/grpc/codes" 39 "gitee.com/zhaochuninhefei/gmgo/grpc/credentials" 40 "gitee.com/zhaochuninhefei/gmgo/grpc/internal/grpctest" 41 "gitee.com/zhaochuninhefei/gmgo/grpc/internal/leakcheck" 42 "gitee.com/zhaochuninhefei/gmgo/grpc/internal/testutils" 43 "gitee.com/zhaochuninhefei/gmgo/grpc/resolver" 44 "gitee.com/zhaochuninhefei/gmgo/grpc/status" 45 "gitee.com/zhaochuninhefei/gmgo/net/http2" 46 "gitee.com/zhaochuninhefei/gmgo/net/http2/hpack" 47 "github.com/google/go-cmp/cmp" 48 ) 49 50 type s struct { 51 grpctest.Tester 52 } 53 54 func Test(t *testing.T) { 55 grpctest.RunSubTests(t, s{}) 56 } 57 58 type server struct { 59 lis net.Listener 60 port string 61 startedErr chan error // error (or nil) with server start value 62 mu sync.Mutex 63 conns map[ServerTransport]bool 64 h *testStreamHandler 65 ready chan struct{} 66 } 67 68 var ( 69 expectedRequest = []byte("ping") 70 expectedResponse = []byte("pong") 71 expectedRequestLarge = make([]byte, initialWindowSize*2) 72 expectedResponseLarge = make([]byte, initialWindowSize*2) 73 expectedInvalidHeaderField = "invalid/content-type" 74 ) 75 76 func init() { 77 expectedRequestLarge[0] = 'g' 78 expectedRequestLarge[len(expectedRequestLarge)-1] = 'r' 79 expectedResponseLarge[0] = 'p' 80 expectedResponseLarge[len(expectedResponseLarge)-1] = 'c' 81 } 82 83 type testStreamHandler struct { 84 t *http2Server 85 notify chan struct{} 86 getNotified chan struct{} 87 } 88 89 type hType int 90 91 const ( 92 normal hType = iota 93 suspended 94 notifyCall 95 misbehaved 96 encodingRequiredStatus 97 invalidHeaderField 98 delayRead 99 pingpong 100 ) 101 102 //goland:noinspection GoUnusedParameter 103 func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { 104 if h.notify == nil { 105 return 106 } 107 go func() { 108 select { 109 case <-h.notify: 110 default: 111 close(h.notify) 112 } 113 }() 114 } 115 116 func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { 117 req := expectedRequest 118 resp := expectedResponse 119 if s.Method() == "foo.Large" { 120 req = expectedRequestLarge 121 resp = expectedResponseLarge 122 } 123 p := make([]byte, len(req)) 124 _, err := s.Read(p) 125 if err != nil { 126 return 127 } 128 if !bytes.Equal(p, req) { 129 t.Errorf("handleStream got %v, want %v", p, req) 130 _ = h.t.WriteStatus(s, status.New(codes.Internal, "panic")) 131 return 132 } 133 // send a response back to the client. 134 _ = h.t.Write(s, nil, resp, &Options{}) 135 // send the trailer to end the stream. 136 _ = h.t.WriteStatus(s, status.New(codes.OK, "")) 137 } 138 139 func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { 140 header := make([]byte, 5) 141 for { 142 if _, err := s.Read(header); err != nil { 143 if err == io.EOF { 144 _ = h.t.WriteStatus(s, status.New(codes.OK, "")) 145 return 146 } 147 t.Errorf("Error on server while reading data header: %v", err) 148 _ = h.t.WriteStatus(s, status.New(codes.Internal, "panic")) 149 return 150 } 151 sz := binary.BigEndian.Uint32(header[1:]) 152 msg := make([]byte, int(sz)) 153 if _, err := s.Read(msg); err != nil { 154 t.Errorf("Error on server while reading message: %v", err) 155 _ = h.t.WriteStatus(s, status.New(codes.Internal, "panic")) 156 return 157 } 158 buf := make([]byte, sz+5) 159 buf[0] = byte(0) 160 binary.BigEndian.PutUint32(buf[1:], sz) 161 copy(buf[5:], msg) 162 _ = h.t.Write(s, nil, buf, &Options{}) 163 } 164 } 165 166 func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { 167 conn, ok := s.st.(*http2Server) 168 if !ok { 169 t.Errorf("Failed to convert %v to *http2Server", s.st) 170 _ = h.t.WriteStatus(s, status.New(codes.Internal, "")) 171 return 172 } 173 var sent int 174 p := make([]byte, http2MaxFrameLen) 175 for sent < initialWindowSize { 176 n := initialWindowSize - sent 177 // The last message may be smaller than http2MaxFrameLen 178 if n <= http2MaxFrameLen { 179 if s.Method() == "foo.Connection" { 180 // Violate connection level flow control window of client but do not 181 // violate any stream level windows. 182 p = make([]byte, n) 183 } else { 184 // Violate stream level flow control window of client. 185 p = make([]byte, n+1) 186 } 187 } 188 _ = conn.controlBuf.put(&dataFrame{ 189 streamID: s.id, 190 h: nil, 191 d: p, 192 onEachWrite: func() {}, 193 }) 194 sent += len(p) 195 } 196 } 197 198 func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *Stream) { 199 // raw newline is not accepted by http2 framer so it must be encoded. 200 _ = h.t.WriteStatus(s, encodingTestStatus) 201 } 202 203 func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) { 204 var headerFields []hpack.HeaderField 205 headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) 206 _ = h.t.controlBuf.put(&headerFrame{ 207 streamID: s.id, 208 hf: headerFields, 209 endStream: false, 210 }) 211 } 212 213 // handleStreamDelayRead delays reads so that the other side has to halt on 214 // stream-level flow control. 215 // This handler assumes dynamic flow control is turned off and assumes window 216 // sizes to be set to defaultWindowSize. 217 func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { 218 req := expectedRequest 219 resp := expectedResponse 220 if s.Method() == "foo.Large" { 221 req = expectedRequestLarge 222 resp = expectedResponseLarge 223 } 224 var ( 225 mu sync.Mutex 226 total int 227 ) 228 s.wq.replenish = func(n int) { 229 mu.Lock() 230 total += n 231 mu.Unlock() 232 s.wq.realReplenish(n) 233 } 234 getTotal := func() int { 235 mu.Lock() 236 defer mu.Unlock() 237 return total 238 } 239 done := make(chan struct{}) 240 defer close(done) 241 go func() { 242 for { 243 select { 244 // Prevent goroutine from leaking. 245 case <-done: 246 return 247 default: 248 } 249 if getTotal() == defaultWindowSize { 250 // Signal the client to start reading and 251 // thereby send window update. 252 close(h.notify) 253 return 254 } 255 runtime.Gosched() 256 } 257 }() 258 p := make([]byte, len(req)) 259 260 // Let the other side run out of stream-level window before 261 // starting to read and thereby sending a window update. 262 timer := time.NewTimer(time.Second * 10) 263 select { 264 case <-h.getNotified: 265 timer.Stop() 266 case <-timer.C: 267 t.Errorf("Server timed-out.") 268 return 269 } 270 _, err := s.Read(p) 271 if err != nil { 272 t.Errorf("s.Read(_) = _, %v, want _, <nil>", err) 273 return 274 } 275 276 if !bytes.Equal(p, req) { 277 t.Errorf("handleStream got %v, want %v", p, req) 278 return 279 } 280 // This write will cause server to run out of stream level, 281 // flow control and the other side won't send a window update 282 // until that happens. 283 if err := h.t.Write(s, nil, resp, &Options{}); err != nil { 284 t.Errorf("server Write got %v, want <nil>", err) 285 return 286 } 287 // Read one more time to ensure that everything remains fine and 288 // that the goroutine, that we launched earlier to signal client 289 // to read, gets enough time to process. 290 _, err = s.Read(p) 291 if err != nil { 292 t.Errorf("s.Read(_) = _, %v, want _, nil", err) 293 return 294 } 295 // send the trailer to end the stream. 296 if err := h.t.WriteStatus(s, status.New(codes.OK, "")); err != nil { 297 t.Errorf("server WriteStatus got %v, want <nil>", err) 298 return 299 } 300 } 301 302 // start starts server. Other goroutines should block on s.readyChan for further operations. 303 func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) { 304 var err error 305 if port == 0 { 306 s.lis, err = net.Listen("tcp", "localhost:0") 307 } else { 308 s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port)) 309 } 310 if err != nil { 311 s.startedErr <- fmt.Errorf("failed to listen: %v", err) 312 return 313 } 314 _, p, err := net.SplitHostPort(s.lis.Addr().String()) 315 if err != nil { 316 s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err) 317 return 318 } 319 s.port = p 320 s.conns = make(map[ServerTransport]bool) 321 s.startedErr <- nil 322 for { 323 conn, err := s.lis.Accept() 324 if err != nil { 325 return 326 } 327 transport, err := NewServerTransport(conn, serverConfig) 328 if err != nil { 329 return 330 } 331 s.mu.Lock() 332 if s.conns == nil { 333 s.mu.Unlock() 334 transport.Close() 335 return 336 } 337 s.conns[transport] = true 338 h := &testStreamHandler{t: transport.(*http2Server)} 339 s.h = h 340 s.mu.Unlock() 341 switch ht { 342 case notifyCall: 343 go transport.HandleStreams(h.handleStreamAndNotify, 344 func(ctx context.Context, _ string) context.Context { 345 return ctx 346 }) 347 case suspended: 348 go transport.HandleStreams(func(*Stream) {}, // Do nothing to handle the stream. 349 func(ctx context.Context, method string) context.Context { 350 return ctx 351 }) 352 case misbehaved: 353 go transport.HandleStreams(func(s *Stream) { 354 go h.handleStreamMisbehave(t, s) 355 }, func(ctx context.Context, method string) context.Context { 356 return ctx 357 }) 358 case encodingRequiredStatus: 359 go transport.HandleStreams(func(s *Stream) { 360 go h.handleStreamEncodingRequiredStatus(s) 361 }, func(ctx context.Context, method string) context.Context { 362 return ctx 363 }) 364 case invalidHeaderField: 365 go transport.HandleStreams(func(s *Stream) { 366 go h.handleStreamInvalidHeaderField(s) 367 }, func(ctx context.Context, method string) context.Context { 368 return ctx 369 }) 370 case delayRead: 371 h.notify = make(chan struct{}) 372 h.getNotified = make(chan struct{}) 373 s.mu.Lock() 374 close(s.ready) 375 s.mu.Unlock() 376 go transport.HandleStreams(func(s *Stream) { 377 go h.handleStreamDelayRead(t, s) 378 }, func(ctx context.Context, method string) context.Context { 379 return ctx 380 }) 381 case pingpong: 382 go transport.HandleStreams(func(s *Stream) { 383 go h.handleStreamPingPong(t, s) 384 }, func(ctx context.Context, method string) context.Context { 385 return ctx 386 }) 387 default: 388 go transport.HandleStreams(func(s *Stream) { 389 go h.handleStream(t, s) 390 }, func(ctx context.Context, method string) context.Context { 391 return ctx 392 }) 393 } 394 } 395 } 396 397 func (s *server) wait(t *testing.T, timeout time.Duration) { 398 select { 399 case err := <-s.startedErr: 400 if err != nil { 401 t.Fatal(err) 402 } 403 case <-time.After(timeout): 404 t.Fatalf("Timed out after %v waiting for server to be ready", timeout) 405 } 406 } 407 408 func (s *server) stop() { 409 _ = s.lis.Close() 410 s.mu.Lock() 411 for c := range s.conns { 412 c.Close() 413 } 414 s.conns = nil 415 s.mu.Unlock() 416 } 417 418 func (s *server) addr() string { 419 if s.lis == nil { 420 return "" 421 } 422 return s.lis.Addr().String() 423 } 424 425 func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hType) *server { 426 server := &server{startedErr: make(chan error, 1), ready: make(chan struct{})} 427 go server.start(t, port, serverConfig, ht) 428 server.wait(t, 2*time.Second) 429 return server 430 } 431 432 func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) { 433 return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{}) 434 } 435 436 func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) { 437 server := setUpServerOnly(t, port, serverConfig, ht) 438 addr := resolver.Address{Addr: "localhost:" + server.port} 439 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 440 ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func() {}, func(GoAwayReason) {}, func() {}) 441 if connErr != nil { 442 cancel() // Do not cancel in success path. 443 t.Fatalf("failed to create transport: %v", connErr) 444 } 445 return server, ct.(*http2Client), cancel 446 } 447 448 func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) { 449 lis, err := net.Listen("tcp", "localhost:0") 450 if err != nil { 451 t.Fatalf("Failed to listen: %v", err) 452 } 453 // Launch a non responsive server. 454 go func() { 455 defer func(lis net.Listener) { 456 _ = lis.Close() 457 }(lis) 458 conn, err := lis.Accept() 459 if err != nil { 460 t.Errorf("Error at server-side while accepting: %v", err) 461 close(connCh) 462 return 463 } 464 connCh <- conn 465 }() 466 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 467 tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func() {}, func(GoAwayReason) {}, func() {}) 468 if err != nil { 469 cancel() // Do not cancel in success path. 470 // Server clean-up. 471 _ = lis.Close() 472 if conn, ok := <-connCh; ok { 473 _ = conn.Close() 474 } 475 t.Fatalf("Failed to dial: %v", err) 476 } 477 return tr.(*http2Client), cancel 478 } 479 480 // TestInflightStreamClosing ensures that closing in-flight stream 481 // sends status error to concurrent stream reader. 482 func (s) TestInflightStreamClosing(t *testing.T) { 483 serverConfig := &ServerConfig{} 484 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 485 defer cancel() 486 defer server.stop() 487 defer client.Close(fmt.Errorf("closed manually by test")) 488 489 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 490 defer cancel() 491 stream, err := client.NewStream(ctx, &CallHdr{}) 492 if err != nil { 493 t.Fatalf("Client failed to create RPC request: %v", err) 494 } 495 496 donec := make(chan struct{}) 497 serr := status.Error(codes.Internal, "client connection is closing") 498 go func() { 499 defer close(donec) 500 if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr { 501 t.Errorf("unexpected Stream error %v, expected %v", err, serr) 502 } 503 }() 504 505 // should unblock concurrent stream.Read 506 client.CloseStream(stream, serr) 507 508 // wait for stream.Read error 509 timeout := time.NewTimer(5 * time.Second) 510 select { 511 case <-donec: 512 if !timeout.Stop() { 513 <-timeout.C 514 } 515 case <-timeout.C: 516 t.Fatalf("Test timed out, expected a status error.") 517 } 518 } 519 520 func (s) TestClientSendAndReceive(t *testing.T) { 521 server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 522 defer cancel() 523 callHdr := &CallHdr{ 524 Host: "localhost", 525 Method: "foo.Small", 526 } 527 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 528 defer ctxCancel() 529 s1, err1 := ct.NewStream(ctx, callHdr) 530 if err1 != nil { 531 t.Fatalf("failed to open stream: %v", err1) 532 } 533 if s1.id != 1 { 534 t.Fatalf("wrong stream id: %d", s1.id) 535 } 536 s2, err2 := ct.NewStream(ctx, callHdr) 537 if err2 != nil { 538 t.Fatalf("failed to open stream: %v", err2) 539 } 540 if s2.id != 3 { 541 t.Fatalf("wrong stream id: %d", s2.id) 542 } 543 opts := Options{Last: true} 544 if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF { 545 t.Fatalf("failed to send data: %v", err) 546 } 547 p := make([]byte, len(expectedResponse)) 548 _, recvErr := s1.Read(p) 549 if recvErr != nil || !bytes.Equal(p, expectedResponse) { 550 t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse) 551 } 552 _, recvErr = s1.Read(p) 553 if recvErr != io.EOF { 554 t.Fatalf("Error: %v; want <EOF>", recvErr) 555 } 556 ct.Close(fmt.Errorf("closed manually by test")) 557 server.stop() 558 } 559 560 func (s) TestClientErrorNotify(t *testing.T) { 561 server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 562 defer cancel() 563 go server.stop() 564 // ct.reader should detect the error and activate ct.Error(). 565 <-ct.Error() 566 ct.Close(fmt.Errorf("closed manually by test")) 567 } 568 569 func performOneRPC(ct ClientTransport) { 570 callHdr := &CallHdr{ 571 Host: "localhost", 572 Method: "foo.Small", 573 } 574 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 575 defer cancel() 576 s, err := ct.NewStream(ctx, callHdr) 577 if err != nil { 578 return 579 } 580 opts := Options{Last: true} 581 if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF { 582 time.Sleep(5 * time.Millisecond) 583 // The following s.Recv()'s could error out because the 584 // underlying transport is gone. 585 // 586 // Read response 587 p := make([]byte, len(expectedResponse)) 588 _, _ = s.Read(p) 589 // Read io.EOF 590 _, _ = s.Read(p) 591 } 592 } 593 594 func (s) TestClientMix(t *testing.T) { 595 s, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 596 defer cancel() 597 go func(s *server) { 598 time.Sleep(5 * time.Second) 599 s.stop() 600 }(s) 601 go func(ct ClientTransport) { 602 <-ct.Error() 603 ct.Close(fmt.Errorf("closed manually by test")) 604 }(ct) 605 for i := 0; i < 1000; i++ { 606 time.Sleep(10 * time.Millisecond) 607 go performOneRPC(ct) 608 } 609 } 610 611 func (s) TestLargeMessage(t *testing.T) { 612 server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 613 defer cancel() 614 callHdr := &CallHdr{ 615 Host: "localhost", 616 Method: "foo.Large", 617 } 618 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 619 defer ctxCancel() 620 var wg sync.WaitGroup 621 for i := 0; i < 2; i++ { 622 wg.Add(1) 623 go func() { 624 defer wg.Done() 625 s, err := ct.NewStream(ctx, callHdr) 626 if err != nil { 627 t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) 628 } 629 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil && err != io.EOF { 630 t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) 631 } 632 p := make([]byte, len(expectedResponseLarge)) 633 if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { 634 t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse) 635 } 636 if _, err = s.Read(p); err != io.EOF { 637 t.Errorf("Failed to complete the stream %v; want <EOF>", err) 638 } 639 }() 640 } 641 wg.Wait() 642 ct.Close(fmt.Errorf("closed manually by test")) 643 server.stop() 644 } 645 646 func (s) TestLargeMessageWithDelayRead(t *testing.T) { 647 // Disable dynamic flow control. 648 sc := &ServerConfig{ 649 InitialWindowSize: defaultWindowSize, 650 InitialConnWindowSize: defaultWindowSize, 651 } 652 co := ConnectOptions{ 653 InitialWindowSize: defaultWindowSize, 654 InitialConnWindowSize: defaultWindowSize, 655 } 656 server, ct, cancel := setUpWithOptions(t, 0, sc, delayRead, co) 657 defer cancel() 658 defer server.stop() 659 defer ct.Close(fmt.Errorf("closed manually by test")) 660 server.mu.Lock() 661 ready := server.ready 662 server.mu.Unlock() 663 callHdr := &CallHdr{ 664 Host: "localhost", 665 Method: "foo.Large", 666 } 667 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) 668 defer cancel() 669 s, err := ct.NewStream(ctx, callHdr) 670 if err != nil { 671 t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) 672 return 673 } 674 // Wait for server's handerler to be initialized 675 select { 676 case <-ready: 677 case <-ctx.Done(): 678 t.Fatalf("Client timed out waiting for server handler to be initialized.") 679 } 680 server.mu.Lock() 681 serviceHandler := server.h 682 server.mu.Unlock() 683 var ( 684 mu sync.Mutex 685 total int 686 ) 687 s.wq.replenish = func(n int) { 688 mu.Lock() 689 total += n 690 mu.Unlock() 691 s.wq.realReplenish(n) 692 } 693 getTotal := func() int { 694 mu.Lock() 695 defer mu.Unlock() 696 return total 697 } 698 done := make(chan struct{}) 699 defer close(done) 700 go func() { 701 for { 702 select { 703 // Prevent goroutine from leaking in case of error. 704 case <-done: 705 return 706 default: 707 } 708 if getTotal() == defaultWindowSize { 709 // unblock server to be able to read and 710 // thereby send stream level window update. 711 close(serviceHandler.getNotified) 712 return 713 } 714 runtime.Gosched() 715 } 716 }() 717 // This write will cause client to run out of stream level, 718 // flow control and the other side won't send a window update 719 // until that happens. 720 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{}); err != nil { 721 t.Fatalf("write(_, _, _) = %v, want <nil>", err) 722 } 723 p := make([]byte, len(expectedResponseLarge)) 724 725 // Wait for the other side to run out of stream level flow control before 726 // reading and thereby sending a window update. 727 select { 728 case <-serviceHandler.notify: 729 case <-ctx.Done(): 730 t.Fatalf("Client timed out") 731 } 732 if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { 733 t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err) 734 } 735 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil { 736 t.Fatalf("Write(_, _, _) = %v, want <nil>", err) 737 } 738 if _, err = s.Read(p); err != io.EOF { 739 t.Fatalf("Failed to complete the stream %v; want <EOF>", err) 740 } 741 } 742 743 func (s) TestGracefulClose(t *testing.T) { 744 server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong) 745 defer cancel() 746 defer func() { 747 // Stop the server's listener to make the server's goroutines terminate 748 // (after the last active stream is done). 749 _ = server.lis.Close() 750 // Check for goroutine leaks (i.e. GracefulClose with an active stream 751 // doesn't eventually close the connection when that stream completes). 752 leakcheck.Check(t) 753 // Correctly clean up the server 754 server.stop() 755 }() 756 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) 757 defer cancel() 758 s, err := ct.NewStream(ctx, &CallHdr{}) 759 if err != nil { 760 t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err) 761 } 762 msg := make([]byte, 1024) 763 outgoingHeader := make([]byte, 5) 764 outgoingHeader[0] = byte(0) 765 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg))) 766 incomingHeader := make([]byte, 5) 767 if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil { 768 t.Fatalf("Error while writing: %v", err) 769 } 770 if _, err := s.Read(incomingHeader); err != nil { 771 t.Fatalf("Error while reading: %v", err) 772 } 773 sz := binary.BigEndian.Uint32(incomingHeader[1:]) 774 recvMsg := make([]byte, int(sz)) 775 if _, err := s.Read(recvMsg); err != nil { 776 t.Fatalf("Error while reading: %v", err) 777 } 778 ct.GracefulClose() 779 var wg sync.WaitGroup 780 // Expect the failure for all the follow-up streams because ct has been closed gracefully. 781 for i := 0; i < 200; i++ { 782 wg.Add(1) 783 go func() { 784 defer wg.Done() 785 str, err := ct.NewStream(ctx, &CallHdr{}) 786 if err != nil && err.(*NewStreamError).Err == ErrConnClosing { 787 return 788 } else if err != nil { 789 t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) 790 return 791 } 792 _ = ct.Write(str, nil, nil, &Options{Last: true}) 793 if _, err := str.Read(make([]byte, 8)); err != errStreamDrain && err != ErrConnClosing { 794 t.Errorf("_.Read(_) = _, %v, want _, %v or %v", err, errStreamDrain, ErrConnClosing) 795 } 796 }() 797 } 798 _ = ct.Write(s, nil, nil, &Options{Last: true}) 799 if _, err := s.Read(incomingHeader); err != io.EOF { 800 t.Fatalf("Client expected EOF from the server. Got: %v", err) 801 } 802 // The stream which was created before graceful close can still proceed. 803 wg.Wait() 804 } 805 806 func (s) TestLargeMessageSuspension(t *testing.T) { 807 server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended) 808 defer cancel() 809 callHdr := &CallHdr{ 810 Host: "localhost", 811 Method: "foo.Large", 812 } 813 // Set a long enough timeout for writing a large message out. 814 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 815 defer cancel() 816 s, err := ct.NewStream(ctx, callHdr) 817 if err != nil { 818 t.Fatalf("failed to open stream: %v", err) 819 } 820 // Launch a goroutine simillar to the stream monitoring goroutine in 821 // stream.go to keep track of context timeout and call CloseStream. 822 go func() { 823 <-ctx.Done() 824 ct.CloseStream(s, ContextErr(ctx.Err())) 825 }() 826 // Write should not be done successfully due to flow control. 827 msg := make([]byte, initialWindowSize*8) 828 _ = ct.Write(s, nil, msg, &Options{}) 829 err = ct.Write(s, nil, msg, &Options{Last: true}) 830 if err != errStreamDone { 831 t.Fatalf("Write got %v, want io.EOF", err) 832 } 833 expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) 834 if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { 835 t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) 836 } 837 ct.Close(fmt.Errorf("closed manually by test")) 838 server.stop() 839 } 840 841 func (s) TestMaxStreams(t *testing.T) { 842 serverConfig := &ServerConfig{ 843 MaxStreams: 1, 844 } 845 server, ct, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 846 defer cancel() 847 defer ct.Close(fmt.Errorf("closed manually by test")) 848 defer server.stop() 849 callHdr := &CallHdr{ 850 Host: "localhost", 851 Method: "foo.Large", 852 } 853 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 854 defer cancel() 855 s, err := ct.NewStream(ctx, callHdr) 856 if err != nil { 857 t.Fatalf("Failed to open stream: %v", err) 858 } 859 // Keep creating streams until one fails with deadline exceeded, marking the application 860 // of server settings on client. 861 var slist []*Stream 862 pctx, cancel := context.WithCancel(context.Background()) 863 defer cancel() 864 timer := time.NewTimer(time.Second * 10) 865 expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) 866 for { 867 select { 868 case <-timer.C: 869 t.Fatalf("Test timeout: client didn't receive server settings.") 870 default: 871 } 872 ctx, cancel := context.WithDeadline(pctx, time.Now().Add(time.Second)) 873 // This is only to get rid of govet. All these context are based on a base 874 // context which is canceled at the end of the test. 875 //goland:noinspection GoDeferInLoop 876 defer cancel() 877 if str, err := ct.NewStream(ctx, callHdr); err == nil { 878 slist = append(slist, str) 879 continue 880 } else if err.Error() != expectedErr.Error() { 881 t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr) 882 } 883 timer.Stop() 884 break 885 } 886 done := make(chan struct{}) 887 // Try and create a new stream. 888 go func() { 889 defer close(done) 890 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) 891 defer cancel() 892 if _, err := ct.NewStream(ctx, callHdr); err != nil { 893 t.Errorf("Failed to open stream: %v", err) 894 } 895 }() 896 // Close all the extra streams created and make sure the new stream is not created. 897 for _, str := range slist { 898 ct.CloseStream(str, nil) 899 } 900 select { 901 case <-done: 902 t.Fatalf("Test failed: didn't expect new stream to be created just yet.") 903 default: 904 } 905 // Close the first stream created so that the new stream can finally be created. 906 ct.CloseStream(s, nil) 907 <-done 908 ct.Close(fmt.Errorf("closed manually by test")) 909 <-ct.writerDone 910 if ct.maxConcurrentStreams != 1 { 911 t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams) 912 } 913 } 914 915 func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { 916 server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended) 917 defer cancel() 918 callHdr := &CallHdr{ 919 Host: "localhost", 920 Method: "foo", 921 } 922 var sc *http2Server 923 // Wait until the server transport is setup. 924 for { 925 server.mu.Lock() 926 if len(server.conns) == 0 { 927 server.mu.Unlock() 928 time.Sleep(time.Millisecond) 929 continue 930 } 931 for k := range server.conns { 932 var ok bool 933 sc, ok = k.(*http2Server) 934 if !ok { 935 t.Fatalf("Failed to convert %v to *http2Server", k) 936 } 937 } 938 server.mu.Unlock() 939 break 940 } 941 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 942 defer cancel() 943 s, err := ct.NewStream(ctx, callHdr) 944 if err != nil { 945 t.Fatalf("Failed to open stream: %v", err) 946 } 947 _ = ct.controlBuf.put(&dataFrame{ 948 streamID: s.id, 949 endStream: false, 950 h: nil, 951 d: make([]byte, http2MaxFrameLen), 952 onEachWrite: func() {}, 953 }) 954 // Loop until the server side stream is created. 955 var ss *Stream 956 for { 957 time.Sleep(time.Second) 958 sc.mu.Lock() 959 if len(sc.activeStreams) == 0 { 960 sc.mu.Unlock() 961 continue 962 } 963 ss = sc.activeStreams[s.id] 964 sc.mu.Unlock() 965 break 966 } 967 ct.Close(fmt.Errorf("closed manually by test")) 968 select { 969 case <-ss.Context().Done(): 970 if ss.Context().Err() != context.Canceled { 971 t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled) 972 } 973 case <-time.After(5 * time.Second): 974 t.Fatalf("Failed to cancel the context of the sever side stream.") 975 } 976 server.stop() 977 } 978 979 func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) { 980 connectOptions := ConnectOptions{ 981 InitialWindowSize: defaultWindowSize, 982 InitialConnWindowSize: defaultWindowSize, 983 } 984 server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) 985 defer cancel() 986 defer server.stop() 987 defer client.Close(fmt.Errorf("closed manually by test")) 988 989 waitWhileTrue(t, func() (bool, error) { 990 server.mu.Lock() 991 defer server.mu.Unlock() 992 993 if len(server.conns) == 0 { 994 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 995 } 996 return false, nil 997 }) 998 999 var st *http2Server 1000 server.mu.Lock() 1001 for k := range server.conns { 1002 st = k.(*http2Server) 1003 } 1004 notifyChan := make(chan struct{}) 1005 server.h.notify = notifyChan 1006 server.mu.Unlock() 1007 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1008 defer cancel() 1009 cstream1, err := client.NewStream(ctx, &CallHdr{}) 1010 if err != nil { 1011 t.Fatalf("Client failed to create first stream. Err: %v", err) 1012 } 1013 1014 <-notifyChan 1015 var sstream1 *Stream 1016 // Access stream on the server. 1017 st.mu.Lock() 1018 for _, v := range st.activeStreams { 1019 if v.id == cstream1.id { 1020 sstream1 = v 1021 } 1022 } 1023 st.mu.Unlock() 1024 if sstream1 == nil { 1025 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) 1026 } 1027 // Exhaust client's connection window. 1028 if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { 1029 t.Fatalf("Server failed to write data. Err: %v", err) 1030 } 1031 notifyChan = make(chan struct{}) 1032 server.mu.Lock() 1033 server.h.notify = notifyChan 1034 server.mu.Unlock() 1035 // Create another stream on client. 1036 cstream2, err := client.NewStream(ctx, &CallHdr{}) 1037 if err != nil { 1038 t.Fatalf("Client failed to create second stream. Err: %v", err) 1039 } 1040 <-notifyChan 1041 var sstream2 *Stream 1042 st.mu.Lock() 1043 for _, v := range st.activeStreams { 1044 if v.id == cstream2.id { 1045 sstream2 = v 1046 } 1047 } 1048 st.mu.Unlock() 1049 if sstream2 == nil { 1050 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) 1051 } 1052 // Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream. 1053 if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { 1054 t.Fatalf("Server failed to write data. Err: %v", err) 1055 } 1056 1057 // Client should be able to read data on second stream. 1058 if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil { 1059 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) 1060 } 1061 1062 // Client should be able to read data on first stream. 1063 if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil { 1064 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) 1065 } 1066 } 1067 1068 func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) { 1069 serverConfig := &ServerConfig{ 1070 InitialWindowSize: defaultWindowSize, 1071 InitialConnWindowSize: defaultWindowSize, 1072 } 1073 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) 1074 defer cancel() 1075 defer server.stop() 1076 defer client.Close(fmt.Errorf("closed manually by test")) 1077 waitWhileTrue(t, func() (bool, error) { 1078 server.mu.Lock() 1079 defer server.mu.Unlock() 1080 1081 if len(server.conns) == 0 { 1082 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") 1083 } 1084 return false, nil 1085 }) 1086 var st *http2Server 1087 server.mu.Lock() 1088 for k := range server.conns { 1089 st = k.(*http2Server) 1090 } 1091 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1092 defer cancel() 1093 server.mu.Unlock() 1094 cstream1, err := client.NewStream(ctx, &CallHdr{}) 1095 if err != nil { 1096 t.Fatalf("Failed to create 1st stream. Err: %v", err) 1097 } 1098 // Exhaust server's connection window. 1099 if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { 1100 t.Fatalf("Client failed to write data. Err: %v", err) 1101 } 1102 //Client should be able to create another stream and send data on it. 1103 cstream2, err := client.NewStream(ctx, &CallHdr{}) 1104 if err != nil { 1105 t.Fatalf("Failed to create 2nd stream. Err: %v", err) 1106 } 1107 if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil { 1108 t.Fatalf("Client failed to write data. Err: %v", err) 1109 } 1110 // Get the streams on server. 1111 waitWhileTrue(t, func() (bool, error) { 1112 st.mu.Lock() 1113 defer st.mu.Unlock() 1114 1115 if len(st.activeStreams) != 2 { 1116 return true, fmt.Errorf("timed-out while waiting for server to have created the streams") 1117 } 1118 return false, nil 1119 }) 1120 var sstream1 *Stream 1121 st.mu.Lock() 1122 for _, v := range st.activeStreams { 1123 if v.id == 1 { 1124 sstream1 = v 1125 } 1126 } 1127 st.mu.Unlock() 1128 // Reading from the stream on server should succeed. 1129 if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil { 1130 t.Fatalf("_.Read(_) = %v, want <nil>", err) 1131 } 1132 1133 if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF { 1134 t.Fatalf("_.Read(_) = %v, want io.EOF", err) 1135 } 1136 1137 } 1138 1139 func (s) TestServerWithMisbehavedClient(t *testing.T) { 1140 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1141 defer server.stop() 1142 // Create a client that can override server stream quota. 1143 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1144 if err != nil { 1145 t.Fatalf("Clent failed to dial:%v", err) 1146 } 1147 defer func(mconn net.Conn) { 1148 _ = mconn.Close() 1149 }(mconn) 1150 if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil { 1151 t.Fatalf("Failed to set write deadline: %v", err) 1152 } 1153 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1154 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1155 } 1156 // success chan indicates that reader received a RSTStream from server. 1157 success := make(chan struct{}) 1158 var mu sync.Mutex 1159 framer := http2.NewFramer(mconn, mconn) 1160 if err := framer.WriteSettings(); err != nil { 1161 t.Fatalf("Error while writing settings: %v", err) 1162 } 1163 go func() { // Launch a reader for this misbehaving client. 1164 for { 1165 frame, err := framer.ReadFrame() 1166 if err != nil { 1167 return 1168 } 1169 switch frame := frame.(type) { 1170 case *http2.PingFrame: 1171 // Write ping ack back so that server's BDP estimation works right. 1172 mu.Lock() 1173 _ = framer.WritePing(true, frame.Data) 1174 mu.Unlock() 1175 case *http2.RSTStreamFrame: 1176 if frame.Header().StreamID != 1 || frame.ErrCode != http2.ErrCodeFlowControl { 1177 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, frame.ErrCode) 1178 } 1179 close(success) 1180 return 1181 default: 1182 // Do nothing. 1183 } 1184 1185 } 1186 }() 1187 // Create a stream. 1188 var buf bytes.Buffer 1189 henc := hpack.NewEncoder(&buf) 1190 // TODO(mmukhi): Remove unnecessary fields. 1191 if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil { 1192 t.Fatalf("Error while encoding header: %v", err) 1193 } 1194 if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { 1195 t.Fatalf("Error while encoding header: %v", err) 1196 } 1197 if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { 1198 t.Fatalf("Error while encoding header: %v", err) 1199 } 1200 if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { 1201 t.Fatalf("Error while encoding header: %v", err) 1202 } 1203 mu.Lock() 1204 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1205 mu.Unlock() 1206 t.Fatalf("Error while writing headers: %v", err) 1207 } 1208 mu.Unlock() 1209 1210 // Test server behavior for violation of stream flow control window size restriction. 1211 timer := time.NewTimer(time.Second * 5) 1212 dbuf := make([]byte, http2MaxFrameLen) 1213 for { 1214 select { 1215 case <-timer.C: 1216 t.Fatalf("Test timed out.") 1217 case <-success: 1218 return 1219 default: 1220 } 1221 mu.Lock() 1222 if err := framer.WriteData(1, false, dbuf); err != nil { 1223 mu.Unlock() 1224 // Error here means the server could have closed the connection due to flow control 1225 // violation. Make sure that is the case by waiting for success chan to be closed. 1226 select { 1227 case <-timer.C: 1228 t.Fatalf("Error while writing data: %v", err) 1229 case <-success: 1230 return 1231 } 1232 } 1233 mu.Unlock() 1234 // This for loop is capable of hogging the CPU and cause starvation 1235 // in Go versions prior to 1.9, 1236 // in single CPU environment. Explicitly relinquish processor. 1237 runtime.Gosched() 1238 } 1239 } 1240 1241 func (s) TestClientWithMisbehavedServer(t *testing.T) { 1242 // Create a misbehaving server. 1243 lis, err := net.Listen("tcp", "localhost:0") 1244 if err != nil { 1245 t.Fatalf("Error while listening: %v", err) 1246 } 1247 defer func(lis net.Listener) { 1248 _ = lis.Close() 1249 }(lis) 1250 // success chan indicates that the server received 1251 // RSTStream from the client. 1252 success := make(chan struct{}) 1253 go func() { // Launch the misbehaving server. 1254 sconn, err := lis.Accept() 1255 if err != nil { 1256 t.Errorf("Error while accepting: %v", err) 1257 return 1258 } 1259 defer func(sconn net.Conn) { 1260 _ = sconn.Close() 1261 }(sconn) 1262 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { 1263 t.Errorf("Error while reading clieng preface: %v", err) 1264 return 1265 } 1266 sfr := http2.NewFramer(sconn, sconn) 1267 if err := sfr.WriteSettingsAck(); err != nil { 1268 t.Errorf("Error while writing settings: %v", err) 1269 return 1270 } 1271 var mu sync.Mutex 1272 for { 1273 frame, err := sfr.ReadFrame() 1274 if err != nil { 1275 return 1276 } 1277 switch frame := frame.(type) { 1278 case *http2.HeadersFrame: 1279 // When the client creates a stream, violate the stream flow control. 1280 go func() { 1281 buf := make([]byte, http2MaxFrameLen) 1282 for { 1283 mu.Lock() 1284 if err := sfr.WriteData(1, false, buf); err != nil { 1285 mu.Unlock() 1286 return 1287 } 1288 mu.Unlock() 1289 // This for loop is capable of hogging the CPU and cause starvation 1290 // in Go versions prior to 1.9, 1291 // in single CPU environment. Explicitly relinquish processor. 1292 runtime.Gosched() 1293 } 1294 }() 1295 case *http2.RSTStreamFrame: 1296 if frame.Header().StreamID != 1 || frame.ErrCode != http2.ErrCodeFlowControl { 1297 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, frame.ErrCode) 1298 } 1299 close(success) 1300 return 1301 case *http2.PingFrame: 1302 mu.Lock() 1303 _ = sfr.WritePing(true, frame.Data) 1304 mu.Unlock() 1305 default: 1306 } 1307 } 1308 }() 1309 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 1310 defer cancel() 1311 ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func() {}, func(GoAwayReason) {}, func() {}) 1312 if err != nil { 1313 t.Fatalf("Error while creating client transport: %v", err) 1314 } 1315 defer ct.Close(fmt.Errorf("closed manually by test")) 1316 str, err := ct.NewStream(connectCtx, &CallHdr{}) 1317 if err != nil { 1318 t.Fatalf("Error while creating stream: %v", err) 1319 } 1320 timer := time.NewTimer(time.Second * 5) 1321 go func() { // This go routine mimics the one in stream.go to call CloseStream. 1322 <-str.Done() 1323 ct.CloseStream(str, nil) 1324 }() 1325 select { 1326 case <-timer.C: 1327 t.Fatalf("Test timed-out.") 1328 case <-success: 1329 } 1330 } 1331 1332 var encodingTestStatus = status.New(codes.Internal, "\n") 1333 1334 func (s) TestEncodingRequiredStatus(t *testing.T) { 1335 server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) 1336 defer cancel() 1337 callHdr := &CallHdr{ 1338 Host: "localhost", 1339 Method: "foo", 1340 } 1341 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1342 defer cancel() 1343 s, err := ct.NewStream(ctx, callHdr) 1344 if err != nil { 1345 return 1346 } 1347 opts := Options{Last: true} 1348 if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone { 1349 t.Fatalf("Failed to write the request: %v", err) 1350 } 1351 p := make([]byte, http2MaxFrameLen) 1352 if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF { 1353 t.Fatalf("Read got error %v, want %v", err, io.EOF) 1354 } 1355 if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) { 1356 t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) 1357 } 1358 ct.Close(fmt.Errorf("closed manually by test")) 1359 server.stop() 1360 } 1361 1362 func (s) TestInvalidHeaderField(t *testing.T) { 1363 server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) 1364 defer cancel() 1365 callHdr := &CallHdr{ 1366 Host: "localhost", 1367 Method: "foo", 1368 } 1369 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1370 defer cancel() 1371 s, err := ct.NewStream(ctx, callHdr) 1372 if err != nil { 1373 return 1374 } 1375 p := make([]byte, http2MaxFrameLen) 1376 _, err = s.trReader.(*transportReader).Read(p) 1377 if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { 1378 t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField) 1379 } 1380 ct.Close(fmt.Errorf("closed manually by test")) 1381 server.stop() 1382 } 1383 1384 func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { 1385 server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) 1386 defer cancel() 1387 defer server.stop() 1388 defer ct.Close(fmt.Errorf("closed manually by test")) 1389 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1390 defer cancel() 1391 s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) 1392 if err != nil { 1393 t.Fatalf("failed to create the stream") 1394 } 1395 timer := time.NewTimer(time.Second) 1396 defer timer.Stop() 1397 select { 1398 case <-s.headerChan: 1399 case <-timer.C: 1400 t.Errorf("s.headerChan: got open, want closed") 1401 } 1402 } 1403 1404 func (s) TestIsReservedHeader(t *testing.T) { 1405 tests := []struct { 1406 h string 1407 want bool 1408 }{ 1409 {"", false}, // but should be rejected earlier 1410 {"foo", false}, 1411 {"content-type", true}, 1412 {"user-agent", true}, 1413 {":anything", true}, 1414 {"grpc-message-type", true}, 1415 {"grpc-encoding", true}, 1416 {"grpc-message", true}, 1417 {"grpc-status", true}, 1418 {"grpc-timeout", true}, 1419 {"te", true}, 1420 } 1421 for _, tt := range tests { 1422 got := isReservedHeader(tt.h) 1423 if got != tt.want { 1424 t.Errorf("isReservedHeader(%q) = %v; want %v", tt.h, got, tt.want) 1425 } 1426 } 1427 } 1428 1429 func (s) TestContextErr(t *testing.T) { 1430 for _, test := range []struct { 1431 // input 1432 errIn error 1433 // outputs 1434 errOut error 1435 }{ 1436 {context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, 1437 {context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())}, 1438 } { 1439 err := ContextErr(test.errIn) 1440 if err.Error() != test.errOut.Error() { 1441 t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) 1442 } 1443 } 1444 } 1445 1446 type windowSizeConfig struct { 1447 serverStream int32 1448 serverConn int32 1449 clientStream int32 1450 clientConn int32 1451 } 1452 1453 func (s) TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) { 1454 wc := windowSizeConfig{ 1455 serverStream: 10 * 1024 * 1024, 1456 serverConn: 12 * 1024 * 1024, 1457 clientStream: 6 * 1024 * 1024, 1458 clientConn: 8 * 1024 * 1024, 1459 } 1460 testFlowControlAccountCheck(t, 1024*1024, wc) 1461 } 1462 1463 func (s) TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) { 1464 wc := windowSizeConfig{ 1465 serverStream: defaultWindowSize, 1466 // Note this is smaller than initialConnWindowSize which is the current default. 1467 serverConn: defaultWindowSize, 1468 clientStream: defaultWindowSize, 1469 clientConn: defaultWindowSize, 1470 } 1471 testFlowControlAccountCheck(t, 1024*1024, wc) 1472 } 1473 1474 func (s) TestAccountCheckDynamicWindowSmallMessage(t *testing.T) { 1475 testFlowControlAccountCheck(t, 1024, windowSizeConfig{}) 1476 } 1477 1478 func (s) TestAccountCheckDynamicWindowLargeMessage(t *testing.T) { 1479 testFlowControlAccountCheck(t, 1024*1024, windowSizeConfig{}) 1480 } 1481 1482 func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) { 1483 sc := &ServerConfig{ 1484 InitialWindowSize: wc.serverStream, 1485 InitialConnWindowSize: wc.serverConn, 1486 } 1487 co := ConnectOptions{ 1488 InitialWindowSize: wc.clientStream, 1489 InitialConnWindowSize: wc.clientConn, 1490 } 1491 server, client, cancel := setUpWithOptions(t, 0, sc, pingpong, co) 1492 defer cancel() 1493 defer server.stop() 1494 defer client.Close(fmt.Errorf("closed manually by test")) 1495 waitWhileTrue(t, func() (bool, error) { 1496 server.mu.Lock() 1497 defer server.mu.Unlock() 1498 if len(server.conns) == 0 { 1499 return true, fmt.Errorf("timed out while waiting for server transport to be created") 1500 } 1501 return false, nil 1502 }) 1503 var st *http2Server 1504 server.mu.Lock() 1505 for k := range server.conns { 1506 st = k.(*http2Server) 1507 } 1508 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1509 defer cancel() 1510 server.mu.Unlock() 1511 const numStreams = 10 1512 clientStreams := make([]*Stream, numStreams) 1513 for i := 0; i < numStreams; i++ { 1514 var err error 1515 clientStreams[i], err = client.NewStream(ctx, &CallHdr{}) 1516 if err != nil { 1517 t.Fatalf("Failed to create stream. Err: %v", err) 1518 } 1519 } 1520 var wg sync.WaitGroup 1521 // For each stream send pingpong messages to the server. 1522 for _, stream := range clientStreams { 1523 wg.Add(1) 1524 go func(stream *Stream) { 1525 defer wg.Done() 1526 buf := make([]byte, msgSize+5) 1527 buf[0] = byte(0) 1528 binary.BigEndian.PutUint32(buf[1:], uint32(msgSize)) 1529 opts := Options{} 1530 header := make([]byte, 5) 1531 for i := 1; i <= 10; i++ { 1532 if err := client.Write(stream, nil, buf, &opts); err != nil { 1533 t.Errorf("Error on client while writing message: %v", err) 1534 return 1535 } 1536 if _, err := stream.Read(header); err != nil { 1537 t.Errorf("Error on client while reading data frame header: %v", err) 1538 return 1539 } 1540 sz := binary.BigEndian.Uint32(header[1:]) 1541 recvMsg := make([]byte, int(sz)) 1542 if _, err := stream.Read(recvMsg); err != nil { 1543 t.Errorf("Error on client while reading data: %v", err) 1544 return 1545 } 1546 if len(recvMsg) != msgSize { 1547 t.Errorf("Length of message received by client: %v, want: %v", len(recvMsg), msgSize) 1548 return 1549 } 1550 } 1551 }(stream) 1552 } 1553 wg.Wait() 1554 serverStreams := map[uint32]*Stream{} 1555 loopyClientStreams := map[uint32]*outStream{} 1556 loopyServerStreams := map[uint32]*outStream{} 1557 // Get all the streams from server reader and writer and client writer. 1558 st.mu.Lock() 1559 for _, stream := range clientStreams { 1560 id := stream.id 1561 serverStreams[id] = st.activeStreams[id] 1562 loopyServerStreams[id] = st.loopy.estdStreams[id] 1563 loopyClientStreams[id] = client.loopy.estdStreams[id] 1564 1565 } 1566 st.mu.Unlock() 1567 // Close all streams 1568 for _, stream := range clientStreams { 1569 _ = client.Write(stream, nil, nil, &Options{Last: true}) 1570 if _, err := stream.Read(make([]byte, 5)); err != io.EOF { 1571 t.Fatalf("Client expected an EOF from the server. Got: %v", err) 1572 } 1573 } 1574 // Close down both server and client so that their internals can be read without data 1575 // races. 1576 client.Close(fmt.Errorf("closed manually by test")) 1577 st.Close() 1578 <-st.readerDone 1579 <-st.writerDone 1580 <-client.readerDone 1581 <-client.writerDone 1582 for _, cstream := range clientStreams { 1583 id := cstream.id 1584 sstream := serverStreams[id] 1585 loopyServerStream := loopyServerStreams[id] 1586 loopyClientStream := loopyClientStreams[id] 1587 // Check stream flow control. 1588 if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { 1589 t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding) 1590 } 1591 if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(client.loopy.oiws)-loopyClientStream.bytesOutStanding { 1592 t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, client.loopy.oiws, loopyClientStream.bytesOutStanding) 1593 } 1594 } 1595 // Check transport flow control. 1596 if client.fc.limit != client.fc.unacked+st.loopy.sendQuota { 1597 t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", client.fc.limit, client.fc.unacked, st.loopy.sendQuota) 1598 } 1599 if st.fc.limit != st.fc.unacked+client.loopy.sendQuota { 1600 t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, client.loopy.sendQuota) 1601 } 1602 } 1603 1604 func waitWhileTrue(t *testing.T, condition func() (bool, error)) { 1605 var ( 1606 wait bool 1607 err error 1608 ) 1609 timer := time.NewTimer(time.Second * 5) 1610 for { 1611 wait, err = condition() 1612 if wait { 1613 select { 1614 case <-timer.C: 1615 t.Fatalf(err.Error()) 1616 default: 1617 time.Sleep(50 * time.Millisecond) 1618 continue 1619 } 1620 } 1621 if !timer.Stop() { 1622 <-timer.C 1623 } 1624 break 1625 } 1626 } 1627 1628 // If any error occurs on a call to Stream.Read, future calls 1629 // should continue to return that same error. 1630 func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { 1631 testRecvBuffer := newRecvBuffer() 1632 s := &Stream{ 1633 ctx: context.Background(), 1634 buf: testRecvBuffer, 1635 requestRead: func(int) {}, 1636 } 1637 s.trReader = &transportReader{ 1638 reader: &recvBufferReader{ 1639 ctx: s.ctx, 1640 ctxDone: s.ctx.Done(), 1641 recv: s.buf, 1642 freeBuffer: func(*bytes.Buffer) {}, 1643 }, 1644 windowHandler: func(int) {}, 1645 } 1646 testData := make([]byte, 1) 1647 testData[0] = 5 1648 testBuffer := bytes.NewBuffer(testData) 1649 testErr := errors.New("test error") 1650 s.write(recvMsg{buffer: testBuffer, err: testErr}) 1651 1652 inBuf := make([]byte, 1) 1653 actualCount, actualErr := s.Read(inBuf) 1654 if actualCount != 0 { 1655 t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount) 1656 } 1657 if actualErr.Error() != testErr.Error() { 1658 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) 1659 } 1660 1661 s.write(recvMsg{buffer: testBuffer, err: nil}) 1662 s.write(recvMsg{buffer: testBuffer, err: errors.New("different error from first")}) 1663 1664 for i := 0; i < 2; i++ { 1665 inBuf := make([]byte, 1) 1666 actualCount, actualErr := s.Read(inBuf) 1667 if actualCount != 0 { 1668 t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount) 1669 } 1670 if actualErr.Error() != testErr.Error() { 1671 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) 1672 } 1673 } 1674 } 1675 1676 // TestHeadersCausingStreamError tests headers that should cause a stream protocol 1677 // error, which would end up with a RST_STREAM being sent to the client and also 1678 // the server closing the stream. 1679 func (s) TestHeadersCausingStreamError(t *testing.T) { 1680 tests := []struct { 1681 name string 1682 headers []struct { 1683 name string 1684 values []string 1685 } 1686 }{ 1687 // If the client sends an HTTP/2 request with a :method header with a 1688 // value other than POST, as specified in the gRPC over HTTP/2 1689 // specification, the server should close the stream. 1690 { 1691 name: "Client Sending Wrong Method", 1692 headers: []struct { 1693 name string 1694 values []string 1695 }{ 1696 {name: ":method", values: []string{"PUT"}}, 1697 {name: ":path", values: []string{"foo"}}, 1698 {name: ":authority", values: []string{"localhost"}}, 1699 {name: "content-type", values: []string{"application/grpc"}}, 1700 }, 1701 }, 1702 // "Transports must consider requests containing the Connection header 1703 // as malformed" - A41 Malformed requests map to a stream error of type 1704 // PROTOCOL_ERROR. 1705 { 1706 name: "Connection header present", 1707 headers: []struct { 1708 name string 1709 values []string 1710 }{ 1711 {name: ":method", values: []string{"POST"}}, 1712 {name: ":path", values: []string{"foo"}}, 1713 {name: ":authority", values: []string{"localhost"}}, 1714 {name: "content-type", values: []string{"application/grpc"}}, 1715 {name: "connection", values: []string{"not-supported"}}, 1716 }, 1717 }, 1718 // multiple :authority or multiple Host headers would make the eventual 1719 // :authority ambiguous as per A41. Since these headers won't have a 1720 // content-type that corresponds to a grpc-client, the server should 1721 // simply write a RST_STREAM to the wire. 1722 { 1723 // Note: multiple authority headers are handled by the framer 1724 // itself, which will cause a stream error. Thus, it will never get 1725 // to operateHeaders with the check in operateHeaders for stream 1726 // error, but the server transport will still send a stream error. 1727 name: "Multiple authority headers", 1728 headers: []struct { 1729 name string 1730 values []string 1731 }{ 1732 {name: ":method", values: []string{"POST"}}, 1733 {name: ":path", values: []string{"foo"}}, 1734 {name: ":authority", values: []string{"localhost", "localhost2"}}, 1735 {name: "host", values: []string{"localhost"}}, 1736 }, 1737 }, 1738 } 1739 for _, test := range tests { 1740 t.Run(test.name, func(t *testing.T) { 1741 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1742 defer server.stop() 1743 // Create a client directly to not tie what you can send to API of 1744 // http2_client.go (i.e. control headers being sent). 1745 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1746 if err != nil { 1747 t.Fatalf("Client failed to dial: %v", err) 1748 } 1749 defer func(mconn net.Conn) { 1750 _ = mconn.Close() 1751 }(mconn) 1752 1753 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1754 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1755 } 1756 1757 framer := http2.NewFramer(mconn, mconn) 1758 if err := framer.WriteSettings(); err != nil { 1759 t.Fatalf("Error while writing settings: %v", err) 1760 } 1761 1762 // result chan indicates that reader received a RSTStream from server. 1763 // An error will be passed on it if any other frame is received. 1764 result := testutils.NewChannel() 1765 1766 // Launch a reader goroutine. 1767 go func() { 1768 for { 1769 frame, err := framer.ReadFrame() 1770 if err != nil { 1771 return 1772 } 1773 switch frame := frame.(type) { 1774 case *http2.SettingsFrame: 1775 // Do nothing. A settings frame is expected from server preface. 1776 case *http2.RSTStreamFrame: 1777 if frame.Header().StreamID != 1 || frame.ErrCode != http2.ErrCodeProtocol { 1778 // Client only created a single stream, so RST Stream should be for that single stream. 1779 result.Send(fmt.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, frame.ErrCode)) 1780 } 1781 // Records that client successfully received RST Stream frame. 1782 result.Send(nil) 1783 return 1784 default: 1785 // The server should send nothing but a single RST Stream frame. 1786 result.Send(errors.New("the client received a frame other than RST Stream")) 1787 } 1788 } 1789 }() 1790 1791 var buf bytes.Buffer 1792 henc := hpack.NewEncoder(&buf) 1793 1794 // Needs to build headers deterministically to conform to gRPC over 1795 // HTTP/2 spec. 1796 for _, header := range test.headers { 1797 for _, value := range header.values { 1798 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { 1799 t.Fatalf("Error while encoding header: %v", err) 1800 } 1801 } 1802 } 1803 1804 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1805 t.Fatalf("Error while writing headers: %v", err) 1806 } 1807 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1808 defer cancel() 1809 r, err := result.Receive(ctx) 1810 if err != nil { 1811 t.Fatalf("Error receiving from channel: %v", err) 1812 } 1813 if r != nil { 1814 t.Fatalf("want nil, got %v", r) 1815 } 1816 }) 1817 } 1818 } 1819 1820 // TestHeadersMultipleHosts tests that a request with multiple hosts gets 1821 // rejected with HTTP Status 400 and gRPC status Internal, regardless of whether 1822 // the client is speaking gRPC or not. 1823 func (s) TestHeadersMultipleHosts(t *testing.T) { 1824 tests := []struct { 1825 name string 1826 headers []struct { 1827 name string 1828 values []string 1829 } 1830 }{ 1831 // Note: multiple authority headers are handled by the framer itself, 1832 // which will cause a stream error. Thus, it will never get to 1833 // operateHeaders with the check in operateHeaders for possible grpc-status sent back. 1834 1835 // multiple :authority or multiple Host headers would make the eventual 1836 // :authority ambiguous as per A41. This takes precedence even over the 1837 // fact a request is non grpc. All of these requests should be rejected 1838 // with grpc-status Internal. 1839 { 1840 name: "Multiple host headers non grpc", 1841 headers: []struct { 1842 name string 1843 values []string 1844 }{ 1845 {name: ":method", values: []string{"POST"}}, 1846 {name: ":path", values: []string{"foo"}}, 1847 {name: ":authority", values: []string{"localhost"}}, 1848 {name: "host", values: []string{"localhost", "localhost2"}}, 1849 }, 1850 }, 1851 { 1852 name: "Multiple host headers grpc", 1853 headers: []struct { 1854 name string 1855 values []string 1856 }{ 1857 {name: ":method", values: []string{"POST"}}, 1858 {name: ":path", values: []string{"foo"}}, 1859 {name: ":authority", values: []string{"localhost"}}, 1860 {name: "content-type", values: []string{"application/grpc"}}, 1861 {name: "host", values: []string{"localhost", "localhost2"}}, 1862 }, 1863 }, 1864 } 1865 for _, test := range tests { 1866 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) 1867 //goland:noinspection GoDeferInLoop 1868 defer server.stop() 1869 // Create a client directly to not tie what you can send to API of 1870 // http2_client.go (i.e. control headers being sent). 1871 mconn, err := net.Dial("tcp", server.lis.Addr().String()) 1872 if err != nil { 1873 t.Fatalf("Client failed to dial: %v", err) 1874 } 1875 //goland:noinspection GoDeferInLoop 1876 defer func(mconn net.Conn) { 1877 _ = mconn.Close() 1878 }(mconn) 1879 1880 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { 1881 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface)) 1882 } 1883 1884 framer := http2.NewFramer(mconn, mconn) 1885 framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil) 1886 if err := framer.WriteSettings(); err != nil { 1887 t.Fatalf("Error while writing settings: %v", err) 1888 } 1889 1890 // result chan indicates that reader received a Headers Frame with 1891 // desired grpc status and message from server. An error will be passed 1892 // on it if any other frame is received. 1893 result := testutils.NewChannel() 1894 1895 // Launch a reader goroutine. 1896 go func() { 1897 for { 1898 frame, err := framer.ReadFrame() 1899 if err != nil { 1900 return 1901 } 1902 switch frame := frame.(type) { 1903 case *http2.SettingsFrame: 1904 // Do nothing. A settings frame is expected from server preface. 1905 case *http2.MetaHeadersFrame: 1906 var resStatus, grpcStatus, grpcMessage string 1907 for _, header := range frame.Fields { 1908 if header.Name == ":status" { 1909 resStatus = header.Value 1910 } 1911 if header.Name == "grpc-status" { 1912 grpcStatus = header.Value 1913 } 1914 if header.Name == "grpc-message" { 1915 grpcMessage = header.Value 1916 } 1917 } 1918 if resStatus != "400" { 1919 result.Send(fmt.Errorf("incorrect HTTP Status got %v, want 200", resStatus)) 1920 return 1921 } 1922 if grpcStatus != "13" { // grpc status code internal 1923 result.Send(fmt.Errorf("incorrect gRPC Status got %v, want 13", grpcStatus)) 1924 return 1925 } 1926 if !strings.Contains(grpcMessage, "both must only have 1 value as per HTTP/2 spec") { 1927 result.Send(fmt.Errorf("incorrect gRPC message")) 1928 return 1929 } 1930 1931 // Records that client successfully received a HeadersFrame 1932 // with expected Trailers-Only response. 1933 result.Send(nil) 1934 return 1935 default: 1936 // The server should send nothing but a single Settings and Headers frame. 1937 result.Send(errors.New("the client received a frame other than Settings or Headers")) 1938 } 1939 } 1940 }() 1941 1942 var buf bytes.Buffer 1943 henc := hpack.NewEncoder(&buf) 1944 1945 // Needs to build headers deterministically to conform to gRPC over 1946 // HTTP/2 spec. 1947 for _, header := range test.headers { 1948 for _, value := range header.values { 1949 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { 1950 t.Fatalf("Error while encoding header: %v", err) 1951 } 1952 } 1953 } 1954 1955 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { 1956 t.Fatalf("Error while writing headers: %v", err) 1957 } 1958 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1959 //goland:noinspection GoDeferInLoop 1960 defer cancel() 1961 r, err := result.Receive(ctx) 1962 if err != nil { 1963 t.Fatalf("Error receiving from channel: %v", err) 1964 } 1965 if r != nil { 1966 t.Fatalf("want nil, got %v", r) 1967 } 1968 } 1969 } 1970 1971 func (s) TestPingPong1B(t *testing.T) { 1972 runPingPongTest(t, 1) 1973 } 1974 1975 func (s) TestPingPong1KB(t *testing.T) { 1976 runPingPongTest(t, 1024) 1977 } 1978 1979 func (s) TestPingPong64KB(t *testing.T) { 1980 runPingPongTest(t, 65536) 1981 } 1982 1983 func (s) TestPingPong1MB(t *testing.T) { 1984 runPingPongTest(t, 1048576) 1985 } 1986 1987 //This is a stress-test of flow control logic. 1988 func runPingPongTest(t *testing.T, msgSize int) { 1989 server, client, cancel := setUp(t, 0, 0, pingpong) 1990 defer cancel() 1991 defer server.stop() 1992 defer client.Close(fmt.Errorf("closed manually by test")) 1993 waitWhileTrue(t, func() (bool, error) { 1994 server.mu.Lock() 1995 defer server.mu.Unlock() 1996 if len(server.conns) == 0 { 1997 return true, fmt.Errorf("timed out while waiting for server transport to be created") 1998 } 1999 return false, nil 2000 }) 2001 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2002 defer cancel() 2003 stream, err := client.NewStream(ctx, &CallHdr{}) 2004 if err != nil { 2005 t.Fatalf("Failed to create stream. Err: %v", err) 2006 } 2007 msg := make([]byte, msgSize) 2008 outgoingHeader := make([]byte, 5) 2009 outgoingHeader[0] = byte(0) 2010 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize)) 2011 opts := &Options{} 2012 incomingHeader := make([]byte, 5) 2013 done := make(chan struct{}) 2014 go func() { 2015 timer := time.NewTimer(time.Second * 5) 2016 <-timer.C 2017 close(done) 2018 }() 2019 for { 2020 select { 2021 case <-done: 2022 _ = client.Write(stream, nil, nil, &Options{Last: true}) 2023 if _, err := stream.Read(incomingHeader); err != io.EOF { 2024 t.Fatalf("Client expected EOF from the server. Got: %v", err) 2025 } 2026 return 2027 default: 2028 if err := client.Write(stream, outgoingHeader, msg, opts); err != nil { 2029 t.Fatalf("Error on client while writing message. Err: %v", err) 2030 } 2031 if _, err := stream.Read(incomingHeader); err != nil { 2032 t.Fatalf("Error on client while reading data header. Err: %v", err) 2033 } 2034 sz := binary.BigEndian.Uint32(incomingHeader[1:]) 2035 recvMsg := make([]byte, int(sz)) 2036 if _, err := stream.Read(recvMsg); err != nil { 2037 t.Fatalf("Error on client while reading data. Err: %v", err) 2038 } 2039 } 2040 } 2041 } 2042 2043 type tableSizeLimit struct { 2044 mu sync.Mutex 2045 limits []uint32 2046 } 2047 2048 func (t *tableSizeLimit) add(limit uint32) { 2049 t.mu.Lock() 2050 t.limits = append(t.limits, limit) 2051 t.mu.Unlock() 2052 } 2053 2054 func (t *tableSizeLimit) getLen() int { 2055 t.mu.Lock() 2056 defer t.mu.Unlock() 2057 return len(t.limits) 2058 } 2059 2060 func (t *tableSizeLimit) getIndex(i int) uint32 { 2061 t.mu.Lock() 2062 defer t.mu.Unlock() 2063 return t.limits[i] 2064 } 2065 2066 func (s) TestHeaderTblSize(t *testing.T) { 2067 limits := &tableSizeLimit{} 2068 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { 2069 e.SetMaxDynamicTableSizeLimit(v) 2070 limits.add(v) 2071 } 2072 defer func() { 2073 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { 2074 e.SetMaxDynamicTableSizeLimit(v) 2075 } 2076 }() 2077 2078 server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) 2079 defer cancel() 2080 defer ct.Close(fmt.Errorf("closed manually by test")) 2081 defer server.stop() 2082 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) 2083 defer ctxCancel() 2084 _, err := ct.NewStream(ctx, &CallHdr{}) 2085 if err != nil { 2086 t.Fatalf("failed to open stream: %v", err) 2087 } 2088 2089 var svrTransport ServerTransport 2090 var i int 2091 for i = 0; i < 1000; i++ { 2092 server.mu.Lock() 2093 if len(server.conns) != 0 { 2094 server.mu.Unlock() 2095 break 2096 } 2097 server.mu.Unlock() 2098 time.Sleep(10 * time.Millisecond) 2099 continue 2100 } 2101 if i == 1000 { 2102 t.Fatalf("unable to create any server transport after 10s") 2103 } 2104 2105 for st := range server.conns { 2106 svrTransport = st 2107 break 2108 } 2109 _ = svrTransport.(*http2Server).controlBuf.put(&outgoingSettings{ 2110 ss: []http2.Setting{ 2111 { 2112 ID: http2.SettingHeaderTableSize, 2113 Val: uint32(100), 2114 }, 2115 }, 2116 }) 2117 2118 for i = 0; i < 1000; i++ { 2119 if limits.getLen() != 1 { 2120 time.Sleep(10 * time.Millisecond) 2121 continue 2122 } 2123 if val := limits.getIndex(0); val != uint32(100) { 2124 t.Fatalf("expected limits[0] = 100, got %d", val) 2125 } 2126 break 2127 } 2128 if i == 1000 { 2129 t.Fatalf("expected len(limits) = 1 within 10s, got != 1") 2130 } 2131 2132 _ = ct.controlBuf.put(&outgoingSettings{ 2133 ss: []http2.Setting{ 2134 { 2135 ID: http2.SettingHeaderTableSize, 2136 Val: uint32(200), 2137 }, 2138 }, 2139 }) 2140 2141 for i := 0; i < 1000; i++ { 2142 if limits.getLen() != 2 { 2143 time.Sleep(10 * time.Millisecond) 2144 continue 2145 } 2146 if val := limits.getIndex(1); val != uint32(200) { 2147 t.Fatalf("expected limits[1] = 200, got %d", val) 2148 } 2149 break 2150 } 2151 if i == 1000 { 2152 t.Fatalf("expected len(limits) = 2 within 10s, got != 2") 2153 } 2154 } 2155 2156 // attrTransportCreds is a transport credential implementation which stores 2157 // Attributes from the ClientHandshakeInfo struct passed in the context locally 2158 // for the test to inspect. 2159 type attrTransportCreds struct { 2160 credentials.TransportCredentials 2161 attr *attributes.Attributes 2162 } 2163 2164 //goland:noinspection GoUnusedParameter 2165 func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 2166 ai := credentials.ClientHandshakeInfoFromContext(ctx) 2167 ac.attr = ai.Attributes 2168 return rawConn, nil, nil 2169 } 2170 func (ac *attrTransportCreds) Info() credentials.ProtocolInfo { 2171 return credentials.ProtocolInfo{} 2172 } 2173 func (ac *attrTransportCreds) Clone() credentials.TransportCredentials { 2174 return nil 2175 } 2176 2177 // TestClientHandshakeInfo adds attributes to the resolver.Address passes to 2178 // NewClientTransport and verifies that these attributes are received by the 2179 // transport credential handshaker. 2180 func (s) TestClientHandshakeInfo(t *testing.T) { 2181 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong) 2182 defer server.stop() 2183 2184 const ( 2185 testAttrKey = "foo" 2186 testAttrVal = "bar" 2187 ) 2188 addr := resolver.Address{ 2189 Addr: "localhost:" + server.port, 2190 Attributes: attributes.New(testAttrKey, testAttrVal), 2191 } 2192 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 2193 defer cancel() 2194 creds := &attrTransportCreds{} 2195 2196 tr, err := NewClientTransport(ctx, context.Background(), addr, ConnectOptions{TransportCredentials: creds}, func() {}, func(GoAwayReason) {}, func() {}) 2197 if err != nil { 2198 t.Fatalf("NewClientTransport(): %v", err) 2199 } 2200 defer tr.Close(fmt.Errorf("closed manually by test")) 2201 2202 wantAttr := attributes.New(testAttrKey, testAttrVal) 2203 if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 2204 t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr) 2205 } 2206 } 2207 2208 // TestClientHandshakeInfoDialer adds attributes to the resolver.Address passes to 2209 // NewClientTransport and verifies that these attributes are received by a custom 2210 // dialer. 2211 func (s) TestClientHandshakeInfoDialer(t *testing.T) { 2212 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong) 2213 defer server.stop() 2214 2215 const ( 2216 testAttrKey = "foo" 2217 testAttrVal = "bar" 2218 ) 2219 addr := resolver.Address{ 2220 Addr: "localhost:" + server.port, 2221 Attributes: attributes.New(testAttrKey, testAttrVal), 2222 } 2223 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) 2224 defer cancel() 2225 2226 var attr *attributes.Attributes 2227 dialer := func(ctx context.Context, addr string) (net.Conn, error) { 2228 ai := credentials.ClientHandshakeInfoFromContext(ctx) 2229 attr = ai.Attributes 2230 return (&net.Dialer{}).DialContext(ctx, "tcp", addr) 2231 } 2232 2233 tr, err := NewClientTransport(ctx, context.Background(), addr, ConnectOptions{Dialer: dialer}, func() {}, func(GoAwayReason) {}, func() {}) 2234 if err != nil { 2235 t.Fatalf("NewClientTransport(): %v", err) 2236 } 2237 defer tr.Close(fmt.Errorf("closed manually by test")) 2238 2239 wantAttr := attributes.New(testAttrKey, testAttrVal) 2240 if gotAttr := attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { 2241 t.Errorf("Received attributes %v in custom dialer, want %v", gotAttr, wantAttr) 2242 } 2243 } 2244 2245 func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { 2246 testStream := func() *Stream { 2247 return &Stream{ 2248 done: make(chan struct{}), 2249 headerChan: make(chan struct{}), 2250 buf: &recvBuffer{ 2251 c: make(chan recvMsg), 2252 mu: sync.Mutex{}, 2253 }, 2254 } 2255 } 2256 2257 testClient := func(ts *Stream) *http2Client { 2258 return &http2Client{ 2259 mu: sync.Mutex{}, 2260 activeStreams: map[uint32]*Stream{ 2261 0: ts, 2262 }, 2263 controlBuf: &controlBuffer{ 2264 ch: make(chan struct{}), 2265 done: make(chan struct{}), 2266 list: &itemList{}, 2267 }, 2268 } 2269 } 2270 2271 for _, test := range []struct { 2272 name string 2273 // input 2274 metaHeaderFrame *http2.MetaHeadersFrame 2275 // output 2276 wantStatus *status.Status 2277 }{ 2278 { 2279 name: "valid header", 2280 metaHeaderFrame: &http2.MetaHeadersFrame{ 2281 Fields: []hpack.HeaderField{ 2282 {Name: "content-type", Value: "application/grpc"}, 2283 {Name: "grpc-status", Value: "0"}, 2284 {Name: ":status", Value: "200"}, 2285 }, 2286 }, 2287 // no error 2288 wantStatus: status.New(codes.OK, ""), 2289 }, 2290 { 2291 name: "missing content-type header", 2292 metaHeaderFrame: &http2.MetaHeadersFrame{ 2293 Fields: []hpack.HeaderField{ 2294 {Name: "grpc-status", Value: "0"}, 2295 {Name: ":status", Value: "200"}, 2296 }, 2297 }, 2298 wantStatus: status.New( 2299 codes.Unknown, 2300 "malformed header: missing HTTP content-type", 2301 ), 2302 }, 2303 { 2304 name: "invalid grpc status header field", 2305 metaHeaderFrame: &http2.MetaHeadersFrame{ 2306 Fields: []hpack.HeaderField{ 2307 {Name: "content-type", Value: "application/grpc"}, 2308 {Name: "grpc-status", Value: "xxxx"}, 2309 {Name: ":status", Value: "200"}, 2310 }, 2311 }, 2312 wantStatus: status.New( 2313 codes.Internal, 2314 "transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", 2315 ), 2316 }, 2317 { 2318 name: "invalid http content type", 2319 metaHeaderFrame: &http2.MetaHeadersFrame{ 2320 Fields: []hpack.HeaderField{ 2321 {Name: "content-type", Value: "application/json"}, 2322 }, 2323 }, 2324 wantStatus: status.New( 2325 codes.Internal, 2326 "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"", 2327 ), 2328 }, 2329 { 2330 name: "http fallback and invalid http status", 2331 metaHeaderFrame: &http2.MetaHeadersFrame{ 2332 Fields: []hpack.HeaderField{ 2333 // No content type provided then fallback into handling http error. 2334 {Name: ":status", Value: "xxxx"}, 2335 }, 2336 }, 2337 wantStatus: status.New( 2338 codes.Internal, 2339 "transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", 2340 ), 2341 }, 2342 { 2343 name: "http2 frame size exceeds", 2344 metaHeaderFrame: &http2.MetaHeadersFrame{ 2345 Fields: nil, 2346 Truncated: true, 2347 }, 2348 wantStatus: status.New( 2349 codes.Internal, 2350 "peer header list size exceeded limit", 2351 ), 2352 }, 2353 { 2354 name: "bad status in grpc mode", 2355 metaHeaderFrame: &http2.MetaHeadersFrame{ 2356 Fields: []hpack.HeaderField{ 2357 {Name: "content-type", Value: "application/grpc"}, 2358 {Name: "grpc-status", Value: "0"}, 2359 {Name: ":status", Value: "504"}, 2360 }, 2361 }, 2362 wantStatus: status.New( 2363 codes.Unavailable, 2364 "unexpected HTTP status code received from server: 504 (Gateway Timeout)", 2365 ), 2366 }, 2367 { 2368 name: "missing http status", 2369 metaHeaderFrame: &http2.MetaHeadersFrame{ 2370 Fields: []hpack.HeaderField{ 2371 {Name: "content-type", Value: "application/grpc"}, 2372 }, 2373 }, 2374 wantStatus: status.New( 2375 codes.Internal, 2376 "malformed header: missing HTTP status", 2377 ), 2378 }, 2379 } { 2380 2381 t.Run(test.name, func(t *testing.T) { 2382 ts := testStream() 2383 s := testClient(ts) 2384 2385 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ 2386 FrameHeader: http2.FrameHeader{ 2387 StreamID: 0, 2388 }, 2389 } 2390 2391 s.operateHeaders(test.metaHeaderFrame) 2392 2393 got := ts.status 2394 want := test.wantStatus 2395 if got.Code() != want.Code() || got.Message() != want.Message() { 2396 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) 2397 } 2398 }) 2399 t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) { 2400 ts := testStream() 2401 s := testClient(ts) 2402 2403 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ 2404 FrameHeader: http2.FrameHeader{ 2405 StreamID: 0, 2406 Flags: http2.FlagHeadersEndStream, 2407 }, 2408 } 2409 2410 s.operateHeaders(test.metaHeaderFrame) 2411 2412 got := ts.status 2413 want := test.wantStatus 2414 if got.Code() != want.Code() || got.Message() != want.Message() { 2415 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) 2416 } 2417 }) 2418 } 2419 }