google.golang.org/grpc@v1.72.2/credentials/tls_ext_test.go (about) 1 /* 2 * 3 * Copyright 2023 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package credentials_test 20 21 import ( 22 "context" 23 "crypto/tls" 24 "crypto/x509" 25 "fmt" 26 "net" 27 "os" 28 "strings" 29 "testing" 30 "time" 31 32 "google.golang.org/grpc" 33 "google.golang.org/grpc/codes" 34 "google.golang.org/grpc/credentials" 35 "google.golang.org/grpc/internal/envconfig" 36 "google.golang.org/grpc/internal/grpctest" 37 "google.golang.org/grpc/internal/stubserver" 38 "google.golang.org/grpc/status" 39 "google.golang.org/grpc/testdata" 40 41 testgrpc "google.golang.org/grpc/interop/grpc_testing" 42 testpb "google.golang.org/grpc/interop/grpc_testing" 43 ) 44 45 const defaultTestTimeout = 10 * time.Second 46 47 type s struct { 48 grpctest.Tester 49 } 50 51 func Test(t *testing.T) { 52 grpctest.RunSubTests(t, s{}) 53 } 54 55 var serverCert tls.Certificate 56 var certPool *x509.CertPool 57 var serverName = "x.test.example.com" 58 59 func init() { 60 var err error 61 serverCert, err = tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 62 if err != nil { 63 panic(fmt.Sprintf("tls.LoadX509KeyPair(server1.pem, server1.key) failed: %v", err)) 64 } 65 66 b, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) 67 if err != nil { 68 panic(fmt.Sprintf("Error reading CA cert file: %v", err)) 69 } 70 certPool = x509.NewCertPool() 71 if !certPool.AppendCertsFromPEM(b) { 72 panic("Error appending cert from PEM") 73 } 74 } 75 76 // Tests that the MinVersion of tls.Config is set to 1.2 if it is not already 77 // set by the user. 78 func (s) TestTLS_MinVersion12(t *testing.T) { 79 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 80 defer cancel() 81 82 testCases := []struct { 83 name string 84 serverTLS func() *tls.Config 85 }{ 86 { 87 name: "base_case", 88 serverTLS: func() *tls.Config { 89 return &tls.Config{ 90 // MinVersion should be set to 1.2 by gRPC by default. 91 Certificates: []tls.Certificate{serverCert}, 92 } 93 }, 94 }, 95 { 96 name: "fallback_to_base", 97 serverTLS: func() *tls.Config { 98 config := &tls.Config{ 99 // MinVersion should be set to 1.2 by gRPC by default. 100 Certificates: []tls.Certificate{serverCert}, 101 } 102 config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 103 return nil, nil 104 } 105 return config 106 }, 107 }, 108 { 109 name: "dynamic_using_get_config_for_client", 110 serverTLS: func() *tls.Config { 111 return &tls.Config{ 112 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 113 return &tls.Config{ 114 // MinVersion should be set to 1.2 by gRPC by default. 115 Certificates: []tls.Certificate{serverCert}, 116 }, nil 117 }, 118 } 119 }, 120 }, 121 } 122 123 for _, tc := range testCases { 124 t.Run(tc.name, func(t *testing.T) { 125 // Create server creds without a minimum version. 126 serverCreds := credentials.NewTLS(tc.serverTLS()) 127 ss := stubserver.StubServer{ 128 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 129 return &testpb.Empty{}, nil 130 }, 131 } 132 133 // Create client creds that supports V1.0-V1.1. 134 clientCreds := credentials.NewTLS(&tls.Config{ 135 ServerName: serverName, 136 RootCAs: certPool, 137 MinVersion: tls.VersionTLS10, 138 MaxVersion: tls.VersionTLS11, 139 }) 140 141 // Start server and client separately, because Start() blocks on a 142 // successful connection, which we will not get. 143 if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { 144 t.Fatalf("Error starting server: %v", err) 145 } 146 defer ss.Stop() 147 148 cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) 149 if err != nil { 150 t.Fatalf("grpc.NewClient error: %v", err) 151 } 152 defer cc.Close() 153 154 client := testgrpc.NewTestServiceClient(cc) 155 156 const wantStr = "authentication handshake failed" 157 if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { 158 t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) 159 } 160 161 }) 162 } 163 } 164 165 // Tests that the MinVersion of tls.Config is not changed if it is set by the 166 // user. 167 func (s) TestTLS_MinVersionOverridable(t *testing.T) { 168 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 169 defer cancel() 170 171 var allCipherSuites []uint16 172 for _, cs := range tls.CipherSuites() { 173 allCipherSuites = append(allCipherSuites, cs.ID) 174 } 175 testCases := []struct { 176 name string 177 serverTLS func() *tls.Config 178 }{ 179 { 180 name: "base_case", 181 serverTLS: func() *tls.Config { 182 return &tls.Config{ 183 MinVersion: tls.VersionTLS10, 184 Certificates: []tls.Certificate{serverCert}, 185 CipherSuites: allCipherSuites, 186 } 187 }, 188 }, 189 { 190 name: "fallback_to_base", 191 serverTLS: func() *tls.Config { 192 config := &tls.Config{ 193 MinVersion: tls.VersionTLS10, 194 Certificates: []tls.Certificate{serverCert}, 195 CipherSuites: allCipherSuites, 196 } 197 config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 198 return nil, nil 199 } 200 return config 201 }, 202 }, 203 { 204 name: "dynamic_using_get_config_for_client", 205 serverTLS: func() *tls.Config { 206 return &tls.Config{ 207 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 208 return &tls.Config{ 209 MinVersion: tls.VersionTLS10, 210 Certificates: []tls.Certificate{serverCert}, 211 CipherSuites: allCipherSuites, 212 }, nil 213 }, 214 } 215 }, 216 }, 217 } 218 219 for _, tc := range testCases { 220 t.Run(tc.name, func(t *testing.T) { 221 // Create server creds that allow v1.0. 222 serverCreds := credentials.NewTLS(tc.serverTLS()) 223 ss := stubserver.StubServer{ 224 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 225 return &testpb.Empty{}, nil 226 }, 227 } 228 229 // Create client creds that supports V1.0-V1.1. 230 clientCreds := credentials.NewTLS(&tls.Config{ 231 ServerName: serverName, 232 RootCAs: certPool, 233 CipherSuites: allCipherSuites, 234 MinVersion: tls.VersionTLS10, 235 MaxVersion: tls.VersionTLS11, 236 }) 237 238 if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { 239 t.Fatalf("Error starting stub server: %v", err) 240 } 241 defer ss.Stop() 242 243 if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { 244 t.Fatalf("EmptyCall err = %v; want <nil>", err) 245 } 246 }) 247 } 248 } 249 250 // Tests that CipherSuites is set to exclude HTTP/2 forbidden suites by default. 251 func (s) TestTLS_CipherSuites(t *testing.T) { 252 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 253 defer cancel() 254 testCases := []struct { 255 name string 256 serverTLS func() *tls.Config 257 }{ 258 { 259 name: "base_case", 260 serverTLS: func() *tls.Config { 261 return &tls.Config{ 262 Certificates: []tls.Certificate{serverCert}, 263 } 264 }, 265 }, 266 { 267 name: "fallback_to_base", 268 serverTLS: func() *tls.Config { 269 config := &tls.Config{ 270 Certificates: []tls.Certificate{serverCert}, 271 } 272 config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 273 return nil, nil 274 } 275 return config 276 }, 277 }, 278 { 279 name: "dynamic_using_get_config_for_client", 280 serverTLS: func() *tls.Config { 281 return &tls.Config{ 282 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 283 return &tls.Config{ 284 Certificates: []tls.Certificate{serverCert}, 285 }, nil 286 }, 287 } 288 }, 289 }, 290 } 291 292 for _, tc := range testCases { 293 t.Run(tc.name, func(t *testing.T) { 294 // Create server creds without cipher suites. 295 serverCreds := credentials.NewTLS(tc.serverTLS()) 296 ss := stubserver.StubServer{ 297 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 298 return &testpb.Empty{}, nil 299 }, 300 } 301 302 // Create client creds that use a forbidden suite only. 303 clientCreds := credentials.NewTLS(&tls.Config{ 304 ServerName: serverName, 305 RootCAs: certPool, 306 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 307 MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. 308 }) 309 310 // Start server and client separately, because Start() blocks on a 311 // successful connection, which we will not get. 312 if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { 313 t.Fatalf("Error starting server: %v", err) 314 } 315 defer ss.Stop() 316 317 cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds)) 318 if err != nil { 319 t.Fatalf("grpc.NewClient error: %v", err) 320 } 321 defer cc.Close() 322 323 client := testgrpc.NewTestServiceClient(cc) 324 325 const wantStr = "authentication handshake failed" 326 if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { 327 t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) 328 } 329 }) 330 } 331 } 332 333 // Tests that CipherSuites is not overridden when it is set. 334 func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { 335 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 336 defer cancel() 337 338 testCases := []struct { 339 name string 340 serverTLS func() *tls.Config 341 }{ 342 { 343 name: "base_case", 344 serverTLS: func() *tls.Config { 345 return &tls.Config{ 346 Certificates: []tls.Certificate{serverCert}, 347 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 348 } 349 }, 350 }, 351 { 352 name: "fallback_to_base", 353 serverTLS: func() *tls.Config { 354 config := &tls.Config{ 355 Certificates: []tls.Certificate{serverCert}, 356 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 357 } 358 config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 359 return nil, nil 360 } 361 return config 362 }, 363 }, 364 { 365 name: "dynamic_using_get_config_for_client", 366 serverTLS: func() *tls.Config { 367 return &tls.Config{ 368 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 369 return &tls.Config{ 370 Certificates: []tls.Certificate{serverCert}, 371 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 372 }, nil 373 }, 374 } 375 }, 376 }, 377 } 378 379 for _, tc := range testCases { 380 t.Run(tc.name, func(t *testing.T) { 381 // Create server that allows only a forbidden cipher suite. 382 serverCreds := credentials.NewTLS(tc.serverTLS()) 383 ss := stubserver.StubServer{ 384 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 385 return &testpb.Empty{}, nil 386 }, 387 } 388 389 // Create server that allows only a forbidden cipher suite. 390 clientCreds := credentials.NewTLS(&tls.Config{ 391 ServerName: serverName, 392 RootCAs: certPool, 393 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 394 MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. 395 }) 396 397 if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { 398 t.Fatalf("Error starting stub server: %v", err) 399 } 400 defer ss.Stop() 401 402 if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { 403 t.Fatalf("EmptyCall err = %v; want <nil>", err) 404 } 405 }) 406 } 407 } 408 409 // TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configured 410 // correctly for a server that doesn't specify the NextProtos field and uses 411 // GetConfigForClient to provide the TLS config during the handshake. 412 func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) { 413 initialVal := envconfig.EnforceALPNEnabled 414 defer func() { 415 envconfig.EnforceALPNEnabled = initialVal 416 }() 417 envconfig.EnforceALPNEnabled = true 418 419 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 420 defer cancel() 421 422 // Create a server that doesn't set the NextProtos field. 423 serverCreds := credentials.NewTLS(&tls.Config{ 424 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 425 return &tls.Config{ 426 Certificates: []tls.Certificate{serverCert}, 427 }, nil 428 }, 429 }) 430 431 ss := stubserver.StubServer{ 432 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 433 return &testpb.Empty{}, nil 434 }, 435 } 436 437 clientCreds := credentials.NewTLS(&tls.Config{ 438 ServerName: serverName, 439 RootCAs: certPool, 440 }) 441 442 if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { 443 t.Fatalf("Error starting stub server: %v", err) 444 } 445 defer ss.Stop() 446 447 if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { 448 t.Fatalf("EmptyCall err = %v; want <nil>", err) 449 } 450 } 451 452 // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when 453 // connecting to a server that doesn't support ALPN. 454 func (s) TestTLS_DisabledALPNClient(t *testing.T) { 455 initialVal := envconfig.EnforceALPNEnabled 456 defer func() { 457 envconfig.EnforceALPNEnabled = initialVal 458 }() 459 460 tests := []struct { 461 name string 462 alpnEnforced bool 463 wantErr bool 464 }{ 465 { 466 name: "enforced", 467 alpnEnforced: true, 468 wantErr: true, 469 }, 470 { 471 name: "not_enforced", 472 }, 473 } 474 475 for _, tc := range tests { 476 t.Run(tc.name, func(t *testing.T) { 477 envconfig.EnforceALPNEnabled = tc.alpnEnforced 478 479 listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{ 480 Certificates: []tls.Certificate{serverCert}, 481 NextProtos: []string{}, // Empty list indicates ALPN is disabled. 482 }) 483 if err != nil { 484 t.Fatalf("Error starting TLS server: %v", err) 485 } 486 487 errCh := make(chan error, 1) 488 go func() { 489 conn, err := listener.Accept() 490 if err != nil { 491 errCh <- fmt.Errorf("listener.Accept returned error: %v", err) 492 } else { 493 // The first write to the TLS listener initiates the TLS handshake. 494 conn.Write([]byte("Hello, World!")) 495 conn.Close() 496 } 497 close(errCh) 498 }() 499 500 serverAddr := listener.Addr().String() 501 conn, err := net.Dial("tcp", serverAddr) 502 if err != nil { 503 t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err) 504 } 505 defer conn.Close() 506 507 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 508 defer cancel() 509 510 clientCfg := tls.Config{ 511 ServerName: serverName, 512 RootCAs: certPool, 513 NextProtos: []string{"h2"}, 514 } 515 _, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn) 516 517 if gotErr := (err != nil); gotErr != tc.wantErr { 518 t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) 519 } 520 521 select { 522 case err := <-errCh: 523 if err != nil { 524 t.Fatalf("Unexpected error received from server: %v", err) 525 } 526 case <-ctx.Done(): 527 t.Fatalf("Timeout waiting for error from server") 528 } 529 }) 530 } 531 } 532 533 // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when 534 // accepting a request from a client that doesn't support ALPN. 535 func (s) TestTLS_DisabledALPNServer(t *testing.T) { 536 initialVal := envconfig.EnforceALPNEnabled 537 defer func() { 538 envconfig.EnforceALPNEnabled = initialVal 539 }() 540 541 tests := []struct { 542 name string 543 alpnEnforced bool 544 wantErr bool 545 }{ 546 { 547 name: "enforced", 548 alpnEnforced: true, 549 wantErr: true, 550 }, 551 { 552 name: "not_enforced", 553 }, 554 } 555 556 for _, tc := range tests { 557 t.Run(tc.name, func(t *testing.T) { 558 envconfig.EnforceALPNEnabled = tc.alpnEnforced 559 560 listener, err := net.Listen("tcp", "localhost:0") 561 if err != nil { 562 t.Fatalf("Error starting server: %v", err) 563 } 564 565 errCh := make(chan error, 1) 566 go func() { 567 conn, err := listener.Accept() 568 if err != nil { 569 errCh <- fmt.Errorf("listener.Accept returned error: %v", err) 570 return 571 } 572 defer conn.Close() 573 serverCfg := tls.Config{ 574 Certificates: []tls.Certificate{serverCert}, 575 NextProtos: []string{"h2"}, 576 } 577 _, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn) 578 if gotErr := (err != nil); gotErr != tc.wantErr { 579 t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) 580 } 581 close(errCh) 582 }() 583 584 serverAddr := listener.Addr().String() 585 clientCfg := &tls.Config{ 586 Certificates: []tls.Certificate{serverCert}, 587 NextProtos: []string{}, // Empty list indicates ALPN is disabled. 588 RootCAs: certPool, 589 ServerName: serverName, 590 } 591 conn, err := tls.Dial("tcp", serverAddr, clientCfg) 592 if err != nil { 593 t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err) 594 } 595 defer conn.Close() 596 597 select { 598 case <-time.After(defaultTestTimeout): 599 t.Fatal("Timed out waiting for completion") 600 case err := <-errCh: 601 if err != nil { 602 t.Fatalf("Unexpected server error: %v", err) 603 } 604 } 605 }) 606 } 607 }