google.golang.org/grpc@v1.72.2/credentials/alts/alts_test.go (about)

     1  //go:build linux || windows
     2  // +build linux windows
     3  
     4  /*
     5   *
     6   * Copyright 2018 gRPC authors.
     7   *
     8   * Licensed under the Apache License, Version 2.0 (the "License");
     9   * you may not use this file except in compliance with the License.
    10   * You may obtain a copy of the License at
    11   *
    12   *     http://www.apache.org/licenses/LICENSE-2.0
    13   *
    14   * Unless required by applicable law or agreed to in writing, software
    15   * distributed under the License is distributed on an "AS IS" BASIS,
    16   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    17   * See the License for the specific language governing permissions and
    18   * limitations under the License.
    19   *
    20   */
    21  
    22  package alts
    23  
    24  import (
    25  	"context"
    26  	"reflect"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/credentials/alts/internal/handshaker"
    34  	"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
    35  	altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    36  	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    37  	"google.golang.org/grpc/credentials/alts/internal/testutil"
    38  	"google.golang.org/grpc/internal/grpctest"
    39  	"google.golang.org/grpc/internal/stubserver"
    40  	"google.golang.org/grpc/internal/testutils"
    41  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    42  	testpb "google.golang.org/grpc/interop/grpc_testing"
    43  	"google.golang.org/grpc/peer"
    44  	"google.golang.org/grpc/status"
    45  	"google.golang.org/protobuf/proto"
    46  )
    47  
    48  const (
    49  	defaultTestLongTimeout  = 60 * time.Second
    50  	defaultTestShortTimeout = 10 * time.Millisecond
    51  )
    52  
    53  type s struct {
    54  	grpctest.Tester
    55  }
    56  
    57  func init() {
    58  	// The vmOnGCP global variable MUST be forced to true. Otherwise, if
    59  	// this test is run anywhere except on a GCP VM, then an ALTS handshake
    60  	// will immediately fail.
    61  	once.Do(func() {})
    62  	vmOnGCP = true
    63  }
    64  
    65  func Test(t *testing.T) {
    66  	grpctest.RunSubTests(t, s{})
    67  }
    68  
    69  func (s) TestInfoServerName(t *testing.T) {
    70  	// This is not testing any handshaker functionality, so it's fine to only
    71  	// use NewServerCreds and not NewClientCreds.
    72  	alts := NewServerCreds(DefaultServerOptions())
    73  	if got, want := alts.Info().ServerName, ""; got != want {
    74  		t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
    75  	}
    76  }
    77  
    78  func (s) TestOverrideServerName(t *testing.T) {
    79  	wantServerName := "server.name"
    80  	// This is not testing any handshaker functionality, so it's fine to only
    81  	// use NewServerCreds and not NewClientCreds.
    82  	c := NewServerCreds(DefaultServerOptions())
    83  	c.OverrideServerName(wantServerName)
    84  	if got, want := c.Info().ServerName, wantServerName; got != want {
    85  		t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
    86  	}
    87  }
    88  
    89  func (s) TestCloneClient(t *testing.T) {
    90  	wantServerName := "server.name"
    91  	opt := DefaultClientOptions()
    92  	opt.TargetServiceAccounts = []string{"not", "empty"}
    93  	c := NewClientCreds(opt)
    94  	c.OverrideServerName(wantServerName)
    95  	cc := c.Clone()
    96  	if got, want := cc.Info().ServerName, wantServerName; got != want {
    97  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
    98  	}
    99  	cc.OverrideServerName("")
   100  	if got, want := c.Info().ServerName, wantServerName; got != want {
   101  		t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
   102  	}
   103  	if got, want := cc.Info().ServerName, ""; got != want {
   104  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
   105  	}
   106  
   107  	ct := c.(*altsTC)
   108  	cct := cc.(*altsTC)
   109  
   110  	if ct.side != cct.side {
   111  		t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
   112  	}
   113  	if ct.hsAddress != cct.hsAddress {
   114  		t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
   115  	}
   116  	if !reflect.DeepEqual(ct.accounts, cct.accounts) {
   117  		t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
   118  	}
   119  }
   120  
   121  func (s) TestCloneServer(t *testing.T) {
   122  	wantServerName := "server.name"
   123  	c := NewServerCreds(DefaultServerOptions())
   124  	c.OverrideServerName(wantServerName)
   125  	cc := c.Clone()
   126  	if got, want := cc.Info().ServerName, wantServerName; got != want {
   127  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
   128  	}
   129  	cc.OverrideServerName("")
   130  	if got, want := c.Info().ServerName, wantServerName; got != want {
   131  		t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
   132  	}
   133  	if got, want := cc.Info().ServerName, ""; got != want {
   134  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
   135  	}
   136  
   137  	ct := c.(*altsTC)
   138  	cct := cc.(*altsTC)
   139  
   140  	if ct.side != cct.side {
   141  		t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
   142  	}
   143  	if ct.hsAddress != cct.hsAddress {
   144  		t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
   145  	}
   146  	if !reflect.DeepEqual(ct.accounts, cct.accounts) {
   147  		t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
   148  	}
   149  }
   150  
   151  func (s) TestInfo(t *testing.T) {
   152  	// This is not testing any handshaker functionality, so it's fine to only
   153  	// use NewServerCreds and not NewClientCreds.
   154  	c := NewServerCreds(DefaultServerOptions())
   155  	info := c.Info()
   156  	if got, want := info.ProtocolVersion, ""; got != want {
   157  		t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
   158  	}
   159  	if got, want := info.SecurityProtocol, "alts"; got != want {
   160  		t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
   161  	}
   162  	if got, want := info.SecurityVersion, "1.0"; got != want {
   163  		t.Errorf("info.SecurityVersion=%v, want %v", got, want)
   164  	}
   165  	if got, want := info.ServerName, ""; got != want {
   166  		t.Errorf("info.ServerName=%v, want %v", got, want)
   167  	}
   168  }
   169  
   170  func (s) TestCompareRPCVersions(t *testing.T) {
   171  	for _, tc := range []struct {
   172  		v1     *altspb.RpcProtocolVersions_Version
   173  		v2     *altspb.RpcProtocolVersions_Version
   174  		output int
   175  	}{
   176  		{
   177  			version(3, 2),
   178  			version(2, 1),
   179  			1,
   180  		},
   181  		{
   182  			version(3, 2),
   183  			version(3, 1),
   184  			1,
   185  		},
   186  		{
   187  			version(2, 1),
   188  			version(3, 2),
   189  			-1,
   190  		},
   191  		{
   192  			version(3, 1),
   193  			version(3, 2),
   194  			-1,
   195  		},
   196  		{
   197  			version(3, 2),
   198  			version(3, 2),
   199  			0,
   200  		},
   201  	} {
   202  		if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
   203  			t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
   204  		}
   205  	}
   206  }
   207  
   208  func (s) TestCheckRPCVersions(t *testing.T) {
   209  	for _, tc := range []struct {
   210  		desc             string
   211  		local            *altspb.RpcProtocolVersions
   212  		peer             *altspb.RpcProtocolVersions
   213  		output           bool
   214  		maxCommonVersion *altspb.RpcProtocolVersions_Version
   215  	}{
   216  		{
   217  			"local.max > peer.max and local.min > peer.min",
   218  			versions(2, 1, 3, 2),
   219  			versions(1, 2, 2, 1),
   220  			true,
   221  			version(2, 1),
   222  		},
   223  		{
   224  			"local.max > peer.max and local.min < peer.min",
   225  			versions(1, 2, 3, 2),
   226  			versions(2, 1, 2, 1),
   227  			true,
   228  			version(2, 1),
   229  		},
   230  		{
   231  			"local.max > peer.max and local.min = peer.min",
   232  			versions(2, 1, 3, 2),
   233  			versions(2, 1, 2, 1),
   234  			true,
   235  			version(2, 1),
   236  		},
   237  		{
   238  			"local.max < peer.max and local.min > peer.min",
   239  			versions(2, 1, 2, 1),
   240  			versions(1, 2, 3, 2),
   241  			true,
   242  			version(2, 1),
   243  		},
   244  		{
   245  			"local.max = peer.max and local.min > peer.min",
   246  			versions(2, 1, 2, 1),
   247  			versions(1, 2, 2, 1),
   248  			true,
   249  			version(2, 1),
   250  		},
   251  		{
   252  			"local.max < peer.max and local.min < peer.min",
   253  			versions(1, 2, 2, 1),
   254  			versions(2, 1, 3, 2),
   255  			true,
   256  			version(2, 1),
   257  		},
   258  		{
   259  			"local.max < peer.max and local.min = peer.min",
   260  			versions(1, 2, 2, 1),
   261  			versions(1, 2, 3, 2),
   262  			true,
   263  			version(2, 1),
   264  		},
   265  		{
   266  			"local.max = peer.max and local.min < peer.min",
   267  			versions(1, 2, 2, 1),
   268  			versions(2, 1, 2, 1),
   269  			true,
   270  			version(2, 1),
   271  		},
   272  		{
   273  			"all equal",
   274  			versions(2, 1, 2, 1),
   275  			versions(2, 1, 2, 1),
   276  			true,
   277  			version(2, 1),
   278  		},
   279  		{
   280  			"max is smaller than min",
   281  			versions(2, 1, 1, 2),
   282  			versions(2, 1, 1, 2),
   283  			false,
   284  			nil,
   285  		},
   286  		{
   287  			"no overlap, local > peer",
   288  			versions(4, 3, 6, 5),
   289  			versions(1, 0, 2, 1),
   290  			false,
   291  			nil,
   292  		},
   293  		{
   294  			"no overlap, local < peer",
   295  			versions(1, 0, 2, 1),
   296  			versions(4, 3, 6, 5),
   297  			false,
   298  			nil,
   299  		},
   300  		{
   301  			"no overlap, max < min",
   302  			versions(6, 5, 4, 3),
   303  			versions(2, 1, 1, 0),
   304  			false,
   305  			nil,
   306  		},
   307  	} {
   308  		output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
   309  		if got, want := output, tc.output; got != want {
   310  			t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
   311  		}
   312  		if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
   313  			t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
   314  		}
   315  	}
   316  }
   317  
   318  // TestFullHandshake performs a full ALTS handshake between a test client and
   319  // server, where both client and server offload to a local, fake handshaker
   320  // service.
   321  func (s) TestFullHandshake(t *testing.T) {
   322  	// Start the fake handshaker service and the server.
   323  	var wait sync.WaitGroup
   324  	defer wait.Wait()
   325  	stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait)
   326  	defer stopHandshaker()
   327  	stopServer, serverAddress := startServer(t, handshakerAddress)
   328  	defer stopServer()
   329  
   330  	// Ping the server, authenticating with ALTS.
   331  	establishAltsConnection(t, handshakerAddress, serverAddress)
   332  
   333  	// Close open connections to the fake handshaker service.
   334  	if err := service.CloseForTesting(); err != nil {
   335  		t.Errorf("service.CloseForTesting() failed: %v", err)
   336  	}
   337  }
   338  
   339  // TestConcurrentHandshakes performs a several, concurrent ALTS handshakes
   340  // between a test client and server, where both client and server offload to a
   341  // local, fake handshaker service.
   342  func (s) TestConcurrentHandshakes(t *testing.T) {
   343  	// Set the max number of concurrent handshakes to 3, so that we can
   344  	// test the handshaker behavior when handshakes are queued by
   345  	// performing more than 3 concurrent handshakes (specifically, 10).
   346  	handshaker.ResetConcurrentHandshakeSemaphoreForTesting(3)
   347  
   348  	// Start the fake handshaker service and the server.
   349  	var wait sync.WaitGroup
   350  	defer wait.Wait()
   351  	stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait)
   352  	defer stopHandshaker()
   353  	stopServer, serverAddress := startServer(t, handshakerAddress)
   354  	defer stopServer()
   355  
   356  	// Ping the server, authenticating with ALTS.
   357  	var waitForConnections sync.WaitGroup
   358  	for i := 0; i < 10; i++ {
   359  		waitForConnections.Add(1)
   360  		go func() {
   361  			establishAltsConnection(t, handshakerAddress, serverAddress)
   362  			waitForConnections.Done()
   363  		}()
   364  	}
   365  	waitForConnections.Wait()
   366  
   367  	// Close open connections to the fake handshaker service.
   368  	if err := service.CloseForTesting(); err != nil {
   369  		t.Errorf("service.CloseForTesting() failed: %v", err)
   370  	}
   371  }
   372  
   373  func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
   374  	return &altspb.RpcProtocolVersions_Version{
   375  		Major: major,
   376  		Minor: minor,
   377  	}
   378  }
   379  
   380  func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
   381  	return &altspb.RpcProtocolVersions{
   382  		MinRpcVersion: version(minMajor, minMinor),
   383  		MaxRpcVersion: version(maxMajor, maxMinor),
   384  	}
   385  }
   386  
   387  func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) {
   388  	clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
   389  	conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds))
   390  	if err != nil {
   391  		t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err)
   392  	}
   393  	defer conn.Close()
   394  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
   395  	defer cancel()
   396  	c := testgrpc.NewTestServiceClient(conn)
   397  	var peer peer.Peer
   398  	success := false
   399  	for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
   400  		_, err = c.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.Peer(&peer))
   401  		if err == nil {
   402  			success = true
   403  			break
   404  		}
   405  		if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded {
   406  			// The server is not ready yet or there were too many concurrent handshakes.
   407  			// Try again.
   408  			continue
   409  		}
   410  		t.Fatalf("c.UnaryCall() failed: %v", err)
   411  	}
   412  	if !success {
   413  		t.Fatalf("c.UnaryCall() timed out after %v", defaultTestShortTimeout)
   414  	}
   415  
   416  	// Check that peer.AuthInfo was populated with an ALTS AuthInfo
   417  	// instance. As a sanity check, also verify that the AuthType() and
   418  	// ApplicationProtocol() have the expected values.
   419  	if got, want := peer.AuthInfo.AuthType(), "alts"; got != want {
   420  		t.Errorf("authInfo.AuthType() = %s, want = %s", got, want)
   421  	}
   422  	authInfo, err := AuthInfoFromPeer(&peer)
   423  	if err != nil {
   424  		t.Errorf("AuthInfoFromPeer failed: %v", err)
   425  	}
   426  	if got, want := authInfo.ApplicationProtocol(), "grpc"; got != want {
   427  		t.Errorf("authInfo.ApplicationProtocol() = %s, want = %s", got, want)
   428  	}
   429  }
   430  
   431  func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
   432  	listener, err := testutils.LocalTCPListener()
   433  	if err != nil {
   434  		t.Fatalf("LocalTCPListener() failed: %v", err)
   435  	}
   436  	s := grpc.NewServer()
   437  	altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{})
   438  	wait.Add(1)
   439  	go func() {
   440  		defer wait.Done()
   441  		if err := s.Serve(listener); err != nil {
   442  			t.Errorf("failed to serve: %v", err)
   443  		}
   444  	}()
   445  	return func() { s.Stop() }, listener.Addr().String()
   446  }
   447  
   448  func startServer(t *testing.T, handshakerServiceAddress string) (stop func(), address string) {
   449  	listener, err := testutils.LocalTCPListener()
   450  	if err != nil {
   451  		t.Fatalf("LocalTCPListener() failed: %v", err)
   452  	}
   453  	serverOpts := &ServerOptions{HandshakerServiceAddress: handshakerServiceAddress}
   454  	creds := NewServerCreds(serverOpts)
   455  	stub := &stubserver.StubServer{
   456  		Listener: listener,
   457  		UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   458  			return &testpb.SimpleResponse{
   459  				Payload: &testpb.Payload{},
   460  			}, nil
   461  		},
   462  		S: grpc.NewServer(grpc.Creds(creds)),
   463  	}
   464  	stubserver.StartTestService(t, stub)
   465  	return func() { stub.S.Stop() }, listener.Addr().String()
   466  }