google.golang.org/grpc@v1.72.2/experimental/credentials/tls_ext_test.go (about) 1 /* 2 * 3 * Copyright 2025 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package credentials_test 20 21 import ( 22 "context" 23 "crypto/tls" 24 "crypto/x509" 25 "fmt" 26 "net" 27 "os" 28 "strings" 29 "testing" 30 "time" 31 32 "google.golang.org/grpc" 33 "google.golang.org/grpc/codes" 34 credsstable "google.golang.org/grpc/credentials" 35 "google.golang.org/grpc/experimental/credentials" 36 "google.golang.org/grpc/internal/envconfig" 37 "google.golang.org/grpc/internal/grpctest" 38 "google.golang.org/grpc/internal/stubserver" 39 "google.golang.org/grpc/status" 40 "google.golang.org/grpc/testdata" 41 42 testgrpc "google.golang.org/grpc/interop/grpc_testing" 43 testpb "google.golang.org/grpc/interop/grpc_testing" 44 ) 45 46 const defaultTestTimeout = 10 * time.Second 47 48 type s struct { 49 grpctest.Tester 50 } 51 52 func Test(t *testing.T) { 53 grpctest.RunSubTests(t, s{}) 54 } 55 56 var serverCert tls.Certificate 57 var certPool *x509.CertPool 58 var serverName = "x.test.example.com" 59 60 func init() { 61 var err error 62 serverCert, err = tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 63 if err != nil { 64 panic(fmt.Sprintf("tls.LoadX509KeyPair(server1.pem, server1.key) failed: %v", err)) 65 } 66 67 b, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) 68 if err != nil { 69 panic(fmt.Sprintf("Error reading CA cert file: %v", err)) 70 } 71 certPool = x509.NewCertPool() 72 if !certPool.AppendCertsFromPEM(b) { 73 panic("Error appending cert from PEM") 74 } 75 } 76 77 // Tests that the MinVersion of tls.Config is set to 1.2 if it is not already 78 // set by the user. 79 func (s) TestTLS_MinVersion12(t *testing.T) { 80 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 81 defer cancel() 82 83 testCases := []struct { 84 name string 85 serverTLS func() *tls.Config 86 }{ 87 { 88 name: "base_case", 89 serverTLS: func() *tls.Config { 90 return &tls.Config{ 91 // MinVersion should be set to 1.2 by gRPC by default. 92 Certificates: []tls.Certificate{serverCert}, 93 } 94 }, 95 }, 96 { 97 name: "fallback_to_base", 98 serverTLS: func() *tls.Config { 99 config := &tls.Config{ 100 // MinVersion should be set to 1.2 by gRPC by default. 101 Certificates: []tls.Certificate{serverCert}, 102 } 103 config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 104 return nil, nil 105 } 106 return config 107 }, 108 }, 109 { 110 name: "dynamic_using_get_config_for_client", 111 serverTLS: func() *tls.Config { 112 return &tls.Config{ 113 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 114 return &tls.Config{ 115 // MinVersion should be set to 1.2 by gRPC by default. 116 Certificates: []tls.Certificate{serverCert}, 117 }, nil 118 }, 119 } 120 }, 121 }, 122 } 123 124 for _, tc := range testCases { 125 t.Run(tc.name, func(t *testing.T) { 126 // Create server creds without a minimum version. 127 serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS()) 128 ss := stubserver.StubServer{ 129 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 130 return &testpb.Empty{}, nil 131 }, 132 } 133 134 // Create client creds that supports V1.0-V1.1. 135 clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ 136 ServerName: serverName, 137 RootCAs: certPool, 138 MinVersion: tls.VersionTLS10, 139 MaxVersion: tls.VersionTLS11, 140 }) 141 142 // Start server and client separately, because Start() blocks on a 143 // successful connection, which we will not get. 144 if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { 145 t.Fatalf("Error starting server: %v", err) 146 } 147 defer ss.Stop() 148 149 cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) 150 if err != nil { 151 t.Fatalf("grpc.NewClient error: %v", err) 152 } 153 defer cc.Close() 154 155 client := testgrpc.NewTestServiceClient(cc) 156 157 const wantStr = "authentication handshake failed" 158 if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { 159 t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) 160 } 161 162 }) 163 } 164 } 165 166 // Tests that the MinVersion of tls.Config is not changed if it is set by the 167 // user. 168 func (s) TestTLS_MinVersionOverridable(t *testing.T) { 169 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 170 defer cancel() 171 172 var allCipherSuites []uint16 173 for _, cs := range tls.CipherSuites() { 174 allCipherSuites = append(allCipherSuites, cs.ID) 175 } 176 testCases := []struct { 177 name string 178 serverTLS func() *tls.Config 179 }{ 180 { 181 name: "base_case", 182 serverTLS: func() *tls.Config { 183 return &tls.Config{ 184 MinVersion: tls.VersionTLS10, 185 Certificates: []tls.Certificate{serverCert}, 186 CipherSuites: allCipherSuites, 187 } 188 }, 189 }, 190 { 191 name: "fallback_to_base", 192 serverTLS: func() *tls.Config { 193 config := &tls.Config{ 194 MinVersion: tls.VersionTLS10, 195 Certificates: []tls.Certificate{serverCert}, 196 CipherSuites: allCipherSuites, 197 } 198 config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 199 return nil, nil 200 } 201 return config 202 }, 203 }, 204 { 205 name: "dynamic_using_get_config_for_client", 206 serverTLS: func() *tls.Config { 207 return &tls.Config{ 208 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 209 return &tls.Config{ 210 MinVersion: tls.VersionTLS10, 211 Certificates: []tls.Certificate{serverCert}, 212 CipherSuites: allCipherSuites, 213 }, nil 214 }, 215 } 216 }, 217 }, 218 } 219 220 for _, tc := range testCases { 221 t.Run(tc.name, func(t *testing.T) { 222 // Create server creds that allow v1.0. 223 serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS()) 224 ss := stubserver.StubServer{ 225 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 226 return &testpb.Empty{}, nil 227 }, 228 } 229 230 // Create client creds that supports V1.0-V1.1. 231 clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ 232 ServerName: serverName, 233 RootCAs: certPool, 234 CipherSuites: allCipherSuites, 235 MinVersion: tls.VersionTLS10, 236 MaxVersion: tls.VersionTLS11, 237 }) 238 239 if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { 240 t.Fatalf("Error starting stub server: %v", err) 241 } 242 defer ss.Stop() 243 244 if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { 245 t.Fatalf("EmptyCall err = %v; want <nil>", err) 246 } 247 }) 248 } 249 } 250 251 // Tests that CipherSuites is set to exclude HTTP/2 forbidden suites by default. 252 func (s) TestTLS_CipherSuites(t *testing.T) { 253 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 254 defer cancel() 255 testCases := []struct { 256 name string 257 serverTLS func() *tls.Config 258 }{ 259 { 260 name: "base_case", 261 serverTLS: func() *tls.Config { 262 return &tls.Config{ 263 Certificates: []tls.Certificate{serverCert}, 264 } 265 }, 266 }, 267 { 268 name: "fallback_to_base", 269 serverTLS: func() *tls.Config { 270 config := &tls.Config{ 271 Certificates: []tls.Certificate{serverCert}, 272 } 273 config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 274 return nil, nil 275 } 276 return config 277 }, 278 }, 279 { 280 name: "dynamic_using_get_config_for_client", 281 serverTLS: func() *tls.Config { 282 return &tls.Config{ 283 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 284 return &tls.Config{ 285 Certificates: []tls.Certificate{serverCert}, 286 }, nil 287 }, 288 } 289 }, 290 }, 291 } 292 293 for _, tc := range testCases { 294 t.Run(tc.name, func(t *testing.T) { 295 // Create server creds without cipher suites. 296 serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS()) 297 ss := stubserver.StubServer{ 298 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 299 return &testpb.Empty{}, nil 300 }, 301 } 302 303 // Create client creds that use a forbidden suite only. 304 clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ 305 ServerName: serverName, 306 RootCAs: certPool, 307 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 308 MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. 309 }) 310 311 // Start server and client separately, because Start() blocks on a 312 // successful connection, which we will not get. 313 if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { 314 t.Fatalf("Error starting server: %v", err) 315 } 316 defer ss.Stop() 317 318 cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds)) 319 if err != nil { 320 t.Fatalf("grpc.NewClient error: %v", err) 321 } 322 defer cc.Close() 323 324 client := testgrpc.NewTestServiceClient(cc) 325 326 const wantStr = "authentication handshake failed" 327 if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { 328 t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) 329 } 330 }) 331 } 332 } 333 334 // Tests that CipherSuites is not overridden when it is set. 335 func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { 336 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 337 defer cancel() 338 339 testCases := []struct { 340 name string 341 serverTLS func() *tls.Config 342 }{ 343 { 344 name: "base_case", 345 serverTLS: func() *tls.Config { 346 return &tls.Config{ 347 Certificates: []tls.Certificate{serverCert}, 348 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 349 } 350 }, 351 }, 352 { 353 name: "fallback_to_base", 354 serverTLS: func() *tls.Config { 355 config := &tls.Config{ 356 Certificates: []tls.Certificate{serverCert}, 357 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 358 } 359 config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 360 return nil, nil 361 } 362 return config 363 }, 364 }, 365 { 366 name: "dynamic_using_get_config_for_client", 367 serverTLS: func() *tls.Config { 368 return &tls.Config{ 369 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 370 return &tls.Config{ 371 Certificates: []tls.Certificate{serverCert}, 372 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 373 }, nil 374 }, 375 } 376 }, 377 }, 378 } 379 380 for _, tc := range testCases { 381 t.Run(tc.name, func(t *testing.T) { 382 // Create server that allows only a forbidden cipher suite. 383 serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS()) 384 ss := stubserver.StubServer{ 385 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 386 return &testpb.Empty{}, nil 387 }, 388 } 389 390 // Create server that allows only a forbidden cipher suite. 391 clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ 392 ServerName: serverName, 393 RootCAs: certPool, 394 CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, 395 MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. 396 }) 397 398 if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { 399 t.Fatalf("Error starting stub server: %v", err) 400 } 401 defer ss.Stop() 402 403 if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { 404 t.Fatalf("EmptyCall err = %v; want <nil>", err) 405 } 406 }) 407 } 408 } 409 410 // TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configured 411 // correctly for a server that doesn't specify the NextProtos field and uses 412 // GetConfigForClient to provide the TLS config during the handshake. 413 func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) { 414 initialVal := envconfig.EnforceALPNEnabled 415 defer func() { 416 envconfig.EnforceALPNEnabled = initialVal 417 }() 418 envconfig.EnforceALPNEnabled = true 419 420 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 421 defer cancel() 422 423 // Create a server that doesn't set the NextProtos field. 424 serverCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ 425 GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { 426 return &tls.Config{ 427 Certificates: []tls.Certificate{serverCert}, 428 }, nil 429 }, 430 }) 431 432 ss := stubserver.StubServer{ 433 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 434 return &testpb.Empty{}, nil 435 }, 436 } 437 438 clientCreds := credsstable.NewTLS(&tls.Config{ 439 ServerName: serverName, 440 RootCAs: certPool, 441 }) 442 443 if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { 444 t.Fatalf("Error starting stub server: %v", err) 445 } 446 defer ss.Stop() 447 448 if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { 449 t.Fatalf("EmptyCall err = %v; want <nil>", err) 450 } 451 } 452 453 // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when 454 // connecting to a server that doesn't support ALPN. 455 func (s) TestTLS_DisabledALPNClient(t *testing.T) { 456 initialVal := envconfig.EnforceALPNEnabled 457 defer func() { 458 envconfig.EnforceALPNEnabled = initialVal 459 }() 460 461 tests := []struct { 462 name string 463 alpnEnforced bool 464 wantErr bool 465 }{ 466 { 467 name: "enforced", 468 }, 469 { 470 name: "not_enforced", 471 }, 472 } 473 474 for _, tc := range tests { 475 t.Run(tc.name, func(t *testing.T) { 476 envconfig.EnforceALPNEnabled = tc.alpnEnforced 477 478 listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{ 479 Certificates: []tls.Certificate{serverCert}, 480 NextProtos: []string{}, // Empty list indicates ALPN is disabled. 481 }) 482 if err != nil { 483 t.Fatalf("Error starting TLS server: %v", err) 484 } 485 486 errCh := make(chan error, 1) 487 go func() { 488 conn, err := listener.Accept() 489 if err != nil { 490 errCh <- fmt.Errorf("listener.Accept returned error: %v", err) 491 } else { 492 // The first write to the TLS listener initiates the TLS handshake. 493 conn.Write([]byte("Hello, World!")) 494 conn.Close() 495 } 496 close(errCh) 497 }() 498 499 serverAddr := listener.Addr().String() 500 conn, err := net.Dial("tcp", serverAddr) 501 if err != nil { 502 t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err) 503 } 504 defer conn.Close() 505 506 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 507 defer cancel() 508 509 clientCfg := tls.Config{ 510 ServerName: serverName, 511 RootCAs: certPool, 512 NextProtos: []string{"h2"}, 513 } 514 _, _, err = credentials.NewTLSWithALPNDisabled(&clientCfg).ClientHandshake(ctx, serverName, conn) 515 516 if gotErr := (err != nil); gotErr != tc.wantErr { 517 t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) 518 } 519 520 select { 521 case err := <-errCh: 522 if err != nil { 523 t.Fatalf("Unexpected error received from server: %v", err) 524 } 525 case <-ctx.Done(): 526 t.Fatalf("Timeout waiting for error from server") 527 } 528 }) 529 } 530 } 531 532 // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when 533 // accepting a request from a client that doesn't support ALPN. 534 func (s) TestTLS_DisabledALPNServer(t *testing.T) { 535 initialVal := envconfig.EnforceALPNEnabled 536 defer func() { 537 envconfig.EnforceALPNEnabled = initialVal 538 }() 539 540 tests := []struct { 541 name string 542 alpnEnforced bool 543 wantErr bool 544 }{ 545 { 546 name: "enforced", 547 }, 548 { 549 name: "not_enforced", 550 }, 551 } 552 553 for _, tc := range tests { 554 t.Run(tc.name, func(t *testing.T) { 555 envconfig.EnforceALPNEnabled = tc.alpnEnforced 556 557 listener, err := net.Listen("tcp", "localhost:0") 558 if err != nil { 559 t.Fatalf("Error starting server: %v", err) 560 } 561 562 errCh := make(chan error, 1) 563 go func() { 564 conn, err := listener.Accept() 565 if err != nil { 566 errCh <- fmt.Errorf("listener.Accept returned error: %v", err) 567 return 568 } 569 defer conn.Close() 570 serverCfg := tls.Config{ 571 Certificates: []tls.Certificate{serverCert}, 572 NextProtos: []string{"h2"}, 573 } 574 _, _, err = credentials.NewTLSWithALPNDisabled(&serverCfg).ServerHandshake(conn) 575 if gotErr := (err != nil); gotErr != tc.wantErr { 576 t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) 577 } 578 close(errCh) 579 }() 580 581 serverAddr := listener.Addr().String() 582 clientCfg := &tls.Config{ 583 Certificates: []tls.Certificate{serverCert}, 584 NextProtos: []string{}, // Empty list indicates ALPN is disabled. 585 RootCAs: certPool, 586 ServerName: serverName, 587 } 588 conn, err := tls.Dial("tcp", serverAddr, clientCfg) 589 if err != nil { 590 t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err) 591 } 592 defer conn.Close() 593 594 select { 595 case <-time.After(defaultTestTimeout): 596 t.Fatal("Timed out waiting for completion") 597 case err := <-errCh: 598 if err != nil { 599 t.Fatalf("Unexpected server error: %v", err) 600 } 601 } 602 }) 603 } 604 }