google.golang.org/grpc@v1.72.2/credentials/alts/alts_test.go (about) 1 //go:build linux || windows 2 // +build linux windows 3 4 /* 5 * 6 * Copyright 2018 gRPC authors. 7 * 8 * Licensed under the Apache License, Version 2.0 (the "License"); 9 * you may not use this file except in compliance with the License. 10 * You may obtain a copy of the License at 11 * 12 * http://www.apache.org/licenses/LICENSE-2.0 13 * 14 * Unless required by applicable law or agreed to in writing, software 15 * distributed under the License is distributed on an "AS IS" BASIS, 16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 * See the License for the specific language governing permissions and 18 * limitations under the License. 19 * 20 */ 21 22 package alts 23 24 import ( 25 "context" 26 "reflect" 27 "sync" 28 "testing" 29 "time" 30 31 "google.golang.org/grpc" 32 "google.golang.org/grpc/codes" 33 "google.golang.org/grpc/credentials/alts/internal/handshaker" 34 "google.golang.org/grpc/credentials/alts/internal/handshaker/service" 35 altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" 36 altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" 37 "google.golang.org/grpc/credentials/alts/internal/testutil" 38 "google.golang.org/grpc/internal/grpctest" 39 "google.golang.org/grpc/internal/stubserver" 40 "google.golang.org/grpc/internal/testutils" 41 testgrpc "google.golang.org/grpc/interop/grpc_testing" 42 testpb "google.golang.org/grpc/interop/grpc_testing" 43 "google.golang.org/grpc/peer" 44 "google.golang.org/grpc/status" 45 "google.golang.org/protobuf/proto" 46 ) 47 48 const ( 49 defaultTestLongTimeout = 60 * time.Second 50 defaultTestShortTimeout = 10 * time.Millisecond 51 ) 52 53 type s struct { 54 grpctest.Tester 55 } 56 57 func init() { 58 // The vmOnGCP global variable MUST be forced to true. Otherwise, if 59 // this test is run anywhere except on a GCP VM, then an ALTS handshake 60 // will immediately fail. 61 once.Do(func() {}) 62 vmOnGCP = true 63 } 64 65 func Test(t *testing.T) { 66 grpctest.RunSubTests(t, s{}) 67 } 68 69 func (s) TestInfoServerName(t *testing.T) { 70 // This is not testing any handshaker functionality, so it's fine to only 71 // use NewServerCreds and not NewClientCreds. 72 alts := NewServerCreds(DefaultServerOptions()) 73 if got, want := alts.Info().ServerName, ""; got != want { 74 t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want) 75 } 76 } 77 78 func (s) TestOverrideServerName(t *testing.T) { 79 wantServerName := "server.name" 80 // This is not testing any handshaker functionality, so it's fine to only 81 // use NewServerCreds and not NewClientCreds. 82 c := NewServerCreds(DefaultServerOptions()) 83 c.OverrideServerName(wantServerName) 84 if got, want := c.Info().ServerName, wantServerName; got != want { 85 t.Fatalf("c.Info().ServerName = %v, want %v", got, want) 86 } 87 } 88 89 func (s) TestCloneClient(t *testing.T) { 90 wantServerName := "server.name" 91 opt := DefaultClientOptions() 92 opt.TargetServiceAccounts = []string{"not", "empty"} 93 c := NewClientCreds(opt) 94 c.OverrideServerName(wantServerName) 95 cc := c.Clone() 96 if got, want := cc.Info().ServerName, wantServerName; got != want { 97 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) 98 } 99 cc.OverrideServerName("") 100 if got, want := c.Info().ServerName, wantServerName; got != want { 101 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want) 102 } 103 if got, want := cc.Info().ServerName, ""; got != want { 104 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) 105 } 106 107 ct := c.(*altsTC) 108 cct := cc.(*altsTC) 109 110 if ct.side != cct.side { 111 t.Errorf("cc.side = %q, want %q", cct.side, ct.side) 112 } 113 if ct.hsAddress != cct.hsAddress { 114 t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress) 115 } 116 if !reflect.DeepEqual(ct.accounts, cct.accounts) { 117 t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts) 118 } 119 } 120 121 func (s) TestCloneServer(t *testing.T) { 122 wantServerName := "server.name" 123 c := NewServerCreds(DefaultServerOptions()) 124 c.OverrideServerName(wantServerName) 125 cc := c.Clone() 126 if got, want := cc.Info().ServerName, wantServerName; got != want { 127 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) 128 } 129 cc.OverrideServerName("") 130 if got, want := c.Info().ServerName, wantServerName; got != want { 131 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want) 132 } 133 if got, want := cc.Info().ServerName, ""; got != want { 134 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) 135 } 136 137 ct := c.(*altsTC) 138 cct := cc.(*altsTC) 139 140 if ct.side != cct.side { 141 t.Errorf("cc.side = %q, want %q", cct.side, ct.side) 142 } 143 if ct.hsAddress != cct.hsAddress { 144 t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress) 145 } 146 if !reflect.DeepEqual(ct.accounts, cct.accounts) { 147 t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts) 148 } 149 } 150 151 func (s) TestInfo(t *testing.T) { 152 // This is not testing any handshaker functionality, so it's fine to only 153 // use NewServerCreds and not NewClientCreds. 154 c := NewServerCreds(DefaultServerOptions()) 155 info := c.Info() 156 if got, want := info.ProtocolVersion, ""; got != want { 157 t.Errorf("info.ProtocolVersion=%v, want %v", got, want) 158 } 159 if got, want := info.SecurityProtocol, "alts"; got != want { 160 t.Errorf("info.SecurityProtocol=%v, want %v", got, want) 161 } 162 if got, want := info.SecurityVersion, "1.0"; got != want { 163 t.Errorf("info.SecurityVersion=%v, want %v", got, want) 164 } 165 if got, want := info.ServerName, ""; got != want { 166 t.Errorf("info.ServerName=%v, want %v", got, want) 167 } 168 } 169 170 func (s) TestCompareRPCVersions(t *testing.T) { 171 for _, tc := range []struct { 172 v1 *altspb.RpcProtocolVersions_Version 173 v2 *altspb.RpcProtocolVersions_Version 174 output int 175 }{ 176 { 177 version(3, 2), 178 version(2, 1), 179 1, 180 }, 181 { 182 version(3, 2), 183 version(3, 1), 184 1, 185 }, 186 { 187 version(2, 1), 188 version(3, 2), 189 -1, 190 }, 191 { 192 version(3, 1), 193 version(3, 2), 194 -1, 195 }, 196 { 197 version(3, 2), 198 version(3, 2), 199 0, 200 }, 201 } { 202 if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want { 203 t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want) 204 } 205 } 206 } 207 208 func (s) TestCheckRPCVersions(t *testing.T) { 209 for _, tc := range []struct { 210 desc string 211 local *altspb.RpcProtocolVersions 212 peer *altspb.RpcProtocolVersions 213 output bool 214 maxCommonVersion *altspb.RpcProtocolVersions_Version 215 }{ 216 { 217 "local.max > peer.max and local.min > peer.min", 218 versions(2, 1, 3, 2), 219 versions(1, 2, 2, 1), 220 true, 221 version(2, 1), 222 }, 223 { 224 "local.max > peer.max and local.min < peer.min", 225 versions(1, 2, 3, 2), 226 versions(2, 1, 2, 1), 227 true, 228 version(2, 1), 229 }, 230 { 231 "local.max > peer.max and local.min = peer.min", 232 versions(2, 1, 3, 2), 233 versions(2, 1, 2, 1), 234 true, 235 version(2, 1), 236 }, 237 { 238 "local.max < peer.max and local.min > peer.min", 239 versions(2, 1, 2, 1), 240 versions(1, 2, 3, 2), 241 true, 242 version(2, 1), 243 }, 244 { 245 "local.max = peer.max and local.min > peer.min", 246 versions(2, 1, 2, 1), 247 versions(1, 2, 2, 1), 248 true, 249 version(2, 1), 250 }, 251 { 252 "local.max < peer.max and local.min < peer.min", 253 versions(1, 2, 2, 1), 254 versions(2, 1, 3, 2), 255 true, 256 version(2, 1), 257 }, 258 { 259 "local.max < peer.max and local.min = peer.min", 260 versions(1, 2, 2, 1), 261 versions(1, 2, 3, 2), 262 true, 263 version(2, 1), 264 }, 265 { 266 "local.max = peer.max and local.min < peer.min", 267 versions(1, 2, 2, 1), 268 versions(2, 1, 2, 1), 269 true, 270 version(2, 1), 271 }, 272 { 273 "all equal", 274 versions(2, 1, 2, 1), 275 versions(2, 1, 2, 1), 276 true, 277 version(2, 1), 278 }, 279 { 280 "max is smaller than min", 281 versions(2, 1, 1, 2), 282 versions(2, 1, 1, 2), 283 false, 284 nil, 285 }, 286 { 287 "no overlap, local > peer", 288 versions(4, 3, 6, 5), 289 versions(1, 0, 2, 1), 290 false, 291 nil, 292 }, 293 { 294 "no overlap, local < peer", 295 versions(1, 0, 2, 1), 296 versions(4, 3, 6, 5), 297 false, 298 nil, 299 }, 300 { 301 "no overlap, max < min", 302 versions(6, 5, 4, 3), 303 versions(2, 1, 1, 0), 304 false, 305 nil, 306 }, 307 } { 308 output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer) 309 if got, want := output, tc.output; got != want { 310 t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want) 311 } 312 if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) { 313 t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want) 314 } 315 } 316 } 317 318 // TestFullHandshake performs a full ALTS handshake between a test client and 319 // server, where both client and server offload to a local, fake handshaker 320 // service. 321 func (s) TestFullHandshake(t *testing.T) { 322 // Start the fake handshaker service and the server. 323 var wait sync.WaitGroup 324 defer wait.Wait() 325 stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait) 326 defer stopHandshaker() 327 stopServer, serverAddress := startServer(t, handshakerAddress) 328 defer stopServer() 329 330 // Ping the server, authenticating with ALTS. 331 establishAltsConnection(t, handshakerAddress, serverAddress) 332 333 // Close open connections to the fake handshaker service. 334 if err := service.CloseForTesting(); err != nil { 335 t.Errorf("service.CloseForTesting() failed: %v", err) 336 } 337 } 338 339 // TestConcurrentHandshakes performs a several, concurrent ALTS handshakes 340 // between a test client and server, where both client and server offload to a 341 // local, fake handshaker service. 342 func (s) TestConcurrentHandshakes(t *testing.T) { 343 // Set the max number of concurrent handshakes to 3, so that we can 344 // test the handshaker behavior when handshakes are queued by 345 // performing more than 3 concurrent handshakes (specifically, 10). 346 handshaker.ResetConcurrentHandshakeSemaphoreForTesting(3) 347 348 // Start the fake handshaker service and the server. 349 var wait sync.WaitGroup 350 defer wait.Wait() 351 stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait) 352 defer stopHandshaker() 353 stopServer, serverAddress := startServer(t, handshakerAddress) 354 defer stopServer() 355 356 // Ping the server, authenticating with ALTS. 357 var waitForConnections sync.WaitGroup 358 for i := 0; i < 10; i++ { 359 waitForConnections.Add(1) 360 go func() { 361 establishAltsConnection(t, handshakerAddress, serverAddress) 362 waitForConnections.Done() 363 }() 364 } 365 waitForConnections.Wait() 366 367 // Close open connections to the fake handshaker service. 368 if err := service.CloseForTesting(); err != nil { 369 t.Errorf("service.CloseForTesting() failed: %v", err) 370 } 371 } 372 373 func version(major, minor uint32) *altspb.RpcProtocolVersions_Version { 374 return &altspb.RpcProtocolVersions_Version{ 375 Major: major, 376 Minor: minor, 377 } 378 } 379 380 func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions { 381 return &altspb.RpcProtocolVersions{ 382 MinRpcVersion: version(minMajor, minMinor), 383 MaxRpcVersion: version(maxMajor, maxMinor), 384 } 385 } 386 387 func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) { 388 clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress}) 389 conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds)) 390 if err != nil { 391 t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err) 392 } 393 defer conn.Close() 394 ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout) 395 defer cancel() 396 c := testgrpc.NewTestServiceClient(conn) 397 var peer peer.Peer 398 success := false 399 for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) { 400 _, err = c.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.Peer(&peer)) 401 if err == nil { 402 success = true 403 break 404 } 405 if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded { 406 // The server is not ready yet or there were too many concurrent handshakes. 407 // Try again. 408 continue 409 } 410 t.Fatalf("c.UnaryCall() failed: %v", err) 411 } 412 if !success { 413 t.Fatalf("c.UnaryCall() timed out after %v", defaultTestShortTimeout) 414 } 415 416 // Check that peer.AuthInfo was populated with an ALTS AuthInfo 417 // instance. As a sanity check, also verify that the AuthType() and 418 // ApplicationProtocol() have the expected values. 419 if got, want := peer.AuthInfo.AuthType(), "alts"; got != want { 420 t.Errorf("authInfo.AuthType() = %s, want = %s", got, want) 421 } 422 authInfo, err := AuthInfoFromPeer(&peer) 423 if err != nil { 424 t.Errorf("AuthInfoFromPeer failed: %v", err) 425 } 426 if got, want := authInfo.ApplicationProtocol(), "grpc"; got != want { 427 t.Errorf("authInfo.ApplicationProtocol() = %s, want = %s", got, want) 428 } 429 } 430 431 func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) { 432 listener, err := testutils.LocalTCPListener() 433 if err != nil { 434 t.Fatalf("LocalTCPListener() failed: %v", err) 435 } 436 s := grpc.NewServer() 437 altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{}) 438 wait.Add(1) 439 go func() { 440 defer wait.Done() 441 if err := s.Serve(listener); err != nil { 442 t.Errorf("failed to serve: %v", err) 443 } 444 }() 445 return func() { s.Stop() }, listener.Addr().String() 446 } 447 448 func startServer(t *testing.T, handshakerServiceAddress string) (stop func(), address string) { 449 listener, err := testutils.LocalTCPListener() 450 if err != nil { 451 t.Fatalf("LocalTCPListener() failed: %v", err) 452 } 453 serverOpts := &ServerOptions{HandshakerServiceAddress: handshakerServiceAddress} 454 creds := NewServerCreds(serverOpts) 455 stub := &stubserver.StubServer{ 456 Listener: listener, 457 UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 458 return &testpb.SimpleResponse{ 459 Payload: &testpb.Payload{}, 460 }, nil 461 }, 462 S: grpc.NewServer(grpc.Creds(creds)), 463 } 464 stubserver.StartTestService(t, stub) 465 return func() { stub.S.Stop() }, listener.Addr().String() 466 }