google.golang.org/grpc@v1.62.1/test/creds_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package test 20 21 import ( 22 "context" 23 "errors" 24 "fmt" 25 "net" 26 "strings" 27 "testing" 28 "time" 29 30 "google.golang.org/grpc" 31 "google.golang.org/grpc/codes" 32 "google.golang.org/grpc/connectivity" 33 "google.golang.org/grpc/credentials" 34 "google.golang.org/grpc/credentials/insecure" 35 "google.golang.org/grpc/internal/testutils" 36 "google.golang.org/grpc/metadata" 37 "google.golang.org/grpc/resolver" 38 "google.golang.org/grpc/resolver/manual" 39 "google.golang.org/grpc/status" 40 "google.golang.org/grpc/tap" 41 "google.golang.org/grpc/testdata" 42 43 testgrpc "google.golang.org/grpc/interop/grpc_testing" 44 testpb "google.golang.org/grpc/interop/grpc_testing" 45 ) 46 47 const ( 48 bundlePerRPCOnly = "perRPCOnly" 49 bundleTLSOnly = "tlsOnly" 50 ) 51 52 type testCredsBundle struct { 53 t *testing.T 54 mode string 55 } 56 57 func (c *testCredsBundle) TransportCredentials() credentials.TransportCredentials { 58 if c.mode == bundlePerRPCOnly { 59 return insecure.NewCredentials() 60 } 61 62 creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") 63 if err != nil { 64 c.t.Logf("Failed to load credentials: %v", err) 65 return nil 66 } 67 return creds 68 } 69 70 func (c *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials { 71 if c.mode == bundleTLSOnly { 72 return nil 73 } 74 return testPerRPCCredentials{authdata: authdata} 75 } 76 77 func (c *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { 78 return &testCredsBundle{mode: mode}, nil 79 } 80 81 func (s) TestCredsBundleBoth(t *testing.T) { 82 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"}) 83 te.tapHandle = authHandle 84 te.customDialOptions = []grpc.DialOption{ 85 grpc.WithCredentialsBundle(&testCredsBundle{t: t}), 86 } 87 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 88 if err != nil { 89 t.Fatalf("Failed to generate credentials %v", err) 90 } 91 te.customServerOptions = []grpc.ServerOption{ 92 grpc.Creds(creds), 93 } 94 te.startServer(&testServer{}) 95 defer te.tearDown() 96 97 cc := te.clientConn() 98 tc := testgrpc.NewTestServiceClient(cc) 99 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 100 defer cancel() 101 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 102 t.Fatalf("Test failed. Reason: %v", err) 103 } 104 } 105 106 func (s) TestCredsBundleTransportCredentials(t *testing.T) { 107 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"}) 108 te.customDialOptions = []grpc.DialOption{ 109 grpc.WithCredentialsBundle(&testCredsBundle{t: t, mode: bundleTLSOnly}), 110 } 111 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 112 if err != nil { 113 t.Fatalf("Failed to generate credentials %v", err) 114 } 115 te.customServerOptions = []grpc.ServerOption{ 116 grpc.Creds(creds), 117 } 118 te.startServer(&testServer{}) 119 defer te.tearDown() 120 121 cc := te.clientConn() 122 tc := testgrpc.NewTestServiceClient(cc) 123 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 124 defer cancel() 125 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 126 t.Fatalf("Test failed. Reason: %v", err) 127 } 128 } 129 130 func (s) TestCredsBundlePerRPCCredentials(t *testing.T) { 131 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"}) 132 te.tapHandle = authHandle 133 te.customDialOptions = []grpc.DialOption{ 134 grpc.WithCredentialsBundle(&testCredsBundle{t: t, mode: bundlePerRPCOnly}), 135 } 136 te.startServer(&testServer{}) 137 defer te.tearDown() 138 139 cc := te.clientConn() 140 tc := testgrpc.NewTestServiceClient(cc) 141 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 142 defer cancel() 143 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 144 t.Fatalf("Test failed. Reason: %v", err) 145 } 146 } 147 148 type clientTimeoutCreds struct { 149 credentials.TransportCredentials 150 timeoutReturned bool 151 } 152 153 func (c *clientTimeoutCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 154 if !c.timeoutReturned { 155 c.timeoutReturned = true 156 return nil, nil, context.DeadlineExceeded 157 } 158 return rawConn, nil, nil 159 } 160 161 func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo { 162 return credentials.ProtocolInfo{} 163 } 164 165 func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials { 166 return nil 167 } 168 169 func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { 170 te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "empty"}) 171 te.userAgent = testAppUA 172 te.startServer(&testServer{security: te.e.security}) 173 defer te.tearDown() 174 175 cc := te.clientConn(grpc.WithTransportCredentials(&clientTimeoutCreds{})) 176 tc := testgrpc.NewTestServiceClient(cc) 177 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 178 defer cancel() 179 // This unary call should succeed, because ClientHandshake will succeed for the second time. 180 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { 181 te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want <nil>", err) 182 } 183 } 184 185 type methodTestCreds struct{} 186 187 func (m *methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { 188 ri, _ := credentials.RequestInfoFromContext(ctx) 189 return nil, status.Errorf(codes.Unknown, ri.Method) 190 } 191 192 func (m *methodTestCreds) RequireTransportSecurity() bool { return false } 193 194 func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { 195 const wantMethod = "/grpc.testing.TestService/EmptyCall" 196 te := newTest(t, env{name: "context-request-info", network: "tcp"}) 197 te.userAgent = testAppUA 198 te.startServer(&testServer{security: te.e.security}) 199 defer te.tearDown() 200 201 cc := te.clientConn(grpc.WithPerRPCCredentials(&methodTestCreds{})) 202 tc := testgrpc.NewTestServiceClient(cc) 203 204 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 205 defer cancel() 206 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { 207 t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) 208 } 209 210 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Convert(err).Message() != wantMethod { 211 t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) 212 } 213 } 214 215 const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" 216 217 type clientAlwaysFailCred struct { 218 credentials.TransportCredentials 219 } 220 221 func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 222 return nil, nil, errors.New(clientAlwaysFailCredErrorMsg) 223 } 224 func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { 225 return credentials.ProtocolInfo{} 226 } 227 func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { 228 return nil 229 } 230 231 func (s) TestFailFastRPCErrorOnBadCertificates(t *testing.T) { 232 te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"}) 233 te.startServer(&testServer{security: te.e.security}) 234 defer te.tearDown() 235 236 opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} 237 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 238 defer cancel() 239 cc, err := grpc.DialContext(ctx, te.srvAddr, opts...) 240 if err != nil { 241 t.Fatalf("Dial(_) = %v, want %v", err, nil) 242 } 243 defer cc.Close() 244 245 tc := testgrpc.NewTestServiceClient(cc) 246 for i := 0; i < 1000; i++ { 247 // This loop runs for at most 1 second. The first several RPCs will fail 248 // with Unavailable because the connection hasn't started. When the 249 // first connection failed with creds error, the next RPC should also 250 // fail with the expected error. 251 if _, err = tc.EmptyCall(ctx, &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { 252 return 253 } 254 time.Sleep(time.Millisecond) 255 } 256 te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) 257 } 258 259 func (s) TestWaitForReadyRPCErrorOnBadCertificates(t *testing.T) { 260 te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"}) 261 te.startServer(&testServer{security: te.e.security}) 262 defer te.tearDown() 263 264 opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} 265 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 266 defer cancel() 267 cc, err := grpc.DialContext(ctx, te.srvAddr, opts...) 268 if err != nil { 269 t.Fatalf("Dial(_) = %v, want %v", err, nil) 270 } 271 defer cc.Close() 272 273 tc := testgrpc.NewTestServiceClient(cc) 274 ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) 275 defer cancel() 276 if _, err = tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { 277 t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) 278 } 279 } 280 281 var ( 282 // test authdata 283 authdata = map[string]string{ 284 "test-key": "test-value", 285 "test-key2-bin": string([]byte{1, 2, 3}), 286 } 287 ) 288 289 type testPerRPCCredentials struct { 290 authdata map[string]string 291 errChan chan error 292 } 293 294 func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { 295 var err error 296 if cr.errChan != nil { 297 err = <-cr.errChan 298 } 299 return cr.authdata, err 300 } 301 302 func (cr testPerRPCCredentials) RequireTransportSecurity() bool { 303 return false 304 } 305 306 func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) { 307 md, ok := metadata.FromIncomingContext(ctx) 308 if !ok { 309 return ctx, fmt.Errorf("didn't find metadata in context") 310 } 311 for k, vwant := range authdata { 312 vgot, ok := md[k] 313 if !ok { 314 return ctx, fmt.Errorf("didn't find authdata key %v in context", k) 315 } 316 if vgot[0] != vwant { 317 return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant) 318 } 319 } 320 return ctx, nil 321 } 322 323 func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) { 324 for _, e := range listTestEnv() { 325 testPerRPCCredentialsViaDialOptions(t, e) 326 } 327 } 328 329 func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) { 330 te := newTest(t, e) 331 te.tapHandle = authHandle 332 te.perRPCCreds = testPerRPCCredentials{authdata: authdata} 333 te.startServer(&testServer{security: e.security}) 334 defer te.tearDown() 335 336 cc := te.clientConn() 337 tc := testgrpc.NewTestServiceClient(cc) 338 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 339 defer cancel() 340 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { 341 t.Fatalf("Test failed. Reason: %v", err) 342 } 343 } 344 345 func (s) TestPerRPCCredentialsViaCallOptions(t *testing.T) { 346 for _, e := range listTestEnv() { 347 testPerRPCCredentialsViaCallOptions(t, e) 348 } 349 } 350 351 func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) { 352 te := newTest(t, e) 353 te.tapHandle = authHandle 354 te.startServer(&testServer{security: e.security}) 355 defer te.tearDown() 356 357 cc := te.clientConn() 358 tc := testgrpc.NewTestServiceClient(cc) 359 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 360 defer cancel() 361 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil { 362 t.Fatalf("Test failed. Reason: %v", err) 363 } 364 } 365 366 func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) { 367 for _, e := range listTestEnv() { 368 testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e) 369 } 370 } 371 372 func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { 373 te := newTest(t, e) 374 te.perRPCCreds = testPerRPCCredentials{authdata: authdata} 375 // When credentials are provided via both dial options and call options, 376 // we apply both sets. 377 te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { 378 md, ok := metadata.FromIncomingContext(ctx) 379 if !ok { 380 return ctx, fmt.Errorf("couldn't find metadata in context") 381 } 382 for k, vwant := range authdata { 383 vgot, ok := md[k] 384 if !ok { 385 return ctx, fmt.Errorf("couldn't find metadata for key %v", k) 386 } 387 if len(vgot) != 2 { 388 return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot)) 389 } 390 if vgot[0] != vwant || vgot[1] != vwant { 391 return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant) 392 } 393 } 394 return ctx, nil 395 } 396 te.startServer(&testServer{security: e.security}) 397 defer te.tearDown() 398 399 cc := te.clientConn() 400 tc := testgrpc.NewTestServiceClient(cc) 401 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 402 defer cancel() 403 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil { 404 t.Fatalf("Test failed. Reason: %v", err) 405 } 406 } 407 408 const testAuthority = "test.auth.ori.ty" 409 410 type authorityCheckCreds struct { 411 credentials.TransportCredentials 412 got string 413 } 414 415 func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 416 c.got = authority 417 return rawConn, nil, nil 418 } 419 func (c *authorityCheckCreds) Info() credentials.ProtocolInfo { 420 return credentials.ProtocolInfo{} 421 } 422 func (c *authorityCheckCreds) Clone() credentials.TransportCredentials { 423 return c 424 } 425 426 // This test makes sure that the authority client handshake gets is the endpoint 427 // in dial target, not the resolved ip address. 428 func (s) TestCredsHandshakeAuthority(t *testing.T) { 429 lis, err := net.Listen("tcp", "localhost:0") 430 if err != nil { 431 t.Fatal(err) 432 } 433 cred := &authorityCheckCreds{} 434 s := grpc.NewServer() 435 go s.Serve(lis) 436 defer s.Stop() 437 438 r := manual.NewBuilderWithScheme("whatever") 439 440 cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred), grpc.WithResolvers(r)) 441 if err != nil { 442 t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) 443 } 444 defer cc.Close() 445 r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) 446 447 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 448 defer cancel() 449 testutils.AwaitState(ctx, t, cc, connectivity.Ready) 450 451 if cred.got != testAuthority { 452 t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) 453 } 454 } 455 456 // This test makes sure that the authority client handshake gets is the endpoint 457 // of the ServerName of the address when it is set. 458 func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) { 459 const testServerName = "test.server.name" 460 461 lis, err := net.Listen("tcp", "localhost:0") 462 if err != nil { 463 t.Fatal(err) 464 } 465 cred := &authorityCheckCreds{} 466 s := grpc.NewServer() 467 go s.Serve(lis) 468 defer s.Stop() 469 470 r := manual.NewBuilderWithScheme("whatever") 471 472 cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred), grpc.WithResolvers(r)) 473 if err != nil { 474 t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) 475 } 476 defer cc.Close() 477 r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}}) 478 479 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 480 defer cancel() 481 testutils.AwaitState(ctx, t, cc, connectivity.Ready) 482 483 if cred.got != testServerName { 484 t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) 485 } 486 } 487 488 type serverDispatchCred struct { 489 rawConnCh chan net.Conn 490 } 491 492 func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 493 return rawConn, nil, nil 494 } 495 func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 496 select { 497 case c.rawConnCh <- rawConn: 498 default: 499 } 500 return nil, nil, credentials.ErrConnDispatched 501 } 502 func (c *serverDispatchCred) Info() credentials.ProtocolInfo { 503 return credentials.ProtocolInfo{} 504 } 505 func (c *serverDispatchCred) Clone() credentials.TransportCredentials { 506 return nil 507 } 508 func (c *serverDispatchCred) OverrideServerName(s string) error { 509 return nil 510 } 511 func (c *serverDispatchCred) getRawConn() net.Conn { 512 return <-c.rawConnCh 513 } 514 515 func (s) TestServerCredsDispatch(t *testing.T) { 516 lis, err := net.Listen("tcp", "localhost:0") 517 if err != nil { 518 t.Fatal(err) 519 } 520 cred := &serverDispatchCred{ 521 rawConnCh: make(chan net.Conn, 1), 522 } 523 s := grpc.NewServer(grpc.Creds(cred)) 524 go s.Serve(lis) 525 defer s.Stop() 526 527 cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred)) 528 if err != nil { 529 t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) 530 } 531 defer cc.Close() 532 533 rawConn := cred.getRawConn() 534 // Give grpc a chance to see the error and potentially close the connection. 535 // And check that connection is not closed after that. 536 time.Sleep(100 * time.Millisecond) 537 // Check rawConn is not closed. 538 if n, err := rawConn.Write([]byte{0}); n <= 0 || err != nil { 539 t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err) 540 } 541 }