google.golang.org/grpc@v1.74.2/internal/transport/transport_test.go (about)

     1  /*
     2   *
     3   * Copyright 2014 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package transport
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"encoding/binary"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"math"
    29  	"net"
    30  	"os"
    31  	"runtime"
    32  	"strconv"
    33  	"strings"
    34  	"sync"
    35  	"sync/atomic"
    36  	"testing"
    37  	"time"
    38  
    39  	"github.com/google/go-cmp/cmp"
    40  	"golang.org/x/net/http2"
    41  	"golang.org/x/net/http2/hpack"
    42  	"google.golang.org/grpc/attributes"
    43  	"google.golang.org/grpc/codes"
    44  	"google.golang.org/grpc/credentials"
    45  	"google.golang.org/grpc/internal/channelz"
    46  	"google.golang.org/grpc/internal/grpctest"
    47  	"google.golang.org/grpc/internal/leakcheck"
    48  	"google.golang.org/grpc/internal/testutils"
    49  	"google.golang.org/grpc/mem"
    50  	"google.golang.org/grpc/metadata"
    51  	"google.golang.org/grpc/resolver"
    52  	"google.golang.org/grpc/status"
    53  )
    54  
    55  type s struct {
    56  	grpctest.Tester
    57  }
    58  
    59  func Test(t *testing.T) {
    60  	grpctest.RunSubTests(t, s{})
    61  }
    62  
    63  var (
    64  	expectedRequest            = []byte("ping")
    65  	expectedResponse           = []byte("pong")
    66  	expectedRequestLarge       = make([]byte, initialWindowSize*2)
    67  	expectedResponseLarge      = make([]byte, initialWindowSize*2)
    68  	expectedInvalidHeaderField = "invalid/content-type"
    69  )
    70  
    71  func init() {
    72  	expectedRequestLarge[0] = 'g'
    73  	expectedRequestLarge[len(expectedRequestLarge)-1] = 'r'
    74  	expectedResponseLarge[0] = 'p'
    75  	expectedResponseLarge[len(expectedResponseLarge)-1] = 'c'
    76  }
    77  
    78  func newBufferSlice(b []byte) mem.BufferSlice {
    79  	return mem.BufferSlice{mem.SliceBuffer(b)}
    80  }
    81  
    82  func (s *Stream) readTo(p []byte) (int, error) {
    83  	data, err := s.read(len(p))
    84  	defer data.Free()
    85  
    86  	if err != nil {
    87  		return 0, err
    88  	}
    89  
    90  	if data.Len() != len(p) {
    91  		if err == nil {
    92  			err = io.ErrUnexpectedEOF
    93  		}
    94  		return 0, err
    95  	}
    96  
    97  	data.CopyTo(p)
    98  	return len(p), nil
    99  }
   100  
   101  type testStreamHandler struct {
   102  	t           *http2Server
   103  	notify      chan struct{}
   104  	getNotified chan struct{}
   105  }
   106  
   107  type hType int
   108  
   109  const (
   110  	normal hType = iota
   111  	suspended
   112  	notifyCall
   113  	misbehaved
   114  	encodingRequiredStatus
   115  	invalidHeaderField
   116  	delayRead
   117  	pingpong
   118  )
   119  
   120  func (h *testStreamHandler) handleStreamAndNotify(*ServerStream) {
   121  	if h.notify == nil {
   122  		return
   123  	}
   124  	go func() {
   125  		select {
   126  		case <-h.notify:
   127  		default:
   128  			close(h.notify)
   129  		}
   130  	}()
   131  }
   132  
   133  func (h *testStreamHandler) handleStream(t *testing.T, s *ServerStream) {
   134  	req := expectedRequest
   135  	resp := expectedResponse
   136  	if s.Method() == "foo.Large" {
   137  		req = expectedRequestLarge
   138  		resp = expectedResponseLarge
   139  	}
   140  	p := make([]byte, len(req))
   141  	_, err := s.readTo(p)
   142  	if err != nil {
   143  		return
   144  	}
   145  	if !bytes.Equal(p, req) {
   146  		t.Errorf("handleStream got %v, want %v", p, req)
   147  		s.WriteStatus(status.New(codes.Internal, "panic"))
   148  		return
   149  	}
   150  	// send a response back to the client.
   151  	s.Write(nil, newBufferSlice(resp), &WriteOptions{})
   152  	// send the trailer to end the stream.
   153  	s.WriteStatus(status.New(codes.OK, ""))
   154  }
   155  
   156  func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *ServerStream) {
   157  	header := make([]byte, 5)
   158  	for {
   159  		if _, err := s.readTo(header); err != nil {
   160  			if err == io.EOF {
   161  				s.WriteStatus(status.New(codes.OK, ""))
   162  				return
   163  			}
   164  			t.Errorf("Error on server while reading data header: %v", err)
   165  			s.WriteStatus(status.New(codes.Internal, "panic"))
   166  			return
   167  		}
   168  		sz := binary.BigEndian.Uint32(header[1:])
   169  		msg := make([]byte, int(sz))
   170  		if _, err := s.readTo(msg); err != nil {
   171  			t.Errorf("Error on server while reading message: %v", err)
   172  			s.WriteStatus(status.New(codes.Internal, "panic"))
   173  			return
   174  		}
   175  		buf := make([]byte, sz+5)
   176  		buf[0] = byte(0)
   177  		binary.BigEndian.PutUint32(buf[1:], uint32(sz))
   178  		copy(buf[5:], msg)
   179  		s.Write(nil, newBufferSlice(buf), &WriteOptions{})
   180  	}
   181  }
   182  
   183  func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *ServerStream) {
   184  	conn, ok := s.st.(*http2Server)
   185  	if !ok {
   186  		t.Errorf("Failed to convert %v to *http2Server", s.st)
   187  		s.WriteStatus(status.New(codes.Internal, ""))
   188  		return
   189  	}
   190  	var sent int
   191  	p := make([]byte, http2MaxFrameLen)
   192  	for sent < initialWindowSize {
   193  		n := initialWindowSize - sent
   194  		// The last message may be smaller than http2MaxFrameLen
   195  		if n <= http2MaxFrameLen {
   196  			if s.Method() == "foo.Connection" {
   197  				// Violate connection level flow control window of client but do not
   198  				// violate any stream level windows.
   199  				p = make([]byte, n)
   200  			} else {
   201  				// Violate stream level flow control window of client.
   202  				p = make([]byte, n+1)
   203  			}
   204  		}
   205  		data := newBufferSlice(p)
   206  		data.Ref()
   207  		conn.controlBuf.put(&dataFrame{
   208  			streamID:    s.id,
   209  			h:           nil,
   210  			data:        data,
   211  			onEachWrite: func() {},
   212  		})
   213  		sent += len(p)
   214  	}
   215  }
   216  
   217  func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *ServerStream) {
   218  	// raw newline is not accepted by http2 framer so it must be encoded.
   219  	s.WriteStatus(encodingTestStatus)
   220  	// Drain any remaining buffers from the stream since it was closed early.
   221  	s.Read(math.MaxInt)
   222  }
   223  
   224  func (h *testStreamHandler) handleStreamInvalidHeaderField(s *ServerStream) {
   225  	headerFields := []hpack.HeaderField{}
   226  	headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
   227  	h.t.controlBuf.put(&headerFrame{
   228  		streamID:  s.id,
   229  		hf:        headerFields,
   230  		endStream: false,
   231  	})
   232  }
   233  
   234  // handleStreamDelayRead delays reads so that the other side has to halt on
   235  // stream-level flow control.
   236  // This handler assumes dynamic flow control is turned off and assumes window
   237  // sizes to be set to defaultWindowSize.
   238  func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream) {
   239  	req := expectedRequest
   240  	resp := expectedResponse
   241  	if s.Method() == "foo.Large" {
   242  		req = expectedRequestLarge
   243  		resp = expectedResponseLarge
   244  	}
   245  	var (
   246  		mu    sync.Mutex
   247  		total int
   248  	)
   249  	s.wq.replenish = func(n int) {
   250  		mu.Lock()
   251  		total += n
   252  		mu.Unlock()
   253  		s.wq.realReplenish(n)
   254  	}
   255  	getTotal := func() int {
   256  		mu.Lock()
   257  		defer mu.Unlock()
   258  		return total
   259  	}
   260  	done := make(chan struct{})
   261  	defer close(done)
   262  	go func() {
   263  		for {
   264  			select {
   265  			// Prevent goroutine from leaking.
   266  			case <-done:
   267  				return
   268  			default:
   269  			}
   270  			if getTotal() == defaultWindowSize {
   271  				// Signal the client to start reading and
   272  				// thereby send window update.
   273  				close(h.notify)
   274  				return
   275  			}
   276  			runtime.Gosched()
   277  		}
   278  	}()
   279  	p := make([]byte, len(req))
   280  
   281  	// Let the other side run out of stream-level window before
   282  	// starting to read and thereby sending a window update.
   283  	timer := time.NewTimer(time.Second * 10)
   284  	select {
   285  	case <-h.getNotified:
   286  		timer.Stop()
   287  	case <-timer.C:
   288  		t.Errorf("Server timed-out.")
   289  		return
   290  	}
   291  	_, err := s.readTo(p)
   292  	if err != nil {
   293  		t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
   294  		return
   295  	}
   296  
   297  	if !bytes.Equal(p, req) {
   298  		t.Errorf("handleStream got %v, want %v", p, req)
   299  		return
   300  	}
   301  	// This write will cause server to run out of stream level,
   302  	// flow control and the other side won't send a window update
   303  	// until that happens.
   304  	if err := s.Write(nil, newBufferSlice(resp), &WriteOptions{}); err != nil {
   305  		t.Errorf("server Write got %v, want <nil>", err)
   306  		return
   307  	}
   308  	// Read one more time to ensure that everything remains fine and
   309  	// that the goroutine, that we launched earlier to signal client
   310  	// to read, gets enough time to process.
   311  	_, err = s.readTo(p)
   312  	if err != nil {
   313  		t.Errorf("s.Read(_) = _, %v, want _, nil", err)
   314  		return
   315  	}
   316  	// send the trailer to end the stream.
   317  	if err := s.WriteStatus(status.New(codes.OK, "")); err != nil {
   318  		t.Errorf("server WriteStatus got %v, want <nil>", err)
   319  		return
   320  	}
   321  }
   322  
   323  type server struct {
   324  	lis              net.Listener
   325  	port             string
   326  	startedErr       chan error // error (or nil) with server start value
   327  	mu               sync.Mutex
   328  	conns            map[ServerTransport]net.Conn
   329  	h                *testStreamHandler
   330  	ready            chan struct{}
   331  	channelz         *channelz.Server
   332  	servingTasksDone chan struct{}
   333  }
   334  
   335  func newTestServer() *server {
   336  	return &server{
   337  		startedErr:       make(chan error, 1),
   338  		ready:            make(chan struct{}),
   339  		servingTasksDone: make(chan struct{}),
   340  		channelz:         channelz.RegisterServer("test server"),
   341  	}
   342  }
   343  
   344  // start starts server. Other goroutines should block on s.readyChan for further operations.
   345  func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) {
   346  	var err error
   347  	if port == 0 {
   348  		s.lis, err = net.Listen("tcp", "localhost:0")
   349  	} else {
   350  		s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
   351  	}
   352  	if err != nil {
   353  		s.startedErr <- fmt.Errorf("failed to listen: %v", err)
   354  		return
   355  	}
   356  	_, p, err := net.SplitHostPort(s.lis.Addr().String())
   357  	if err != nil {
   358  		s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err)
   359  		return
   360  	}
   361  	s.port = p
   362  	s.conns = make(map[ServerTransport]net.Conn)
   363  	s.startedErr <- nil
   364  	wg := sync.WaitGroup{}
   365  	defer func() {
   366  		wg.Wait()
   367  		close(s.servingTasksDone)
   368  	}()
   369  
   370  	for {
   371  		conn, err := s.lis.Accept()
   372  		if err != nil {
   373  			return
   374  		}
   375  		rawConn := conn
   376  		if serverConfig.MaxStreams == 0 {
   377  			serverConfig.MaxStreams = math.MaxUint32
   378  		}
   379  		transport, err := NewServerTransport(conn, serverConfig)
   380  		if err != nil {
   381  			return
   382  		}
   383  		s.mu.Lock()
   384  		if s.conns == nil {
   385  			s.mu.Unlock()
   386  			transport.Close(errors.New("s.conns is nil"))
   387  			return
   388  		}
   389  		s.conns[transport] = rawConn
   390  		h := &testStreamHandler{t: transport.(*http2Server)}
   391  		s.h = h
   392  		s.mu.Unlock()
   393  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   394  		defer cancel()
   395  		wg.Add(1)
   396  		switch ht {
   397  		case notifyCall:
   398  			go func() {
   399  				transport.HandleStreams(ctx, h.handleStreamAndNotify)
   400  				wg.Done()
   401  			}()
   402  		case suspended:
   403  			go func() {
   404  				transport.HandleStreams(ctx, func(*ServerStream) {})
   405  				wg.Done()
   406  			}()
   407  		case misbehaved:
   408  			go func() {
   409  				transport.HandleStreams(ctx, func(s *ServerStream) {
   410  					wg.Add(1)
   411  					go func() {
   412  						h.handleStreamMisbehave(t, s)
   413  						wg.Done()
   414  					}()
   415  				})
   416  				wg.Done()
   417  			}()
   418  		case encodingRequiredStatus:
   419  			go func() {
   420  				transport.HandleStreams(ctx, func(s *ServerStream) {
   421  					wg.Add(1)
   422  					go func() {
   423  						h.handleStreamEncodingRequiredStatus(s)
   424  						wg.Done()
   425  					}()
   426  				})
   427  				wg.Done()
   428  			}()
   429  		case invalidHeaderField:
   430  			go func() {
   431  				transport.HandleStreams(ctx, func(s *ServerStream) {
   432  					wg.Add(1)
   433  					go func() {
   434  						h.handleStreamInvalidHeaderField(s)
   435  						wg.Done()
   436  					}()
   437  				})
   438  				wg.Done()
   439  			}()
   440  		case delayRead:
   441  			h.notify = make(chan struct{})
   442  			h.getNotified = make(chan struct{})
   443  			s.mu.Lock()
   444  			close(s.ready)
   445  			s.mu.Unlock()
   446  			go func() {
   447  				transport.HandleStreams(ctx, func(s *ServerStream) {
   448  					wg.Add(1)
   449  					go func() {
   450  						h.handleStreamDelayRead(t, s)
   451  						wg.Done()
   452  					}()
   453  				})
   454  				wg.Done()
   455  			}()
   456  		case pingpong:
   457  			go func() {
   458  				transport.HandleStreams(ctx, func(s *ServerStream) {
   459  					wg.Add(1)
   460  					go func() {
   461  						h.handleStreamPingPong(t, s)
   462  						wg.Done()
   463  					}()
   464  				})
   465  				wg.Done()
   466  			}()
   467  		default:
   468  			go func() {
   469  				transport.HandleStreams(ctx, func(s *ServerStream) {
   470  					wg.Add(1)
   471  					go func() {
   472  						h.handleStream(t, s)
   473  						wg.Done()
   474  					}()
   475  				})
   476  				wg.Done()
   477  			}()
   478  		}
   479  	}
   480  }
   481  
   482  func (s *server) wait(t *testing.T, timeout time.Duration) {
   483  	select {
   484  	case err := <-s.startedErr:
   485  		if err != nil {
   486  			t.Fatal(err)
   487  		}
   488  	case <-time.After(timeout):
   489  		t.Fatalf("Timed out after %v waiting for server to be ready", timeout)
   490  	}
   491  }
   492  
   493  func (s *server) stop() {
   494  	s.lis.Close()
   495  	s.mu.Lock()
   496  	for c := range s.conns {
   497  		c.Close(errors.New("server Stop called"))
   498  	}
   499  	s.conns = nil
   500  	s.mu.Unlock()
   501  	<-s.servingTasksDone
   502  }
   503  
   504  func (s *server) addr() string {
   505  	if s.lis == nil {
   506  		return ""
   507  	}
   508  	return s.lis.Addr().String()
   509  }
   510  
   511  func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server {
   512  	server := newTestServer()
   513  	sc.ChannelzParent = server.channelz
   514  	go server.start(t, port, sc, ht)
   515  	server.wait(t, 2*time.Second)
   516  	return server
   517  }
   518  
   519  func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
   520  	return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
   521  }
   522  
   523  func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
   524  	server := setUpServerOnly(t, port, sc, ht)
   525  	addr := resolver.Address{Addr: "localhost:" + server.port}
   526  	copts.ChannelzParent = channelzSubChannel(t)
   527  
   528  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   529  	t.Cleanup(cancel)
   530  	connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second)
   531  	ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {})
   532  	if connErr != nil {
   533  		cCancel() // Do not cancel in success path.
   534  		t.Fatalf("failed to create transport: %v", connErr)
   535  	}
   536  	return server, ct.(*http2Client), cCancel
   537  }
   538  
   539  func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) {
   540  	lis, err := net.Listen("tcp", "localhost:0")
   541  	if err != nil {
   542  		t.Fatalf("Failed to listen: %v", err)
   543  	}
   544  	// Launch a non responsive server.
   545  	go func() {
   546  		defer lis.Close()
   547  		conn, err := lis.Accept()
   548  		if err != nil {
   549  			t.Errorf("Error at server-side while accepting: %v", err)
   550  			close(connCh)
   551  			return
   552  		}
   553  		framer := http2.NewFramer(conn, conn)
   554  		if err := framer.WriteSettings(); err != nil {
   555  			t.Errorf("Error at server-side while writing settings: %v", err)
   556  			close(connCh)
   557  			return
   558  		}
   559  		connCh <- conn
   560  	}()
   561  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   562  	t.Cleanup(cancel)
   563  	connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second)
   564  	tr, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
   565  	if err != nil {
   566  		cCancel() // Do not cancel in success path.
   567  		// Server clean-up.
   568  		lis.Close()
   569  		if conn, ok := <-connCh; ok {
   570  			conn.Close()
   571  		}
   572  		t.Fatalf("Failed to dial: %v", err)
   573  	}
   574  	return tr.(*http2Client), cCancel
   575  }
   576  
   577  // TestInflightStreamClosing ensures that closing in-flight stream
   578  // sends status error to concurrent stream reader.
   579  func (s) TestInflightStreamClosing(t *testing.T) {
   580  	serverConfig := &ServerConfig{}
   581  	server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
   582  	defer cancel()
   583  	defer server.stop()
   584  	defer client.Close(fmt.Errorf("closed manually by test"))
   585  
   586  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   587  	defer cancel()
   588  	stream, err := client.NewStream(ctx, &CallHdr{})
   589  	if err != nil {
   590  		t.Fatalf("Client failed to create RPC request: %v", err)
   591  	}
   592  
   593  	donec := make(chan struct{})
   594  	serr := status.Error(codes.Internal, "client connection is closing")
   595  	go func() {
   596  		defer close(donec)
   597  		if _, err := stream.readTo(make([]byte, defaultWindowSize)); err != serr {
   598  			t.Errorf("unexpected Stream error %v, expected %v", err, serr)
   599  		}
   600  	}()
   601  
   602  	// should unblock concurrent stream.Read
   603  	stream.Close(serr)
   604  
   605  	// wait for stream.Read error
   606  	timeout := time.NewTimer(5 * time.Second)
   607  	select {
   608  	case <-donec:
   609  		if !timeout.Stop() {
   610  			<-timeout.C
   611  		}
   612  	case <-timeout.C:
   613  		t.Fatalf("Test timed out, expected a status error.")
   614  	}
   615  }
   616  
   617  // Tests that when streamID > MaxStreamId, the current client transport drains.
   618  func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
   619  	server, ct, cancel := setUp(t, 0, normal)
   620  	defer cancel()
   621  	defer server.stop()
   622  	callHdr := &CallHdr{
   623  		Host:   "localhost",
   624  		Method: "foo.Small",
   625  	}
   626  
   627  	originalMaxStreamID := MaxStreamID
   628  	MaxStreamID = 3
   629  	defer func() {
   630  		MaxStreamID = originalMaxStreamID
   631  	}()
   632  
   633  	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   634  	defer ctxCancel()
   635  
   636  	s, err := ct.NewStream(ctx, callHdr)
   637  	if err != nil {
   638  		t.Fatalf("ct.NewStream() = %v", err)
   639  	}
   640  	if s.id != 1 {
   641  		t.Fatalf("Stream id: %d, want: 1", s.id)
   642  	}
   643  
   644  	if got, want := ct.stateForTesting(), reachable; got != want {
   645  		t.Fatalf("Client transport state %v, want %v", got, want)
   646  	}
   647  
   648  	// The expected stream ID here is 3 since stream IDs are incremented by 2.
   649  	s, err = ct.NewStream(ctx, callHdr)
   650  	if err != nil {
   651  		t.Fatalf("ct.NewStream() = %v", err)
   652  	}
   653  	if s.id != 3 {
   654  		t.Fatalf("Stream id: %d, want: 3", s.id)
   655  	}
   656  
   657  	// Verifying that ct.state is draining when next stream ID > MaxStreamId.
   658  	if got, want := ct.stateForTesting(), draining; got != want {
   659  		t.Fatalf("Client transport state %v, want %v", got, want)
   660  	}
   661  }
   662  
   663  func (s) TestClientSendAndReceive(t *testing.T) {
   664  	server, ct, cancel := setUp(t, 0, normal)
   665  	defer cancel()
   666  	callHdr := &CallHdr{
   667  		Host:   "localhost",
   668  		Method: "foo.Small",
   669  	}
   670  	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   671  	defer ctxCancel()
   672  	s1, err1 := ct.NewStream(ctx, callHdr)
   673  	if err1 != nil {
   674  		t.Fatalf("failed to open stream: %v", err1)
   675  	}
   676  	if s1.id != 1 {
   677  		t.Fatalf("wrong stream id: %d", s1.id)
   678  	}
   679  	s2, err2 := ct.NewStream(ctx, callHdr)
   680  	if err2 != nil {
   681  		t.Fatalf("failed to open stream: %v", err2)
   682  	}
   683  	if s2.id != 3 {
   684  		t.Fatalf("wrong stream id: %d", s2.id)
   685  	}
   686  	opts := WriteOptions{Last: true}
   687  	if err := s1.Write(nil, newBufferSlice(expectedRequest), &opts); err != nil && err != io.EOF {
   688  		t.Fatalf("failed to send data: %v", err)
   689  	}
   690  	p := make([]byte, len(expectedResponse))
   691  	_, recvErr := s1.readTo(p)
   692  	if recvErr != nil || !bytes.Equal(p, expectedResponse) {
   693  		t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
   694  	}
   695  	_, recvErr = s1.readTo(p)
   696  	if recvErr != io.EOF {
   697  		t.Fatalf("Error: %v; want <EOF>", recvErr)
   698  	}
   699  	ct.Close(fmt.Errorf("closed manually by test"))
   700  	server.stop()
   701  }
   702  
   703  func (s) TestClientErrorNotify(t *testing.T) {
   704  	server, ct, cancel := setUp(t, 0, normal)
   705  	defer cancel()
   706  	go server.stop()
   707  	// ct.reader should detect the error and activate ct.Error().
   708  	<-ct.Error()
   709  	ct.Close(fmt.Errorf("closed manually by test"))
   710  }
   711  
   712  func performOneRPC(ct ClientTransport) {
   713  	callHdr := &CallHdr{
   714  		Host:   "localhost",
   715  		Method: "foo.Small",
   716  	}
   717  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   718  	defer cancel()
   719  	s, err := ct.NewStream(ctx, callHdr)
   720  	if err != nil {
   721  		return
   722  	}
   723  	opts := WriteOptions{Last: true}
   724  	if err := s.Write([]byte{}, newBufferSlice(expectedRequest), &opts); err == nil || err == io.EOF {
   725  		time.Sleep(5 * time.Millisecond)
   726  		// The following s.Recv()'s could error out because the
   727  		// underlying transport is gone.
   728  		//
   729  		// Read response
   730  		p := make([]byte, len(expectedResponse))
   731  		s.readTo(p)
   732  		// Read io.EOF
   733  		s.readTo(p)
   734  	}
   735  }
   736  
   737  func (s) TestClientMix(t *testing.T) {
   738  	s, ct, cancel := setUp(t, 0, normal)
   739  	defer cancel()
   740  	time.AfterFunc(time.Second, s.stop)
   741  	go func(ct ClientTransport) {
   742  		<-ct.Error()
   743  		ct.Close(fmt.Errorf("closed manually by test"))
   744  	}(ct)
   745  	for i := 0; i < 750; i++ {
   746  		time.Sleep(2 * time.Millisecond)
   747  		go performOneRPC(ct)
   748  	}
   749  }
   750  
   751  func (s) TestLargeMessage(t *testing.T) {
   752  	server, ct, cancel := setUp(t, 0, normal)
   753  	defer cancel()
   754  	callHdr := &CallHdr{
   755  		Host:   "localhost",
   756  		Method: "foo.Large",
   757  	}
   758  	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   759  	defer ctxCancel()
   760  	var wg sync.WaitGroup
   761  	for i := 0; i < 2; i++ {
   762  		wg.Add(1)
   763  		go func() {
   764  			defer wg.Done()
   765  			s, err := ct.NewStream(ctx, callHdr)
   766  			if err != nil {
   767  				t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
   768  			}
   769  			if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{Last: true}); err != nil && err != io.EOF {
   770  				t.Errorf("%v.Write(_, _, _) = %v, want  <nil>", ct, err)
   771  			}
   772  			p := make([]byte, len(expectedResponseLarge))
   773  			if _, err := s.readTo(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
   774  				t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
   775  			}
   776  			if _, err = s.readTo(p); err != io.EOF {
   777  				t.Errorf("Failed to complete the stream %v; want <EOF>", err)
   778  			}
   779  		}()
   780  	}
   781  	wg.Wait()
   782  	ct.Close(fmt.Errorf("closed manually by test"))
   783  	server.stop()
   784  }
   785  
   786  func (s) TestLargeMessageWithDelayRead(t *testing.T) {
   787  	// Disable dynamic flow control.
   788  	sc := &ServerConfig{
   789  		InitialWindowSize:     defaultWindowSize,
   790  		InitialConnWindowSize: defaultWindowSize,
   791  		StaticWindowSize:      true,
   792  	}
   793  	co := ConnectOptions{
   794  		InitialWindowSize:     defaultWindowSize,
   795  		InitialConnWindowSize: defaultWindowSize,
   796  		StaticWindowSize:      true,
   797  	}
   798  	server, ct, cancel := setUpWithOptions(t, 0, sc, delayRead, co)
   799  	defer cancel()
   800  	defer server.stop()
   801  	defer ct.Close(fmt.Errorf("closed manually by test"))
   802  	server.mu.Lock()
   803  	ready := server.ready
   804  	server.mu.Unlock()
   805  	callHdr := &CallHdr{
   806  		Host:   "localhost",
   807  		Method: "foo.Large",
   808  	}
   809  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   810  	defer cancel()
   811  	s, err := ct.NewStream(ctx, callHdr)
   812  	if err != nil {
   813  		t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
   814  		return
   815  	}
   816  	// Wait for server's handler to be initialized
   817  	select {
   818  	case <-ready:
   819  	case <-ctx.Done():
   820  		t.Fatalf("Client timed out waiting for server handler to be initialized.")
   821  	}
   822  	server.mu.Lock()
   823  	serviceHandler := server.h
   824  	server.mu.Unlock()
   825  	var (
   826  		mu    sync.Mutex
   827  		total int
   828  	)
   829  	s.wq.replenish = func(n int) {
   830  		mu.Lock()
   831  		total += n
   832  		mu.Unlock()
   833  		s.wq.realReplenish(n)
   834  	}
   835  	getTotal := func() int {
   836  		mu.Lock()
   837  		defer mu.Unlock()
   838  		return total
   839  	}
   840  	done := make(chan struct{})
   841  	defer close(done)
   842  	go func() {
   843  		for {
   844  			select {
   845  			// Prevent goroutine from leaking in case of error.
   846  			case <-done:
   847  				return
   848  			default:
   849  			}
   850  			if getTotal() == defaultWindowSize {
   851  				// unblock server to be able to read and
   852  				// thereby send stream level window update.
   853  				close(serviceHandler.getNotified)
   854  				return
   855  			}
   856  			runtime.Gosched()
   857  		}
   858  	}()
   859  	// This write will cause client to run out of stream level,
   860  	// flow control and the other side won't send a window update
   861  	// until that happens.
   862  	if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{}); err != nil {
   863  		t.Fatalf("write(_, _, _) = %v, want  <nil>", err)
   864  	}
   865  	p := make([]byte, len(expectedResponseLarge))
   866  
   867  	// Wait for the other side to run out of stream level flow control before
   868  	// reading and thereby sending a window update.
   869  	select {
   870  	case <-serviceHandler.notify:
   871  	case <-ctx.Done():
   872  		t.Fatalf("Client timed out")
   873  	}
   874  	if _, err := s.readTo(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
   875  		t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
   876  	}
   877  	if err := s.Write([]byte{}, newBufferSlice(expectedRequestLarge), &WriteOptions{Last: true}); err != nil {
   878  		t.Fatalf("Write(_, _, _) = %v, want <nil>", err)
   879  	}
   880  	if _, err = s.readTo(p); err != io.EOF {
   881  		t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
   882  	}
   883  }
   884  
   885  // TestGracefulClose ensures that GracefulClose allows in-flight streams to
   886  // proceed until they complete naturally, while not allowing creation of new
   887  // streams during this window.
   888  func (s) TestGracefulClose(t *testing.T) {
   889  	leakcheck.SetTrackingBufferPool(t)
   890  	server, ct, cancel := setUp(t, 0, pingpong)
   891  	defer cancel()
   892  	defer func() {
   893  		// Stop the server's listener to make the server's goroutines terminate
   894  		// (after the last active stream is done).
   895  		server.lis.Close()
   896  		// Check for goroutine leaks (i.e. GracefulClose with an active stream
   897  		// doesn't eventually close the connection when that stream completes).
   898  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   899  		defer cancel()
   900  		leakcheck.CheckGoroutines(ctx, t)
   901  		leakcheck.CheckTrackingBufferPool()
   902  		// Correctly clean up the server
   903  		server.stop()
   904  	}()
   905  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   906  	defer cancel()
   907  
   908  	// Create a stream that will exist for this whole test and confirm basic
   909  	// functionality.
   910  	s, err := ct.NewStream(ctx, &CallHdr{})
   911  	if err != nil {
   912  		t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err)
   913  	}
   914  	msg := make([]byte, 1024)
   915  	outgoingHeader := make([]byte, 5)
   916  	outgoingHeader[0] = byte(0)
   917  	binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg)))
   918  	incomingHeader := make([]byte, 5)
   919  	if err := s.Write(outgoingHeader, newBufferSlice(msg), &WriteOptions{}); err != nil {
   920  		t.Fatalf("Error while writing: %v", err)
   921  	}
   922  	if _, err := s.readTo(incomingHeader); err != nil {
   923  		t.Fatalf("Error while reading: %v", err)
   924  	}
   925  	sz := binary.BigEndian.Uint32(incomingHeader[1:])
   926  	recvMsg := make([]byte, int(sz))
   927  	if _, err := s.readTo(recvMsg); err != nil {
   928  		t.Fatalf("Error while reading: %v", err)
   929  	}
   930  
   931  	// Gracefully close the transport, which should not affect the existing
   932  	// stream.
   933  	ct.GracefulClose()
   934  
   935  	var wg sync.WaitGroup
   936  	// Expect errors creating new streams because the client transport has been
   937  	// gracefully closed.
   938  	for i := 0; i < 200; i++ {
   939  		wg.Add(1)
   940  		go func() {
   941  			defer wg.Done()
   942  			_, err := ct.NewStream(ctx, &CallHdr{})
   943  			if err != nil && err.(*NewStreamError).Err == ErrConnClosing && err.(*NewStreamError).AllowTransparentRetry {
   944  				return
   945  			}
   946  			t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing)
   947  		}()
   948  	}
   949  
   950  	// Confirm the existing stream still functions as expected.
   951  	s.Write(nil, nil, &WriteOptions{Last: true})
   952  	if _, err := s.readTo(incomingHeader); err != io.EOF {
   953  		t.Fatalf("Client expected EOF from the server. Got: %v", err)
   954  	}
   955  	wg.Wait()
   956  }
   957  
   958  func (s) TestLargeMessageSuspension(t *testing.T) {
   959  	server, ct, cancel := setUp(t, 0, suspended)
   960  	defer cancel()
   961  	defer ct.Close(fmt.Errorf("closed manually by test"))
   962  	defer server.stop()
   963  	callHdr := &CallHdr{
   964  		Host:   "localhost",
   965  		Method: "foo.Large",
   966  	}
   967  	// Set a long enough timeout for writing a large message out.
   968  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   969  	defer cancel()
   970  	s, err := ct.NewStream(ctx, callHdr)
   971  	if err != nil {
   972  		t.Fatalf("failed to open stream: %v", err)
   973  	}
   974  	// Write should not be done successfully due to flow control.
   975  	msg := make([]byte, initialWindowSize*8)
   976  	s.Write(nil, newBufferSlice(msg), &WriteOptions{})
   977  	err = s.Write(nil, newBufferSlice(msg), &WriteOptions{Last: true})
   978  	if err != errStreamDone {
   979  		t.Fatalf("Write got %v, want io.EOF", err)
   980  	}
   981  	// The server will send an RST stream frame on observing the deadline
   982  	// expiration making the client stream fail with a DeadlineExceeded status.
   983  	_, err = s.readTo(make([]byte, 8))
   984  	if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
   985  		t.Fatalf("Read got unexpected error: %v, want status with code %v", err, codes.DeadlineExceeded)
   986  	}
   987  	if got, want := s.Status().Code(), codes.DeadlineExceeded; got != want {
   988  		t.Fatalf("Read got status %v with code %v, want %v", s.Status(), got, want)
   989  	}
   990  }
   991  
   992  func (s) TestMaxStreams(t *testing.T) {
   993  	serverConfig := &ServerConfig{
   994  		MaxStreams: 1,
   995  	}
   996  	server, ct, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
   997  	defer cancel()
   998  	defer ct.Close(fmt.Errorf("closed manually by test"))
   999  	defer server.stop()
  1000  	callHdr := &CallHdr{
  1001  		Host:   "localhost",
  1002  		Method: "foo.Large",
  1003  	}
  1004  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1005  	defer cancel()
  1006  	s, err := ct.NewStream(ctx, callHdr)
  1007  	if err != nil {
  1008  		t.Fatalf("Failed to open stream: %v", err)
  1009  	}
  1010  	// Keep creating streams until one fails with deadline exceeded, marking the application
  1011  	// of server settings on client.
  1012  	slist := []*ClientStream{}
  1013  	pctx, cancel := context.WithCancel(context.Background())
  1014  	defer cancel()
  1015  	timer := time.NewTimer(time.Second * 10)
  1016  	expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
  1017  	for {
  1018  		select {
  1019  		case <-timer.C:
  1020  			t.Fatalf("Test timeout: client didn't receive server settings.")
  1021  		default:
  1022  		}
  1023  		ctx, cancel := context.WithDeadline(pctx, time.Now().Add(time.Second))
  1024  		// This is only to get rid of govet. All these context are based on a base
  1025  		// context which is canceled at the end of the test.
  1026  		defer cancel()
  1027  		if str, err := ct.NewStream(ctx, callHdr); err == nil {
  1028  			slist = append(slist, str)
  1029  			continue
  1030  		} else if err.Error() != expectedErr.Error() {
  1031  			t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr)
  1032  		}
  1033  		timer.Stop()
  1034  		break
  1035  	}
  1036  	done := make(chan struct{})
  1037  	// Try and create a new stream.
  1038  	go func() {
  1039  		defer close(done)
  1040  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
  1041  		defer cancel()
  1042  		if _, err := ct.NewStream(ctx, callHdr); err != nil {
  1043  			t.Errorf("Failed to open stream: %v", err)
  1044  		}
  1045  	}()
  1046  	// Close all the extra streams created and make sure the new stream is not created.
  1047  	for _, str := range slist {
  1048  		str.Close(nil)
  1049  	}
  1050  	select {
  1051  	case <-done:
  1052  		t.Fatalf("Test failed: didn't expect new stream to be created just yet.")
  1053  	default:
  1054  	}
  1055  	// Close the first stream created so that the new stream can finally be created.
  1056  	s.Close(nil)
  1057  	<-done
  1058  	ct.Close(fmt.Errorf("closed manually by test"))
  1059  	<-ct.writerDone
  1060  	if ct.maxConcurrentStreams != 1 {
  1061  		t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams)
  1062  	}
  1063  }
  1064  
  1065  func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
  1066  	server, ct, cancel := setUp(t, 0, suspended)
  1067  	defer cancel()
  1068  	callHdr := &CallHdr{
  1069  		Host:   "localhost",
  1070  		Method: "foo",
  1071  	}
  1072  	var sc *http2Server
  1073  	// Wait until the server transport is setup.
  1074  	for {
  1075  		server.mu.Lock()
  1076  		if len(server.conns) == 0 {
  1077  			server.mu.Unlock()
  1078  			time.Sleep(time.Millisecond)
  1079  			continue
  1080  		}
  1081  		for k := range server.conns {
  1082  			var ok bool
  1083  			sc, ok = k.(*http2Server)
  1084  			if !ok {
  1085  				t.Fatalf("Failed to convert %v to *http2Server", k)
  1086  			}
  1087  		}
  1088  		server.mu.Unlock()
  1089  		break
  1090  	}
  1091  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1092  	defer cancel()
  1093  	s, err := ct.NewStream(ctx, callHdr)
  1094  	if err != nil {
  1095  		t.Fatalf("Failed to open stream: %v", err)
  1096  	}
  1097  	d := newBufferSlice(make([]byte, http2MaxFrameLen))
  1098  	d.Ref()
  1099  	ct.controlBuf.put(&dataFrame{
  1100  		streamID:    s.id,
  1101  		endStream:   false,
  1102  		h:           nil,
  1103  		data:        d,
  1104  		onEachWrite: func() {},
  1105  	})
  1106  	// Loop until the server side stream is created.
  1107  	var ss *ServerStream
  1108  	for {
  1109  		time.Sleep(time.Second)
  1110  		sc.mu.Lock()
  1111  		if len(sc.activeStreams) == 0 {
  1112  			sc.mu.Unlock()
  1113  			continue
  1114  		}
  1115  		ss = sc.activeStreams[s.id]
  1116  		sc.mu.Unlock()
  1117  		break
  1118  	}
  1119  	ct.Close(fmt.Errorf("closed manually by test"))
  1120  	select {
  1121  	case <-ss.Context().Done():
  1122  		if ss.Context().Err() != context.Canceled {
  1123  			t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled)
  1124  		}
  1125  	case <-time.After(5 * time.Second):
  1126  		t.Fatalf("Failed to cancel the context of the sever side stream.")
  1127  	}
  1128  	server.stop()
  1129  }
  1130  
  1131  func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) {
  1132  	connectOptions := ConnectOptions{
  1133  		InitialWindowSize:     defaultWindowSize,
  1134  		InitialConnWindowSize: defaultWindowSize,
  1135  	}
  1136  	server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions)
  1137  	defer cancel()
  1138  	defer server.stop()
  1139  	defer client.Close(fmt.Errorf("closed manually by test"))
  1140  
  1141  	waitWhileTrue(t, func() (bool, error) {
  1142  		server.mu.Lock()
  1143  		defer server.mu.Unlock()
  1144  
  1145  		if len(server.conns) == 0 {
  1146  			return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
  1147  		}
  1148  		return false, nil
  1149  	})
  1150  
  1151  	var st *http2Server
  1152  	server.mu.Lock()
  1153  	for k := range server.conns {
  1154  		st = k.(*http2Server)
  1155  	}
  1156  	notifyChan := make(chan struct{})
  1157  	server.h.notify = notifyChan
  1158  	server.mu.Unlock()
  1159  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1160  	defer cancel()
  1161  	cstream1, err := client.NewStream(ctx, &CallHdr{})
  1162  	if err != nil {
  1163  		t.Fatalf("Client failed to create first stream. Err: %v", err)
  1164  	}
  1165  
  1166  	<-notifyChan
  1167  	var sstream1 *ServerStream
  1168  	// Access stream on the server.
  1169  	st.mu.Lock()
  1170  	for _, v := range st.activeStreams {
  1171  		if v.id == cstream1.id {
  1172  			sstream1 = v
  1173  		}
  1174  	}
  1175  	st.mu.Unlock()
  1176  	if sstream1 == nil {
  1177  		t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
  1178  	}
  1179  	// Exhaust client's connection window.
  1180  	if err := sstream1.Write([]byte{}, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil {
  1181  		t.Fatalf("Server failed to write data. Err: %v", err)
  1182  	}
  1183  	notifyChan = make(chan struct{})
  1184  	server.mu.Lock()
  1185  	server.h.notify = notifyChan
  1186  	server.mu.Unlock()
  1187  	// Create another stream on client.
  1188  	cstream2, err := client.NewStream(ctx, &CallHdr{})
  1189  	if err != nil {
  1190  		t.Fatalf("Client failed to create second stream. Err: %v", err)
  1191  	}
  1192  	<-notifyChan
  1193  	var sstream2 *ServerStream
  1194  	st.mu.Lock()
  1195  	for _, v := range st.activeStreams {
  1196  		if v.id == cstream2.id {
  1197  			sstream2 = v
  1198  		}
  1199  	}
  1200  	st.mu.Unlock()
  1201  	if sstream2 == nil {
  1202  		t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
  1203  	}
  1204  	// Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream.
  1205  	if err := sstream2.Write([]byte{}, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil {
  1206  		t.Fatalf("Server failed to write data. Err: %v", err)
  1207  	}
  1208  
  1209  	// Client should be able to read data on second stream.
  1210  	if _, err := cstream2.readTo(make([]byte, defaultWindowSize)); err != nil {
  1211  		t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
  1212  	}
  1213  
  1214  	// Client should be able to read data on first stream.
  1215  	if _, err := cstream1.readTo(make([]byte, defaultWindowSize)); err != nil {
  1216  		t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
  1217  	}
  1218  }
  1219  
  1220  func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) {
  1221  	serverConfig := &ServerConfig{
  1222  		InitialWindowSize:     defaultWindowSize,
  1223  		InitialConnWindowSize: defaultWindowSize,
  1224  	}
  1225  	server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
  1226  	defer cancel()
  1227  	defer server.stop()
  1228  	defer client.Close(fmt.Errorf("closed manually by test"))
  1229  	waitWhileTrue(t, func() (bool, error) {
  1230  		server.mu.Lock()
  1231  		defer server.mu.Unlock()
  1232  
  1233  		if len(server.conns) == 0 {
  1234  			return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
  1235  		}
  1236  		return false, nil
  1237  	})
  1238  	var st *http2Server
  1239  	server.mu.Lock()
  1240  	for k := range server.conns {
  1241  		st = k.(*http2Server)
  1242  	}
  1243  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1244  	defer cancel()
  1245  	server.mu.Unlock()
  1246  	cstream1, err := client.NewStream(ctx, &CallHdr{})
  1247  	if err != nil {
  1248  		t.Fatalf("Failed to create 1st stream. Err: %v", err)
  1249  	}
  1250  	// Exhaust server's connection window.
  1251  	if err := cstream1.Write(nil, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{Last: true}); err != nil {
  1252  		t.Fatalf("Client failed to write data. Err: %v", err)
  1253  	}
  1254  	// Client should be able to create another stream and send data on it.
  1255  	cstream2, err := client.NewStream(ctx, &CallHdr{})
  1256  	if err != nil {
  1257  		t.Fatalf("Failed to create 2nd stream. Err: %v", err)
  1258  	}
  1259  	if err := cstream2.Write(nil, newBufferSlice(make([]byte, defaultWindowSize)), &WriteOptions{}); err != nil {
  1260  		t.Fatalf("Client failed to write data. Err: %v", err)
  1261  	}
  1262  	// Get the streams on server.
  1263  	waitWhileTrue(t, func() (bool, error) {
  1264  		st.mu.Lock()
  1265  		defer st.mu.Unlock()
  1266  
  1267  		if len(st.activeStreams) != 2 {
  1268  			return true, fmt.Errorf("timed-out while waiting for server to have created the streams")
  1269  		}
  1270  		return false, nil
  1271  	})
  1272  	var sstream1 *ServerStream
  1273  	st.mu.Lock()
  1274  	for _, v := range st.activeStreams {
  1275  		if v.id == 1 {
  1276  			sstream1 = v
  1277  		}
  1278  	}
  1279  	st.mu.Unlock()
  1280  	// Reading from the stream on server should succeed.
  1281  	if _, err := sstream1.readTo(make([]byte, defaultWindowSize)); err != nil {
  1282  		t.Fatalf("_.Read(_) = %v, want <nil>", err)
  1283  	}
  1284  
  1285  	if _, err := sstream1.readTo(make([]byte, 1)); err != io.EOF {
  1286  		t.Fatalf("_.Read(_) = %v, want io.EOF", err)
  1287  	}
  1288  
  1289  }
  1290  
  1291  func (s) TestServerWithMisbehavedClient(t *testing.T) {
  1292  	server := setUpServerOnly(t, 0, &ServerConfig{}, suspended)
  1293  	defer server.stop()
  1294  	// Create a client that can override server stream quota.
  1295  	mconn, err := net.Dial("tcp", server.lis.Addr().String())
  1296  	if err != nil {
  1297  		t.Fatalf("Clent failed to dial:%v", err)
  1298  	}
  1299  	defer mconn.Close()
  1300  	if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil {
  1301  		t.Fatalf("Failed to set write deadline: %v", err)
  1302  	}
  1303  	if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
  1304  		t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
  1305  	}
  1306  	// success chan indicates that reader received a RSTStream from server.
  1307  	success := make(chan struct{})
  1308  	var mu sync.Mutex
  1309  	framer := http2.NewFramer(mconn, mconn)
  1310  	if err := framer.WriteSettings(); err != nil {
  1311  		t.Fatalf("Error while writing settings: %v", err)
  1312  	}
  1313  	go func() { // Launch a reader for this misbehaving client.
  1314  		for {
  1315  			frame, err := framer.ReadFrame()
  1316  			if err != nil {
  1317  				return
  1318  			}
  1319  			switch frame := frame.(type) {
  1320  			case *http2.PingFrame:
  1321  				// Write ping ack back so that server's BDP estimation works right.
  1322  				mu.Lock()
  1323  				framer.WritePing(true, frame.Data)
  1324  				mu.Unlock()
  1325  			case *http2.RSTStreamFrame:
  1326  				if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl {
  1327  					t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))
  1328  				}
  1329  				close(success)
  1330  				return
  1331  			default:
  1332  				// Do nothing.
  1333  			}
  1334  
  1335  		}
  1336  	}()
  1337  	// Create a stream.
  1338  	var buf bytes.Buffer
  1339  	henc := hpack.NewEncoder(&buf)
  1340  	// TODO(mmukhi): Remove unnecessary fields.
  1341  	if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil {
  1342  		t.Fatalf("Error while encoding header: %v", err)
  1343  	}
  1344  	if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil {
  1345  		t.Fatalf("Error while encoding header: %v", err)
  1346  	}
  1347  	if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil {
  1348  		t.Fatalf("Error while encoding header: %v", err)
  1349  	}
  1350  	if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil {
  1351  		t.Fatalf("Error while encoding header: %v", err)
  1352  	}
  1353  	mu.Lock()
  1354  	if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
  1355  		mu.Unlock()
  1356  		t.Fatalf("Error while writing headers: %v", err)
  1357  	}
  1358  	mu.Unlock()
  1359  
  1360  	// Test server behavior for violation of stream flow control window size restriction.
  1361  	timer := time.NewTimer(time.Second * 5)
  1362  	dbuf := make([]byte, http2MaxFrameLen)
  1363  	for {
  1364  		select {
  1365  		case <-timer.C:
  1366  			t.Fatalf("Test timed out.")
  1367  		case <-success:
  1368  			return
  1369  		default:
  1370  		}
  1371  		mu.Lock()
  1372  		if err := framer.WriteData(1, false, dbuf); err != nil {
  1373  			mu.Unlock()
  1374  			// Error here means the server could have closed the connection due to flow control
  1375  			// violation. Make sure that is the case by waiting for success chan to be closed.
  1376  			select {
  1377  			case <-timer.C:
  1378  				t.Fatalf("Error while writing data: %v", err)
  1379  			case <-success:
  1380  				return
  1381  			}
  1382  		}
  1383  		mu.Unlock()
  1384  		// This for loop is capable of hogging the CPU and cause starvation
  1385  		// in Go versions prior to 1.9,
  1386  		// in single CPU environment. Explicitly relinquish processor.
  1387  		runtime.Gosched()
  1388  	}
  1389  }
  1390  
  1391  func (s) TestClientHonorsConnectContext(t *testing.T) {
  1392  	// Create a server that will not send a preface.
  1393  	lis, err := net.Listen("tcp", "localhost:0")
  1394  	if err != nil {
  1395  		t.Fatalf("Error while listening: %v", err)
  1396  	}
  1397  	defer lis.Close()
  1398  	go func() { // Launch the misbehaving server.
  1399  		sconn, err := lis.Accept()
  1400  		if err != nil {
  1401  			t.Errorf("Error while accepting: %v", err)
  1402  			return
  1403  		}
  1404  		defer sconn.Close()
  1405  		if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
  1406  			t.Errorf("Error while reading client preface: %v", err)
  1407  			return
  1408  		}
  1409  		sfr := http2.NewFramer(sconn, sconn)
  1410  		// Do not write a settings frame, but read from the conn forever.
  1411  		for {
  1412  			if _, err := sfr.ReadFrame(); err != nil {
  1413  				return
  1414  			}
  1415  		}
  1416  	}()
  1417  
  1418  	// Test context cancellation.
  1419  	timeBefore := time.Now()
  1420  	connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1421  	time.AfterFunc(100*time.Millisecond, cancel)
  1422  
  1423  	parent := channelzSubChannel(t)
  1424  	copts := ConnectOptions{ChannelzParent: parent}
  1425  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1426  	defer cancel()
  1427  	_, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
  1428  	if err == nil {
  1429  		t.Fatalf("NewHTTP2Client() returned successfully; wanted error")
  1430  	}
  1431  	t.Logf("NewHTTP2Client() = _, %v", err)
  1432  	if time.Since(timeBefore) > 3*time.Second {
  1433  		t.Fatalf("NewHTTP2Client returned > 2.9s after context cancellation")
  1434  	}
  1435  
  1436  	// Test context deadline.
  1437  	connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
  1438  	defer cancel()
  1439  	_, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
  1440  	if err == nil {
  1441  		t.Fatalf("NewHTTP2Client() returned successfully; wanted error")
  1442  	}
  1443  	t.Logf("NewHTTP2Client() = _, %v", err)
  1444  }
  1445  
  1446  func (s) TestClientWithMisbehavedServer(t *testing.T) {
  1447  	// Create a misbehaving server.
  1448  	lis, err := net.Listen("tcp", "localhost:0")
  1449  	if err != nil {
  1450  		t.Fatalf("Error while listening: %v", err)
  1451  	}
  1452  	defer lis.Close()
  1453  	// success chan indicates that the server received
  1454  	// RSTStream from the client.
  1455  	success := make(chan struct{})
  1456  	go func() { // Launch the misbehaving server.
  1457  		sconn, err := lis.Accept()
  1458  		if err != nil {
  1459  			t.Errorf("Error while accepting: %v", err)
  1460  			return
  1461  		}
  1462  		defer sconn.Close()
  1463  		if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
  1464  			t.Errorf("Error while reading client preface: %v", err)
  1465  			return
  1466  		}
  1467  		sfr := http2.NewFramer(sconn, sconn)
  1468  		if err := sfr.WriteSettings(); err != nil {
  1469  			t.Errorf("Error while writing settings: %v", err)
  1470  			return
  1471  		}
  1472  		if err := sfr.WriteSettingsAck(); err != nil {
  1473  			t.Errorf("Error while writing settings: %v", err)
  1474  			return
  1475  		}
  1476  		var mu sync.Mutex
  1477  		for {
  1478  			frame, err := sfr.ReadFrame()
  1479  			if err != nil {
  1480  				return
  1481  			}
  1482  			switch frame := frame.(type) {
  1483  			case *http2.HeadersFrame:
  1484  				// When the client creates a stream, violate the stream flow control.
  1485  				go func() {
  1486  					buf := make([]byte, http2MaxFrameLen)
  1487  					for {
  1488  						mu.Lock()
  1489  						if err := sfr.WriteData(1, false, buf); err != nil {
  1490  							mu.Unlock()
  1491  							return
  1492  						}
  1493  						mu.Unlock()
  1494  						// This for loop is capable of hogging the CPU and cause starvation
  1495  						// in Go versions prior to 1.9,
  1496  						// in single CPU environment. Explicitly relinquish processor.
  1497  						runtime.Gosched()
  1498  					}
  1499  				}()
  1500  			case *http2.RSTStreamFrame:
  1501  				if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl {
  1502  					t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))
  1503  				}
  1504  				close(success)
  1505  				return
  1506  			case *http2.PingFrame:
  1507  				mu.Lock()
  1508  				sfr.WritePing(true, frame.Data)
  1509  				mu.Unlock()
  1510  			default:
  1511  			}
  1512  		}
  1513  	}()
  1514  	connectCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
  1515  	defer cancel()
  1516  
  1517  	parent := channelzSubChannel(t)
  1518  	copts := ConnectOptions{ChannelzParent: parent}
  1519  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1520  	defer cancel()
  1521  	ct, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
  1522  	if err != nil {
  1523  		t.Fatalf("Error while creating client transport: %v", err)
  1524  	}
  1525  	defer ct.Close(fmt.Errorf("closed manually by test"))
  1526  
  1527  	str, err := ct.NewStream(connectCtx, &CallHdr{})
  1528  	if err != nil {
  1529  		t.Fatalf("Error while creating stream: %v", err)
  1530  	}
  1531  	timer := time.NewTimer(time.Second * 5)
  1532  	go func() { // This go routine mimics the one in stream.go to call CloseStream.
  1533  		<-str.Done()
  1534  		str.Close(nil)
  1535  	}()
  1536  	select {
  1537  	case <-timer.C:
  1538  		t.Fatalf("Test timed-out.")
  1539  	case <-success:
  1540  	}
  1541  	// Drain the remaining buffers in the stream by reading until an error is
  1542  	// encountered.
  1543  	str.Read(math.MaxInt)
  1544  }
  1545  
  1546  var encodingTestStatus = status.New(codes.Internal, "\n")
  1547  
  1548  func (s) TestEncodingRequiredStatus(t *testing.T) {
  1549  	server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
  1550  	defer cancel()
  1551  	callHdr := &CallHdr{
  1552  		Host:   "localhost",
  1553  		Method: "foo",
  1554  	}
  1555  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1556  	defer cancel()
  1557  	s, err := ct.NewStream(ctx, callHdr)
  1558  	if err != nil {
  1559  		return
  1560  	}
  1561  	opts := WriteOptions{Last: true}
  1562  	if err := s.Write(nil, newBufferSlice(expectedRequest), &opts); err != nil && err != errStreamDone {
  1563  		t.Fatalf("Failed to write the request: %v", err)
  1564  	}
  1565  	p := make([]byte, http2MaxFrameLen)
  1566  	if _, err := s.readTo(p); err != io.EOF {
  1567  		t.Fatalf("Read got error %v, want %v", err, io.EOF)
  1568  	}
  1569  	if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) {
  1570  		t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus)
  1571  	}
  1572  	ct.Close(fmt.Errorf("closed manually by test"))
  1573  	server.stop()
  1574  	// Drain any remaining buffers from the stream since it was closed early.
  1575  	s.Read(math.MaxInt)
  1576  }
  1577  
  1578  func (s) TestInvalidHeaderField(t *testing.T) {
  1579  	server, ct, cancel := setUp(t, 0, invalidHeaderField)
  1580  	defer cancel()
  1581  	callHdr := &CallHdr{
  1582  		Host:   "localhost",
  1583  		Method: "foo",
  1584  	}
  1585  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1586  	defer cancel()
  1587  	s, err := ct.NewStream(ctx, callHdr)
  1588  	if err != nil {
  1589  		return
  1590  	}
  1591  	p := make([]byte, http2MaxFrameLen)
  1592  	_, err = s.readTo(p)
  1593  	if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
  1594  		t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField)
  1595  	}
  1596  	ct.Close(fmt.Errorf("closed manually by test"))
  1597  	server.stop()
  1598  }
  1599  
  1600  func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
  1601  	server, ct, cancel := setUp(t, 0, invalidHeaderField)
  1602  	defer cancel()
  1603  	defer server.stop()
  1604  	defer ct.Close(fmt.Errorf("closed manually by test"))
  1605  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1606  	defer cancel()
  1607  	s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"})
  1608  	if err != nil {
  1609  		t.Fatalf("failed to create the stream")
  1610  	}
  1611  	timer := time.NewTimer(time.Second)
  1612  	defer timer.Stop()
  1613  	select {
  1614  	case <-s.headerChan:
  1615  	case <-timer.C:
  1616  		t.Errorf("s.headerChan: got open, want closed")
  1617  	}
  1618  }
  1619  
  1620  func (s) TestIsReservedHeader(t *testing.T) {
  1621  	tests := []struct {
  1622  		h    string
  1623  		want bool
  1624  	}{
  1625  		{"", false}, // but should be rejected earlier
  1626  		{"foo", false},
  1627  		{"content-type", true},
  1628  		{"user-agent", true},
  1629  		{":anything", true},
  1630  		{"grpc-message-type", true},
  1631  		{"grpc-encoding", true},
  1632  		{"grpc-message", true},
  1633  		{"grpc-status", true},
  1634  		{"grpc-timeout", true},
  1635  		{"te", true},
  1636  	}
  1637  	for _, tt := range tests {
  1638  		got := isReservedHeader(tt.h)
  1639  		if got != tt.want {
  1640  			t.Errorf("isReservedHeader(%q) = %v; want %v", tt.h, got, tt.want)
  1641  		}
  1642  	}
  1643  }
  1644  
  1645  func (s) TestContextErr(t *testing.T) {
  1646  	for _, test := range []struct {
  1647  		// input
  1648  		errIn error
  1649  		// outputs
  1650  		errOut error
  1651  	}{
  1652  		{context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())},
  1653  		{context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())},
  1654  	} {
  1655  		err := ContextErr(test.errIn)
  1656  		if err.Error() != test.errOut.Error() {
  1657  			t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
  1658  		}
  1659  	}
  1660  }
  1661  
  1662  type windowSizeConfig struct {
  1663  	serverStream int32
  1664  	serverConn   int32
  1665  	clientStream int32
  1666  	clientConn   int32
  1667  }
  1668  
  1669  func (s) TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) {
  1670  	wc := windowSizeConfig{
  1671  		serverStream: 10 * 1024 * 1024,
  1672  		serverConn:   12 * 1024 * 1024,
  1673  		clientStream: 6 * 1024 * 1024,
  1674  		clientConn:   8 * 1024 * 1024,
  1675  	}
  1676  	testFlowControlAccountCheck(t, 1024*1024, wc)
  1677  }
  1678  
  1679  func (s) TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) {
  1680  	// These settings disable dynamic window sizes based on BDP estimation;
  1681  	// must be at least defaultWindowSize or the setting is ignored.
  1682  	wc := windowSizeConfig{
  1683  		serverStream: defaultWindowSize,
  1684  		serverConn:   defaultWindowSize,
  1685  		clientStream: defaultWindowSize,
  1686  		clientConn:   defaultWindowSize,
  1687  	}
  1688  	testFlowControlAccountCheck(t, 1024*1024, wc)
  1689  }
  1690  
  1691  func (s) TestAccountCheckDynamicWindowSmallMessage(t *testing.T) {
  1692  	testFlowControlAccountCheck(t, 1024, windowSizeConfig{})
  1693  }
  1694  
  1695  func (s) TestAccountCheckDynamicWindowLargeMessage(t *testing.T) {
  1696  	testFlowControlAccountCheck(t, 1024*1024, windowSizeConfig{})
  1697  }
  1698  
  1699  func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) {
  1700  	sc := &ServerConfig{
  1701  		InitialWindowSize:     wc.serverStream,
  1702  		InitialConnWindowSize: wc.serverConn,
  1703  		StaticWindowSize:      true,
  1704  	}
  1705  	co := ConnectOptions{
  1706  		InitialWindowSize:     wc.clientStream,
  1707  		InitialConnWindowSize: wc.clientConn,
  1708  		StaticWindowSize:      true,
  1709  	}
  1710  	server, client, cancel := setUpWithOptions(t, 0, sc, pingpong, co)
  1711  	defer cancel()
  1712  	defer server.stop()
  1713  	defer client.Close(fmt.Errorf("closed manually by test"))
  1714  	waitWhileTrue(t, func() (bool, error) {
  1715  		server.mu.Lock()
  1716  		defer server.mu.Unlock()
  1717  		if len(server.conns) == 0 {
  1718  			return true, fmt.Errorf("timed out while waiting for server transport to be created")
  1719  		}
  1720  		return false, nil
  1721  	})
  1722  	var st *http2Server
  1723  	server.mu.Lock()
  1724  	for k := range server.conns {
  1725  		st = k.(*http2Server)
  1726  	}
  1727  	server.mu.Unlock()
  1728  
  1729  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1730  	defer cancel()
  1731  	const numStreams = 5
  1732  	clientStreams := make([]*ClientStream, numStreams)
  1733  	for i := 0; i < numStreams; i++ {
  1734  		var err error
  1735  		clientStreams[i], err = client.NewStream(ctx, &CallHdr{})
  1736  		if err != nil {
  1737  			t.Fatalf("Failed to create stream. Err: %v", err)
  1738  		}
  1739  	}
  1740  	var wg sync.WaitGroup
  1741  	// For each stream send pingpong messages to the server.
  1742  	for _, stream := range clientStreams {
  1743  		wg.Add(1)
  1744  		go func(stream *ClientStream) {
  1745  			defer wg.Done()
  1746  			buf := make([]byte, msgSize+5)
  1747  			buf[0] = byte(0)
  1748  			binary.BigEndian.PutUint32(buf[1:], uint32(msgSize))
  1749  			opts := WriteOptions{}
  1750  			header := make([]byte, 5)
  1751  			for i := 1; i <= 5; i++ {
  1752  				if err := stream.Write(nil, newBufferSlice(buf), &opts); err != nil {
  1753  					t.Errorf("Error on client while writing message %v on stream %v: %v", i, stream.id, err)
  1754  					return
  1755  				}
  1756  				if _, err := stream.readTo(header); err != nil {
  1757  					t.Errorf("Error on client while reading data frame header %v on stream %v: %v", i, stream.id, err)
  1758  					return
  1759  				}
  1760  				sz := binary.BigEndian.Uint32(header[1:])
  1761  				recvMsg := make([]byte, int(sz))
  1762  				if _, err := stream.readTo(recvMsg); err != nil {
  1763  					t.Errorf("Error on client while reading data %v on stream %v: %v", i, stream.id, err)
  1764  					return
  1765  				}
  1766  				if len(recvMsg) != msgSize {
  1767  					t.Errorf("Length of message %v received by client on stream %v: %v, want: %v", i, stream.id, len(recvMsg), msgSize)
  1768  					return
  1769  				}
  1770  			}
  1771  			t.Logf("stream %v done with pingpongs", stream.id)
  1772  		}(stream)
  1773  	}
  1774  	wg.Wait()
  1775  	serverStreams := map[uint32]*ServerStream{}
  1776  	loopyClientStreams := map[uint32]*outStream{}
  1777  	loopyServerStreams := map[uint32]*outStream{}
  1778  	// Get all the streams from server reader and writer and client writer.
  1779  	st.mu.Lock()
  1780  	client.mu.Lock()
  1781  	for _, stream := range clientStreams {
  1782  		id := stream.id
  1783  		serverStreams[id] = st.activeStreams[id]
  1784  		loopyServerStreams[id] = st.loopy.estdStreams[id]
  1785  		loopyClientStreams[id] = client.loopy.estdStreams[id]
  1786  
  1787  	}
  1788  	client.mu.Unlock()
  1789  	st.mu.Unlock()
  1790  	// Close all streams
  1791  	for _, stream := range clientStreams {
  1792  		stream.Write(nil, nil, &WriteOptions{Last: true})
  1793  		if _, err := stream.readTo(make([]byte, 5)); err != io.EOF {
  1794  			t.Fatalf("Client expected an EOF from the server. Got: %v", err)
  1795  		}
  1796  	}
  1797  	// Close down both server and client so that their internals can be read without data
  1798  	// races.
  1799  	client.Close(errors.New("closed manually by test"))
  1800  	st.Close(errors.New("closed manually by test"))
  1801  	<-st.readerDone
  1802  	<-st.loopyWriterDone
  1803  	<-client.readerDone
  1804  	<-client.writerDone
  1805  	for _, cstream := range clientStreams {
  1806  		id := cstream.id
  1807  		sstream := serverStreams[id]
  1808  		loopyServerStream := loopyServerStreams[id]
  1809  		loopyClientStream := loopyClientStreams[id]
  1810  		if loopyServerStream == nil {
  1811  			t.Fatalf("Unexpected nil loopyServerStream")
  1812  		}
  1813  		// Check stream flow control.
  1814  		if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding {
  1815  			t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding)
  1816  		}
  1817  		if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(client.loopy.oiws)-loopyClientStream.bytesOutStanding {
  1818  			t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, client.loopy.oiws, loopyClientStream.bytesOutStanding)
  1819  		}
  1820  	}
  1821  	// Check transport flow control.
  1822  	if client.fc.limit != client.fc.unacked+st.loopy.sendQuota {
  1823  		t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", client.fc.limit, client.fc.unacked, st.loopy.sendQuota)
  1824  	}
  1825  	if st.fc.limit != st.fc.unacked+client.loopy.sendQuota {
  1826  		t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, client.loopy.sendQuota)
  1827  	}
  1828  }
  1829  
  1830  func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
  1831  	var (
  1832  		wait bool
  1833  		err  error
  1834  	)
  1835  	timer := time.NewTimer(time.Second * 5)
  1836  	for {
  1837  		wait, err = condition()
  1838  		if wait {
  1839  			select {
  1840  			case <-timer.C:
  1841  				t.Fatal(err)
  1842  			default:
  1843  				time.Sleep(50 * time.Millisecond)
  1844  				continue
  1845  			}
  1846  		}
  1847  		if !timer.Stop() {
  1848  			<-timer.C
  1849  		}
  1850  		break
  1851  	}
  1852  }
  1853  
  1854  // If any error occurs on a call to Stream.Read, future calls
  1855  // should continue to return that same error.
  1856  func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
  1857  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1858  	defer cancel()
  1859  	testRecvBuffer := newRecvBuffer()
  1860  	s := &Stream{
  1861  		ctx:         ctx,
  1862  		buf:         testRecvBuffer,
  1863  		requestRead: func(int) {},
  1864  	}
  1865  	s.trReader = &transportReader{
  1866  		reader: &recvBufferReader{
  1867  			ctx:     s.ctx,
  1868  			ctxDone: s.ctx.Done(),
  1869  			recv:    s.buf,
  1870  		},
  1871  		windowHandler: func(int) {},
  1872  	}
  1873  	testData := make([]byte, 1)
  1874  	testData[0] = 5
  1875  	testErr := errors.New("test error")
  1876  	s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: testErr})
  1877  
  1878  	inBuf := make([]byte, 1)
  1879  	actualCount, actualErr := s.readTo(inBuf)
  1880  	if actualCount != 0 {
  1881  		t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
  1882  	}
  1883  	if actualErr.Error() != testErr.Error() {
  1884  		t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
  1885  	}
  1886  
  1887  	s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: nil})
  1888  	s.write(recvMsg{buffer: mem.SliceBuffer(testData), err: errors.New("different error from first")})
  1889  
  1890  	for i := 0; i < 2; i++ {
  1891  		inBuf := make([]byte, 1)
  1892  		actualCount, actualErr := s.readTo(inBuf)
  1893  		if actualCount != 0 {
  1894  			t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
  1895  		}
  1896  		if actualErr.Error() != testErr.Error() {
  1897  			t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
  1898  		}
  1899  	}
  1900  }
  1901  
  1902  // TestHeadersCausingStreamError tests headers that should cause a stream protocol
  1903  // error, which would end up with a RST_STREAM being sent to the client and also
  1904  // the server closing the stream.
  1905  func (s) TestHeadersCausingStreamError(t *testing.T) {
  1906  	tests := []struct {
  1907  		name    string
  1908  		headers []struct {
  1909  			name   string
  1910  			values []string
  1911  		}
  1912  	}{
  1913  		// "Transports must consider requests containing the Connection header
  1914  		// as malformed" - A41 Malformed requests map to a stream error of type
  1915  		// PROTOCOL_ERROR.
  1916  		{
  1917  			name: "Connection header present",
  1918  			headers: []struct {
  1919  				name   string
  1920  				values []string
  1921  			}{
  1922  				{name: ":method", values: []string{"POST"}},
  1923  				{name: ":path", values: []string{"foo"}},
  1924  				{name: ":authority", values: []string{"localhost"}},
  1925  				{name: "content-type", values: []string{"application/grpc"}},
  1926  				{name: "connection", values: []string{"not-supported"}},
  1927  			},
  1928  		},
  1929  		// multiple :authority or multiple Host headers would make the eventual
  1930  		// :authority ambiguous as per A41. Since these headers won't have a
  1931  		// content-type that corresponds to a grpc-client, the server should
  1932  		// simply write a RST_STREAM to the wire.
  1933  		{
  1934  			// Note: multiple authority headers are handled by the framer
  1935  			// itself, which will cause a stream error. Thus, it will never get
  1936  			// to operateHeaders with the check in operateHeaders for stream
  1937  			// error, but the server transport will still send a stream error.
  1938  			name: "Multiple authority headers",
  1939  			headers: []struct {
  1940  				name   string
  1941  				values []string
  1942  			}{
  1943  				{name: ":method", values: []string{"POST"}},
  1944  				{name: ":path", values: []string{"foo"}},
  1945  				{name: ":authority", values: []string{"localhost", "localhost2"}},
  1946  				{name: "host", values: []string{"localhost"}},
  1947  			},
  1948  		},
  1949  	}
  1950  	for _, test := range tests {
  1951  		t.Run(test.name, func(t *testing.T) {
  1952  			server := setUpServerOnly(t, 0, &ServerConfig{}, suspended)
  1953  			defer server.stop()
  1954  			// Create a client directly to not tie what you can send to API of
  1955  			// http2_client.go (i.e. control headers being sent).
  1956  			mconn, err := net.Dial("tcp", server.lis.Addr().String())
  1957  			if err != nil {
  1958  				t.Fatalf("Client failed to dial: %v", err)
  1959  			}
  1960  			defer mconn.Close()
  1961  
  1962  			if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
  1963  				t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
  1964  			}
  1965  
  1966  			framer := http2.NewFramer(mconn, mconn)
  1967  			if err := framer.WriteSettings(); err != nil {
  1968  				t.Fatalf("Error while writing settings: %v", err)
  1969  			}
  1970  
  1971  			// result chan indicates that reader received a RSTStream from server.
  1972  			// An error will be passed on it if any other frame is received.
  1973  			result := testutils.NewChannel()
  1974  
  1975  			// Launch a reader goroutine.
  1976  			go func() {
  1977  				for {
  1978  					frame, err := framer.ReadFrame()
  1979  					if err != nil {
  1980  						return
  1981  					}
  1982  					switch frame := frame.(type) {
  1983  					case *http2.SettingsFrame:
  1984  						// Do nothing. A settings frame is expected from server preface.
  1985  					case *http2.RSTStreamFrame:
  1986  						if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol {
  1987  							// Client only created a single stream, so RST Stream should be for that single stream.
  1988  							result.Send(fmt.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)))
  1989  						}
  1990  						// Records that client successfully received RST Stream frame.
  1991  						result.Send(nil)
  1992  						return
  1993  					default:
  1994  						// The server should send nothing but a single RST Stream frame.
  1995  						result.Send(errors.New("the client received a frame other than RST Stream"))
  1996  					}
  1997  				}
  1998  			}()
  1999  
  2000  			var buf bytes.Buffer
  2001  			henc := hpack.NewEncoder(&buf)
  2002  
  2003  			// Needs to build headers deterministically to conform to gRPC over
  2004  			// HTTP/2 spec.
  2005  			for _, header := range test.headers {
  2006  				for _, value := range header.values {
  2007  					if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil {
  2008  						t.Fatalf("Error while encoding header: %v", err)
  2009  					}
  2010  				}
  2011  			}
  2012  
  2013  			if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
  2014  				t.Fatalf("Error while writing headers: %v", err)
  2015  			}
  2016  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2017  			defer cancel()
  2018  			r, err := result.Receive(ctx)
  2019  			if err != nil {
  2020  				t.Fatalf("Error receiving from channel: %v", err)
  2021  			}
  2022  			if r != nil {
  2023  				t.Fatalf("want nil, got %v", r)
  2024  			}
  2025  		})
  2026  	}
  2027  }
  2028  
  2029  // TestHeadersHTTPStatusGRPCStatus tests requests with certain headers get a
  2030  // certain HTTP and gRPC status back.
  2031  func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) {
  2032  	tests := []struct {
  2033  		name    string
  2034  		headers []struct {
  2035  			name   string
  2036  			values []string
  2037  		}
  2038  		httpStatusWant  string
  2039  		grpcStatusWant  string
  2040  		grpcMessageWant string
  2041  	}{
  2042  		// Note: multiple authority headers are handled by the framer itself,
  2043  		// which will cause a stream error. Thus, it will never get to
  2044  		// operateHeaders with the check in operateHeaders for possible
  2045  		// grpc-status sent back.
  2046  
  2047  		// multiple :authority or multiple Host headers would make the eventual
  2048  		// :authority ambiguous as per A41. This takes precedence even over the
  2049  		// fact a request is non grpc. All of these requests should be rejected
  2050  		// with grpc-status Internal. Thus, requests with multiple hosts should
  2051  		// get rejected with HTTP Status 400 and gRPC status Internal,
  2052  		// regardless of whether the client is speaking gRPC or not.
  2053  		{
  2054  			name: "Multiple host headers non grpc",
  2055  			headers: []struct {
  2056  				name   string
  2057  				values []string
  2058  			}{
  2059  				{name: ":method", values: []string{"POST"}},
  2060  				{name: ":path", values: []string{"foo"}},
  2061  				{name: ":authority", values: []string{"localhost"}},
  2062  				{name: "host", values: []string{"localhost", "localhost2"}},
  2063  			},
  2064  			httpStatusWant:  "400",
  2065  			grpcStatusWant:  "13",
  2066  			grpcMessageWant: "both must only have 1 value as per HTTP/2 spec",
  2067  		},
  2068  		{
  2069  			name: "Multiple host headers grpc",
  2070  			headers: []struct {
  2071  				name   string
  2072  				values []string
  2073  			}{
  2074  				{name: ":method", values: []string{"POST"}},
  2075  				{name: ":path", values: []string{"foo"}},
  2076  				{name: ":authority", values: []string{"localhost"}},
  2077  				{name: "content-type", values: []string{"application/grpc"}},
  2078  				{name: "host", values: []string{"localhost", "localhost2"}},
  2079  			},
  2080  			httpStatusWant:  "400",
  2081  			grpcStatusWant:  "13",
  2082  			grpcMessageWant: "both must only have 1 value as per HTTP/2 spec",
  2083  		},
  2084  		// If the client sends an HTTP/2 request with a :method header with a
  2085  		// value other than POST, as specified in the gRPC over HTTP/2
  2086  		// specification, the server should fail the RPC.
  2087  		{
  2088  			name: "Client Sending Wrong Method",
  2089  			headers: []struct {
  2090  				name   string
  2091  				values []string
  2092  			}{
  2093  				{name: ":method", values: []string{"PUT"}},
  2094  				{name: ":path", values: []string{"foo"}},
  2095  				{name: ":authority", values: []string{"localhost"}},
  2096  				{name: "content-type", values: []string{"application/grpc"}},
  2097  			},
  2098  			httpStatusWant:  "405",
  2099  			grpcStatusWant:  "13",
  2100  			grpcMessageWant: "which should be POST",
  2101  		},
  2102  		{
  2103  			name: "Client Sending Wrong Content-Type",
  2104  			headers: []struct {
  2105  				name   string
  2106  				values []string
  2107  			}{
  2108  				{name: ":method", values: []string{"POST"}},
  2109  				{name: ":path", values: []string{"foo"}},
  2110  				{name: ":authority", values: []string{"localhost"}},
  2111  				{name: "content-type", values: []string{"application/json"}},
  2112  			},
  2113  			httpStatusWant:  "415",
  2114  			grpcStatusWant:  "3",
  2115  			grpcMessageWant: `invalid gRPC request content-type "application/json"`,
  2116  		},
  2117  		{
  2118  			name: "Client Sending Bad Timeout",
  2119  			headers: []struct {
  2120  				name   string
  2121  				values []string
  2122  			}{
  2123  				{name: ":method", values: []string{"POST"}},
  2124  				{name: ":path", values: []string{"foo"}},
  2125  				{name: ":authority", values: []string{"localhost"}},
  2126  				{name: "content-type", values: []string{"application/grpc"}},
  2127  				{name: "grpc-timeout", values: []string{"18f6n"}},
  2128  			},
  2129  			httpStatusWant:  "400",
  2130  			grpcStatusWant:  "13",
  2131  			grpcMessageWant: "malformed grpc-timeout",
  2132  		},
  2133  		{
  2134  			name: "Client Sending Bad Binary Header",
  2135  			headers: []struct {
  2136  				name   string
  2137  				values []string
  2138  			}{
  2139  				{name: ":method", values: []string{"POST"}},
  2140  				{name: ":path", values: []string{"foo"}},
  2141  				{name: ":authority", values: []string{"localhost"}},
  2142  				{name: "content-type", values: []string{"application/grpc"}},
  2143  				{name: "foobar-bin", values: []string{"X()3e@#$-"}},
  2144  			},
  2145  			httpStatusWant:  "400",
  2146  			grpcStatusWant:  "13",
  2147  			grpcMessageWant: `header "foobar-bin": illegal base64 data`,
  2148  		},
  2149  	}
  2150  	for _, test := range tests {
  2151  		t.Run(test.name, func(t *testing.T) {
  2152  			server := setUpServerOnly(t, 0, &ServerConfig{}, suspended)
  2153  			defer server.stop()
  2154  			// Create a client directly to not tie what you can send to API of
  2155  			// http2_client.go (i.e. control headers being sent).
  2156  			mconn, err := net.Dial("tcp", server.lis.Addr().String())
  2157  			if err != nil {
  2158  				t.Fatalf("Client failed to dial: %v", err)
  2159  			}
  2160  			defer mconn.Close()
  2161  
  2162  			if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
  2163  				t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
  2164  			}
  2165  
  2166  			framer := http2.NewFramer(mconn, mconn)
  2167  			framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil)
  2168  			if err := framer.WriteSettings(); err != nil {
  2169  				t.Fatalf("Error while writing settings: %v", err)
  2170  			}
  2171  
  2172  			// result chan indicates that reader received a Headers Frame with
  2173  			// desired grpc status and message from server. An error will be passed
  2174  			// on it if any other frame is received.
  2175  			result := testutils.NewChannel()
  2176  
  2177  			// Launch a reader goroutine.
  2178  			go func() {
  2179  				for {
  2180  					frame, err := framer.ReadFrame()
  2181  					if err != nil {
  2182  						return
  2183  					}
  2184  					switch frame := frame.(type) {
  2185  					case *http2.SettingsFrame:
  2186  						// Do nothing. A settings frame is expected from server preface.
  2187  					case *http2.MetaHeadersFrame:
  2188  						var httpStatus, grpcStatus, grpcMessage string
  2189  						for _, header := range frame.Fields {
  2190  							if header.Name == ":status" {
  2191  								httpStatus = header.Value
  2192  							}
  2193  							if header.Name == "grpc-status" {
  2194  								grpcStatus = header.Value
  2195  							}
  2196  							if header.Name == "grpc-message" {
  2197  								grpcMessage = header.Value
  2198  							}
  2199  						}
  2200  						if httpStatus != test.httpStatusWant {
  2201  							result.Send(fmt.Errorf("incorrect HTTP Status got %v, want %v", httpStatus, test.httpStatusWant))
  2202  							return
  2203  						}
  2204  						if grpcStatus != test.grpcStatusWant { // grpc status code internal
  2205  							result.Send(fmt.Errorf("incorrect gRPC Status got %v, want %v", grpcStatus, test.grpcStatusWant))
  2206  							return
  2207  						}
  2208  						if !strings.Contains(grpcMessage, test.grpcMessageWant) {
  2209  							result.Send(fmt.Errorf("incorrect gRPC message, want %q got %q", test.grpcMessageWant, grpcMessage))
  2210  							return
  2211  						}
  2212  
  2213  						// Records that client successfully received a HeadersFrame
  2214  						// with expected Trailers-Only response.
  2215  						result.Send(nil)
  2216  						return
  2217  					default:
  2218  						// The server should send nothing but a single Settings and Headers frame.
  2219  						result.Send(errors.New("the client received a frame other than Settings or Headers"))
  2220  					}
  2221  				}
  2222  			}()
  2223  
  2224  			var buf bytes.Buffer
  2225  			henc := hpack.NewEncoder(&buf)
  2226  
  2227  			// Needs to build headers deterministically to conform to gRPC over
  2228  			// HTTP/2 spec.
  2229  			for _, header := range test.headers {
  2230  				for _, value := range header.values {
  2231  					if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil {
  2232  						t.Fatalf("Error while encoding header: %v", err)
  2233  					}
  2234  				}
  2235  			}
  2236  
  2237  			if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
  2238  				t.Fatalf("Error while writing headers: %v", err)
  2239  			}
  2240  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2241  			defer cancel()
  2242  			r, err := result.Receive(ctx)
  2243  			if err != nil {
  2244  				t.Fatalf("Error receiving from channel: %v", err)
  2245  			}
  2246  			if r != nil {
  2247  				t.Fatalf("want nil, got %v", r)
  2248  			}
  2249  		})
  2250  	}
  2251  }
  2252  
  2253  func (s) TestWriteHeaderConnectionError(t *testing.T) {
  2254  	server, client, cancel := setUp(t, 0, notifyCall)
  2255  	defer cancel()
  2256  	defer server.stop()
  2257  
  2258  	waitWhileTrue(t, func() (bool, error) {
  2259  		server.mu.Lock()
  2260  		defer server.mu.Unlock()
  2261  
  2262  		if len(server.conns) == 0 {
  2263  			return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
  2264  		}
  2265  		return false, nil
  2266  	})
  2267  
  2268  	server.mu.Lock()
  2269  
  2270  	if len(server.conns) != 1 {
  2271  		t.Fatalf("Server has %d connections from the client, want 1", len(server.conns))
  2272  	}
  2273  
  2274  	// Get the server transport for the connection to the client.
  2275  	var serverTransport *http2Server
  2276  	for k := range server.conns {
  2277  		serverTransport = k.(*http2Server)
  2278  	}
  2279  	notifyChan := make(chan struct{})
  2280  	server.h.notify = notifyChan
  2281  	server.mu.Unlock()
  2282  
  2283  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2284  	defer cancel()
  2285  	cstream, err := client.NewStream(ctx, &CallHdr{})
  2286  	if err != nil {
  2287  		t.Fatalf("Client failed to create first stream. Err: %v", err)
  2288  	}
  2289  
  2290  	<-notifyChan // Wait for server stream to be established.
  2291  	var sstream *ServerStream
  2292  	// Access stream on the server.
  2293  	serverTransport.mu.Lock()
  2294  	for _, v := range serverTransport.activeStreams {
  2295  		if v.id == cstream.id {
  2296  			sstream = v
  2297  		}
  2298  	}
  2299  	serverTransport.mu.Unlock()
  2300  	if sstream == nil {
  2301  		t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream.id)
  2302  	}
  2303  
  2304  	client.Close(fmt.Errorf("closed manually by test"))
  2305  
  2306  	// Wait for server transport to be closed.
  2307  	<-serverTransport.done
  2308  
  2309  	// Write header on a closed server transport.
  2310  	err = sstream.SendHeader(metadata.MD{})
  2311  	st := status.Convert(err)
  2312  	if st.Code() != codes.Unavailable {
  2313  		t.Fatalf("WriteHeader() failed with status code %s, want %s", st.Code(), codes.Unavailable)
  2314  	}
  2315  }
  2316  
  2317  func (s) TestPingPong1B(t *testing.T) {
  2318  	runPingPongTest(t, 1)
  2319  }
  2320  
  2321  func (s) TestPingPong1KB(t *testing.T) {
  2322  	runPingPongTest(t, 1024)
  2323  }
  2324  
  2325  func (s) TestPingPong64KB(t *testing.T) {
  2326  	runPingPongTest(t, 65536)
  2327  }
  2328  
  2329  func (s) TestPingPong1MB(t *testing.T) {
  2330  	runPingPongTest(t, 1048576)
  2331  }
  2332  
  2333  // This is a stress-test of flow control logic.
  2334  func runPingPongTest(t *testing.T, msgSize int) {
  2335  	server, client, cancel := setUp(t, 0, pingpong)
  2336  	defer cancel()
  2337  	defer server.stop()
  2338  	defer client.Close(fmt.Errorf("closed manually by test"))
  2339  	waitWhileTrue(t, func() (bool, error) {
  2340  		server.mu.Lock()
  2341  		defer server.mu.Unlock()
  2342  		if len(server.conns) == 0 {
  2343  			return true, fmt.Errorf("timed out while waiting for server transport to be created")
  2344  		}
  2345  		return false, nil
  2346  	})
  2347  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2348  	defer cancel()
  2349  	stream, err := client.NewStream(ctx, &CallHdr{})
  2350  	if err != nil {
  2351  		t.Fatalf("Failed to create stream. Err: %v", err)
  2352  	}
  2353  	msg := make([]byte, msgSize)
  2354  	outgoingHeader := make([]byte, 5)
  2355  	outgoingHeader[0] = byte(0)
  2356  	binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize))
  2357  	opts := &WriteOptions{}
  2358  	incomingHeader := make([]byte, 5)
  2359  
  2360  	ctx, cancel = context.WithTimeout(ctx, 10*time.Millisecond)
  2361  	defer cancel()
  2362  	for ctx.Err() == nil {
  2363  		if err := stream.Write(outgoingHeader, newBufferSlice(msg), opts); err != nil {
  2364  			t.Fatalf("Error on client while writing message. Err: %v", err)
  2365  		}
  2366  		if _, err := stream.readTo(incomingHeader); err != nil {
  2367  			t.Fatalf("Error on client while reading data header. Err: %v", err)
  2368  		}
  2369  		sz := binary.BigEndian.Uint32(incomingHeader[1:])
  2370  		recvMsg := make([]byte, int(sz))
  2371  		if _, err := stream.readTo(recvMsg); err != nil {
  2372  			t.Fatalf("Error on client while reading data. Err: %v", err)
  2373  		}
  2374  	}
  2375  
  2376  	stream.Write(nil, nil, &WriteOptions{Last: true})
  2377  	if _, err := stream.readTo(incomingHeader); err != io.EOF {
  2378  		t.Fatalf("Client expected EOF from the server. Got: %v", err)
  2379  	}
  2380  }
  2381  
  2382  type tableSizeLimit struct {
  2383  	mu     sync.Mutex
  2384  	limits []uint32
  2385  }
  2386  
  2387  func (t *tableSizeLimit) add(limit uint32) {
  2388  	t.mu.Lock()
  2389  	t.limits = append(t.limits, limit)
  2390  	t.mu.Unlock()
  2391  }
  2392  
  2393  func (t *tableSizeLimit) getLen() int {
  2394  	t.mu.Lock()
  2395  	defer t.mu.Unlock()
  2396  	return len(t.limits)
  2397  }
  2398  
  2399  func (t *tableSizeLimit) getIndex(i int) uint32 {
  2400  	t.mu.Lock()
  2401  	defer t.mu.Unlock()
  2402  	return t.limits[i]
  2403  }
  2404  
  2405  func (s) TestHeaderTblSize(t *testing.T) {
  2406  	limits := &tableSizeLimit{}
  2407  	updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
  2408  		e.SetMaxDynamicTableSizeLimit(v)
  2409  		limits.add(v)
  2410  	}
  2411  	defer func() {
  2412  		updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
  2413  			e.SetMaxDynamicTableSizeLimit(v)
  2414  		}
  2415  	}()
  2416  
  2417  	server, ct, cancel := setUp(t, 0, normal)
  2418  	defer cancel()
  2419  	defer ct.Close(fmt.Errorf("closed manually by test"))
  2420  	defer server.stop()
  2421  	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2422  	defer ctxCancel()
  2423  	_, err := ct.NewStream(ctx, &CallHdr{})
  2424  	if err != nil {
  2425  		t.Fatalf("failed to open stream: %v", err)
  2426  	}
  2427  
  2428  	var svrTransport ServerTransport
  2429  	var i int
  2430  	for i = 0; i < 1000; i++ {
  2431  		server.mu.Lock()
  2432  		if len(server.conns) != 0 {
  2433  			server.mu.Unlock()
  2434  			break
  2435  		}
  2436  		server.mu.Unlock()
  2437  		time.Sleep(10 * time.Millisecond)
  2438  		continue
  2439  	}
  2440  	if i == 1000 {
  2441  		t.Fatalf("unable to create any server transport after 10s")
  2442  	}
  2443  
  2444  	for st := range server.conns {
  2445  		svrTransport = st
  2446  		break
  2447  	}
  2448  	svrTransport.(*http2Server).controlBuf.put(&outgoingSettings{
  2449  		ss: []http2.Setting{
  2450  			{
  2451  				ID:  http2.SettingHeaderTableSize,
  2452  				Val: uint32(100),
  2453  			},
  2454  		},
  2455  	})
  2456  
  2457  	for i = 0; i < 1000; i++ {
  2458  		if limits.getLen() != 1 {
  2459  			time.Sleep(10 * time.Millisecond)
  2460  			continue
  2461  		}
  2462  		if val := limits.getIndex(0); val != uint32(100) {
  2463  			t.Fatalf("expected limits[0] = 100, got %d", val)
  2464  		}
  2465  		break
  2466  	}
  2467  	if i == 1000 {
  2468  		t.Fatalf("expected len(limits) = 1 within 10s, got != 1")
  2469  	}
  2470  
  2471  	ct.controlBuf.put(&outgoingSettings{
  2472  		ss: []http2.Setting{
  2473  			{
  2474  				ID:  http2.SettingHeaderTableSize,
  2475  				Val: uint32(200),
  2476  			},
  2477  		},
  2478  	})
  2479  
  2480  	for i := 0; i < 1000; i++ {
  2481  		if limits.getLen() != 2 {
  2482  			time.Sleep(10 * time.Millisecond)
  2483  			continue
  2484  		}
  2485  		if val := limits.getIndex(1); val != uint32(200) {
  2486  			t.Fatalf("expected limits[1] = 200, got %d", val)
  2487  		}
  2488  		break
  2489  	}
  2490  	if i == 1000 {
  2491  		t.Fatalf("expected len(limits) = 2 within 10s, got != 2")
  2492  	}
  2493  }
  2494  
  2495  // attrTransportCreds is a transport credential implementation which stores
  2496  // Attributes from the ClientHandshakeInfo struct passed in the context locally
  2497  // for the test to inspect.
  2498  type attrTransportCreds struct {
  2499  	credentials.TransportCredentials
  2500  	attr *attributes.Attributes
  2501  }
  2502  
  2503  func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  2504  	ai := credentials.ClientHandshakeInfoFromContext(ctx)
  2505  	ac.attr = ai.Attributes
  2506  	return rawConn, nil, nil
  2507  }
  2508  func (ac *attrTransportCreds) Info() credentials.ProtocolInfo {
  2509  	return credentials.ProtocolInfo{}
  2510  }
  2511  func (ac *attrTransportCreds) Clone() credentials.TransportCredentials {
  2512  	return nil
  2513  }
  2514  
  2515  // TestClientHandshakeInfo adds attributes to the resolver.Address passes to
  2516  // NewHTTP2Client and verifies that these attributes are received by the
  2517  // transport credential handshaker.
  2518  func (s) TestClientHandshakeInfo(t *testing.T) {
  2519  	server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong)
  2520  	defer server.stop()
  2521  
  2522  	const (
  2523  		testAttrKey = "foo"
  2524  		testAttrVal = "bar"
  2525  	)
  2526  	addr := resolver.Address{
  2527  		Addr:       "localhost:" + server.port,
  2528  		Attributes: attributes.New(testAttrKey, testAttrVal),
  2529  	}
  2530  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
  2531  	defer cancel()
  2532  	creds := &attrTransportCreds{}
  2533  
  2534  	copts := ConnectOptions{
  2535  		TransportCredentials: creds,
  2536  		ChannelzParent:       channelzSubChannel(t),
  2537  	}
  2538  	tr, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {})
  2539  	if err != nil {
  2540  		t.Fatalf("NewHTTP2Client(): %v", err)
  2541  	}
  2542  	defer tr.Close(fmt.Errorf("closed manually by test"))
  2543  
  2544  	wantAttr := attributes.New(testAttrKey, testAttrVal)
  2545  	if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
  2546  		t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr)
  2547  	}
  2548  }
  2549  
  2550  // TestClientHandshakeInfoDialer adds attributes to the resolver.Address passes to
  2551  // NewHTTP2Client and verifies that these attributes are received by a custom
  2552  // dialer.
  2553  func (s) TestClientHandshakeInfoDialer(t *testing.T) {
  2554  	server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong)
  2555  	defer server.stop()
  2556  
  2557  	const (
  2558  		testAttrKey = "foo"
  2559  		testAttrVal = "bar"
  2560  	)
  2561  	addr := resolver.Address{
  2562  		Addr:       "localhost:" + server.port,
  2563  		Attributes: attributes.New(testAttrKey, testAttrVal),
  2564  	}
  2565  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
  2566  	defer cancel()
  2567  
  2568  	var attr *attributes.Attributes
  2569  	dialer := func(ctx context.Context, addr string) (net.Conn, error) {
  2570  		ai := credentials.ClientHandshakeInfoFromContext(ctx)
  2571  		attr = ai.Attributes
  2572  		return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
  2573  	}
  2574  
  2575  	copts := ConnectOptions{
  2576  		Dialer:         dialer,
  2577  		ChannelzParent: channelzSubChannel(t),
  2578  	}
  2579  	tr, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {})
  2580  	if err != nil {
  2581  		t.Fatalf("NewHTTP2Client(): %v", err)
  2582  	}
  2583  	defer tr.Close(fmt.Errorf("closed manually by test"))
  2584  
  2585  	wantAttr := attributes.New(testAttrKey, testAttrVal)
  2586  	if gotAttr := attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
  2587  		t.Errorf("Received attributes %v in custom dialer, want %v", gotAttr, wantAttr)
  2588  	}
  2589  }
  2590  
  2591  func (s) TestClientDecodeHeaderStatusErr(t *testing.T) {
  2592  	testStream := func() *ClientStream {
  2593  		return &ClientStream{
  2594  			Stream: &Stream{
  2595  				buf: &recvBuffer{
  2596  					c:  make(chan recvMsg),
  2597  					mu: sync.Mutex{},
  2598  				},
  2599  			},
  2600  			done:       make(chan struct{}),
  2601  			headerChan: make(chan struct{}),
  2602  		}
  2603  	}
  2604  
  2605  	testClient := func(ts *ClientStream) *http2Client {
  2606  		return &http2Client{
  2607  			mu: sync.Mutex{},
  2608  			activeStreams: map[uint32]*ClientStream{
  2609  				0: ts,
  2610  			},
  2611  			controlBuf: newControlBuffer(make(<-chan struct{})),
  2612  		}
  2613  	}
  2614  
  2615  	for _, test := range []struct {
  2616  		name string
  2617  		// input
  2618  		metaHeaderFrame *http2.MetaHeadersFrame
  2619  		// output
  2620  		wantStatus *status.Status
  2621  	}{
  2622  		{
  2623  			name: "valid header",
  2624  			metaHeaderFrame: &http2.MetaHeadersFrame{
  2625  				Fields: []hpack.HeaderField{
  2626  					{Name: "content-type", Value: "application/grpc"},
  2627  					{Name: "grpc-status", Value: "0"},
  2628  					{Name: ":status", Value: "200"},
  2629  				},
  2630  			},
  2631  			// no error
  2632  			wantStatus: status.New(codes.OK, ""),
  2633  		},
  2634  		{
  2635  			name: "missing content-type header",
  2636  			metaHeaderFrame: &http2.MetaHeadersFrame{
  2637  				Fields: []hpack.HeaderField{
  2638  					{Name: "grpc-status", Value: "0"},
  2639  					{Name: ":status", Value: "200"},
  2640  				},
  2641  			},
  2642  			wantStatus: status.New(
  2643  				codes.Unknown,
  2644  				"malformed header: missing HTTP content-type",
  2645  			),
  2646  		},
  2647  		{
  2648  			name: "invalid grpc status header field",
  2649  			metaHeaderFrame: &http2.MetaHeadersFrame{
  2650  				Fields: []hpack.HeaderField{
  2651  					{Name: "content-type", Value: "application/grpc"},
  2652  					{Name: "grpc-status", Value: "xxxx"},
  2653  					{Name: ":status", Value: "200"},
  2654  				},
  2655  			},
  2656  			wantStatus: status.New(
  2657  				codes.Internal,
  2658  				"transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax",
  2659  			),
  2660  		},
  2661  		{
  2662  			name: "invalid http content type",
  2663  			metaHeaderFrame: &http2.MetaHeadersFrame{
  2664  				Fields: []hpack.HeaderField{
  2665  					{Name: "content-type", Value: "application/json"},
  2666  				},
  2667  			},
  2668  			wantStatus: status.New(
  2669  				codes.Internal,
  2670  				"malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"",
  2671  			),
  2672  		},
  2673  		{
  2674  			name: "http fallback and invalid http status",
  2675  			metaHeaderFrame: &http2.MetaHeadersFrame{
  2676  				Fields: []hpack.HeaderField{
  2677  					// No content type provided then fallback into handling http error.
  2678  					{Name: ":status", Value: "xxxx"},
  2679  				},
  2680  			},
  2681  			wantStatus: status.New(
  2682  				codes.Internal,
  2683  				"transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax",
  2684  			),
  2685  		},
  2686  		{
  2687  			name: "http2 frame size exceeds",
  2688  			metaHeaderFrame: &http2.MetaHeadersFrame{
  2689  				Fields:    nil,
  2690  				Truncated: true,
  2691  			},
  2692  			wantStatus: status.New(
  2693  				codes.Internal,
  2694  				"peer header list size exceeded limit",
  2695  			),
  2696  		},
  2697  		{
  2698  			name: "bad status in grpc mode",
  2699  			metaHeaderFrame: &http2.MetaHeadersFrame{
  2700  				Fields: []hpack.HeaderField{
  2701  					{Name: "content-type", Value: "application/grpc"},
  2702  					{Name: "grpc-status", Value: "0"},
  2703  					{Name: ":status", Value: "504"},
  2704  				},
  2705  			},
  2706  			wantStatus: status.New(
  2707  				codes.Unavailable,
  2708  				"unexpected HTTP status code received from server: 504 (Gateway Timeout)",
  2709  			),
  2710  		},
  2711  		{
  2712  			name: "missing http status",
  2713  			metaHeaderFrame: &http2.MetaHeadersFrame{
  2714  				Fields: []hpack.HeaderField{
  2715  					{Name: "content-type", Value: "application/grpc"},
  2716  				},
  2717  			},
  2718  			wantStatus: status.New(
  2719  				codes.Internal,
  2720  				"malformed header: missing HTTP status",
  2721  			),
  2722  		},
  2723  	} {
  2724  
  2725  		t.Run(test.name, func(t *testing.T) {
  2726  			ts := testStream()
  2727  			s := testClient(ts)
  2728  
  2729  			test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{
  2730  				FrameHeader: http2.FrameHeader{
  2731  					StreamID: 0,
  2732  				},
  2733  			}
  2734  
  2735  			s.operateHeaders(test.metaHeaderFrame)
  2736  
  2737  			got := ts.status
  2738  			want := test.wantStatus
  2739  			if got.Code() != want.Code() || got.Message() != want.Message() {
  2740  				t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want)
  2741  			}
  2742  		})
  2743  		t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) {
  2744  			ts := testStream()
  2745  			s := testClient(ts)
  2746  
  2747  			test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{
  2748  				FrameHeader: http2.FrameHeader{
  2749  					StreamID: 0,
  2750  					Flags:    http2.FlagHeadersEndStream,
  2751  				},
  2752  			}
  2753  
  2754  			s.operateHeaders(test.metaHeaderFrame)
  2755  
  2756  			got := ts.status
  2757  			want := test.wantStatus
  2758  			if got.Code() != want.Code() || got.Message() != want.Message() {
  2759  				t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want)
  2760  			}
  2761  		})
  2762  	}
  2763  }
  2764  
  2765  func TestConnectionError_Unwrap(t *testing.T) {
  2766  	err := connectionErrorf(false, os.ErrNotExist, "unwrap me")
  2767  	if !errors.Is(err, os.ErrNotExist) {
  2768  		t.Error("ConnectionError does not unwrap")
  2769  	}
  2770  }
  2771  
  2772  // Test that in the event of a graceful client transport shutdown, i.e.,
  2773  // clientTransport.Close(), client sends a goaway to the server with the correct
  2774  // error code and debug data.
  2775  func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
  2776  	// Create a server.
  2777  	lis, err := net.Listen("tcp", "localhost:0")
  2778  	if err != nil {
  2779  		t.Fatalf("Error while listening: %v", err)
  2780  	}
  2781  	defer lis.Close()
  2782  	// greetDone is used to notify when server is done greeting the client.
  2783  	greetDone := make(chan struct{})
  2784  	// errorCh verifies that desired GOAWAY not received by server
  2785  	errorCh := make(chan error)
  2786  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2787  	defer cancel()
  2788  	// Launch the server.
  2789  	go func() {
  2790  		sconn, err := lis.Accept()
  2791  		if err != nil {
  2792  			t.Errorf("Error while accepting: %v", err)
  2793  		}
  2794  		defer sconn.Close()
  2795  		if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
  2796  			t.Errorf("Error while writing settings ack: %v", err)
  2797  			return
  2798  		}
  2799  		sfr := http2.NewFramer(sconn, sconn)
  2800  		if err := sfr.WriteSettings(); err != nil {
  2801  			t.Errorf("Error while writing settings %v", err)
  2802  			return
  2803  		}
  2804  		fr, _ := sfr.ReadFrame()
  2805  		if _, ok := fr.(*http2.SettingsFrame); !ok {
  2806  			t.Errorf("Expected settings frame, got %v", fr)
  2807  		}
  2808  		fr, _ = sfr.ReadFrame()
  2809  		if fr, ok := fr.(*http2.SettingsFrame); !ok || !fr.IsAck() {
  2810  			t.Errorf("Expected settings ACK frame, got %v", fr)
  2811  		}
  2812  		fr, _ = sfr.ReadFrame()
  2813  		if fr, ok := fr.(*http2.HeadersFrame); !ok || !fr.Flags.Has(http2.FlagHeadersEndHeaders) {
  2814  			t.Errorf("Expected Headers frame with END_HEADERS frame, got %v", fr)
  2815  		}
  2816  		close(greetDone)
  2817  
  2818  		frame, err := sfr.ReadFrame()
  2819  		if err != nil {
  2820  			return
  2821  		}
  2822  		switch fr := frame.(type) {
  2823  		case *http2.GoAwayFrame:
  2824  			// Records that the server successfully received a GOAWAY frame.
  2825  			goAwayFrame := fr
  2826  			if goAwayFrame.ErrCode == http2.ErrCodeNo {
  2827  				t.Logf("Received goAway frame from client")
  2828  				close(errorCh)
  2829  			} else {
  2830  				errorCh <- fmt.Errorf("received unexpected goAway frame: %v", err)
  2831  				close(errorCh)
  2832  			}
  2833  			return
  2834  		default:
  2835  			errorCh <- fmt.Errorf("server received a frame other than GOAWAY: %v", err)
  2836  			close(errorCh)
  2837  			return
  2838  		}
  2839  	}()
  2840  
  2841  	ct, err := NewHTTP2Client(ctx, ctx, resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {})
  2842  	if err != nil {
  2843  		t.Fatalf("Error while creating client transport: %v", err)
  2844  	}
  2845  	_, err = ct.NewStream(ctx, &CallHdr{})
  2846  	if err != nil {
  2847  		t.Fatalf("failed to open stream: %v", err)
  2848  	}
  2849  	// Wait until server receives the headers and settings frame as part of greet.
  2850  	<-greetDone
  2851  	ct.Close(errors.New("manually closed by client"))
  2852  	t.Logf("Closed the client connection")
  2853  	select {
  2854  	case err := <-errorCh:
  2855  		if err != nil {
  2856  			t.Errorf("Error receiving the GOAWAY frame: %v", err)
  2857  		}
  2858  	case <-ctx.Done():
  2859  		t.Errorf("Context timed out")
  2860  	}
  2861  }
  2862  
  2863  // readHangingConn is a wrapper around net.Conn that makes the Read() hang when
  2864  // Close() is called.
  2865  type readHangingConn struct {
  2866  	net.Conn
  2867  	readHangConn chan struct{} // Read() hangs until this channel is closed by Close().
  2868  	closed       *atomic.Bool  // Set to true when Close() is called.
  2869  }
  2870  
  2871  func (hc *readHangingConn) Read(b []byte) (n int, err error) {
  2872  	n, err = hc.Conn.Read(b)
  2873  	if hc.closed.Load() {
  2874  		<-hc.readHangConn // hang the read till we want
  2875  	}
  2876  	return n, err
  2877  }
  2878  
  2879  func (hc *readHangingConn) Close() error {
  2880  	hc.closed.Store(true)
  2881  	return hc.Conn.Close()
  2882  }
  2883  
  2884  // Tests that closing a client transport does not return until the reader
  2885  // goroutine exits.
  2886  func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) {
  2887  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2888  	defer cancel()
  2889  
  2890  	server := setUpServerOnly(t, 0, &ServerConfig{}, normal)
  2891  	defer server.stop()
  2892  	addr := resolver.Address{Addr: "localhost:" + server.port}
  2893  
  2894  	isReaderHanging := &atomic.Bool{}
  2895  	readHangConn := make(chan struct{})
  2896  	copts := ConnectOptions{
  2897  		Dialer: func(_ context.Context, addr string) (net.Conn, error) {
  2898  			conn, err := net.Dial("tcp", addr)
  2899  			if err != nil {
  2900  				return nil, err
  2901  			}
  2902  			return &readHangingConn{Conn: conn, readHangConn: readHangConn, closed: isReaderHanging}, nil
  2903  		},
  2904  		ChannelzParent: channelzSubChannel(t),
  2905  	}
  2906  
  2907  	// Create a client transport with a custom dialer that hangs the Read()
  2908  	// after Close().
  2909  	ct, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {})
  2910  	if err != nil {
  2911  		t.Fatalf("Failed to create transport: %v", err)
  2912  	}
  2913  
  2914  	if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
  2915  		t.Fatalf("Failed to open stream: %v", err)
  2916  	}
  2917  
  2918  	// Closing the client transport will result in the underlying net.Conn being
  2919  	// closed, which will result in readHangingConn.Read() to hang. This will
  2920  	// stall the exit of the reader goroutine, and will stall client
  2921  	// transport's Close from returning.
  2922  	transportClosed := make(chan struct{})
  2923  	go func() {
  2924  		ct.Close(errors.New("manually closed by client"))
  2925  		close(transportClosed)
  2926  	}()
  2927  
  2928  	// Wait for a short duration and ensure that the client transport's Close()
  2929  	// does not return.
  2930  	select {
  2931  	case <-transportClosed:
  2932  		t.Fatal("Transport closed before reader completed")
  2933  	case <-time.After(defaultTestShortTimeout):
  2934  	}
  2935  
  2936  	// Closing the channel will unblock the reader goroutine and will ensure
  2937  	// that the client transport's Close() returns.
  2938  	close(readHangConn)
  2939  	select {
  2940  	case <-transportClosed:
  2941  	case <-time.After(defaultTestTimeout):
  2942  		t.Fatal("Timeout when waiting for transport to close")
  2943  	}
  2944  }
  2945  
  2946  // hangingConn is a net.Conn wrapper for testing, simulating hanging connections
  2947  // after a GOAWAY frame is sent, of which Write operations pause until explicitly
  2948  // signaled or a timeout occurs.
  2949  type hangingConn struct {
  2950  	net.Conn
  2951  	hangConn     chan struct{}
  2952  	startHanging *atomic.Bool
  2953  }
  2954  
  2955  func (hc *hangingConn) Write(b []byte) (n int, err error) {
  2956  	n, err = hc.Conn.Write(b)
  2957  	if hc.startHanging.Load() {
  2958  		<-hc.hangConn
  2959  	}
  2960  	return n, err
  2961  }
  2962  
  2963  // Tests the scenario where a client transport is closed and writing of the
  2964  // GOAWAY frame as part of the close does not complete because of a network
  2965  // hang. The test verifies that the client transport is closed without waiting
  2966  // for too long.
  2967  func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) {
  2968  	// Override timer for writing GOAWAY to 0 so that the connection write
  2969  	// always times out. It is equivalent of real network hang when conn
  2970  	// write for goaway doesn't finish in specified deadline
  2971  	origGoAwayLoopyTimeout := goAwayLoopyWriterTimeout
  2972  	goAwayLoopyWriterTimeout = time.Millisecond
  2973  	defer func() {
  2974  		goAwayLoopyWriterTimeout = origGoAwayLoopyTimeout
  2975  	}()
  2976  
  2977  	// Create the server set up.
  2978  	connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2979  	defer cancel()
  2980  	server := setUpServerOnly(t, 0, &ServerConfig{}, normal)
  2981  	defer server.stop()
  2982  	addr := resolver.Address{Addr: "localhost:" + server.port}
  2983  	isGreetingDone := &atomic.Bool{}
  2984  	hangConn := make(chan struct{})
  2985  	defer close(hangConn)
  2986  	dialer := func(_ context.Context, addr string) (net.Conn, error) {
  2987  		conn, err := net.Dial("tcp", addr)
  2988  		if err != nil {
  2989  			return nil, err
  2990  		}
  2991  		return &hangingConn{Conn: conn, hangConn: hangConn, startHanging: isGreetingDone}, nil
  2992  	}
  2993  	copts := ConnectOptions{Dialer: dialer}
  2994  	copts.ChannelzParent = channelzSubChannel(t)
  2995  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  2996  	defer cancel()
  2997  	// Create client transport with custom dialer
  2998  	ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {})
  2999  	if connErr != nil {
  3000  		t.Fatalf("failed to create transport: %v", connErr)
  3001  	}
  3002  
  3003  	if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
  3004  		t.Fatalf("Failed to open stream: %v", err)
  3005  	}
  3006  
  3007  	isGreetingDone.Store(true)
  3008  	ct.Close(errors.New("manually closed by client"))
  3009  }
  3010  
  3011  // TestReadHeaderMultipleBuffers tests the stream when the gRPC headers are
  3012  // split across multiple buffers. It verifies that the reporting of the
  3013  // number of bytes read for flow control is correct.
  3014  func (s) TestReadMessageHeaderMultipleBuffers(t *testing.T) {
  3015  	headerLen := 5
  3016  	recvBuffer := newRecvBuffer()
  3017  	recvBuffer.put(recvMsg{buffer: make(mem.SliceBuffer, 3)})
  3018  	recvBuffer.put(recvMsg{buffer: make(mem.SliceBuffer, headerLen-3)})
  3019  	bytesRead := 0
  3020  	s := Stream{
  3021  		requestRead: func(int) {},
  3022  		trReader: &transportReader{
  3023  			reader: &recvBufferReader{
  3024  				recv: recvBuffer,
  3025  			},
  3026  			windowHandler: func(i int) {
  3027  				bytesRead += i
  3028  			},
  3029  		},
  3030  	}
  3031  
  3032  	header := make([]byte, headerLen)
  3033  	err := s.ReadMessageHeader(header)
  3034  	if err != nil {
  3035  		t.Fatalf("ReadHeader(%v) = %v", header, err)
  3036  	}
  3037  	if bytesRead != headerLen {
  3038  		t.Errorf("bytesRead = %d, want = %d", bytesRead, headerLen)
  3039  	}
  3040  }
  3041  
  3042  // Tests a scenario when the client doesn't send an RST frame when the
  3043  // configured deadline is reached. The test verifies that the server sends an
  3044  // RST stream only after the deadline is reached.
  3045  func (s) TestServerSendsRSTAfterDeadlineToMisbehavedClient(t *testing.T) {
  3046  	server := setUpServerOnly(t, 0, &ServerConfig{}, suspended)
  3047  	defer server.stop()
  3048  	// Create a client that can override server stream quota.
  3049  	mconn, err := net.Dial("tcp", server.lis.Addr().String())
  3050  	if err != nil {
  3051  		t.Fatalf("Clent failed to dial:%v", err)
  3052  	}
  3053  	defer mconn.Close()
  3054  	if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil {
  3055  		t.Fatalf("Failed to set write deadline: %v", err)
  3056  	}
  3057  	if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
  3058  		t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
  3059  	}
  3060  	// rstTimeChan chan indicates that reader received a RSTStream from server.
  3061  	rstTimeChan := make(chan time.Time, 1)
  3062  	var mu sync.Mutex
  3063  	framer := http2.NewFramer(mconn, mconn)
  3064  	if err := framer.WriteSettings(); err != nil {
  3065  		t.Fatalf("Error while writing settings: %v", err)
  3066  	}
  3067  	go func() { // Launch a reader for this misbehaving client.
  3068  		for {
  3069  			frame, err := framer.ReadFrame()
  3070  			if err != nil {
  3071  				return
  3072  			}
  3073  			switch frame := frame.(type) {
  3074  			case *http2.PingFrame:
  3075  				// Write ping ack back so that server's BDP estimation works right.
  3076  				mu.Lock()
  3077  				framer.WritePing(true, frame.Data)
  3078  				mu.Unlock()
  3079  			case *http2.RSTStreamFrame:
  3080  				if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeCancel {
  3081  					t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeCancel", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))
  3082  				}
  3083  				rstTimeChan <- time.Now()
  3084  				return
  3085  			default:
  3086  				// Do nothing.
  3087  			}
  3088  		}
  3089  	}()
  3090  	// Create a stream.
  3091  	var buf bytes.Buffer
  3092  	henc := hpack.NewEncoder(&buf)
  3093  	if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil {
  3094  		t.Fatalf("Error while encoding header: %v", err)
  3095  	}
  3096  	if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil {
  3097  		t.Fatalf("Error while encoding header: %v", err)
  3098  	}
  3099  	if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil {
  3100  		t.Fatalf("Error while encoding header: %v", err)
  3101  	}
  3102  	if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil {
  3103  		t.Fatalf("Error while encoding header: %v", err)
  3104  	}
  3105  	if err := henc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: "10m"}); err != nil {
  3106  		t.Fatalf("Error while encoding header: %v", err)
  3107  	}
  3108  	mu.Lock()
  3109  	startTime := time.Now()
  3110  	if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
  3111  		mu.Unlock()
  3112  		t.Fatalf("Error while writing headers: %v", err)
  3113  	}
  3114  	mu.Unlock()
  3115  
  3116  	// Test server behavior for deadline expiration.
  3117  	var rstTime time.Time
  3118  	select {
  3119  	case <-time.After(5 * time.Second):
  3120  		t.Fatalf("Test timed out.")
  3121  	case rstTime = <-rstTimeChan:
  3122  	}
  3123  
  3124  	if got, want := rstTime.Sub(startTime), 10*time.Millisecond; got < want {
  3125  		t.Fatalf("RST frame received earlier than expected by duration: %v", want-got)
  3126  	}
  3127  }