google.golang.org/grpc@v1.62.1/test/transport_test.go (about) 1 /* 2 * 3 * Copyright 2023 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 package test 19 20 import ( 21 "context" 22 "io" 23 "net" 24 "sync" 25 "testing" 26 27 "google.golang.org/grpc" 28 "google.golang.org/grpc/codes" 29 "google.golang.org/grpc/credentials" 30 "google.golang.org/grpc/internal/grpcsync" 31 "google.golang.org/grpc/internal/stubserver" 32 "google.golang.org/grpc/internal/transport" 33 "google.golang.org/grpc/status" 34 35 testgrpc "google.golang.org/grpc/interop/grpc_testing" 36 testpb "google.golang.org/grpc/interop/grpc_testing" 37 ) 38 39 // connWrapperWithCloseCh wraps a net.Conn and fires an event when closed. 40 type connWrapperWithCloseCh struct { 41 net.Conn 42 close *grpcsync.Event 43 } 44 45 // Close closes the connection and sends a value on the close channel. 46 func (cw *connWrapperWithCloseCh) Close() error { 47 cw.close.Fire() 48 return cw.Conn.Close() 49 } 50 51 // These custom creds are used for storing the connections made by the client. 52 // The closeCh in conn can be used to detect when conn is closed. 53 type transportRestartCheckCreds struct { 54 mu sync.Mutex 55 connections []*connWrapperWithCloseCh 56 } 57 58 func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 59 return rawConn, nil, nil 60 } 61 func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 62 c.mu.Lock() 63 defer c.mu.Unlock() 64 conn := &connWrapperWithCloseCh{Conn: rawConn, close: grpcsync.NewEvent()} 65 c.connections = append(c.connections, conn) 66 return conn, nil, nil 67 } 68 func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo { 69 return credentials.ProtocolInfo{} 70 } 71 func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials { 72 return c 73 } 74 func (c *transportRestartCheckCreds) OverrideServerName(s string) error { 75 return nil 76 } 77 78 // Tests that the client transport drains and restarts when next stream ID exceeds 79 // MaxStreamID. This test also verifies that subsequent RPCs use a new client 80 // transport and the old transport is closed. 81 func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) { 82 // Set the transport's MaxStreamID to 4 to cause connection to drain after 2 RPCs. 83 originalMaxStreamID := transport.MaxStreamID 84 transport.MaxStreamID = 4 85 defer func() { 86 transport.MaxStreamID = originalMaxStreamID 87 }() 88 89 ss := &stubserver.StubServer{ 90 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { 91 if _, err := stream.Recv(); err != nil { 92 return status.Errorf(codes.Internal, "unexpected error receiving: %v", err) 93 } 94 if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil { 95 return status.Errorf(codes.Internal, "unexpected error sending: %v", err) 96 } 97 if recv, err := stream.Recv(); err != io.EOF { 98 return status.Errorf(codes.Internal, "Recv = %v, %v; want _, io.EOF", recv, err) 99 } 100 return nil 101 }, 102 } 103 104 creds := &transportRestartCheckCreds{} 105 if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil { 106 t.Fatalf("Starting stubServer: %v", err) 107 } 108 defer ss.Stop() 109 110 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 111 defer cancel() 112 113 var streams []testgrpc.TestService_FullDuplexCallClient 114 115 const numStreams = 3 116 // expected number of conns when each stream is created i.e., 3rd stream is created 117 // on a new connection. 118 expectedNumConns := [numStreams]int{1, 1, 2} 119 120 // Set up 3 streams. 121 for i := 0; i < numStreams; i++ { 122 s, err := ss.Client.FullDuplexCall(ctx) 123 if err != nil { 124 t.Fatalf("Creating FullDuplex stream: %v", err) 125 } 126 streams = append(streams, s) 127 // Verify expected num of conns after each stream is created. 128 if len(creds.connections) != expectedNumConns[i] { 129 t.Fatalf("Got number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i]) 130 } 131 } 132 133 // Verify all streams still work. 134 for i, stream := range streams { 135 if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { 136 t.Fatalf("Sending on stream %d: %v", i, err) 137 } 138 if _, err := stream.Recv(); err != nil { 139 t.Fatalf("Receiving on stream %d: %v", i, err) 140 } 141 } 142 143 for i, stream := range streams { 144 if err := stream.CloseSend(); err != nil { 145 t.Fatalf("CloseSend() on stream %d: %v", i, err) 146 } 147 } 148 149 // Verifying first connection was closed. 150 select { 151 case <-creds.connections[0].close.Done(): 152 case <-ctx.Done(): 153 t.Fatal("Timeout expired when waiting for first client transport to close") 154 } 155 }