google.golang.org/grpc@v1.62.1/credentials/alts/alts_test.go (about) 1 //go:build linux || windows 2 // +build linux windows 3 4 /* 5 * 6 * Copyright 2018 gRPC authors. 7 * 8 * Licensed under the Apache License, Version 2.0 (the "License"); 9 * you may not use this file except in compliance with the License. 10 * You may obtain a copy of the License at 11 * 12 * http://www.apache.org/licenses/LICENSE-2.0 13 * 14 * Unless required by applicable law or agreed to in writing, software 15 * distributed under the License is distributed on an "AS IS" BASIS, 16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 * See the License for the specific language governing permissions and 18 * limitations under the License. 19 * 20 */ 21 22 package alts 23 24 import ( 25 "context" 26 "reflect" 27 "sync" 28 "testing" 29 "time" 30 31 "google.golang.org/grpc" 32 "google.golang.org/grpc/codes" 33 "google.golang.org/grpc/credentials/alts/internal/handshaker" 34 "google.golang.org/grpc/credentials/alts/internal/handshaker/service" 35 altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" 36 altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" 37 "google.golang.org/grpc/credentials/alts/internal/testutil" 38 "google.golang.org/grpc/internal/grpctest" 39 "google.golang.org/grpc/internal/testutils" 40 testgrpc "google.golang.org/grpc/interop/grpc_testing" 41 testpb "google.golang.org/grpc/interop/grpc_testing" 42 "google.golang.org/grpc/peer" 43 "google.golang.org/grpc/status" 44 "google.golang.org/protobuf/proto" 45 ) 46 47 const ( 48 defaultTestLongTimeout = 60 * time.Second 49 defaultTestShortTimeout = 10 * time.Millisecond 50 ) 51 52 type s struct { 53 grpctest.Tester 54 } 55 56 func init() { 57 // The vmOnGCP global variable MUST be forced to true. Otherwise, if 58 // this test is run anywhere except on a GCP VM, then an ALTS handshake 59 // will immediately fail. 60 once.Do(func() {}) 61 vmOnGCP = true 62 } 63 64 func Test(t *testing.T) { 65 grpctest.RunSubTests(t, s{}) 66 } 67 68 func (s) TestInfoServerName(t *testing.T) { 69 // This is not testing any handshaker functionality, so it's fine to only 70 // use NewServerCreds and not NewClientCreds. 71 alts := NewServerCreds(DefaultServerOptions()) 72 if got, want := alts.Info().ServerName, ""; got != want { 73 t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want) 74 } 75 } 76 77 func (s) TestOverrideServerName(t *testing.T) { 78 wantServerName := "server.name" 79 // This is not testing any handshaker functionality, so it's fine to only 80 // use NewServerCreds and not NewClientCreds. 81 c := NewServerCreds(DefaultServerOptions()) 82 c.OverrideServerName(wantServerName) 83 if got, want := c.Info().ServerName, wantServerName; got != want { 84 t.Fatalf("c.Info().ServerName = %v, want %v", got, want) 85 } 86 } 87 88 func (s) TestCloneClient(t *testing.T) { 89 wantServerName := "server.name" 90 opt := DefaultClientOptions() 91 opt.TargetServiceAccounts = []string{"not", "empty"} 92 c := NewClientCreds(opt) 93 c.OverrideServerName(wantServerName) 94 cc := c.Clone() 95 if got, want := cc.Info().ServerName, wantServerName; got != want { 96 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) 97 } 98 cc.OverrideServerName("") 99 if got, want := c.Info().ServerName, wantServerName; got != want { 100 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want) 101 } 102 if got, want := cc.Info().ServerName, ""; got != want { 103 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) 104 } 105 106 ct := c.(*altsTC) 107 cct := cc.(*altsTC) 108 109 if ct.side != cct.side { 110 t.Errorf("cc.side = %q, want %q", cct.side, ct.side) 111 } 112 if ct.hsAddress != cct.hsAddress { 113 t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress) 114 } 115 if !reflect.DeepEqual(ct.accounts, cct.accounts) { 116 t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts) 117 } 118 } 119 120 func (s) TestCloneServer(t *testing.T) { 121 wantServerName := "server.name" 122 c := NewServerCreds(DefaultServerOptions()) 123 c.OverrideServerName(wantServerName) 124 cc := c.Clone() 125 if got, want := cc.Info().ServerName, wantServerName; got != want { 126 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) 127 } 128 cc.OverrideServerName("") 129 if got, want := c.Info().ServerName, wantServerName; got != want { 130 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want) 131 } 132 if got, want := cc.Info().ServerName, ""; got != want { 133 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) 134 } 135 136 ct := c.(*altsTC) 137 cct := cc.(*altsTC) 138 139 if ct.side != cct.side { 140 t.Errorf("cc.side = %q, want %q", cct.side, ct.side) 141 } 142 if ct.hsAddress != cct.hsAddress { 143 t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress) 144 } 145 if !reflect.DeepEqual(ct.accounts, cct.accounts) { 146 t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts) 147 } 148 } 149 150 func (s) TestInfo(t *testing.T) { 151 // This is not testing any handshaker functionality, so it's fine to only 152 // use NewServerCreds and not NewClientCreds. 153 c := NewServerCreds(DefaultServerOptions()) 154 info := c.Info() 155 if got, want := info.ProtocolVersion, ""; got != want { 156 t.Errorf("info.ProtocolVersion=%v, want %v", got, want) 157 } 158 if got, want := info.SecurityProtocol, "alts"; got != want { 159 t.Errorf("info.SecurityProtocol=%v, want %v", got, want) 160 } 161 if got, want := info.SecurityVersion, "1.0"; got != want { 162 t.Errorf("info.SecurityVersion=%v, want %v", got, want) 163 } 164 if got, want := info.ServerName, ""; got != want { 165 t.Errorf("info.ServerName=%v, want %v", got, want) 166 } 167 } 168 169 func (s) TestCompareRPCVersions(t *testing.T) { 170 for _, tc := range []struct { 171 v1 *altspb.RpcProtocolVersions_Version 172 v2 *altspb.RpcProtocolVersions_Version 173 output int 174 }{ 175 { 176 version(3, 2), 177 version(2, 1), 178 1, 179 }, 180 { 181 version(3, 2), 182 version(3, 1), 183 1, 184 }, 185 { 186 version(2, 1), 187 version(3, 2), 188 -1, 189 }, 190 { 191 version(3, 1), 192 version(3, 2), 193 -1, 194 }, 195 { 196 version(3, 2), 197 version(3, 2), 198 0, 199 }, 200 } { 201 if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want { 202 t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want) 203 } 204 } 205 } 206 207 func (s) TestCheckRPCVersions(t *testing.T) { 208 for _, tc := range []struct { 209 desc string 210 local *altspb.RpcProtocolVersions 211 peer *altspb.RpcProtocolVersions 212 output bool 213 maxCommonVersion *altspb.RpcProtocolVersions_Version 214 }{ 215 { 216 "local.max > peer.max and local.min > peer.min", 217 versions(2, 1, 3, 2), 218 versions(1, 2, 2, 1), 219 true, 220 version(2, 1), 221 }, 222 { 223 "local.max > peer.max and local.min < peer.min", 224 versions(1, 2, 3, 2), 225 versions(2, 1, 2, 1), 226 true, 227 version(2, 1), 228 }, 229 { 230 "local.max > peer.max and local.min = peer.min", 231 versions(2, 1, 3, 2), 232 versions(2, 1, 2, 1), 233 true, 234 version(2, 1), 235 }, 236 { 237 "local.max < peer.max and local.min > peer.min", 238 versions(2, 1, 2, 1), 239 versions(1, 2, 3, 2), 240 true, 241 version(2, 1), 242 }, 243 { 244 "local.max = peer.max and local.min > peer.min", 245 versions(2, 1, 2, 1), 246 versions(1, 2, 2, 1), 247 true, 248 version(2, 1), 249 }, 250 { 251 "local.max < peer.max and local.min < peer.min", 252 versions(1, 2, 2, 1), 253 versions(2, 1, 3, 2), 254 true, 255 version(2, 1), 256 }, 257 { 258 "local.max < peer.max and local.min = peer.min", 259 versions(1, 2, 2, 1), 260 versions(1, 2, 3, 2), 261 true, 262 version(2, 1), 263 }, 264 { 265 "local.max = peer.max and local.min < peer.min", 266 versions(1, 2, 2, 1), 267 versions(2, 1, 2, 1), 268 true, 269 version(2, 1), 270 }, 271 { 272 "all equal", 273 versions(2, 1, 2, 1), 274 versions(2, 1, 2, 1), 275 true, 276 version(2, 1), 277 }, 278 { 279 "max is smaller than min", 280 versions(2, 1, 1, 2), 281 versions(2, 1, 1, 2), 282 false, 283 nil, 284 }, 285 { 286 "no overlap, local > peer", 287 versions(4, 3, 6, 5), 288 versions(1, 0, 2, 1), 289 false, 290 nil, 291 }, 292 { 293 "no overlap, local < peer", 294 versions(1, 0, 2, 1), 295 versions(4, 3, 6, 5), 296 false, 297 nil, 298 }, 299 { 300 "no overlap, max < min", 301 versions(6, 5, 4, 3), 302 versions(2, 1, 1, 0), 303 false, 304 nil, 305 }, 306 } { 307 output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer) 308 if got, want := output, tc.output; got != want { 309 t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want) 310 } 311 if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) { 312 t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want) 313 } 314 } 315 } 316 317 // TestFullHandshake performs a full ALTS handshake between a test client and 318 // server, where both client and server offload to a local, fake handshaker 319 // service. 320 func (s) TestFullHandshake(t *testing.T) { 321 // Start the fake handshaker service and the server. 322 var wait sync.WaitGroup 323 defer wait.Wait() 324 stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait) 325 defer stopHandshaker() 326 stopServer, serverAddress := startServer(t, handshakerAddress, &wait) 327 defer stopServer() 328 329 // Ping the server, authenticating with ALTS. 330 establishAltsConnection(t, handshakerAddress, serverAddress) 331 332 // Close open connections to the fake handshaker service. 333 if err := service.CloseForTesting(); err != nil { 334 t.Errorf("service.CloseForTesting() failed: %v", err) 335 } 336 } 337 338 // TestConcurrentHandshakes performs a several, concurrent ALTS handshakes 339 // between a test client and server, where both client and server offload to a 340 // local, fake handshaker service. 341 func (s) TestConcurrentHandshakes(t *testing.T) { 342 // Set the max number of concurrent handshakes to 3, so that we can 343 // test the handshaker behavior when handshakes are queued by 344 // performing more than 3 concurrent handshakes (specifically, 10). 345 handshaker.ResetConcurrentHandshakeSemaphoreForTesting(3) 346 347 // Start the fake handshaker service and the server. 348 var wait sync.WaitGroup 349 defer wait.Wait() 350 stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait) 351 defer stopHandshaker() 352 stopServer, serverAddress := startServer(t, handshakerAddress, &wait) 353 defer stopServer() 354 355 // Ping the server, authenticating with ALTS. 356 var waitForConnections sync.WaitGroup 357 for i := 0; i < 10; i++ { 358 waitForConnections.Add(1) 359 go func() { 360 establishAltsConnection(t, handshakerAddress, serverAddress) 361 waitForConnections.Done() 362 }() 363 } 364 waitForConnections.Wait() 365 366 // Close open connections to the fake handshaker service. 367 if err := service.CloseForTesting(); err != nil { 368 t.Errorf("service.CloseForTesting() failed: %v", err) 369 } 370 } 371 372 func version(major, minor uint32) *altspb.RpcProtocolVersions_Version { 373 return &altspb.RpcProtocolVersions_Version{ 374 Major: major, 375 Minor: minor, 376 } 377 } 378 379 func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions { 380 return &altspb.RpcProtocolVersions{ 381 MinRpcVersion: version(minMajor, minMinor), 382 MaxRpcVersion: version(maxMajor, maxMinor), 383 } 384 } 385 386 func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) { 387 clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress}) 388 conn, err := grpc.Dial(serverAddress, grpc.WithTransportCredentials(clientCreds)) 389 if err != nil { 390 t.Fatalf("grpc.Dial(%v) failed: %v", serverAddress, err) 391 } 392 defer conn.Close() 393 ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout) 394 defer cancel() 395 c := testgrpc.NewTestServiceClient(conn) 396 var peer peer.Peer 397 success := false 398 for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) { 399 _, err = c.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.Peer(&peer)) 400 if err == nil { 401 success = true 402 break 403 } 404 if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded { 405 // The server is not ready yet or there were too many concurrent handshakes. 406 // Try again. 407 continue 408 } 409 t.Fatalf("c.UnaryCall() failed: %v", err) 410 } 411 if !success { 412 t.Fatalf("c.UnaryCall() timed out after %v", defaultTestShortTimeout) 413 } 414 415 // Check that peer.AuthInfo was populated with an ALTS AuthInfo 416 // instance. As a sanity check, also verify that the AuthType() and 417 // ApplicationProtocol() have the expected values. 418 if got, want := peer.AuthInfo.AuthType(), "alts"; got != want { 419 t.Errorf("authInfo.AuthType() = %s, want = %s", got, want) 420 } 421 authInfo, err := AuthInfoFromPeer(&peer) 422 if err != nil { 423 t.Errorf("AuthInfoFromPeer failed: %v", err) 424 } 425 if got, want := authInfo.ApplicationProtocol(), "grpc"; got != want { 426 t.Errorf("authInfo.ApplicationProtocol() = %s, want = %s", got, want) 427 } 428 } 429 430 func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) { 431 listener, err := testutils.LocalTCPListener() 432 if err != nil { 433 t.Fatalf("LocalTCPListener() failed: %v", err) 434 } 435 s := grpc.NewServer() 436 altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{}) 437 wait.Add(1) 438 go func() { 439 defer wait.Done() 440 if err := s.Serve(listener); err != nil { 441 t.Errorf("failed to serve: %v", err) 442 } 443 }() 444 return func() { s.Stop() }, listener.Addr().String() 445 } 446 447 func startServer(t *testing.T, handshakerServiceAddress string, wait *sync.WaitGroup) (stop func(), address string) { 448 listener, err := testutils.LocalTCPListener() 449 if err != nil { 450 t.Fatalf("LocalTCPListener() failed: %v", err) 451 } 452 serverOpts := &ServerOptions{HandshakerServiceAddress: handshakerServiceAddress} 453 creds := NewServerCreds(serverOpts) 454 s := grpc.NewServer(grpc.Creds(creds)) 455 testgrpc.RegisterTestServiceServer(s, &testServer{}) 456 wait.Add(1) 457 go func() { 458 defer wait.Done() 459 if err := s.Serve(listener); err != nil { 460 t.Errorf("s.Serve(%v) failed: %v", listener, err) 461 } 462 }() 463 return func() { s.Stop() }, listener.Addr().String() 464 } 465 466 type testServer struct { 467 testgrpc.UnimplementedTestServiceServer 468 } 469 470 func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 471 return &testpb.SimpleResponse{ 472 Payload: &testpb.Payload{}, 473 }, nil 474 }