google.golang.org/grpc@v1.72.2/experimental/credentials/tls_ext_test.go (about)

     1  /*
     2   *
     3   * Copyright 2025 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package credentials_test
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"fmt"
    26  	"net"
    27  	"os"
    28  	"strings"
    29  	"testing"
    30  	"time"
    31  
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/codes"
    34  	credsstable "google.golang.org/grpc/credentials"
    35  	"google.golang.org/grpc/experimental/credentials"
    36  	"google.golang.org/grpc/internal/envconfig"
    37  	"google.golang.org/grpc/internal/grpctest"
    38  	"google.golang.org/grpc/internal/stubserver"
    39  	"google.golang.org/grpc/status"
    40  	"google.golang.org/grpc/testdata"
    41  
    42  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    43  	testpb "google.golang.org/grpc/interop/grpc_testing"
    44  )
    45  
    46  const defaultTestTimeout = 10 * time.Second
    47  
    48  type s struct {
    49  	grpctest.Tester
    50  }
    51  
    52  func Test(t *testing.T) {
    53  	grpctest.RunSubTests(t, s{})
    54  }
    55  
    56  var serverCert tls.Certificate
    57  var certPool *x509.CertPool
    58  var serverName = "x.test.example.com"
    59  
    60  func init() {
    61  	var err error
    62  	serverCert, err = tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
    63  	if err != nil {
    64  		panic(fmt.Sprintf("tls.LoadX509KeyPair(server1.pem, server1.key) failed: %v", err))
    65  	}
    66  
    67  	b, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
    68  	if err != nil {
    69  		panic(fmt.Sprintf("Error reading CA cert file: %v", err))
    70  	}
    71  	certPool = x509.NewCertPool()
    72  	if !certPool.AppendCertsFromPEM(b) {
    73  		panic("Error appending cert from PEM")
    74  	}
    75  }
    76  
    77  // Tests that the MinVersion of tls.Config is set to 1.2 if it is not already
    78  // set by the user.
    79  func (s) TestTLS_MinVersion12(t *testing.T) {
    80  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    81  	defer cancel()
    82  
    83  	testCases := []struct {
    84  		name      string
    85  		serverTLS func() *tls.Config
    86  	}{
    87  		{
    88  			name: "base_case",
    89  			serverTLS: func() *tls.Config {
    90  				return &tls.Config{
    91  					// MinVersion should be set to 1.2 by gRPC by default.
    92  					Certificates: []tls.Certificate{serverCert},
    93  				}
    94  			},
    95  		},
    96  		{
    97  			name: "fallback_to_base",
    98  			serverTLS: func() *tls.Config {
    99  				config := &tls.Config{
   100  					// MinVersion should be set to 1.2 by gRPC by default.
   101  					Certificates: []tls.Certificate{serverCert},
   102  				}
   103  				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   104  					return nil, nil
   105  				}
   106  				return config
   107  			},
   108  		},
   109  		{
   110  			name: "dynamic_using_get_config_for_client",
   111  			serverTLS: func() *tls.Config {
   112  				return &tls.Config{
   113  					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   114  						return &tls.Config{
   115  							// MinVersion should be set to 1.2 by gRPC by default.
   116  							Certificates: []tls.Certificate{serverCert},
   117  						}, nil
   118  					},
   119  				}
   120  			},
   121  		},
   122  	}
   123  
   124  	for _, tc := range testCases {
   125  		t.Run(tc.name, func(t *testing.T) {
   126  			// Create server creds without a minimum version.
   127  			serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS())
   128  			ss := stubserver.StubServer{
   129  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   130  					return &testpb.Empty{}, nil
   131  				},
   132  			}
   133  
   134  			// Create client creds that supports V1.0-V1.1.
   135  			clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
   136  				ServerName: serverName,
   137  				RootCAs:    certPool,
   138  				MinVersion: tls.VersionTLS10,
   139  				MaxVersion: tls.VersionTLS11,
   140  			})
   141  
   142  			// Start server and client separately, because Start() blocks on a
   143  			// successful connection, which we will not get.
   144  			if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
   145  				t.Fatalf("Error starting server: %v", err)
   146  			}
   147  			defer ss.Stop()
   148  
   149  			cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds))
   150  			if err != nil {
   151  				t.Fatalf("grpc.NewClient error: %v", err)
   152  			}
   153  			defer cc.Close()
   154  
   155  			client := testgrpc.NewTestServiceClient(cc)
   156  
   157  			const wantStr = "authentication handshake failed"
   158  			if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
   159  				t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
   160  			}
   161  
   162  		})
   163  	}
   164  }
   165  
   166  // Tests that the MinVersion of tls.Config is not changed if it is set by the
   167  // user.
   168  func (s) TestTLS_MinVersionOverridable(t *testing.T) {
   169  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   170  	defer cancel()
   171  
   172  	var allCipherSuites []uint16
   173  	for _, cs := range tls.CipherSuites() {
   174  		allCipherSuites = append(allCipherSuites, cs.ID)
   175  	}
   176  	testCases := []struct {
   177  		name      string
   178  		serverTLS func() *tls.Config
   179  	}{
   180  		{
   181  			name: "base_case",
   182  			serverTLS: func() *tls.Config {
   183  				return &tls.Config{
   184  					MinVersion:   tls.VersionTLS10,
   185  					Certificates: []tls.Certificate{serverCert},
   186  					CipherSuites: allCipherSuites,
   187  				}
   188  			},
   189  		},
   190  		{
   191  			name: "fallback_to_base",
   192  			serverTLS: func() *tls.Config {
   193  				config := &tls.Config{
   194  					MinVersion:   tls.VersionTLS10,
   195  					Certificates: []tls.Certificate{serverCert},
   196  					CipherSuites: allCipherSuites,
   197  				}
   198  				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   199  					return nil, nil
   200  				}
   201  				return config
   202  			},
   203  		},
   204  		{
   205  			name: "dynamic_using_get_config_for_client",
   206  			serverTLS: func() *tls.Config {
   207  				return &tls.Config{
   208  					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   209  						return &tls.Config{
   210  							MinVersion:   tls.VersionTLS10,
   211  							Certificates: []tls.Certificate{serverCert},
   212  							CipherSuites: allCipherSuites,
   213  						}, nil
   214  					},
   215  				}
   216  			},
   217  		},
   218  	}
   219  
   220  	for _, tc := range testCases {
   221  		t.Run(tc.name, func(t *testing.T) {
   222  			// Create server creds that allow v1.0.
   223  			serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS())
   224  			ss := stubserver.StubServer{
   225  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   226  					return &testpb.Empty{}, nil
   227  				},
   228  			}
   229  
   230  			// Create client creds that supports V1.0-V1.1.
   231  			clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
   232  				ServerName:   serverName,
   233  				RootCAs:      certPool,
   234  				CipherSuites: allCipherSuites,
   235  				MinVersion:   tls.VersionTLS10,
   236  				MaxVersion:   tls.VersionTLS11,
   237  			})
   238  
   239  			if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
   240  				t.Fatalf("Error starting stub server: %v", err)
   241  			}
   242  			defer ss.Stop()
   243  
   244  			if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   245  				t.Fatalf("EmptyCall err = %v; want <nil>", err)
   246  			}
   247  		})
   248  	}
   249  }
   250  
   251  // Tests that CipherSuites is set to exclude HTTP/2 forbidden suites by default.
   252  func (s) TestTLS_CipherSuites(t *testing.T) {
   253  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   254  	defer cancel()
   255  	testCases := []struct {
   256  		name      string
   257  		serverTLS func() *tls.Config
   258  	}{
   259  		{
   260  			name: "base_case",
   261  			serverTLS: func() *tls.Config {
   262  				return &tls.Config{
   263  					Certificates: []tls.Certificate{serverCert},
   264  				}
   265  			},
   266  		},
   267  		{
   268  			name: "fallback_to_base",
   269  			serverTLS: func() *tls.Config {
   270  				config := &tls.Config{
   271  					Certificates: []tls.Certificate{serverCert},
   272  				}
   273  				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   274  					return nil, nil
   275  				}
   276  				return config
   277  			},
   278  		},
   279  		{
   280  			name: "dynamic_using_get_config_for_client",
   281  			serverTLS: func() *tls.Config {
   282  				return &tls.Config{
   283  					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   284  						return &tls.Config{
   285  							Certificates: []tls.Certificate{serverCert},
   286  						}, nil
   287  					},
   288  				}
   289  			},
   290  		},
   291  	}
   292  
   293  	for _, tc := range testCases {
   294  		t.Run(tc.name, func(t *testing.T) {
   295  			// Create server creds without cipher suites.
   296  			serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS())
   297  			ss := stubserver.StubServer{
   298  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   299  					return &testpb.Empty{}, nil
   300  				},
   301  			}
   302  
   303  			// Create client creds that use a forbidden suite only.
   304  			clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
   305  				ServerName:   serverName,
   306  				RootCAs:      certPool,
   307  				CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   308  				MaxVersion:   tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
   309  			})
   310  
   311  			// Start server and client separately, because Start() blocks on a
   312  			// successful connection, which we will not get.
   313  			if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
   314  				t.Fatalf("Error starting server: %v", err)
   315  			}
   316  			defer ss.Stop()
   317  
   318  			cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds))
   319  			if err != nil {
   320  				t.Fatalf("grpc.NewClient error: %v", err)
   321  			}
   322  			defer cc.Close()
   323  
   324  			client := testgrpc.NewTestServiceClient(cc)
   325  
   326  			const wantStr = "authentication handshake failed"
   327  			if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
   328  				t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
   329  			}
   330  		})
   331  	}
   332  }
   333  
   334  // Tests that CipherSuites is not overridden when it is set.
   335  func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
   336  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   337  	defer cancel()
   338  
   339  	testCases := []struct {
   340  		name      string
   341  		serverTLS func() *tls.Config
   342  	}{
   343  		{
   344  			name: "base_case",
   345  			serverTLS: func() *tls.Config {
   346  				return &tls.Config{
   347  					Certificates: []tls.Certificate{serverCert},
   348  					CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   349  				}
   350  			},
   351  		},
   352  		{
   353  			name: "fallback_to_base",
   354  			serverTLS: func() *tls.Config {
   355  				config := &tls.Config{
   356  					Certificates: []tls.Certificate{serverCert},
   357  					CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   358  				}
   359  				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   360  					return nil, nil
   361  				}
   362  				return config
   363  			},
   364  		},
   365  		{
   366  			name: "dynamic_using_get_config_for_client",
   367  			serverTLS: func() *tls.Config {
   368  				return &tls.Config{
   369  					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   370  						return &tls.Config{
   371  							Certificates: []tls.Certificate{serverCert},
   372  							CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   373  						}, nil
   374  					},
   375  				}
   376  			},
   377  		},
   378  	}
   379  
   380  	for _, tc := range testCases {
   381  		t.Run(tc.name, func(t *testing.T) {
   382  			// Create server that allows only a forbidden cipher suite.
   383  			serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS())
   384  			ss := stubserver.StubServer{
   385  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   386  					return &testpb.Empty{}, nil
   387  				},
   388  			}
   389  
   390  			// Create server that allows only a forbidden cipher suite.
   391  			clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
   392  				ServerName:   serverName,
   393  				RootCAs:      certPool,
   394  				CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   395  				MaxVersion:   tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
   396  			})
   397  
   398  			if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
   399  				t.Fatalf("Error starting stub server: %v", err)
   400  			}
   401  			defer ss.Stop()
   402  
   403  			if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   404  				t.Fatalf("EmptyCall err = %v; want <nil>", err)
   405  			}
   406  		})
   407  	}
   408  }
   409  
   410  // TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configured
   411  // correctly for a server that doesn't specify the NextProtos field and uses
   412  // GetConfigForClient to provide the TLS config during the handshake.
   413  func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
   414  	initialVal := envconfig.EnforceALPNEnabled
   415  	defer func() {
   416  		envconfig.EnforceALPNEnabled = initialVal
   417  	}()
   418  	envconfig.EnforceALPNEnabled = true
   419  
   420  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   421  	defer cancel()
   422  
   423  	// Create a server that doesn't set the NextProtos field.
   424  	serverCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
   425  		GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   426  			return &tls.Config{
   427  				Certificates: []tls.Certificate{serverCert},
   428  			}, nil
   429  		},
   430  	})
   431  
   432  	ss := stubserver.StubServer{
   433  		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   434  			return &testpb.Empty{}, nil
   435  		},
   436  	}
   437  
   438  	clientCreds := credsstable.NewTLS(&tls.Config{
   439  		ServerName: serverName,
   440  		RootCAs:    certPool,
   441  	})
   442  
   443  	if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
   444  		t.Fatalf("Error starting stub server: %v", err)
   445  	}
   446  	defer ss.Stop()
   447  
   448  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   449  		t.Fatalf("EmptyCall err = %v; want <nil>", err)
   450  	}
   451  }
   452  
   453  // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
   454  // connecting to a server that doesn't support ALPN.
   455  func (s) TestTLS_DisabledALPNClient(t *testing.T) {
   456  	initialVal := envconfig.EnforceALPNEnabled
   457  	defer func() {
   458  		envconfig.EnforceALPNEnabled = initialVal
   459  	}()
   460  
   461  	tests := []struct {
   462  		name         string
   463  		alpnEnforced bool
   464  		wantErr      bool
   465  	}{
   466  		{
   467  			name: "enforced",
   468  		},
   469  		{
   470  			name: "not_enforced",
   471  		},
   472  	}
   473  
   474  	for _, tc := range tests {
   475  		t.Run(tc.name, func(t *testing.T) {
   476  			envconfig.EnforceALPNEnabled = tc.alpnEnforced
   477  
   478  			listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
   479  				Certificates: []tls.Certificate{serverCert},
   480  				NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
   481  			})
   482  			if err != nil {
   483  				t.Fatalf("Error starting TLS server: %v", err)
   484  			}
   485  
   486  			errCh := make(chan error, 1)
   487  			go func() {
   488  				conn, err := listener.Accept()
   489  				if err != nil {
   490  					errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
   491  				} else {
   492  					// The first write to the TLS listener initiates the TLS handshake.
   493  					conn.Write([]byte("Hello, World!"))
   494  					conn.Close()
   495  				}
   496  				close(errCh)
   497  			}()
   498  
   499  			serverAddr := listener.Addr().String()
   500  			conn, err := net.Dial("tcp", serverAddr)
   501  			if err != nil {
   502  				t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
   503  			}
   504  			defer conn.Close()
   505  
   506  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   507  			defer cancel()
   508  
   509  			clientCfg := tls.Config{
   510  				ServerName: serverName,
   511  				RootCAs:    certPool,
   512  				NextProtos: []string{"h2"},
   513  			}
   514  			_, _, err = credentials.NewTLSWithALPNDisabled(&clientCfg).ClientHandshake(ctx, serverName, conn)
   515  
   516  			if gotErr := (err != nil); gotErr != tc.wantErr {
   517  				t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
   518  			}
   519  
   520  			select {
   521  			case err := <-errCh:
   522  				if err != nil {
   523  					t.Fatalf("Unexpected error received from server: %v", err)
   524  				}
   525  			case <-ctx.Done():
   526  				t.Fatalf("Timeout waiting for error from server")
   527  			}
   528  		})
   529  	}
   530  }
   531  
   532  // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
   533  // accepting a request from a client that doesn't support ALPN.
   534  func (s) TestTLS_DisabledALPNServer(t *testing.T) {
   535  	initialVal := envconfig.EnforceALPNEnabled
   536  	defer func() {
   537  		envconfig.EnforceALPNEnabled = initialVal
   538  	}()
   539  
   540  	tests := []struct {
   541  		name         string
   542  		alpnEnforced bool
   543  		wantErr      bool
   544  	}{
   545  		{
   546  			name: "enforced",
   547  		},
   548  		{
   549  			name: "not_enforced",
   550  		},
   551  	}
   552  
   553  	for _, tc := range tests {
   554  		t.Run(tc.name, func(t *testing.T) {
   555  			envconfig.EnforceALPNEnabled = tc.alpnEnforced
   556  
   557  			listener, err := net.Listen("tcp", "localhost:0")
   558  			if err != nil {
   559  				t.Fatalf("Error starting server: %v", err)
   560  			}
   561  
   562  			errCh := make(chan error, 1)
   563  			go func() {
   564  				conn, err := listener.Accept()
   565  				if err != nil {
   566  					errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
   567  					return
   568  				}
   569  				defer conn.Close()
   570  				serverCfg := tls.Config{
   571  					Certificates: []tls.Certificate{serverCert},
   572  					NextProtos:   []string{"h2"},
   573  				}
   574  				_, _, err = credentials.NewTLSWithALPNDisabled(&serverCfg).ServerHandshake(conn)
   575  				if gotErr := (err != nil); gotErr != tc.wantErr {
   576  					t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
   577  				}
   578  				close(errCh)
   579  			}()
   580  
   581  			serverAddr := listener.Addr().String()
   582  			clientCfg := &tls.Config{
   583  				Certificates: []tls.Certificate{serverCert},
   584  				NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
   585  				RootCAs:      certPool,
   586  				ServerName:   serverName,
   587  			}
   588  			conn, err := tls.Dial("tcp", serverAddr, clientCfg)
   589  			if err != nil {
   590  				t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
   591  			}
   592  			defer conn.Close()
   593  
   594  			select {
   595  			case <-time.After(defaultTestTimeout):
   596  				t.Fatal("Timed out waiting for completion")
   597  			case err := <-errCh:
   598  				if err != nil {
   599  					t.Fatalf("Unexpected server error: %v", err)
   600  				}
   601  			}
   602  		})
   603  	}
   604  }