google.golang.org/grpc@v1.62.1/credentials/alts/internal/handshaker/handshaker_test.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package handshaker
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"errors"
    25  	"fmt"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/google/go-cmp/cmp"
    30  	"github.com/google/go-cmp/cmp/cmpopts"
    31  	grpc "google.golang.org/grpc"
    32  	core "google.golang.org/grpc/credentials/alts/internal"
    33  	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    34  	"google.golang.org/grpc/credentials/alts/internal/testutil"
    35  	"google.golang.org/grpc/internal/envconfig"
    36  	"google.golang.org/grpc/internal/grpctest"
    37  )
    38  
    39  type s struct {
    40  	grpctest.Tester
    41  }
    42  
    43  func Test(t *testing.T) {
    44  	grpctest.RunSubTests(t, s{})
    45  }
    46  
    47  var (
    48  	testRecordProtocol = rekeyRecordProtocolName
    49  	testKey            = []byte{
    50  		// 44 arbitrary bytes.
    51  		0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49,
    52  		0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b,
    53  		0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2,
    54  	}
    55  	testServiceAccount        = "test_service_account"
    56  	testTargetServiceAccounts = []string{testServiceAccount}
    57  	testClientIdentity        = &altspb.Identity{
    58  		IdentityOneof: &altspb.Identity_Hostname{
    59  			Hostname: "i_am_a_client",
    60  		},
    61  	}
    62  )
    63  
    64  const defaultTestTimeout = 10 * time.Second
    65  
    66  // testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object.
    67  type testRPCStream struct {
    68  	grpc.ClientStream
    69  	t        *testing.T
    70  	isClient bool
    71  	// The resp expected to be returned by Recv(). Make sure this is set to
    72  	// the content the test requires before Recv() is invoked.
    73  	recvBuf *altspb.HandshakerResp
    74  	// false if it is the first access to Handshaker service on Envelope.
    75  	first bool
    76  	// useful for testing concurrent calls.
    77  	delay time.Duration
    78  	// The minimum expected value of the network_latency_ms field in a
    79  	// NextHandshakeMessageReq.
    80  	minExpectedNetworkLatency time.Duration
    81  }
    82  
    83  func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) {
    84  	resp := t.recvBuf
    85  	t.recvBuf = nil
    86  	return resp, nil
    87  }
    88  
    89  func (t *testRPCStream) Send(req *altspb.HandshakerReq) error {
    90  	var resp *altspb.HandshakerResp
    91  	if !t.first {
    92  		// Generate the bytes to be returned by Recv() for the initial
    93  		// handshaking.
    94  		t.first = true
    95  		if t.isClient {
    96  			resp = &altspb.HandshakerResp{
    97  				OutFrames: testutil.MakeFrame("ClientInit"),
    98  				// Simulate consuming ServerInit.
    99  				BytesConsumed: 14,
   100  			}
   101  		} else {
   102  			resp = &altspb.HandshakerResp{
   103  				OutFrames: testutil.MakeFrame("ServerInit"),
   104  				// Simulate consuming ClientInit.
   105  				BytesConsumed: 14,
   106  			}
   107  		}
   108  	} else {
   109  		switch req := req.ReqOneof.(type) {
   110  		case *altspb.HandshakerReq_Next:
   111  			// Compare the network_latency_ms field to the minimum expected network
   112  			// latency.
   113  			if nl := time.Duration(req.Next.NetworkLatencyMs) * time.Millisecond; nl < t.minExpectedNetworkLatency {
   114  				return fmt.Errorf("networkLatency (%v) is smaller than expected min network latency (%v)", nl, t.minExpectedNetworkLatency)
   115  			}
   116  		default:
   117  			return fmt.Errorf("handshake request has unexpected type: %v", req)
   118  		}
   119  
   120  		// Add delay to test concurrent calls.
   121  		cleanup := stat.Update()
   122  		defer cleanup()
   123  		time.Sleep(t.delay)
   124  
   125  		// Generate the response to be returned by Recv() for the
   126  		// follow-up handshaking.
   127  		result := &altspb.HandshakerResult{
   128  			RecordProtocol: testRecordProtocol,
   129  			KeyData:        testKey,
   130  		}
   131  		resp = &altspb.HandshakerResp{
   132  			Result: result,
   133  			// Simulate consuming ClientFinished or ServerFinished.
   134  			BytesConsumed: 18,
   135  		}
   136  	}
   137  	t.recvBuf = resp
   138  	return nil
   139  }
   140  
   141  func (t *testRPCStream) CloseSend() error {
   142  	return nil
   143  }
   144  
   145  var stat testutil.Stats
   146  
   147  func (s) TestClientHandshake(t *testing.T) {
   148  	for _, testCase := range []struct {
   149  		delay              time.Duration
   150  		numberOfHandshakes int
   151  		readLatency        time.Duration
   152  	}{
   153  		{0 * time.Millisecond, 1, time.Duration(0)},
   154  		{0 * time.Millisecond, 1, 2 * time.Millisecond},
   155  		{100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes), time.Duration(0)},
   156  	} {
   157  		errc := make(chan error)
   158  		stat.Reset()
   159  
   160  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   161  		defer cancel()
   162  
   163  		for i := 0; i < testCase.numberOfHandshakes; i++ {
   164  			stream := &testRPCStream{
   165  				t:                         t,
   166  				isClient:                  true,
   167  				minExpectedNetworkLatency: testCase.readLatency,
   168  			}
   169  			// Preload the inbound frames.
   170  			f1 := testutil.MakeFrame("ServerInit")
   171  			f2 := testutil.MakeFrame("ServerFinished")
   172  			in := bytes.NewBuffer(f1)
   173  			in.Write(f2)
   174  			out := new(bytes.Buffer)
   175  			tc := testutil.NewTestConnWithReadLatency(in, out, testCase.readLatency)
   176  			chs := &altsHandshaker{
   177  				stream: stream,
   178  				conn:   tc,
   179  				clientOpts: &ClientHandshakerOptions{
   180  					TargetServiceAccounts: testTargetServiceAccounts,
   181  					ClientIdentity:        testClientIdentity,
   182  				},
   183  				side: core.ClientSide,
   184  			}
   185  			go func() {
   186  				_, context, err := chs.ClientHandshake(ctx)
   187  				if err == nil && context == nil {
   188  					errc <- errors.New("expected non-nil ALTS context")
   189  					return
   190  				}
   191  				errc <- err
   192  				chs.Close()
   193  			}()
   194  		}
   195  
   196  		// Ensure that there are no errors.
   197  		for i := 0; i < testCase.numberOfHandshakes; i++ {
   198  			if err := <-errc; err != nil {
   199  				t.Errorf("ClientHandshake() = _, %v, want _, <nil>", err)
   200  			}
   201  		}
   202  
   203  		// Ensure that there are no concurrent calls more than the limit.
   204  		if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) {
   205  			t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes)
   206  		}
   207  	}
   208  }
   209  
   210  func (s) TestServerHandshake(t *testing.T) {
   211  	for _, testCase := range []struct {
   212  		delay              time.Duration
   213  		numberOfHandshakes int
   214  	}{
   215  		{0 * time.Millisecond, 1},
   216  		{100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)},
   217  	} {
   218  		errc := make(chan error)
   219  		stat.Reset()
   220  
   221  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   222  		defer cancel()
   223  
   224  		for i := 0; i < testCase.numberOfHandshakes; i++ {
   225  			stream := &testRPCStream{
   226  				t:        t,
   227  				isClient: false,
   228  			}
   229  			// Preload the inbound frames.
   230  			f1 := testutil.MakeFrame("ClientInit")
   231  			f2 := testutil.MakeFrame("ClientFinished")
   232  			in := bytes.NewBuffer(f1)
   233  			in.Write(f2)
   234  			out := new(bytes.Buffer)
   235  			tc := testutil.NewTestConn(in, out)
   236  			shs := &altsHandshaker{
   237  				stream:     stream,
   238  				conn:       tc,
   239  				serverOpts: DefaultServerHandshakerOptions(),
   240  				side:       core.ServerSide,
   241  			}
   242  			go func() {
   243  				_, context, err := shs.ServerHandshake(ctx)
   244  				if err == nil && context == nil {
   245  					errc <- errors.New("expected non-nil ALTS context")
   246  					return
   247  				}
   248  				errc <- err
   249  				shs.Close()
   250  			}()
   251  		}
   252  
   253  		// Ensure that there are no errors.
   254  		for i := 0; i < testCase.numberOfHandshakes; i++ {
   255  			if err := <-errc; err != nil {
   256  				t.Errorf("ServerHandshake() = _, %v, want _, <nil>", err)
   257  			}
   258  		}
   259  
   260  		// Ensure that there are no concurrent calls more than the limit.
   261  		if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) {
   262  			t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes)
   263  		}
   264  	}
   265  }
   266  
   267  // testUnresponsiveRPCStream is used for testing the PeerNotResponding case.
   268  type testUnresponsiveRPCStream struct {
   269  	grpc.ClientStream
   270  }
   271  
   272  func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) {
   273  	return &altspb.HandshakerResp{}, nil
   274  }
   275  
   276  func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error {
   277  	return nil
   278  }
   279  
   280  func (t *testUnresponsiveRPCStream) CloseSend() error {
   281  	return nil
   282  }
   283  
   284  func (s) TestPeerNotResponding(t *testing.T) {
   285  	stream := &testUnresponsiveRPCStream{}
   286  	chs := &altsHandshaker{
   287  		stream: stream,
   288  		conn:   testutil.NewUnresponsiveTestConn(),
   289  		clientOpts: &ClientHandshakerOptions{
   290  			TargetServiceAccounts: testTargetServiceAccounts,
   291  			ClientIdentity:        testClientIdentity,
   292  		},
   293  		side: core.ClientSide,
   294  	}
   295  
   296  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   297  	defer cancel()
   298  	_, context, err := chs.ClientHandshake(ctx)
   299  	chs.Close()
   300  	if context != nil {
   301  		t.Error("expected non-nil ALTS context")
   302  	}
   303  	if got, want := err, core.PeerNotRespondingError; got != want {
   304  		t.Errorf("ClientHandshake() = %v, want %v", got, want)
   305  	}
   306  }
   307  
   308  func (s) TestNewClientHandshaker(t *testing.T) {
   309  	conn := testutil.NewTestConn(nil, nil)
   310  	clientConn := &grpc.ClientConn{}
   311  	opts := &ClientHandshakerOptions{}
   312  	hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
   313  	if err != nil {
   314  		t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
   315  	}
   316  	expectedHs := &altsHandshaker{
   317  		stream:     nil,
   318  		conn:       conn,
   319  		clientConn: clientConn,
   320  		clientOpts: opts,
   321  		serverOpts: nil,
   322  		side:       core.ClientSide,
   323  	}
   324  	cmpOpts := []cmp.Option{
   325  		cmp.AllowUnexported(altsHandshaker{}),
   326  		cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
   327  	}
   328  	if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
   329  		t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
   330  	}
   331  	if hs.(*altsHandshaker).stream != nil {
   332  		t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream")
   333  	}
   334  	if hs.(*altsHandshaker).clientConn != clientConn {
   335  		t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn")
   336  	}
   337  	hs.Close()
   338  }
   339  
   340  func (s) TestNewServerHandshaker(t *testing.T) {
   341  	conn := testutil.NewTestConn(nil, nil)
   342  	clientConn := &grpc.ClientConn{}
   343  	opts := &ServerHandshakerOptions{}
   344  	hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
   345  	if err != nil {
   346  		t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
   347  	}
   348  	expectedHs := &altsHandshaker{
   349  		stream:     nil,
   350  		conn:       conn,
   351  		clientConn: clientConn,
   352  		clientOpts: nil,
   353  		serverOpts: opts,
   354  		side:       core.ServerSide,
   355  	}
   356  	cmpOpts := []cmp.Option{
   357  		cmp.AllowUnexported(altsHandshaker{}),
   358  		cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
   359  	}
   360  	if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
   361  		t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
   362  	}
   363  	if hs.(*altsHandshaker).stream != nil {
   364  		t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream")
   365  	}
   366  	if hs.(*altsHandshaker).clientConn != clientConn {
   367  		t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn")
   368  	}
   369  	hs.Close()
   370  }