google.golang.org/grpc@v1.74.2/xds/internal/clients/grpctransport/grpc_transport_test.go (about)

     1  /*
     2   *
     3   * Copyright 2025 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package grpctransport
    20  
    21  import (
    22  	"context"
    23  	"io"
    24  	"net"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/google/go-cmp/cmp"
    29  	"google.golang.org/grpc"
    30  	"google.golang.org/grpc/credentials"
    31  	"google.golang.org/grpc/credentials/insecure"
    32  	"google.golang.org/grpc/credentials/local"
    33  	"google.golang.org/grpc/internal/grpctest"
    34  	"google.golang.org/grpc/xds/internal/clients"
    35  	"google.golang.org/protobuf/proto"
    36  	"google.golang.org/protobuf/testing/protocmp"
    37  
    38  	v3discoverygrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
    39  	v3discoverypb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
    40  )
    41  
    42  const (
    43  	defaultTestTimeout      = 10 * time.Second
    44  	defaultTestShortTimeout = 10 * time.Millisecond // For events expected to *not* happen.
    45  )
    46  
    47  type s struct {
    48  	grpctest.Tester
    49  }
    50  
    51  func Test(t *testing.T) {
    52  	grpctest.RunSubTests(t, s{})
    53  }
    54  
    55  // testServer implements the AggregatedDiscoveryServiceServer interface to test
    56  // the gRPC transport implementation.
    57  type testServer struct {
    58  	v3discoverygrpc.UnimplementedAggregatedDiscoveryServiceServer
    59  
    60  	address     string                               // address of the server
    61  	requestChan chan *v3discoverypb.DiscoveryRequest // channel to send the received requests on for verification
    62  	response    *v3discoverypb.DiscoveryResponse     // response to send back to the client from handler
    63  }
    64  
    65  // setupTestServer set up the gRPC server for AggregatedDiscoveryService. It
    66  // creates an instance of testServer that returns the provided response from
    67  // the StreamAggregatedResources() handler and registers it with a gRPC server.
    68  func setupTestServer(t *testing.T, response *v3discoverypb.DiscoveryResponse) *testServer {
    69  	t.Helper()
    70  
    71  	lis, err := net.Listen("tcp", "localhost:0")
    72  	if err != nil {
    73  		t.Fatalf("Failed to listen on localhost:0: %v", err)
    74  	}
    75  	ts := &testServer{
    76  		requestChan: make(chan *v3discoverypb.DiscoveryRequest),
    77  		address:     lis.Addr().String(),
    78  		response:    response,
    79  	}
    80  
    81  	s := grpc.NewServer()
    82  
    83  	v3discoverygrpc.RegisterAggregatedDiscoveryServiceServer(s, ts)
    84  	go s.Serve(lis)
    85  	t.Cleanup(s.Stop)
    86  
    87  	return ts
    88  }
    89  
    90  // StreamAggregatedResources handles bidirectional streaming of
    91  // DiscoveryRequest and DiscoveryResponse. It waits for a message from the
    92  // client on the stream, and then sends a discovery response message back to
    93  // the client. It also put the received message in requestChan for client to
    94  // verify if the correct request was received. It continues until the client
    95  // closes the stream.
    96  func (s *testServer) StreamAggregatedResources(stream v3discoverygrpc.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error {
    97  	ctx := stream.Context()
    98  
    99  	for {
   100  		// Receive a DiscoveryRequest from the client
   101  		req, err := stream.Recv()
   102  		if err == io.EOF {
   103  			return nil // Stream closed by client
   104  		}
   105  		if err != nil {
   106  			return err // Handle other errors
   107  		}
   108  
   109  		// Push received request for client to verify the correct request was
   110  		// received.
   111  		select {
   112  		case s.requestChan <- req:
   113  		case <-ctx.Done():
   114  			return ctx.Err()
   115  		}
   116  
   117  		// Send the response back to the client
   118  		if err := stream.Send(s.response); err != nil {
   119  			return err
   120  		}
   121  	}
   122  }
   123  
   124  type testCredentials struct {
   125  	credentials.Bundle
   126  	transportCredentials credentials.TransportCredentials
   127  }
   128  
   129  func (tc *testCredentials) TransportCredentials() credentials.TransportCredentials {
   130  	return tc.transportCredentials
   131  }
   132  func (tc *testCredentials) PerRPCCredentials() credentials.PerRPCCredentials {
   133  	return nil
   134  }
   135  
   136  // TestBuild_Success verifies that the Builder successfully creates a new
   137  // Transport in both cases when provided clients.ServerIdentifer is same
   138  // one of the existing transport or a new one.
   139  func (s) TestBuild_Success(t *testing.T) {
   140  	configs := map[string]Config{
   141  		"local":    {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}},
   142  		"insecure": {Credentials: insecure.NewBundle()},
   143  	}
   144  	b := NewBuilder(configs)
   145  
   146  	serverID1 := clients.ServerIdentifier{
   147  		ServerURI:  "server-address",
   148  		Extensions: ServerIdentifierExtension{ConfigName: "local"},
   149  	}
   150  	tr1, err := b.Build(serverID1)
   151  	if err != nil {
   152  		t.Fatalf("Build(serverID1) call failed: %v", err)
   153  	}
   154  	defer tr1.Close()
   155  
   156  	serverID2 := clients.ServerIdentifier{
   157  		ServerURI:  "server-address",
   158  		Extensions: ServerIdentifierExtension{ConfigName: "local"},
   159  	}
   160  	tr2, err := b.Build(serverID2)
   161  	if err != nil {
   162  		t.Fatalf("Build(serverID2) call failed: %v", err)
   163  	}
   164  	defer tr2.Close()
   165  
   166  	serverID3 := clients.ServerIdentifier{
   167  		ServerURI:  "server-address",
   168  		Extensions: ServerIdentifierExtension{ConfigName: "insecure"},
   169  	}
   170  	tr3, err := b.Build(serverID3)
   171  	if err != nil {
   172  		t.Fatalf("Build(serverID3) call failed: %v", err)
   173  	}
   174  	defer tr3.Close()
   175  }
   176  
   177  // TestBuild_Failure verifies that the Builder returns error when incorrect
   178  // ServerIdentifier is provided.
   179  //
   180  // It covers the following scenarios:
   181  // - ServerURI is empty.
   182  // - Extensions is nil.
   183  // - Extensions is not ServerIdentifierExtension.
   184  // - Credentials are nil.
   185  func (s) TestBuild_Failure(t *testing.T) {
   186  	tests := []struct {
   187  		name     string
   188  		serverID clients.ServerIdentifier
   189  	}{
   190  		{
   191  			name: "ServerURI is empty",
   192  			serverID: clients.ServerIdentifier{
   193  				ServerURI:  "",
   194  				Extensions: ServerIdentifierExtension{ConfigName: "local"},
   195  			},
   196  		},
   197  		{
   198  			name:     "Extensions is nil",
   199  			serverID: clients.ServerIdentifier{ServerURI: "server-address"},
   200  		},
   201  		{
   202  			name: "Extensions is not a ServerIdentifierExtension",
   203  			serverID: clients.ServerIdentifier{
   204  				ServerURI:  "server-address",
   205  				Extensions: 1,
   206  			},
   207  		},
   208  		{
   209  			name: "ServerIdentifierExtension without ConfigName",
   210  			serverID: clients.ServerIdentifier{
   211  				ServerURI:  "server-address",
   212  				Extensions: ServerIdentifierExtension{},
   213  			},
   214  		},
   215  		{
   216  			name: "ServerIdentifierExtension ConfigName is not present",
   217  			serverID: clients.ServerIdentifier{
   218  				ServerURI:  "server-address",
   219  				Extensions: ServerIdentifierExtension{ConfigName: "unknown"},
   220  			},
   221  		},
   222  		{
   223  			name: "ServerIdentifierExtension ConfigName maps to nil credentials",
   224  			serverID: clients.ServerIdentifier{
   225  				ServerURI:  "server-address",
   226  				Extensions: ServerIdentifierExtension{ConfigName: "nil-credentials"},
   227  			},
   228  		},
   229  		{
   230  			name: "ServerIdentifierExtension is added as pointer",
   231  			serverID: clients.ServerIdentifier{
   232  				ServerURI:  "server-address",
   233  				Extensions: &ServerIdentifierExtension{ConfigName: "local"},
   234  			},
   235  		},
   236  	}
   237  	for _, test := range tests {
   238  		t.Run(test.name, func(t *testing.T) {
   239  			configs := map[string]Config{
   240  				"local":           {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}},
   241  				"nil-credentials": {Credentials: nil},
   242  			}
   243  			b := NewBuilder(configs)
   244  			tr, err := b.Build(test.serverID)
   245  			if err == nil {
   246  				t.Fatalf("Build() succeeded, want error")
   247  			}
   248  			if tr != nil {
   249  				t.Fatalf("Got non-nil transport from Build(), want nil")
   250  			}
   251  		})
   252  	}
   253  }
   254  
   255  // TestNewStream_Success verifies that NewStream() successfully creates a new
   256  // client stream for the server when provided a valid server URI and a config
   257  // with valid credentials.
   258  func (s) TestNewStream_Success(t *testing.T) {
   259  	ts := setupTestServer(t, &v3discoverypb.DiscoveryResponse{VersionInfo: "1"})
   260  
   261  	serverCfg := clients.ServerIdentifier{
   262  		ServerURI:  ts.address,
   263  		Extensions: ServerIdentifierExtension{ConfigName: "local"},
   264  	}
   265  	configs := map[string]Config{
   266  		"local": {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}},
   267  	}
   268  	builder := NewBuilder(configs)
   269  	transport, err := builder.Build(serverCfg)
   270  	if err != nil {
   271  		t.Fatalf("Failed to build transport: %v", err)
   272  	}
   273  	defer transport.Close()
   274  
   275  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   276  	defer cancel()
   277  	if _, err = transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources"); err != nil {
   278  		t.Fatalf("transport.NewStream() failed: %v", err)
   279  	}
   280  }
   281  
   282  // TestNewStream_Success_WithCustomGRPCNewClient verifies that NewStream()
   283  // successfully creates a new client stream for the server when provided a
   284  // valid server URI and a config with valid credentials and a custom gRPC
   285  // NewClient function.
   286  func (s) TestNewStream_Success_WithCustomGRPCNewClient(t *testing.T) {
   287  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   288  	defer cancel()
   289  
   290  	ts := setupTestServer(t, &v3discoverypb.DiscoveryResponse{VersionInfo: "1"})
   291  
   292  	// Create a custom dialer function that will be used by the gRPC client.
   293  	customDialerCalled := make(chan struct{}, 1)
   294  	customGRPCNewClient := func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
   295  		customDialerCalled <- struct{}{}
   296  		return grpc.NewClient(target, opts...)
   297  	}
   298  
   299  	configs := map[string]Config{
   300  		"custom-dialer-config": {
   301  			Credentials:   &testCredentials{transportCredentials: local.NewCredentials()},
   302  			GRPCNewClient: customGRPCNewClient,
   303  		},
   304  	}
   305  	builder := NewBuilder(configs)
   306  
   307  	serverID := clients.ServerIdentifier{
   308  		ServerURI:  ts.address,
   309  		Extensions: ServerIdentifierExtension{ConfigName: "custom-dialer-config"},
   310  	}
   311  
   312  	transport, err := builder.Build(serverID)
   313  	if err != nil {
   314  		t.Fatalf("builder.Build(%+v) failed: %v", serverID, err)
   315  	}
   316  	defer transport.Close()
   317  
   318  	select {
   319  	case <-customDialerCalled:
   320  	case <-ctx.Done():
   321  		t.Fatalf("Timeout waiting for custom dialer to be called: %v", ctx.Err())
   322  	}
   323  
   324  	// Verify that the transport works by creating a stream.
   325  	if _, err = transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources"); err != nil {
   326  		t.Fatalf("transport.NewStream() failed with custom dialer: %v", err)
   327  	}
   328  }
   329  
   330  // TestNewStream_Error verifies that NewStream() returns an error
   331  // when attempting to create a stream with an invalid server URI.
   332  func (s) TestNewStream_Error(t *testing.T) {
   333  	serverCfg := clients.ServerIdentifier{
   334  		ServerURI:  "invalid-server-uri",
   335  		Extensions: ServerIdentifierExtension{ConfigName: "local"},
   336  	}
   337  	configs := map[string]Config{
   338  		"local": {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}},
   339  	}
   340  	builder := NewBuilder(configs)
   341  	transport, err := builder.Build(serverCfg)
   342  	if err != nil {
   343  		t.Fatalf("Failed to build transport: %v", err)
   344  	}
   345  	defer transport.Close()
   346  
   347  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   348  	defer cancel()
   349  	if _, err = transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources"); err == nil {
   350  		t.Fatal("transport.NewStream() succeeded, want failure")
   351  	}
   352  }
   353  
   354  // TestStream_SendAndRecv verifies that Send() and Recv() successfully send
   355  // and receive messages on the stream to and from the gRPC server.
   356  //
   357  // It starts a gRPC test server using setupTestServer(). The test then sends a
   358  // testDiscoverRequest on the stream and verifies that the received discovery
   359  // request on the server is same as sent. It then wait to receive a
   360  // testDiscoverResponse from the server and verifies that the received
   361  // discovery response is same as sent from the server.
   362  func (s) TestStream_SendAndRecv(t *testing.T) {
   363  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   364  	defer cancel()
   365  
   366  	ts := setupTestServer(t, &v3discoverypb.DiscoveryResponse{VersionInfo: "1"})
   367  
   368  	// Build a grpc-based transport to the above server.
   369  	serverCfg := clients.ServerIdentifier{
   370  		ServerURI:  ts.address,
   371  		Extensions: ServerIdentifierExtension{ConfigName: "local"},
   372  	}
   373  	configs := map[string]Config{
   374  		"local": {Credentials: &testCredentials{transportCredentials: local.NewCredentials()}},
   375  	}
   376  	builder := NewBuilder(configs)
   377  	transport, err := builder.Build(serverCfg)
   378  	if err != nil {
   379  		t.Fatalf("Failed to build transport: %v", err)
   380  	}
   381  	defer transport.Close()
   382  
   383  	// Create a new stream to the server.
   384  	stream, err := transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources")
   385  	if err != nil {
   386  		t.Fatalf("Failed to create stream: %v", err)
   387  	}
   388  
   389  	// Send a discovery request message on the stream.
   390  	testDiscoverRequest := &v3discoverypb.DiscoveryRequest{VersionInfo: "1"}
   391  	msg, err := proto.Marshal(testDiscoverRequest)
   392  	if err != nil {
   393  		t.Fatalf("Failed to marshal DiscoveryRequest: %v", err)
   394  	}
   395  	if err := stream.Send(msg); err != nil {
   396  		t.Fatalf("Failed to send message: %v", err)
   397  	}
   398  
   399  	// Verify that the DiscoveryRequest received on the server was same as
   400  	// sent.
   401  	select {
   402  	case gotReq := <-ts.requestChan:
   403  		if diff := cmp.Diff(testDiscoverRequest, gotReq, protocmp.Transform()); diff != "" {
   404  			t.Fatalf("Unexpected diff in request received on server (-want +got):\n%s", diff)
   405  		}
   406  	case <-ctx.Done():
   407  		t.Fatalf("Timeout waiting for request to reach server")
   408  	}
   409  
   410  	// Wait until response message is received from the server.
   411  	res, err := stream.Recv()
   412  	if err != nil {
   413  		t.Fatalf("Failed to receive message: %v", err)
   414  	}
   415  
   416  	// Verify that the DiscoveryResponse received was same as sent from the
   417  	// server.
   418  	var gotRes v3discoverypb.DiscoveryResponse
   419  	if err := proto.Unmarshal(res, &gotRes); err != nil {
   420  		t.Fatalf("Failed to unmarshal response from server to DiscoveryResponse: %v", err)
   421  	}
   422  	if diff := cmp.Diff(ts.response, &gotRes, protocmp.Transform()); diff != "" {
   423  		t.Fatalf("proto.Unmarshal(res, &gotRes) returned unexpected diff (-want +got):\n%s", diff)
   424  	}
   425  }