google.golang.org/grpc@v1.72.2/test/creds_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package test 20 21 import ( 22 "context" 23 "errors" 24 "fmt" 25 "net" 26 "strings" 27 "testing" 28 "time" 29 30 "google.golang.org/grpc" 31 "google.golang.org/grpc/codes" 32 "google.golang.org/grpc/connectivity" 33 "google.golang.org/grpc/credentials" 34 "google.golang.org/grpc/credentials/insecure" 35 "google.golang.org/grpc/internal/testutils" 36 "google.golang.org/grpc/metadata" 37 "google.golang.org/grpc/resolver" 38 "google.golang.org/grpc/resolver/manual" 39 "google.golang.org/grpc/status" 40 "google.golang.org/grpc/tap" 41 "google.golang.org/grpc/testdata" 42 43 testgrpc "google.golang.org/grpc/interop/grpc_testing" 44 testpb "google.golang.org/grpc/interop/grpc_testing" 45 ) 46 47 const ( 48 bundlePerRPCOnly = "perRPCOnly" 49 bundleTLSOnly = "tlsOnly" 50 ) 51 52 type testCredsBundle struct { 53 t *testing.T 54 mode string 55 } 56 57 func (c *testCredsBundle) TransportCredentials() credentials.TransportCredentials { 58 if c.mode == bundlePerRPCOnly { 59 return insecure.NewCredentials() 60 } 61 62 creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") 63 if err != nil { 64 c.t.Logf("Failed to load credentials: %v", err) 65 return nil 66 } 67 return creds 68 } 69 70 func (c *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials { 71 if c.mode == bundleTLSOnly { 72 return nil 73 } 74 return testPerRPCCredentials{authdata: authdata} 75 } 76 77 func (c *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { 78 return &testCredsBundle{mode: mode}, nil 79 } 80 81 func (s) TestCredsBundleBoth(t *testing.T) { 82 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"}) 83 te.tapHandle = authHandle 84 te.customDialOptions = []grpc.DialOption{ 85 grpc.WithCredentialsBundle(&testCredsBundle{t: t}), 86 } 87 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 88 if err != nil { 89 t.Fatalf("Failed to generate credentials %v", err) 90 } 91 te.customServerOptions = []grpc.ServerOption{ 92 grpc.Creds(creds), 93 } 94 te.startServer(&testServer{}) 95 defer te.tearDown() 96 97 cc := te.clientConn() 98 tc := testgrpc.NewTestServiceClient(cc) 99 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 100 defer cancel() 101 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 102 t.Fatalf("Test failed. Reason: %v", err) 103 } 104 } 105 106 func (s) TestCredsBundleTransportCredentials(t *testing.T) { 107 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"}) 108 te.customDialOptions = []grpc.DialOption{ 109 grpc.WithCredentialsBundle(&testCredsBundle{t: t, mode: bundleTLSOnly}), 110 } 111 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 112 if err != nil { 113 t.Fatalf("Failed to generate credentials %v", err) 114 } 115 te.customServerOptions = []grpc.ServerOption{ 116 grpc.Creds(creds), 117 } 118 te.startServer(&testServer{}) 119 defer te.tearDown() 120 121 cc := te.clientConn() 122 tc := testgrpc.NewTestServiceClient(cc) 123 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 124 defer cancel() 125 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 126 t.Fatalf("Test failed. Reason: %v", err) 127 } 128 } 129 130 func (s) TestCredsBundlePerRPCCredentials(t *testing.T) { 131 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"}) 132 te.tapHandle = authHandle 133 te.customDialOptions = []grpc.DialOption{ 134 grpc.WithCredentialsBundle(&testCredsBundle{t: t, mode: bundlePerRPCOnly}), 135 } 136 te.startServer(&testServer{}) 137 defer te.tearDown() 138 139 cc := te.clientConn() 140 tc := testgrpc.NewTestServiceClient(cc) 141 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 142 defer cancel() 143 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 144 t.Fatalf("Test failed. Reason: %v", err) 145 } 146 } 147 148 type clientTimeoutCreds struct { 149 credentials.TransportCredentials 150 timeoutReturned bool 151 } 152 153 func (c *clientTimeoutCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 154 if !c.timeoutReturned { 155 c.timeoutReturned = true 156 return nil, nil, context.DeadlineExceeded 157 } 158 return rawConn, nil, nil 159 } 160 161 func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo { 162 return credentials.ProtocolInfo{} 163 } 164 165 func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials { 166 return nil 167 } 168 169 func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { 170 te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "empty"}) 171 te.userAgent = testAppUA 172 te.startServer(&testServer{security: te.e.security}) 173 defer te.tearDown() 174 175 cc := te.clientConn(grpc.WithTransportCredentials(&clientTimeoutCreds{})) 176 tc := testgrpc.NewTestServiceClient(cc) 177 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 178 defer cancel() 179 // This unary call should succeed, because ClientHandshake will succeed for the second time. 180 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { 181 te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want <nil>", err) 182 } 183 } 184 185 type methodTestCreds struct{} 186 187 func (m *methodTestCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { 188 ri, _ := credentials.RequestInfoFromContext(ctx) 189 return nil, status.Error(codes.Unknown, ri.Method) 190 } 191 192 func (m *methodTestCreds) RequireTransportSecurity() bool { return false } 193 194 func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { 195 const wantMethod = "/grpc.testing.TestService/EmptyCall" 196 te := newTest(t, env{name: "context-request-info", network: "tcp"}) 197 te.userAgent = testAppUA 198 te.startServer(&testServer{security: te.e.security}) 199 defer te.tearDown() 200 201 cc := te.clientConn(grpc.WithPerRPCCredentials(&methodTestCreds{})) 202 tc := testgrpc.NewTestServiceClient(cc) 203 204 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 205 defer cancel() 206 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { 207 t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) 208 } 209 210 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Convert(err).Message() != wantMethod { 211 t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) 212 } 213 } 214 215 const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" 216 217 type clientAlwaysFailCred struct { 218 credentials.TransportCredentials 219 } 220 221 func (c clientAlwaysFailCred) ClientHandshake(context.Context, string, net.Conn) (net.Conn, credentials.AuthInfo, error) { 222 return nil, nil, errors.New(clientAlwaysFailCredErrorMsg) 223 } 224 func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { 225 return credentials.ProtocolInfo{} 226 } 227 func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { 228 return nil 229 } 230 231 func (s) TestFailFastRPCErrorOnBadCertificates(t *testing.T) { 232 te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"}) 233 te.startServer(&testServer{security: te.e.security}) 234 defer te.tearDown() 235 236 opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} 237 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 238 defer cancel() 239 cc, err := grpc.NewClient(te.srvAddr, opts...) 240 if err != nil { 241 t.Fatalf("NewClient(_) = %v, want %v", err, nil) 242 } 243 defer cc.Close() 244 245 tc := testgrpc.NewTestServiceClient(cc) 246 if _, err = tc.EmptyCall(ctx, &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { 247 return 248 } 249 te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) 250 } 251 252 func (s) TestWaitForReadyRPCErrorOnBadCertificates(t *testing.T) { 253 te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"}) 254 te.startServer(&testServer{security: te.e.security}) 255 defer te.tearDown() 256 257 opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} 258 cc, err := grpc.NewClient(te.srvAddr, opts...) 259 if err != nil { 260 t.Fatalf("NewClient(_) = %v, want %v", err, nil) 261 } 262 defer cc.Close() 263 264 // The DNS resolver may take more than defaultTestShortTimeout, we let the 265 // channel enter TransientFailure signalling that the first resolver state 266 // has been produced. 267 cc.Connect() 268 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 269 defer cancel() 270 testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) 271 272 tc := testgrpc.NewTestServiceClient(cc) 273 // Use a short context as WaitForReady waits for context expiration before 274 // failing the RPC. 275 ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) 276 defer cancel() 277 if _, err = tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { 278 t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) 279 } 280 } 281 282 var ( 283 // test authdata 284 authdata = map[string]string{ 285 "test-key": "test-value", 286 "test-key2-bin": string([]byte{1, 2, 3}), 287 } 288 ) 289 290 type testPerRPCCredentials struct { 291 authdata map[string]string 292 errChan chan error 293 } 294 295 func (cr testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { 296 var err error 297 if cr.errChan != nil { 298 err = <-cr.errChan 299 } 300 return cr.authdata, err 301 } 302 303 func (cr testPerRPCCredentials) RequireTransportSecurity() bool { 304 return false 305 } 306 307 func authHandle(ctx context.Context, _ *tap.Info) (context.Context, error) { 308 md, ok := metadata.FromIncomingContext(ctx) 309 if !ok { 310 return ctx, fmt.Errorf("didn't find metadata in context") 311 } 312 for k, vwant := range authdata { 313 vgot, ok := md[k] 314 if !ok { 315 return ctx, fmt.Errorf("didn't find authdata key %v in context", k) 316 } 317 if vgot[0] != vwant { 318 return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant) 319 } 320 } 321 return ctx, nil 322 } 323 324 func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) { 325 for _, e := range listTestEnv() { 326 testPerRPCCredentialsViaDialOptions(t, e) 327 } 328 } 329 330 func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) { 331 te := newTest(t, e) 332 te.tapHandle = authHandle 333 te.perRPCCreds = testPerRPCCredentials{authdata: authdata} 334 te.startServer(&testServer{security: e.security}) 335 defer te.tearDown() 336 337 cc := te.clientConn() 338 tc := testgrpc.NewTestServiceClient(cc) 339 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 340 defer cancel() 341 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 342 t.Fatalf("Test failed. Reason: %v", err) 343 } 344 } 345 346 func (s) TestPerRPCCredentialsViaCallOptions(t *testing.T) { 347 for _, e := range listTestEnv() { 348 testPerRPCCredentialsViaCallOptions(t, e) 349 } 350 } 351 352 func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) { 353 te := newTest(t, e) 354 te.tapHandle = authHandle 355 te.startServer(&testServer{security: e.security}) 356 defer te.tearDown() 357 358 cc := te.clientConn() 359 tc := testgrpc.NewTestServiceClient(cc) 360 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 361 defer cancel() 362 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil { 363 t.Fatalf("Test failed. Reason: %v", err) 364 } 365 } 366 367 func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) { 368 for _, e := range listTestEnv() { 369 testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e) 370 } 371 } 372 373 func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { 374 te := newTest(t, e) 375 te.perRPCCreds = testPerRPCCredentials{authdata: authdata} 376 // When credentials are provided via both dial options and call options, 377 // we apply both sets. 378 te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { 379 md, ok := metadata.FromIncomingContext(ctx) 380 if !ok { 381 return ctx, fmt.Errorf("couldn't find metadata in context") 382 } 383 for k, vwant := range authdata { 384 vgot, ok := md[k] 385 if !ok { 386 return ctx, fmt.Errorf("couldn't find metadata for key %v", k) 387 } 388 if len(vgot) != 2 { 389 return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot)) 390 } 391 if vgot[0] != vwant || vgot[1] != vwant { 392 return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant) 393 } 394 } 395 return ctx, nil 396 } 397 te.startServer(&testServer{security: e.security}) 398 defer te.tearDown() 399 400 cc := te.clientConn() 401 tc := testgrpc.NewTestServiceClient(cc) 402 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 403 defer cancel() 404 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil { 405 t.Fatalf("Test failed. Reason: %v", err) 406 } 407 } 408 409 const testAuthority = "test.auth.ori.ty" 410 411 type authorityCheckCreds struct { 412 credentials.TransportCredentials 413 got string 414 } 415 416 func (c *authorityCheckCreds) ClientHandshake(_ context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 417 c.got = authority 418 return rawConn, nil, nil 419 } 420 func (c *authorityCheckCreds) Info() credentials.ProtocolInfo { 421 return credentials.ProtocolInfo{} 422 } 423 func (c *authorityCheckCreds) Clone() credentials.TransportCredentials { 424 return c 425 } 426 427 // This test makes sure that the authority client handshake gets is the endpoint 428 // in dial target, not the resolved ip address. 429 func (s) TestCredsHandshakeAuthority(t *testing.T) { 430 lis, err := net.Listen("tcp", "localhost:0") 431 if err != nil { 432 t.Fatal(err) 433 } 434 cred := &authorityCheckCreds{} 435 s := grpc.NewServer() 436 go s.Serve(lis) 437 defer s.Stop() 438 439 r := manual.NewBuilderWithScheme("whatever") 440 441 cc, err := grpc.NewClient(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred), grpc.WithResolvers(r)) 442 if err != nil { 443 t.Fatalf("grpc.NewClient(%q) = %v", lis.Addr().String(), err) 444 } 445 defer cc.Close() 446 cc.Connect() 447 r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) 448 449 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 450 defer cancel() 451 testutils.AwaitState(ctx, t, cc, connectivity.Ready) 452 453 if cred.got != testAuthority { 454 t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) 455 } 456 } 457 458 // This test makes sure that the authority client handshake gets is the endpoint 459 // of the ServerName of the address when it is set. 460 func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) { 461 const testServerName = "test.server.name" 462 463 lis, err := net.Listen("tcp", "localhost:0") 464 if err != nil { 465 t.Fatal(err) 466 } 467 cred := &authorityCheckCreds{} 468 s := grpc.NewServer() 469 go s.Serve(lis) 470 defer s.Stop() 471 472 r := manual.NewBuilderWithScheme("whatever") 473 474 cc, err := grpc.NewClient(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred), grpc.WithResolvers(r)) 475 if err != nil { 476 t.Fatalf("grpc.NewClient(%q) = %v", lis.Addr().String(), err) 477 } 478 defer cc.Close() 479 cc.Connect() 480 r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}}) 481 482 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 483 defer cancel() 484 testutils.AwaitState(ctx, t, cc, connectivity.Ready) 485 486 if cred.got != testServerName { 487 t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) 488 } 489 } 490 491 type serverDispatchCred struct { 492 rawConnCh chan net.Conn 493 } 494 495 func (c *serverDispatchCred) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 496 return rawConn, nil, nil 497 } 498 func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 499 select { 500 case c.rawConnCh <- rawConn: 501 default: 502 } 503 return nil, nil, credentials.ErrConnDispatched 504 } 505 func (c *serverDispatchCred) Info() credentials.ProtocolInfo { 506 return credentials.ProtocolInfo{} 507 } 508 func (c *serverDispatchCred) Clone() credentials.TransportCredentials { 509 return nil 510 } 511 func (c *serverDispatchCred) OverrideServerName(string) error { 512 return nil 513 } 514 func (c *serverDispatchCred) getRawConn() net.Conn { 515 return <-c.rawConnCh 516 } 517 518 func (s) TestServerCredsDispatch(t *testing.T) { 519 lis, err := net.Listen("tcp", "localhost:0") 520 if err != nil { 521 t.Fatal(err) 522 } 523 cred := &serverDispatchCred{ 524 rawConnCh: make(chan net.Conn, 1), 525 } 526 s := grpc.NewServer(grpc.Creds(cred)) 527 go s.Serve(lis) 528 defer s.Stop() 529 530 cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(cred)) 531 if err != nil { 532 t.Fatalf("grpc.NewClient(%q) = %v", lis.Addr().String(), err) 533 } 534 defer cc.Close() 535 cc.Connect() 536 537 rawConn := cred.getRawConn() 538 // Give grpc a chance to see the error and potentially close the connection. 539 // And check that connection is not closed after that. 540 time.Sleep(100 * time.Millisecond) 541 // Check rawConn is not closed. 542 if n, err := rawConn.Write([]byte{0}); n <= 0 || err != nil { 543 t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err) 544 } 545 }