google.golang.org/grpc@v1.72.2/credentials/alts/internal/handshaker/handshaker_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package handshaker 20 21 import ( 22 "bytes" 23 "context" 24 "errors" 25 "fmt" 26 "testing" 27 "time" 28 29 "github.com/google/go-cmp/cmp" 30 "github.com/google/go-cmp/cmp/cmpopts" 31 grpc "google.golang.org/grpc" 32 core "google.golang.org/grpc/credentials/alts/internal" 33 altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" 34 "google.golang.org/grpc/credentials/alts/internal/testutil" 35 "google.golang.org/grpc/internal/envconfig" 36 "google.golang.org/grpc/internal/grpctest" 37 ) 38 39 type s struct { 40 grpctest.Tester 41 } 42 43 func Test(t *testing.T) { 44 grpctest.RunSubTests(t, s{}) 45 } 46 47 var ( 48 testRecordProtocol = rekeyRecordProtocolName 49 testKey = []byte{ 50 // 44 arbitrary bytes. 51 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 52 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b, 53 0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 54 } 55 testServiceAccount = "test_service_account" 56 testTargetServiceAccounts = []string{testServiceAccount} 57 testClientIdentity = &altspb.Identity{ 58 IdentityOneof: &altspb.Identity_Hostname{ 59 Hostname: "i_am_a_client", 60 }, 61 } 62 ) 63 64 const defaultTestTimeout = 10 * time.Second 65 66 // testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object. 67 type testRPCStream struct { 68 grpc.ClientStream 69 t *testing.T 70 isClient bool 71 // The resp expected to be returned by Recv(). Make sure this is set to 72 // the content the test requires before Recv() is invoked. 73 recvBuf *altspb.HandshakerResp 74 // false if it is the first access to Handshaker service on Envelope. 75 first bool 76 // useful for testing concurrent calls. 77 delay time.Duration 78 // The minimum expected value of the network_latency_ms field in a 79 // NextHandshakeMessageReq. 80 minExpectedNetworkLatency time.Duration 81 } 82 83 func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) { 84 resp := t.recvBuf 85 t.recvBuf = nil 86 return resp, nil 87 } 88 89 func (t *testRPCStream) Send(req *altspb.HandshakerReq) error { 90 var resp *altspb.HandshakerResp 91 if !t.first { 92 // Generate the bytes to be returned by Recv() for the initial 93 // handshaking. 94 t.first = true 95 if t.isClient { 96 resp = &altspb.HandshakerResp{ 97 OutFrames: testutil.MakeFrame("ClientInit"), 98 // Simulate consuming ServerInit. 99 BytesConsumed: 14, 100 } 101 } else { 102 resp = &altspb.HandshakerResp{ 103 OutFrames: testutil.MakeFrame("ServerInit"), 104 // Simulate consuming ClientInit. 105 BytesConsumed: 14, 106 } 107 } 108 } else { 109 switch req := req.ReqOneof.(type) { 110 case *altspb.HandshakerReq_Next: 111 // Compare the network_latency_ms field to the minimum expected network 112 // latency. 113 if nl := time.Duration(req.Next.NetworkLatencyMs) * time.Millisecond; nl < t.minExpectedNetworkLatency { 114 return fmt.Errorf("networkLatency (%v) is smaller than expected min network latency (%v)", nl, t.minExpectedNetworkLatency) 115 } 116 default: 117 return fmt.Errorf("handshake request has unexpected type: %v", req) 118 } 119 120 // Add delay to test concurrent calls. 121 cleanup := stat.Update() 122 defer cleanup() 123 time.Sleep(t.delay) 124 125 // Generate the response to be returned by Recv() for the 126 // follow-up handshaking. 127 result := &altspb.HandshakerResult{ 128 RecordProtocol: testRecordProtocol, 129 KeyData: testKey, 130 } 131 resp = &altspb.HandshakerResp{ 132 Result: result, 133 // Simulate consuming ClientFinished or ServerFinished. 134 BytesConsumed: 18, 135 } 136 } 137 t.recvBuf = resp 138 return nil 139 } 140 141 func (t *testRPCStream) CloseSend() error { 142 return nil 143 } 144 145 var stat testutil.Stats 146 147 func (s) TestClientHandshake(t *testing.T) { 148 for _, testCase := range []struct { 149 delay time.Duration 150 numberOfHandshakes int 151 readLatency time.Duration 152 }{ 153 {0 * time.Millisecond, 1, time.Duration(0)}, 154 {0 * time.Millisecond, 1, 2 * time.Millisecond}, 155 {100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes), time.Duration(0)}, 156 } { 157 errc := make(chan error) 158 stat.Reset() 159 160 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 161 defer cancel() 162 163 for i := 0; i < testCase.numberOfHandshakes; i++ { 164 stream := &testRPCStream{ 165 t: t, 166 isClient: true, 167 minExpectedNetworkLatency: testCase.readLatency, 168 } 169 // Preload the inbound frames. 170 f1 := testutil.MakeFrame("ServerInit") 171 f2 := testutil.MakeFrame("ServerFinished") 172 in := bytes.NewBuffer(f1) 173 in.Write(f2) 174 out := new(bytes.Buffer) 175 tc := testutil.NewTestConnWithReadLatency(in, out, testCase.readLatency) 176 chs := &altsHandshaker{ 177 stream: stream, 178 conn: tc, 179 clientOpts: &ClientHandshakerOptions{ 180 TargetServiceAccounts: testTargetServiceAccounts, 181 ClientIdentity: testClientIdentity, 182 }, 183 side: core.ClientSide, 184 } 185 go func() { 186 _, context, err := chs.ClientHandshake(ctx) 187 if err == nil && context == nil { 188 errc <- errors.New("expected non-nil ALTS context") 189 return 190 } 191 errc <- err 192 chs.Close() 193 }() 194 } 195 196 // Ensure that there are no errors. 197 for i := 0; i < testCase.numberOfHandshakes; i++ { 198 if err := <-errc; err != nil { 199 t.Errorf("ClientHandshake() = _, %v, want _, <nil>", err) 200 } 201 } 202 203 // Ensure that there are no concurrent calls more than the limit. 204 if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) { 205 t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes) 206 } 207 } 208 } 209 210 func (s) TestServerHandshake(t *testing.T) { 211 for _, testCase := range []struct { 212 delay time.Duration 213 numberOfHandshakes int 214 }{ 215 {0 * time.Millisecond, 1}, 216 {100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)}, 217 } { 218 errc := make(chan error) 219 stat.Reset() 220 221 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 222 defer cancel() 223 224 for i := 0; i < testCase.numberOfHandshakes; i++ { 225 stream := &testRPCStream{ 226 t: t, 227 isClient: false, 228 } 229 // Preload the inbound frames. 230 f1 := testutil.MakeFrame("ClientInit") 231 f2 := testutil.MakeFrame("ClientFinished") 232 in := bytes.NewBuffer(f1) 233 in.Write(f2) 234 out := new(bytes.Buffer) 235 tc := testutil.NewTestConn(in, out) 236 shs := &altsHandshaker{ 237 stream: stream, 238 conn: tc, 239 serverOpts: DefaultServerHandshakerOptions(), 240 side: core.ServerSide, 241 } 242 go func() { 243 _, context, err := shs.ServerHandshake(ctx) 244 if err == nil && context == nil { 245 errc <- errors.New("expected non-nil ALTS context") 246 return 247 } 248 errc <- err 249 shs.Close() 250 }() 251 } 252 253 // Ensure that there are no errors. 254 for i := 0; i < testCase.numberOfHandshakes; i++ { 255 if err := <-errc; err != nil { 256 t.Errorf("ServerHandshake() = _, %v, want _, <nil>", err) 257 } 258 } 259 260 // Ensure that there are no concurrent calls more than the limit. 261 if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) { 262 t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes) 263 } 264 } 265 } 266 267 // testUnresponsiveRPCStream is used for testing the PeerNotResponding case. 268 type testUnresponsiveRPCStream struct { 269 grpc.ClientStream 270 } 271 272 func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) { 273 return &altspb.HandshakerResp{}, nil 274 } 275 276 func (t *testUnresponsiveRPCStream) Send(*altspb.HandshakerReq) error { 277 return nil 278 } 279 280 func (t *testUnresponsiveRPCStream) CloseSend() error { 281 return nil 282 } 283 284 func (s) TestPeerNotResponding(t *testing.T) { 285 stream := &testUnresponsiveRPCStream{} 286 chs := &altsHandshaker{ 287 stream: stream, 288 conn: testutil.NewUnresponsiveTestConn(), 289 clientOpts: &ClientHandshakerOptions{ 290 TargetServiceAccounts: testTargetServiceAccounts, 291 ClientIdentity: testClientIdentity, 292 }, 293 side: core.ClientSide, 294 } 295 296 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 297 defer cancel() 298 _, context, err := chs.ClientHandshake(ctx) 299 chs.Close() 300 if context != nil { 301 t.Error("expected non-nil ALTS context") 302 } 303 if got, want := err, core.PeerNotRespondingError; got != want { 304 t.Errorf("ClientHandshake() = %v, want %v", got, want) 305 } 306 } 307 308 func (s) TestNewClientHandshaker(t *testing.T) { 309 conn := testutil.NewTestConn(nil, nil) 310 clientConn := &grpc.ClientConn{} 311 opts := &ClientHandshakerOptions{} 312 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 313 defer cancel() 314 hs, err := NewClientHandshaker(ctx, clientConn, conn, opts) 315 if err != nil { 316 t.Errorf("NewClientHandshaker returned unexpected error: %v", err) 317 } 318 expectedHs := &altsHandshaker{ 319 stream: nil, 320 conn: conn, 321 clientConn: clientConn, 322 clientOpts: opts, 323 serverOpts: nil, 324 side: core.ClientSide, 325 } 326 cmpOpts := []cmp.Option{ 327 cmp.AllowUnexported(altsHandshaker{}), 328 cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"), 329 } 330 if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) { 331 t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want) 332 } 333 if hs.(*altsHandshaker).stream != nil { 334 t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream") 335 } 336 if hs.(*altsHandshaker).clientConn != clientConn { 337 t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn") 338 } 339 hs.Close() 340 } 341 342 func (s) TestNewServerHandshaker(t *testing.T) { 343 conn := testutil.NewTestConn(nil, nil) 344 clientConn := &grpc.ClientConn{} 345 opts := &ServerHandshakerOptions{} 346 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 347 defer cancel() 348 hs, err := NewServerHandshaker(ctx, clientConn, conn, opts) 349 if err != nil { 350 t.Errorf("NewServerHandshaker returned unexpected error: %v", err) 351 } 352 expectedHs := &altsHandshaker{ 353 stream: nil, 354 conn: conn, 355 clientConn: clientConn, 356 clientOpts: nil, 357 serverOpts: opts, 358 side: core.ServerSide, 359 } 360 cmpOpts := []cmp.Option{ 361 cmp.AllowUnexported(altsHandshaker{}), 362 cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"), 363 } 364 if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) { 365 t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want) 366 } 367 if hs.(*altsHandshaker).stream != nil { 368 t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream") 369 } 370 if hs.(*altsHandshaker).clientConn != clientConn { 371 t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn") 372 } 373 hs.Close() 374 }