google.golang.org/grpc@v1.74.2/xds/internal/clients/grpctransport/grpc_transport_test.go (about) 1 /* 2 * 3 * Copyright 2025 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package grpctransport 20 21 import ( 22 "context" 23 "io" 24 "net" 25 "testing" 26 "time" 27 28 "github.com/google/go-cmp/cmp" 29 "google.golang.org/grpc" 30 "google.golang.org/grpc/credentials" 31 "google.golang.org/grpc/credentials/insecure" 32 "google.golang.org/grpc/credentials/local" 33 "google.golang.org/grpc/internal/grpctest" 34 "google.golang.org/grpc/xds/internal/clients" 35 "google.golang.org/protobuf/proto" 36 "google.golang.org/protobuf/testing/protocmp" 37 38 v3discoverygrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" 39 v3discoverypb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" 40 ) 41 42 const ( 43 defaultTestTimeout = 10 * time.Second 44 defaultTestShortTimeout = 10 * time.Millisecond // For events expected to *not* happen. 45 ) 46 47 type s struct { 48 grpctest.Tester 49 } 50 51 func Test(t *testing.T) { 52 grpctest.RunSubTests(t, s{}) 53 } 54 55 // testServer implements the AggregatedDiscoveryServiceServer interface to test 56 // the gRPC transport implementation. 57 type testServer struct { 58 v3discoverygrpc.UnimplementedAggregatedDiscoveryServiceServer 59 60 address string // address of the server 61 requestChan chan *v3discoverypb.DiscoveryRequest // channel to send the received requests on for verification 62 response *v3discoverypb.DiscoveryResponse // response to send back to the client from handler 63 } 64 65 // setupTestServer set up the gRPC server for AggregatedDiscoveryService. It 66 // creates an instance of testServer that returns the provided response from 67 // the StreamAggregatedResources() handler and registers it with a gRPC server. 68 func setupTestServer(t *testing.T, response *v3discoverypb.DiscoveryResponse) *testServer { 69 t.Helper() 70 71 lis, err := net.Listen("tcp", "localhost:0") 72 if err != nil { 73 t.Fatalf("Failed to listen on localhost:0: %v", err) 74 } 75 ts := &testServer{ 76 requestChan: make(chan *v3discoverypb.DiscoveryRequest), 77 address: lis.Addr().String(), 78 response: response, 79 } 80 81 s := grpc.NewServer() 82 83 v3discoverygrpc.RegisterAggregatedDiscoveryServiceServer(s, ts) 84 go s.Serve(lis) 85 t.Cleanup(s.Stop) 86 87 return ts 88 } 89 90 // StreamAggregatedResources handles bidirectional streaming of 91 // DiscoveryRequest and DiscoveryResponse. It waits for a message from the 92 // client on the stream, and then sends a discovery response message back to 93 // the client. It also put the received message in requestChan for client to 94 // verify if the correct request was received. It continues until the client 95 // closes the stream. 96 func (s *testServer) StreamAggregatedResources(stream v3discoverygrpc.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error { 97 ctx := stream.Context() 98 99 for { 100 // Receive a DiscoveryRequest from the client 101 req, err := stream.Recv() 102 if err == io.EOF { 103 return nil // Stream closed by client 104 } 105 if err != nil { 106 return err // Handle other errors 107 } 108 109 // Push received request for client to verify the correct request was 110 // received. 111 select { 112 case s.requestChan <- req: 113 case <-ctx.Done(): 114 return ctx.Err() 115 } 116 117 // Send the response back to the client 118 if err := stream.Send(s.response); err != nil { 119 return err 120 } 121 } 122 } 123 124 type testCredentials struct { 125 credentials.Bundle 126 transportCredentials credentials.TransportCredentials 127 } 128 129 func (tc *testCredentials) TransportCredentials() credentials.TransportCredentials { 130 return tc.transportCredentials 131 } 132 func (tc *testCredentials) PerRPCCredentials() credentials.PerRPCCredentials { 133 return nil 134 } 135 136 // TestBuild_Success verifies that the Builder successfully creates a new 137 // Transport in both cases when provided clients.ServerIdentifer is same 138 // one of the existing transport or a new one. 139 func (s) TestBuild_Success(t *testing.T) { 140 configs := map[string]Config{ 141 "local": {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}}, 142 "insecure": {Credentials: insecure.NewBundle()}, 143 } 144 b := NewBuilder(configs) 145 146 serverID1 := clients.ServerIdentifier{ 147 ServerURI: "server-address", 148 Extensions: ServerIdentifierExtension{ConfigName: "local"}, 149 } 150 tr1, err := b.Build(serverID1) 151 if err != nil { 152 t.Fatalf("Build(serverID1) call failed: %v", err) 153 } 154 defer tr1.Close() 155 156 serverID2 := clients.ServerIdentifier{ 157 ServerURI: "server-address", 158 Extensions: ServerIdentifierExtension{ConfigName: "local"}, 159 } 160 tr2, err := b.Build(serverID2) 161 if err != nil { 162 t.Fatalf("Build(serverID2) call failed: %v", err) 163 } 164 defer tr2.Close() 165 166 serverID3 := clients.ServerIdentifier{ 167 ServerURI: "server-address", 168 Extensions: ServerIdentifierExtension{ConfigName: "insecure"}, 169 } 170 tr3, err := b.Build(serverID3) 171 if err != nil { 172 t.Fatalf("Build(serverID3) call failed: %v", err) 173 } 174 defer tr3.Close() 175 } 176 177 // TestBuild_Failure verifies that the Builder returns error when incorrect 178 // ServerIdentifier is provided. 179 // 180 // It covers the following scenarios: 181 // - ServerURI is empty. 182 // - Extensions is nil. 183 // - Extensions is not ServerIdentifierExtension. 184 // - Credentials are nil. 185 func (s) TestBuild_Failure(t *testing.T) { 186 tests := []struct { 187 name string 188 serverID clients.ServerIdentifier 189 }{ 190 { 191 name: "ServerURI is empty", 192 serverID: clients.ServerIdentifier{ 193 ServerURI: "", 194 Extensions: ServerIdentifierExtension{ConfigName: "local"}, 195 }, 196 }, 197 { 198 name: "Extensions is nil", 199 serverID: clients.ServerIdentifier{ServerURI: "server-address"}, 200 }, 201 { 202 name: "Extensions is not a ServerIdentifierExtension", 203 serverID: clients.ServerIdentifier{ 204 ServerURI: "server-address", 205 Extensions: 1, 206 }, 207 }, 208 { 209 name: "ServerIdentifierExtension without ConfigName", 210 serverID: clients.ServerIdentifier{ 211 ServerURI: "server-address", 212 Extensions: ServerIdentifierExtension{}, 213 }, 214 }, 215 { 216 name: "ServerIdentifierExtension ConfigName is not present", 217 serverID: clients.ServerIdentifier{ 218 ServerURI: "server-address", 219 Extensions: ServerIdentifierExtension{ConfigName: "unknown"}, 220 }, 221 }, 222 { 223 name: "ServerIdentifierExtension ConfigName maps to nil credentials", 224 serverID: clients.ServerIdentifier{ 225 ServerURI: "server-address", 226 Extensions: ServerIdentifierExtension{ConfigName: "nil-credentials"}, 227 }, 228 }, 229 { 230 name: "ServerIdentifierExtension is added as pointer", 231 serverID: clients.ServerIdentifier{ 232 ServerURI: "server-address", 233 Extensions: &ServerIdentifierExtension{ConfigName: "local"}, 234 }, 235 }, 236 } 237 for _, test := range tests { 238 t.Run(test.name, func(t *testing.T) { 239 configs := map[string]Config{ 240 "local": {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}}, 241 "nil-credentials": {Credentials: nil}, 242 } 243 b := NewBuilder(configs) 244 tr, err := b.Build(test.serverID) 245 if err == nil { 246 t.Fatalf("Build() succeeded, want error") 247 } 248 if tr != nil { 249 t.Fatalf("Got non-nil transport from Build(), want nil") 250 } 251 }) 252 } 253 } 254 255 // TestNewStream_Success verifies that NewStream() successfully creates a new 256 // client stream for the server when provided a valid server URI and a config 257 // with valid credentials. 258 func (s) TestNewStream_Success(t *testing.T) { 259 ts := setupTestServer(t, &v3discoverypb.DiscoveryResponse{VersionInfo: "1"}) 260 261 serverCfg := clients.ServerIdentifier{ 262 ServerURI: ts.address, 263 Extensions: ServerIdentifierExtension{ConfigName: "local"}, 264 } 265 configs := map[string]Config{ 266 "local": {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}}, 267 } 268 builder := NewBuilder(configs) 269 transport, err := builder.Build(serverCfg) 270 if err != nil { 271 t.Fatalf("Failed to build transport: %v", err) 272 } 273 defer transport.Close() 274 275 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 276 defer cancel() 277 if _, err = transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources"); err != nil { 278 t.Fatalf("transport.NewStream() failed: %v", err) 279 } 280 } 281 282 // TestNewStream_Success_WithCustomGRPCNewClient verifies that NewStream() 283 // successfully creates a new client stream for the server when provided a 284 // valid server URI and a config with valid credentials and a custom gRPC 285 // NewClient function. 286 func (s) TestNewStream_Success_WithCustomGRPCNewClient(t *testing.T) { 287 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 288 defer cancel() 289 290 ts := setupTestServer(t, &v3discoverypb.DiscoveryResponse{VersionInfo: "1"}) 291 292 // Create a custom dialer function that will be used by the gRPC client. 293 customDialerCalled := make(chan struct{}, 1) 294 customGRPCNewClient := func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { 295 customDialerCalled <- struct{}{} 296 return grpc.NewClient(target, opts...) 297 } 298 299 configs := map[string]Config{ 300 "custom-dialer-config": { 301 Credentials: &testCredentials{transportCredentials: local.NewCredentials()}, 302 GRPCNewClient: customGRPCNewClient, 303 }, 304 } 305 builder := NewBuilder(configs) 306 307 serverID := clients.ServerIdentifier{ 308 ServerURI: ts.address, 309 Extensions: ServerIdentifierExtension{ConfigName: "custom-dialer-config"}, 310 } 311 312 transport, err := builder.Build(serverID) 313 if err != nil { 314 t.Fatalf("builder.Build(%+v) failed: %v", serverID, err) 315 } 316 defer transport.Close() 317 318 select { 319 case <-customDialerCalled: 320 case <-ctx.Done(): 321 t.Fatalf("Timeout waiting for custom dialer to be called: %v", ctx.Err()) 322 } 323 324 // Verify that the transport works by creating a stream. 325 if _, err = transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources"); err != nil { 326 t.Fatalf("transport.NewStream() failed with custom dialer: %v", err) 327 } 328 } 329 330 // TestNewStream_Error verifies that NewStream() returns an error 331 // when attempting to create a stream with an invalid server URI. 332 func (s) TestNewStream_Error(t *testing.T) { 333 serverCfg := clients.ServerIdentifier{ 334 ServerURI: "invalid-server-uri", 335 Extensions: ServerIdentifierExtension{ConfigName: "local"}, 336 } 337 configs := map[string]Config{ 338 "local": {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}}, 339 } 340 builder := NewBuilder(configs) 341 transport, err := builder.Build(serverCfg) 342 if err != nil { 343 t.Fatalf("Failed to build transport: %v", err) 344 } 345 defer transport.Close() 346 347 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 348 defer cancel() 349 if _, err = transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources"); err == nil { 350 t.Fatal("transport.NewStream() succeeded, want failure") 351 } 352 } 353 354 // TestStream_SendAndRecv verifies that Send() and Recv() successfully send 355 // and receive messages on the stream to and from the gRPC server. 356 // 357 // It starts a gRPC test server using setupTestServer(). The test then sends a 358 // testDiscoverRequest on the stream and verifies that the received discovery 359 // request on the server is same as sent. It then wait to receive a 360 // testDiscoverResponse from the server and verifies that the received 361 // discovery response is same as sent from the server. 362 func (s) TestStream_SendAndRecv(t *testing.T) { 363 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 364 defer cancel() 365 366 ts := setupTestServer(t, &v3discoverypb.DiscoveryResponse{VersionInfo: "1"}) 367 368 // Build a grpc-based transport to the above server. 369 serverCfg := clients.ServerIdentifier{ 370 ServerURI: ts.address, 371 Extensions: ServerIdentifierExtension{ConfigName: "local"}, 372 } 373 configs := map[string]Config{ 374 "local": {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}}, 375 } 376 builder := NewBuilder(configs) 377 transport, err := builder.Build(serverCfg) 378 if err != nil { 379 t.Fatalf("Failed to build transport: %v", err) 380 } 381 defer transport.Close() 382 383 // Create a new stream to the server. 384 stream, err := transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources") 385 if err != nil { 386 t.Fatalf("Failed to create stream: %v", err) 387 } 388 389 // Send a discovery request message on the stream. 390 testDiscoverRequest := &v3discoverypb.DiscoveryRequest{VersionInfo: "1"} 391 msg, err := proto.Marshal(testDiscoverRequest) 392 if err != nil { 393 t.Fatalf("Failed to marshal DiscoveryRequest: %v", err) 394 } 395 if err := stream.Send(msg); err != nil { 396 t.Fatalf("Failed to send message: %v", err) 397 } 398 399 // Verify that the DiscoveryRequest received on the server was same as 400 // sent. 401 select { 402 case gotReq := <-ts.requestChan: 403 if diff := cmp.Diff(testDiscoverRequest, gotReq, protocmp.Transform()); diff != "" { 404 t.Fatalf("Unexpected diff in request received on server (-want +got):\n%s", diff) 405 } 406 case <-ctx.Done(): 407 t.Fatalf("Timeout waiting for request to reach server") 408 } 409 410 // Wait until response message is received from the server. 411 res, err := stream.Recv() 412 if err != nil { 413 t.Fatalf("Failed to receive message: %v", err) 414 } 415 416 // Verify that the DiscoveryResponse received was same as sent from the 417 // server. 418 var gotRes v3discoverypb.DiscoveryResponse 419 if err := proto.Unmarshal(res, &gotRes); err != nil { 420 t.Fatalf("Failed to unmarshal response from server to DiscoveryResponse: %v", err) 421 } 422 if diff := cmp.Diff(ts.response, &gotRes, protocmp.Transform()); diff != "" { 423 t.Fatalf("proto.Unmarshal(res, &gotRes) returned unexpected diff (-want +got):\n%s", diff) 424 } 425 }