google.golang.org/grpc@v1.72.2/test/transport_test.go (about) 1 /* 2 * 3 * Copyright 2023 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 package test 19 20 import ( 21 "context" 22 "encoding/binary" 23 "io" 24 "net" 25 "sync" 26 "testing" 27 28 "golang.org/x/net/http2" 29 "google.golang.org/grpc" 30 "google.golang.org/grpc/codes" 31 "google.golang.org/grpc/credentials" 32 "google.golang.org/grpc/credentials/insecure" 33 "google.golang.org/grpc/internal/grpcsync" 34 "google.golang.org/grpc/internal/stubserver" 35 "google.golang.org/grpc/internal/testutils" 36 "google.golang.org/grpc/internal/transport" 37 "google.golang.org/grpc/status" 38 39 testgrpc "google.golang.org/grpc/interop/grpc_testing" 40 testpb "google.golang.org/grpc/interop/grpc_testing" 41 ) 42 43 // connWrapperWithCloseCh wraps a net.Conn and fires an event when closed. 44 type connWrapperWithCloseCh struct { 45 net.Conn 46 close *grpcsync.Event 47 } 48 49 // Close closes the connection and sends a value on the close channel. 50 func (cw *connWrapperWithCloseCh) Close() error { 51 cw.close.Fire() 52 return cw.Conn.Close() 53 } 54 55 // These custom creds are used for storing the connections made by the client. 56 // The closeCh in conn can be used to detect when conn is closed. 57 type transportRestartCheckCreds struct { 58 mu sync.Mutex 59 connections []*connWrapperWithCloseCh 60 } 61 62 func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 63 return rawConn, nil, nil 64 } 65 func (c *transportRestartCheckCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 66 c.mu.Lock() 67 defer c.mu.Unlock() 68 conn := &connWrapperWithCloseCh{Conn: rawConn, close: grpcsync.NewEvent()} 69 c.connections = append(c.connections, conn) 70 return conn, nil, nil 71 } 72 func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo { 73 return credentials.ProtocolInfo{} 74 } 75 func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials { 76 return c 77 } 78 func (c *transportRestartCheckCreds) OverrideServerName(string) error { 79 return nil 80 } 81 82 // Tests that the client transport drains and restarts when next stream ID exceeds 83 // MaxStreamID. This test also verifies that subsequent RPCs use a new client 84 // transport and the old transport is closed. 85 func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { 86 // Set the transport's MaxStreamID to 4 to cause connection to drain after 2 RPCs. 87 originalMaxStreamID := transport.MaxStreamID 88 transport.MaxStreamID = 4 89 defer func() { 90 transport.MaxStreamID = originalMaxStreamID 91 }() 92 93 ss := &stubserver.StubServer{ 94 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { 95 if _, err := stream.Recv(); err != nil { 96 return status.Errorf(codes.Internal, "unexpected error receiving: %v", err) 97 } 98 if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil { 99 return status.Errorf(codes.Internal, "unexpected error sending: %v", err) 100 } 101 if recv, err := stream.Recv(); err != io.EOF { 102 return status.Errorf(codes.Internal, "Recv = %v, %v; want _, io.EOF", recv, err) 103 } 104 return nil 105 }, 106 } 107 108 creds := &transportRestartCheckCreds{} 109 if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil { 110 t.Fatalf("Starting stubServer: %v", err) 111 } 112 defer ss.Stop() 113 114 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 115 defer cancel() 116 117 var streams []testgrpc.TestService_FullDuplexCallClient 118 119 const numStreams = 3 120 // expected number of conns when each stream is created i.e., 3rd stream is created 121 // on a new connection. 122 expectedNumConns := [numStreams]int{1, 1, 2} 123 124 // Set up 3 streams. 125 for i := 0; i < numStreams; i++ { 126 s, err := ss.Client.FullDuplexCall(ctx) 127 if err != nil { 128 t.Fatalf("Creating FullDuplex stream: %v", err) 129 } 130 streams = append(streams, s) 131 // Verify expected num of conns after each stream is created. 132 if len(creds.connections) != expectedNumConns[i] { 133 t.Fatalf("Got number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i]) 134 } 135 } 136 137 // Verify all streams still work. 138 for i, stream := range streams { 139 if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { 140 t.Fatalf("Sending on stream %d: %v", i, err) 141 } 142 if _, err := stream.Recv(); err != nil { 143 t.Fatalf("Receiving on stream %d: %v", i, err) 144 } 145 } 146 147 for i, stream := range streams { 148 if err := stream.CloseSend(); err != nil { 149 t.Fatalf("CloseSend() on stream %d: %v", i, err) 150 } 151 } 152 153 // Verifying first connection was closed. 154 select { 155 case <-creds.connections[0].close.Done(): 156 case <-ctx.Done(): 157 t.Fatal("Timeout expired when waiting for first client transport to close") 158 } 159 } 160 161 // Tests that an RST_STREAM frame that causes an io.ErrUnexpectedEOF while 162 // reading a gRPC message is correctly converted to a gRPC status with code 163 // CANCELLED. The test sends a data frame with a partial gRPC message, followed 164 // by an RST_STREAM frame with HTTP/2 code CANCELLED. The test asserts the 165 // client receives the correct status. 166 func (s) TestRSTDuringMessageRead(t *testing.T) { 167 lis, err := testutils.LocalTCPListener() 168 if err != nil { 169 t.Fatal(err) 170 } 171 defer lis.Close() 172 173 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 174 defer cancel() 175 cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) 176 if err != nil { 177 t.Fatalf("grpc.NewClient(%s) = %v", lis.Addr().String(), err) 178 } 179 defer cc.Close() 180 181 go func() { 182 conn, err := lis.Accept() 183 if err != nil { 184 t.Errorf("lis.Accept() = %v", err) 185 return 186 } 187 defer conn.Close() 188 framer := http2.NewFramer(conn, conn) 189 190 if _, err := io.ReadFull(conn, make([]byte, len(clientPreface))); err != nil { 191 t.Errorf("Error while reading client preface: %v", err) 192 return 193 } 194 if err := framer.WriteSettings(); err != nil { 195 t.Errorf("Error while writing settings: %v", err) 196 return 197 } 198 if err := framer.WriteSettingsAck(); err != nil { 199 t.Errorf("Error while writing settings: %v", err) 200 return 201 } 202 for ctx.Err() == nil { 203 frame, err := framer.ReadFrame() 204 if err != nil { 205 return 206 } 207 switch frame := frame.(type) { 208 case *http2.HeadersFrame: 209 // When the client creates a stream, write a partial gRPC 210 // message followed by an RST_STREAM. 211 const messageLen = 2048 212 buf := make([]byte, messageLen/2) 213 // Write the gRPC message length header. 214 binary.BigEndian.PutUint32(buf[1:5], uint32(messageLen)) 215 if err := framer.WriteData(1, false, buf); err != nil { 216 return 217 } 218 framer.WriteRSTStream(1, http2.ErrCodeCancel) 219 default: 220 t.Logf("Server received frame: %v", frame) 221 } 222 } 223 }() 224 225 // The server will send a partial gRPC message before cancelling the stream. 226 // The client should get a gRPC status with code CANCELLED. 227 client := testgrpc.NewTestServiceClient(cc) 228 if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Canceled { 229 t.Fatalf("client.EmptyCall() returned %v; want status with code %v", err, codes.Canceled) 230 } 231 }