google.golang.org/grpc@v1.62.1/credentials/alts/internal/handshaker/handshaker_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package handshaker 20 21 import ( 22 "bytes" 23 "context" 24 "errors" 25 "fmt" 26 "testing" 27 "time" 28 29 "github.com/google/go-cmp/cmp" 30 "github.com/google/go-cmp/cmp/cmpopts" 31 grpc "google.golang.org/grpc" 32 core "google.golang.org/grpc/credentials/alts/internal" 33 altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" 34 "google.golang.org/grpc/credentials/alts/internal/testutil" 35 "google.golang.org/grpc/internal/envconfig" 36 "google.golang.org/grpc/internal/grpctest" 37 ) 38 39 type s struct { 40 grpctest.Tester 41 } 42 43 func Test(t *testing.T) { 44 grpctest.RunSubTests(t, s{}) 45 } 46 47 var ( 48 testRecordProtocol = rekeyRecordProtocolName 49 testKey = []byte{ 50 // 44 arbitrary bytes. 51 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 52 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b, 53 0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 54 } 55 testServiceAccount = "test_service_account" 56 testTargetServiceAccounts = []string{testServiceAccount} 57 testClientIdentity = &altspb.Identity{ 58 IdentityOneof: &altspb.Identity_Hostname{ 59 Hostname: "i_am_a_client", 60 }, 61 } 62 ) 63 64 const defaultTestTimeout = 10 * time.Second 65 66 // testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object. 67 type testRPCStream struct { 68 grpc.ClientStream 69 t *testing.T 70 isClient bool 71 // The resp expected to be returned by Recv(). Make sure this is set to 72 // the content the test requires before Recv() is invoked. 73 recvBuf *altspb.HandshakerResp 74 // false if it is the first access to Handshaker service on Envelope. 75 first bool 76 // useful for testing concurrent calls. 77 delay time.Duration 78 // The minimum expected value of the network_latency_ms field in a 79 // NextHandshakeMessageReq. 80 minExpectedNetworkLatency time.Duration 81 } 82 83 func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) { 84 resp := t.recvBuf 85 t.recvBuf = nil 86 return resp, nil 87 } 88 89 func (t *testRPCStream) Send(req *altspb.HandshakerReq) error { 90 var resp *altspb.HandshakerResp 91 if !t.first { 92 // Generate the bytes to be returned by Recv() for the initial 93 // handshaking. 94 t.first = true 95 if t.isClient { 96 resp = &altspb.HandshakerResp{ 97 OutFrames: testutil.MakeFrame("ClientInit"), 98 // Simulate consuming ServerInit. 99 BytesConsumed: 14, 100 } 101 } else { 102 resp = &altspb.HandshakerResp{ 103 OutFrames: testutil.MakeFrame("ServerInit"), 104 // Simulate consuming ClientInit. 105 BytesConsumed: 14, 106 } 107 } 108 } else { 109 switch req := req.ReqOneof.(type) { 110 case *altspb.HandshakerReq_Next: 111 // Compare the network_latency_ms field to the minimum expected network 112 // latency. 113 if nl := time.Duration(req.Next.NetworkLatencyMs) * time.Millisecond; nl < t.minExpectedNetworkLatency { 114 return fmt.Errorf("networkLatency (%v) is smaller than expected min network latency (%v)", nl, t.minExpectedNetworkLatency) 115 } 116 default: 117 return fmt.Errorf("handshake request has unexpected type: %v", req) 118 } 119 120 // Add delay to test concurrent calls. 121 cleanup := stat.Update() 122 defer cleanup() 123 time.Sleep(t.delay) 124 125 // Generate the response to be returned by Recv() for the 126 // follow-up handshaking. 127 result := &altspb.HandshakerResult{ 128 RecordProtocol: testRecordProtocol, 129 KeyData: testKey, 130 } 131 resp = &altspb.HandshakerResp{ 132 Result: result, 133 // Simulate consuming ClientFinished or ServerFinished. 134 BytesConsumed: 18, 135 } 136 } 137 t.recvBuf = resp 138 return nil 139 } 140 141 func (t *testRPCStream) CloseSend() error { 142 return nil 143 } 144 145 var stat testutil.Stats 146 147 func (s) TestClientHandshake(t *testing.T) { 148 for _, testCase := range []struct { 149 delay time.Duration 150 numberOfHandshakes int 151 readLatency time.Duration 152 }{ 153 {0 * time.Millisecond, 1, time.Duration(0)}, 154 {0 * time.Millisecond, 1, 2 * time.Millisecond}, 155 {100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes), time.Duration(0)}, 156 } { 157 errc := make(chan error) 158 stat.Reset() 159 160 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 161 defer cancel() 162 163 for i := 0; i < testCase.numberOfHandshakes; i++ { 164 stream := &testRPCStream{ 165 t: t, 166 isClient: true, 167 minExpectedNetworkLatency: testCase.readLatency, 168 } 169 // Preload the inbound frames. 170 f1 := testutil.MakeFrame("ServerInit") 171 f2 := testutil.MakeFrame("ServerFinished") 172 in := bytes.NewBuffer(f1) 173 in.Write(f2) 174 out := new(bytes.Buffer) 175 tc := testutil.NewTestConnWithReadLatency(in, out, testCase.readLatency) 176 chs := &altsHandshaker{ 177 stream: stream, 178 conn: tc, 179 clientOpts: &ClientHandshakerOptions{ 180 TargetServiceAccounts: testTargetServiceAccounts, 181 ClientIdentity: testClientIdentity, 182 }, 183 side: core.ClientSide, 184 } 185 go func() { 186 _, context, err := chs.ClientHandshake(ctx) 187 if err == nil && context == nil { 188 errc <- errors.New("expected non-nil ALTS context") 189 return 190 } 191 errc <- err 192 chs.Close() 193 }() 194 } 195 196 // Ensure that there are no errors. 197 for i := 0; i < testCase.numberOfHandshakes; i++ { 198 if err := <-errc; err != nil { 199 t.Errorf("ClientHandshake() = _, %v, want _, <nil>", err) 200 } 201 } 202 203 // Ensure that there are no concurrent calls more than the limit. 204 if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) { 205 t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes) 206 } 207 } 208 } 209 210 func (s) TestServerHandshake(t *testing.T) { 211 for _, testCase := range []struct { 212 delay time.Duration 213 numberOfHandshakes int 214 }{ 215 {0 * time.Millisecond, 1}, 216 {100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)}, 217 } { 218 errc := make(chan error) 219 stat.Reset() 220 221 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 222 defer cancel() 223 224 for i := 0; i < testCase.numberOfHandshakes; i++ { 225 stream := &testRPCStream{ 226 t: t, 227 isClient: false, 228 } 229 // Preload the inbound frames. 230 f1 := testutil.MakeFrame("ClientInit") 231 f2 := testutil.MakeFrame("ClientFinished") 232 in := bytes.NewBuffer(f1) 233 in.Write(f2) 234 out := new(bytes.Buffer) 235 tc := testutil.NewTestConn(in, out) 236 shs := &altsHandshaker{ 237 stream: stream, 238 conn: tc, 239 serverOpts: DefaultServerHandshakerOptions(), 240 side: core.ServerSide, 241 } 242 go func() { 243 _, context, err := shs.ServerHandshake(ctx) 244 if err == nil && context == nil { 245 errc <- errors.New("expected non-nil ALTS context") 246 return 247 } 248 errc <- err 249 shs.Close() 250 }() 251 } 252 253 // Ensure that there are no errors. 254 for i := 0; i < testCase.numberOfHandshakes; i++ { 255 if err := <-errc; err != nil { 256 t.Errorf("ServerHandshake() = _, %v, want _, <nil>", err) 257 } 258 } 259 260 // Ensure that there are no concurrent calls more than the limit. 261 if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) { 262 t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes) 263 } 264 } 265 } 266 267 // testUnresponsiveRPCStream is used for testing the PeerNotResponding case. 268 type testUnresponsiveRPCStream struct { 269 grpc.ClientStream 270 } 271 272 func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) { 273 return &altspb.HandshakerResp{}, nil 274 } 275 276 func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error { 277 return nil 278 } 279 280 func (t *testUnresponsiveRPCStream) CloseSend() error { 281 return nil 282 } 283 284 func (s) TestPeerNotResponding(t *testing.T) { 285 stream := &testUnresponsiveRPCStream{} 286 chs := &altsHandshaker{ 287 stream: stream, 288 conn: testutil.NewUnresponsiveTestConn(), 289 clientOpts: &ClientHandshakerOptions{ 290 TargetServiceAccounts: testTargetServiceAccounts, 291 ClientIdentity: testClientIdentity, 292 }, 293 side: core.ClientSide, 294 } 295 296 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 297 defer cancel() 298 _, context, err := chs.ClientHandshake(ctx) 299 chs.Close() 300 if context != nil { 301 t.Error("expected non-nil ALTS context") 302 } 303 if got, want := err, core.PeerNotRespondingError; got != want { 304 t.Errorf("ClientHandshake() = %v, want %v", got, want) 305 } 306 } 307 308 func (s) TestNewClientHandshaker(t *testing.T) { 309 conn := testutil.NewTestConn(nil, nil) 310 clientConn := &grpc.ClientConn{} 311 opts := &ClientHandshakerOptions{} 312 hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts) 313 if err != nil { 314 t.Errorf("NewClientHandshaker returned unexpected error: %v", err) 315 } 316 expectedHs := &altsHandshaker{ 317 stream: nil, 318 conn: conn, 319 clientConn: clientConn, 320 clientOpts: opts, 321 serverOpts: nil, 322 side: core.ClientSide, 323 } 324 cmpOpts := []cmp.Option{ 325 cmp.AllowUnexported(altsHandshaker{}), 326 cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"), 327 } 328 if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) { 329 t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want) 330 } 331 if hs.(*altsHandshaker).stream != nil { 332 t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream") 333 } 334 if hs.(*altsHandshaker).clientConn != clientConn { 335 t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn") 336 } 337 hs.Close() 338 } 339 340 func (s) TestNewServerHandshaker(t *testing.T) { 341 conn := testutil.NewTestConn(nil, nil) 342 clientConn := &grpc.ClientConn{} 343 opts := &ServerHandshakerOptions{} 344 hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts) 345 if err != nil { 346 t.Errorf("NewServerHandshaker returned unexpected error: %v", err) 347 } 348 expectedHs := &altsHandshaker{ 349 stream: nil, 350 conn: conn, 351 clientConn: clientConn, 352 clientOpts: nil, 353 serverOpts: opts, 354 side: core.ServerSide, 355 } 356 cmpOpts := []cmp.Option{ 357 cmp.AllowUnexported(altsHandshaker{}), 358 cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"), 359 } 360 if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) { 361 t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want) 362 } 363 if hs.(*altsHandshaker).stream != nil { 364 t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream") 365 } 366 if hs.(*altsHandshaker).clientConn != clientConn { 367 t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn") 368 } 369 hs.Close() 370 }