google.golang.org/grpc@v1.72.2/test/transport_test.go (about)

     1  /*
     2  *
     3  * Copyright 2023 gRPC authors.
     4  *
     5  * Licensed under the Apache License, Version 2.0 (the "License");
     6  * you may not use this file except in compliance with the License.
     7  * You may obtain a copy of the License at
     8  *
     9  *     http://www.apache.org/licenses/LICENSE-2.0
    10  *
    11  * Unless required by applicable law or agreed to in writing, software
    12  * distributed under the License is distributed on an "AS IS" BASIS,
    13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  * See the License for the specific language governing permissions and
    15  * limitations under the License.
    16  *
    17   */
    18  package test
    19  
    20  import (
    21  	"context"
    22  	"encoding/binary"
    23  	"io"
    24  	"net"
    25  	"sync"
    26  	"testing"
    27  
    28  	"golang.org/x/net/http2"
    29  	"google.golang.org/grpc"
    30  	"google.golang.org/grpc/codes"
    31  	"google.golang.org/grpc/credentials"
    32  	"google.golang.org/grpc/credentials/insecure"
    33  	"google.golang.org/grpc/internal/grpcsync"
    34  	"google.golang.org/grpc/internal/stubserver"
    35  	"google.golang.org/grpc/internal/testutils"
    36  	"google.golang.org/grpc/internal/transport"
    37  	"google.golang.org/grpc/status"
    38  
    39  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    40  	testpb "google.golang.org/grpc/interop/grpc_testing"
    41  )
    42  
    43  // connWrapperWithCloseCh wraps a net.Conn and fires an event when closed.
    44  type connWrapperWithCloseCh struct {
    45  	net.Conn
    46  	close *grpcsync.Event
    47  }
    48  
    49  // Close closes the connection and sends a value on the close channel.
    50  func (cw *connWrapperWithCloseCh) Close() error {
    51  	cw.close.Fire()
    52  	return cw.Conn.Close()
    53  }
    54  
    55  // These custom creds are used for storing the connections made by the client.
    56  // The closeCh in conn can be used to detect when conn is closed.
    57  type transportRestartCheckCreds struct {
    58  	mu          sync.Mutex
    59  	connections []*connWrapperWithCloseCh
    60  }
    61  
    62  func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    63  	return rawConn, nil, nil
    64  }
    65  func (c *transportRestartCheckCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    66  	c.mu.Lock()
    67  	defer c.mu.Unlock()
    68  	conn := &connWrapperWithCloseCh{Conn: rawConn, close: grpcsync.NewEvent()}
    69  	c.connections = append(c.connections, conn)
    70  	return conn, nil, nil
    71  }
    72  func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo {
    73  	return credentials.ProtocolInfo{}
    74  }
    75  func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials {
    76  	return c
    77  }
    78  func (c *transportRestartCheckCreds) OverrideServerName(string) error {
    79  	return nil
    80  }
    81  
    82  // Tests that the client transport drains and restarts when next stream ID exceeds
    83  // MaxStreamID. This test also verifies that subsequent RPCs use a new client
    84  // transport and the old transport is closed.
    85  func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
    86  	// Set the transport's MaxStreamID to 4 to cause connection to drain after 2 RPCs.
    87  	originalMaxStreamID := transport.MaxStreamID
    88  	transport.MaxStreamID = 4
    89  	defer func() {
    90  		transport.MaxStreamID = originalMaxStreamID
    91  	}()
    92  
    93  	ss := &stubserver.StubServer{
    94  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
    95  			if _, err := stream.Recv(); err != nil {
    96  				return status.Errorf(codes.Internal, "unexpected error receiving: %v", err)
    97  			}
    98  			if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil {
    99  				return status.Errorf(codes.Internal, "unexpected error sending: %v", err)
   100  			}
   101  			if recv, err := stream.Recv(); err != io.EOF {
   102  				return status.Errorf(codes.Internal, "Recv = %v, %v; want _, io.EOF", recv, err)
   103  			}
   104  			return nil
   105  		},
   106  	}
   107  
   108  	creds := &transportRestartCheckCreds{}
   109  	if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil {
   110  		t.Fatalf("Starting stubServer: %v", err)
   111  	}
   112  	defer ss.Stop()
   113  
   114  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   115  	defer cancel()
   116  
   117  	var streams []testgrpc.TestService_FullDuplexCallClient
   118  
   119  	const numStreams = 3
   120  	// expected number of conns when each stream is created i.e., 3rd stream is created
   121  	// on a new connection.
   122  	expectedNumConns := [numStreams]int{1, 1, 2}
   123  
   124  	// Set up 3 streams.
   125  	for i := 0; i < numStreams; i++ {
   126  		s, err := ss.Client.FullDuplexCall(ctx)
   127  		if err != nil {
   128  			t.Fatalf("Creating FullDuplex stream: %v", err)
   129  		}
   130  		streams = append(streams, s)
   131  		// Verify expected num of conns after each stream is created.
   132  		if len(creds.connections) != expectedNumConns[i] {
   133  			t.Fatalf("Got number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i])
   134  		}
   135  	}
   136  
   137  	// Verify all streams still work.
   138  	for i, stream := range streams {
   139  		if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
   140  			t.Fatalf("Sending on stream %d: %v", i, err)
   141  		}
   142  		if _, err := stream.Recv(); err != nil {
   143  			t.Fatalf("Receiving on stream %d: %v", i, err)
   144  		}
   145  	}
   146  
   147  	for i, stream := range streams {
   148  		if err := stream.CloseSend(); err != nil {
   149  			t.Fatalf("CloseSend() on stream %d: %v", i, err)
   150  		}
   151  	}
   152  
   153  	// Verifying first connection was closed.
   154  	select {
   155  	case <-creds.connections[0].close.Done():
   156  	case <-ctx.Done():
   157  		t.Fatal("Timeout expired when waiting for first client transport to close")
   158  	}
   159  }
   160  
   161  // Tests that an RST_STREAM frame that causes an io.ErrUnexpectedEOF while
   162  // reading a gRPC message is correctly converted to a gRPC status with code
   163  // CANCELLED. The test sends a data frame with a partial gRPC message, followed
   164  // by an RST_STREAM frame with HTTP/2 code CANCELLED. The test asserts the
   165  // client receives the correct status.
   166  func (s) TestRSTDuringMessageRead(t *testing.T) {
   167  	lis, err := testutils.LocalTCPListener()
   168  	if err != nil {
   169  		t.Fatal(err)
   170  	}
   171  	defer lis.Close()
   172  
   173  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   174  	defer cancel()
   175  	cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
   176  	if err != nil {
   177  		t.Fatalf("grpc.NewClient(%s) = %v", lis.Addr().String(), err)
   178  	}
   179  	defer cc.Close()
   180  
   181  	go func() {
   182  		conn, err := lis.Accept()
   183  		if err != nil {
   184  			t.Errorf("lis.Accept() = %v", err)
   185  			return
   186  		}
   187  		defer conn.Close()
   188  		framer := http2.NewFramer(conn, conn)
   189  
   190  		if _, err := io.ReadFull(conn, make([]byte, len(clientPreface))); err != nil {
   191  			t.Errorf("Error while reading client preface: %v", err)
   192  			return
   193  		}
   194  		if err := framer.WriteSettings(); err != nil {
   195  			t.Errorf("Error while writing settings: %v", err)
   196  			return
   197  		}
   198  		if err := framer.WriteSettingsAck(); err != nil {
   199  			t.Errorf("Error while writing settings: %v", err)
   200  			return
   201  		}
   202  		for ctx.Err() == nil {
   203  			frame, err := framer.ReadFrame()
   204  			if err != nil {
   205  				return
   206  			}
   207  			switch frame := frame.(type) {
   208  			case *http2.HeadersFrame:
   209  				// When the client creates a stream, write a partial gRPC
   210  				// message followed by an RST_STREAM.
   211  				const messageLen = 2048
   212  				buf := make([]byte, messageLen/2)
   213  				// Write the gRPC message length header.
   214  				binary.BigEndian.PutUint32(buf[1:5], uint32(messageLen))
   215  				if err := framer.WriteData(1, false, buf); err != nil {
   216  					return
   217  				}
   218  				framer.WriteRSTStream(1, http2.ErrCodeCancel)
   219  			default:
   220  				t.Logf("Server received frame: %v", frame)
   221  			}
   222  		}
   223  	}()
   224  
   225  	// The server will send a partial gRPC message before cancelling the stream.
   226  	// The client should get a gRPC status with code CANCELLED.
   227  	client := testgrpc.NewTestServiceClient(cc)
   228  	if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Canceled {
   229  		t.Fatalf("client.EmptyCall() returned %v; want status with code %v", err, codes.Canceled)
   230  	}
   231  }