gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/grpc/credentials/alts/internal/handshaker/handshaker_test.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package handshaker
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"errors"
    25  	"testing"
    26  	"time"
    27  
    28  	grpc "gitee.com/zhaochuninhefei/gmgo/grpc"
    29  	core "gitee.com/zhaochuninhefei/gmgo/grpc/credentials/alts/internal"
    30  	altspb "gitee.com/zhaochuninhefei/gmgo/grpc/credentials/alts/internal/proto/grpc_gcp"
    31  	"gitee.com/zhaochuninhefei/gmgo/grpc/credentials/alts/internal/testutil"
    32  	"gitee.com/zhaochuninhefei/gmgo/grpc/internal/grpctest"
    33  )
    34  
    35  type s struct {
    36  	grpctest.Tester
    37  }
    38  
    39  func Test(t *testing.T) {
    40  	grpctest.RunSubTests(t, s{})
    41  }
    42  
    43  var (
    44  	testRecordProtocol = rekeyRecordProtocolName
    45  	testKey            = []byte{
    46  		// 44 arbitrary bytes.
    47  		0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49,
    48  		0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b,
    49  		0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2,
    50  	}
    51  	testServiceAccount        = "test_service_account"
    52  	testTargetServiceAccounts = []string{testServiceAccount}
    53  	testClientIdentity        = &altspb.Identity{
    54  		IdentityOneof: &altspb.Identity_Hostname{
    55  			Hostname: "i_am_a_client",
    56  		},
    57  	}
    58  )
    59  
    60  const defaultTestTimeout = 10 * time.Second
    61  
    62  // testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object.
    63  type testRPCStream struct {
    64  	grpc.ClientStream
    65  	t        *testing.T
    66  	isClient bool
    67  	// The resp expected to be returned by Recv(). Make sure this is set to
    68  	// the content the test requires before Recv() is invoked.
    69  	recvBuf *altspb.HandshakerResp
    70  	// false if it is the first access to Handshaker service on Envelope.
    71  	first bool
    72  	// useful for testing concurrent calls.
    73  	delay time.Duration
    74  }
    75  
    76  func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) {
    77  	resp := t.recvBuf
    78  	t.recvBuf = nil
    79  	return resp, nil
    80  }
    81  
    82  func (t *testRPCStream) Send(req *altspb.HandshakerReq) error {
    83  	var resp *altspb.HandshakerResp
    84  	if !t.first {
    85  		// Generate the bytes to be returned by Recv() for the initial
    86  		// handshaking.
    87  		t.first = true
    88  		if t.isClient {
    89  			resp = &altspb.HandshakerResp{
    90  				OutFrames: testutil.MakeFrame("ClientInit"),
    91  				// Simulate consuming ServerInit.
    92  				BytesConsumed: 14,
    93  			}
    94  		} else {
    95  			resp = &altspb.HandshakerResp{
    96  				OutFrames: testutil.MakeFrame("ServerInit"),
    97  				// Simulate consuming ClientInit.
    98  				BytesConsumed: 14,
    99  			}
   100  		}
   101  	} else {
   102  		// Add delay to test concurrent calls.
   103  		cleanup := stat.Update()
   104  		defer cleanup()
   105  		time.Sleep(t.delay)
   106  
   107  		// Generate the response to be returned by Recv() for the
   108  		// follow-up handshaking.
   109  		result := &altspb.HandshakerResult{
   110  			RecordProtocol: testRecordProtocol,
   111  			KeyData:        testKey,
   112  		}
   113  		resp = &altspb.HandshakerResp{
   114  			Result: result,
   115  			// Simulate consuming ClientFinished or ServerFinished.
   116  			BytesConsumed: 18,
   117  		}
   118  	}
   119  	t.recvBuf = resp
   120  	return nil
   121  }
   122  
   123  func (t *testRPCStream) CloseSend() error {
   124  	return nil
   125  }
   126  
   127  var stat testutil.Stats
   128  
   129  func (s) TestClientHandshake(t *testing.T) {
   130  	for _, testCase := range []struct {
   131  		delay              time.Duration
   132  		numberOfHandshakes int
   133  	}{
   134  		{0 * time.Millisecond, 1},
   135  		{100 * time.Millisecond, 10 * maxPendingHandshakes},
   136  	} {
   137  		errc := make(chan error)
   138  		stat.Reset()
   139  
   140  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   141  		defer cancel()
   142  
   143  		for i := 0; i < testCase.numberOfHandshakes; i++ {
   144  			stream := &testRPCStream{
   145  				t:        t,
   146  				isClient: true,
   147  			}
   148  			// Preload the inbound frames.
   149  			f1 := testutil.MakeFrame("ServerInit")
   150  			f2 := testutil.MakeFrame("ServerFinished")
   151  			in := bytes.NewBuffer(f1)
   152  			in.Write(f2)
   153  			out := new(bytes.Buffer)
   154  			tc := testutil.NewTestConn(in, out)
   155  			chs := &altsHandshaker{
   156  				stream: stream,
   157  				conn:   tc,
   158  				clientOpts: &ClientHandshakerOptions{
   159  					TargetServiceAccounts: testTargetServiceAccounts,
   160  					ClientIdentity:        testClientIdentity,
   161  				},
   162  				side: core.ClientSide,
   163  			}
   164  			go func() {
   165  				_, context, err := chs.ClientHandshake(ctx)
   166  				if err == nil && context == nil {
   167  					errc <- errors.New("expected non-nil ALTS context")
   168  					return
   169  				}
   170  				errc <- err
   171  				chs.Close()
   172  			}()
   173  		}
   174  
   175  		// Ensure all errors are expected.
   176  		for i := 0; i < testCase.numberOfHandshakes; i++ {
   177  			if err := <-errc; err != nil && err != errDropped {
   178  				t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
   179  			}
   180  		}
   181  
   182  		// Ensure that there are no concurrent calls more than the limit.
   183  		if stat.MaxConcurrentCalls > maxPendingHandshakes {
   184  			t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
   185  		}
   186  	}
   187  }
   188  
   189  func (s) TestServerHandshake(t *testing.T) {
   190  	for _, testCase := range []struct {
   191  		delay              time.Duration
   192  		numberOfHandshakes int
   193  	}{
   194  		{0 * time.Millisecond, 1},
   195  		{100 * time.Millisecond, 10 * maxPendingHandshakes},
   196  	} {
   197  		errc := make(chan error)
   198  		stat.Reset()
   199  
   200  		ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   201  		defer cancel()
   202  
   203  		for i := 0; i < testCase.numberOfHandshakes; i++ {
   204  			stream := &testRPCStream{
   205  				t:        t,
   206  				isClient: false,
   207  			}
   208  			// Preload the inbound frames.
   209  			f1 := testutil.MakeFrame("ClientInit")
   210  			f2 := testutil.MakeFrame("ClientFinished")
   211  			in := bytes.NewBuffer(f1)
   212  			in.Write(f2)
   213  			out := new(bytes.Buffer)
   214  			tc := testutil.NewTestConn(in, out)
   215  			shs := &altsHandshaker{
   216  				stream:     stream,
   217  				conn:       tc,
   218  				serverOpts: DefaultServerHandshakerOptions(),
   219  				side:       core.ServerSide,
   220  			}
   221  			go func() {
   222  				_, context, err := shs.ServerHandshake(ctx)
   223  				if err == nil && context == nil {
   224  					errc <- errors.New("expected non-nil ALTS context")
   225  					return
   226  				}
   227  				errc <- err
   228  				shs.Close()
   229  			}()
   230  		}
   231  
   232  		// Ensure all errors are expected.
   233  		for i := 0; i < testCase.numberOfHandshakes; i++ {
   234  			if err := <-errc; err != nil && err != errDropped {
   235  				t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
   236  			}
   237  		}
   238  
   239  		// Ensure that there are no concurrent calls more than the limit.
   240  		if stat.MaxConcurrentCalls > maxPendingHandshakes {
   241  			t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
   242  		}
   243  	}
   244  }
   245  
   246  // testUnresponsiveRPCStream is used for testing the PeerNotResponding case.
   247  type testUnresponsiveRPCStream struct {
   248  	grpc.ClientStream
   249  }
   250  
   251  func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) {
   252  	return &altspb.HandshakerResp{}, nil
   253  }
   254  
   255  func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error {
   256  	return nil
   257  }
   258  
   259  func (t *testUnresponsiveRPCStream) CloseSend() error {
   260  	return nil
   261  }
   262  
   263  func (s) TestPeerNotResponding(t *testing.T) {
   264  	stream := &testUnresponsiveRPCStream{}
   265  	chs := &altsHandshaker{
   266  		stream: stream,
   267  		conn:   testutil.NewUnresponsiveTestConn(),
   268  		clientOpts: &ClientHandshakerOptions{
   269  			TargetServiceAccounts: testTargetServiceAccounts,
   270  			ClientIdentity:        testClientIdentity,
   271  		},
   272  		side: core.ClientSide,
   273  	}
   274  
   275  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   276  	defer cancel()
   277  	_, context, err := chs.ClientHandshake(ctx)
   278  	chs.Close()
   279  	if context != nil {
   280  		t.Error("expected non-nil ALTS context")
   281  	}
   282  	if got, want := err, core.PeerNotRespondingError; got != want {
   283  		t.Errorf("ClientHandshake() = %v, want %v", got, want)
   284  	}
   285  }