google.golang.org/grpc@v1.72.2/xds/internal/clients/grpctransport/grpc_transport_test.go (about)

     1  /*
     2   *
     3   * Copyright 2025 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package grpctransport
    20  
    21  import (
    22  	"context"
    23  	"io"
    24  	"net"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/google/go-cmp/cmp"
    29  	"google.golang.org/grpc"
    30  	"google.golang.org/grpc/credentials"
    31  	"google.golang.org/grpc/credentials/insecure"
    32  	"google.golang.org/grpc/credentials/local"
    33  	"google.golang.org/grpc/internal/grpctest"
    34  	"google.golang.org/grpc/xds/internal/clients"
    35  	"google.golang.org/protobuf/proto"
    36  	"google.golang.org/protobuf/testing/protocmp"
    37  
    38  	v3discoverygrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
    39  	v3discoverypb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
    40  )
    41  
    42  const (
    43  	defaultTestTimeout = 10 * time.Second
    44  )
    45  
    46  type s struct {
    47  	grpctest.Tester
    48  }
    49  
    50  func Test(t *testing.T) {
    51  	grpctest.RunSubTests(t, s{})
    52  }
    53  
    54  // testServer implements the AggregatedDiscoveryServiceServer interface to test
    55  // the gRPC transport implementation.
    56  type testServer struct {
    57  	v3discoverygrpc.UnimplementedAggregatedDiscoveryServiceServer
    58  
    59  	address     string                               // address of the server
    60  	requestChan chan *v3discoverypb.DiscoveryRequest // channel to send the received requests on for verification
    61  	response    *v3discoverypb.DiscoveryResponse     // response to send back to the client from handler
    62  }
    63  
    64  // setupTestServer set up the gRPC server for AggregatedDiscoveryService. It
    65  // creates an instance of testServer that returns the provided response from
    66  // the StreamAggregatedResources() handler and registers it with a gRPC server.
    67  func setupTestServer(t *testing.T, response *v3discoverypb.DiscoveryResponse) *testServer {
    68  	t.Helper()
    69  
    70  	lis, err := net.Listen("tcp", "localhost:0")
    71  	if err != nil {
    72  		t.Fatalf("Failed to listen on localhost:0: %v", err)
    73  	}
    74  	ts := &testServer{
    75  		requestChan: make(chan *v3discoverypb.DiscoveryRequest),
    76  		address:     lis.Addr().String(),
    77  		response:    response,
    78  	}
    79  
    80  	s := grpc.NewServer()
    81  
    82  	v3discoverygrpc.RegisterAggregatedDiscoveryServiceServer(s, ts)
    83  	go s.Serve(lis)
    84  	t.Cleanup(s.Stop)
    85  
    86  	return ts
    87  }
    88  
    89  // StreamAggregatedResources handles bidirectional streaming of
    90  // DiscoveryRequest and DiscoveryResponse. It waits for a message from the
    91  // client on the stream, and then sends a discovery response message back to
    92  // the client. It also put the received message in requestChan for client to
    93  // verify if the correct request was received. It continues until the client
    94  // closes the stream.
    95  func (s *testServer) StreamAggregatedResources(stream v3discoverygrpc.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error {
    96  	ctx := stream.Context()
    97  
    98  	for {
    99  		// Receive a DiscoveryRequest from the client
   100  		req, err := stream.Recv()
   101  		if err == io.EOF {
   102  			return nil // Stream closed by client
   103  		}
   104  		if err != nil {
   105  			return err // Handle other errors
   106  		}
   107  
   108  		// Push received request for client to verify the correct request was
   109  		// received.
   110  		select {
   111  		case s.requestChan <- req:
   112  		case <-ctx.Done():
   113  			return ctx.Err()
   114  		}
   115  
   116  		// Send the response back to the client
   117  		if err := stream.Send(s.response); err != nil {
   118  			return err
   119  		}
   120  	}
   121  }
   122  
   123  type testCredentials struct {
   124  	credentials.Bundle
   125  	transportCredentials credentials.TransportCredentials
   126  }
   127  
   128  func (tc *testCredentials) TransportCredentials() credentials.TransportCredentials {
   129  	return tc.transportCredentials
   130  }
   131  
   132  // TestBuild_Success verifies that the Builder successfully creates a new
   133  // Transport with a non-nil grpc.ClientConn.
   134  func (s) TestBuild_Success(t *testing.T) {
   135  	serverCfg := clients.ServerIdentifier{
   136  		ServerURI:  "server-address",
   137  		Extensions: ServerIdentifierExtension{Credentials: &testCredentials{transportCredentials: local.NewCredentials()}},
   138  	}
   139  
   140  	b := &Builder{}
   141  	tr, err := b.Build(serverCfg)
   142  	if err != nil {
   143  		t.Fatalf("Build() failed: %v", err)
   144  	}
   145  	defer tr.Close()
   146  
   147  	if tr == nil {
   148  		t.Fatalf("Got nil transport from Build(), want non-nil")
   149  	}
   150  	if tr.(*grpcTransport).cc == nil {
   151  		t.Fatalf("Got nil grpc.ClientConn in transport, want non-nil")
   152  	}
   153  }
   154  
   155  // TestBuild_Failure verifies that the Builder returns error when incorrect
   156  // ServerIdentifier is provided.
   157  //
   158  // It covers the following scenarios:
   159  // - ServerURI is empty.
   160  // - Extensions is nil.
   161  // - Extensions is not ServerIdentifierExtension.
   162  // - Credentials are nil.
   163  func (s) TestBuild_Failure(t *testing.T) {
   164  	tests := []struct {
   165  		name      string
   166  		serverCfg clients.ServerIdentifier
   167  	}{
   168  		{
   169  			name: "ServerURI is empty",
   170  			serverCfg: clients.ServerIdentifier{
   171  				ServerURI:  "",
   172  				Extensions: ServerIdentifierExtension{Credentials: insecure.NewBundle()},
   173  			},
   174  		},
   175  		{
   176  			name:      "Extensions is nil",
   177  			serverCfg: clients.ServerIdentifier{ServerURI: "server-address"},
   178  		},
   179  		{
   180  			name: "Extensions is not a ServerIdentifierExtension",
   181  			serverCfg: clients.ServerIdentifier{
   182  				ServerURI:  "server-address",
   183  				Extensions: 1,
   184  			},
   185  		},
   186  		{
   187  			name: "ServerIdentifierExtension Credentials is nil",
   188  			serverCfg: clients.ServerIdentifier{
   189  				ServerURI:  "server-address",
   190  				Extensions: ServerIdentifierExtension{},
   191  			},
   192  		},
   193  	}
   194  	for _, test := range tests {
   195  		t.Run(test.name, func(t *testing.T) {
   196  			b := &Builder{}
   197  			tr, err := b.Build(test.serverCfg)
   198  			if err == nil {
   199  				t.Fatalf("Build() succeeded, want error")
   200  			}
   201  			if tr != nil {
   202  				t.Fatalf("Got non-nil transport from Build(), want nil")
   203  			}
   204  		})
   205  	}
   206  }
   207  
   208  // TestNewStream_Success verifies that NewStream() successfully creates a new
   209  // client stream for the server when provided a valid server URI.
   210  func (s) TestNewStream_Success(t *testing.T) {
   211  	ts := setupTestServer(t, &v3discoverypb.DiscoveryResponse{VersionInfo: "1"})
   212  
   213  	serverCfg := clients.ServerIdentifier{
   214  		ServerURI:  ts.address,
   215  		Extensions: ServerIdentifierExtension{Credentials: insecure.NewBundle()},
   216  	}
   217  	builder := Builder{}
   218  	transport, err := builder.Build(serverCfg)
   219  	if err != nil {
   220  		t.Fatalf("Failed to build transport: %v", err)
   221  	}
   222  	defer transport.Close()
   223  
   224  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   225  	defer cancel()
   226  	if _, err = transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources"); err != nil {
   227  		t.Fatalf("transport.NewStream() failed: %v", err)
   228  	}
   229  }
   230  
   231  // TestNewStream_Error verifies that NewStream() returns an error
   232  // when attempting to create a stream with an invalid server URI.
   233  func (s) TestNewStream_Error(t *testing.T) {
   234  	serverCfg := clients.ServerIdentifier{
   235  		ServerURI:  "invalid-server-uri",
   236  		Extensions: ServerIdentifierExtension{Credentials: insecure.NewBundle()},
   237  	}
   238  	builder := Builder{}
   239  	transport, err := builder.Build(serverCfg)
   240  	if err != nil {
   241  		t.Fatalf("Failed to build transport: %v", err)
   242  	}
   243  	defer transport.Close()
   244  
   245  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   246  	defer cancel()
   247  	if _, err = transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources"); err == nil {
   248  		t.Fatal("transport.NewStream() succeeded, want failure")
   249  	}
   250  }
   251  
   252  // TestStream_SendAndRecv verifies that Send() and Recv() successfully send
   253  // and receive messages on the stream to and from the gRPC server.
   254  //
   255  // It starts a gRPC test server using setupTestServer(). The test then sends a
   256  // testDiscoverRequest on the stream and verifies that the received discovery
   257  // request on the server is same as sent. It then wait to receive a
   258  // testDiscoverResponse from the server and verifies that the received
   259  // discovery response is same as sent from the server.
   260  func (s) TestStream_SendAndRecv(t *testing.T) {
   261  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout*2000)
   262  	defer cancel()
   263  
   264  	ts := setupTestServer(t, &v3discoverypb.DiscoveryResponse{VersionInfo: "1"})
   265  
   266  	// Build a grpc-based transport to the above server.
   267  	serverCfg := clients.ServerIdentifier{
   268  		ServerURI:  ts.address,
   269  		Extensions: ServerIdentifierExtension{Credentials: insecure.NewBundle()},
   270  	}
   271  	builder := Builder{}
   272  	transport, err := builder.Build(serverCfg)
   273  	if err != nil {
   274  		t.Fatalf("Failed to build transport: %v", err)
   275  	}
   276  	defer transport.Close()
   277  
   278  	// Create a new stream to the server.
   279  	stream, err := transport.NewStream(ctx, "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources")
   280  	if err != nil {
   281  		t.Fatalf("Failed to create stream: %v", err)
   282  	}
   283  
   284  	// Send a discovery request message on the stream.
   285  	testDiscoverRequest := &v3discoverypb.DiscoveryRequest{VersionInfo: "1"}
   286  	msg, err := proto.Marshal(testDiscoverRequest)
   287  	if err != nil {
   288  		t.Fatalf("Failed to marshal DiscoveryRequest: %v", err)
   289  	}
   290  	if err := stream.Send(msg); err != nil {
   291  		t.Fatalf("Failed to send message: %v", err)
   292  	}
   293  
   294  	// Verify that the DiscoveryRequest received on the server was same as
   295  	// sent.
   296  	select {
   297  	case gotReq := <-ts.requestChan:
   298  		if diff := cmp.Diff(testDiscoverRequest, gotReq, protocmp.Transform()); diff != "" {
   299  			t.Fatalf("Unexpected diff in request received on server (-want +got):\n%s", diff)
   300  		}
   301  	case <-ctx.Done():
   302  		t.Fatalf("Timeout waiting for request to reach server")
   303  	}
   304  
   305  	// Wait until response message is received from the server.
   306  	res, err := stream.Recv()
   307  	if err != nil {
   308  		t.Fatalf("Failed to receive message: %v", err)
   309  	}
   310  
   311  	// Verify that the DiscoveryResponse received was same as sent from the
   312  	// server.
   313  	var gotRes v3discoverypb.DiscoveryResponse
   314  	if err := proto.Unmarshal(res, &gotRes); err != nil {
   315  		t.Fatalf("Failed to unmarshal response from server to DiscoveryResponse: %v", err)
   316  	}
   317  	if diff := cmp.Diff(ts.response, &gotRes, protocmp.Transform()); diff != "" {
   318  		t.Fatalf("proto.Unmarshal(res, &gotRes) returned unexpected diff (-want +got):\n%s", diff)
   319  	}
   320  }