google.golang.org/grpc@v1.72.2/credentials/tls_ext_test.go (about)

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package credentials_test
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"fmt"
    26  	"net"
    27  	"os"
    28  	"strings"
    29  	"testing"
    30  	"time"
    31  
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/codes"
    34  	"google.golang.org/grpc/credentials"
    35  	"google.golang.org/grpc/internal/envconfig"
    36  	"google.golang.org/grpc/internal/grpctest"
    37  	"google.golang.org/grpc/internal/stubserver"
    38  	"google.golang.org/grpc/status"
    39  	"google.golang.org/grpc/testdata"
    40  
    41  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    42  	testpb "google.golang.org/grpc/interop/grpc_testing"
    43  )
    44  
    45  const defaultTestTimeout = 10 * time.Second
    46  
    47  type s struct {
    48  	grpctest.Tester
    49  }
    50  
    51  func Test(t *testing.T) {
    52  	grpctest.RunSubTests(t, s{})
    53  }
    54  
    55  var serverCert tls.Certificate
    56  var certPool *x509.CertPool
    57  var serverName = "x.test.example.com"
    58  
    59  func init() {
    60  	var err error
    61  	serverCert, err = tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
    62  	if err != nil {
    63  		panic(fmt.Sprintf("tls.LoadX509KeyPair(server1.pem, server1.key) failed: %v", err))
    64  	}
    65  
    66  	b, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
    67  	if err != nil {
    68  		panic(fmt.Sprintf("Error reading CA cert file: %v", err))
    69  	}
    70  	certPool = x509.NewCertPool()
    71  	if !certPool.AppendCertsFromPEM(b) {
    72  		panic("Error appending cert from PEM")
    73  	}
    74  }
    75  
    76  // Tests that the MinVersion of tls.Config is set to 1.2 if it is not already
    77  // set by the user.
    78  func (s) TestTLS_MinVersion12(t *testing.T) {
    79  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    80  	defer cancel()
    81  
    82  	testCases := []struct {
    83  		name      string
    84  		serverTLS func() *tls.Config
    85  	}{
    86  		{
    87  			name: "base_case",
    88  			serverTLS: func() *tls.Config {
    89  				return &tls.Config{
    90  					// MinVersion should be set to 1.2 by gRPC by default.
    91  					Certificates: []tls.Certificate{serverCert},
    92  				}
    93  			},
    94  		},
    95  		{
    96  			name: "fallback_to_base",
    97  			serverTLS: func() *tls.Config {
    98  				config := &tls.Config{
    99  					// MinVersion should be set to 1.2 by gRPC by default.
   100  					Certificates: []tls.Certificate{serverCert},
   101  				}
   102  				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   103  					return nil, nil
   104  				}
   105  				return config
   106  			},
   107  		},
   108  		{
   109  			name: "dynamic_using_get_config_for_client",
   110  			serverTLS: func() *tls.Config {
   111  				return &tls.Config{
   112  					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   113  						return &tls.Config{
   114  							// MinVersion should be set to 1.2 by gRPC by default.
   115  							Certificates: []tls.Certificate{serverCert},
   116  						}, nil
   117  					},
   118  				}
   119  			},
   120  		},
   121  	}
   122  
   123  	for _, tc := range testCases {
   124  		t.Run(tc.name, func(t *testing.T) {
   125  			// Create server creds without a minimum version.
   126  			serverCreds := credentials.NewTLS(tc.serverTLS())
   127  			ss := stubserver.StubServer{
   128  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   129  					return &testpb.Empty{}, nil
   130  				},
   131  			}
   132  
   133  			// Create client creds that supports V1.0-V1.1.
   134  			clientCreds := credentials.NewTLS(&tls.Config{
   135  				ServerName: serverName,
   136  				RootCAs:    certPool,
   137  				MinVersion: tls.VersionTLS10,
   138  				MaxVersion: tls.VersionTLS11,
   139  			})
   140  
   141  			// Start server and client separately, because Start() blocks on a
   142  			// successful connection, which we will not get.
   143  			if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
   144  				t.Fatalf("Error starting server: %v", err)
   145  			}
   146  			defer ss.Stop()
   147  
   148  			cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds))
   149  			if err != nil {
   150  				t.Fatalf("grpc.NewClient error: %v", err)
   151  			}
   152  			defer cc.Close()
   153  
   154  			client := testgrpc.NewTestServiceClient(cc)
   155  
   156  			const wantStr = "authentication handshake failed"
   157  			if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
   158  				t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
   159  			}
   160  
   161  		})
   162  	}
   163  }
   164  
   165  // Tests that the MinVersion of tls.Config is not changed if it is set by the
   166  // user.
   167  func (s) TestTLS_MinVersionOverridable(t *testing.T) {
   168  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   169  	defer cancel()
   170  
   171  	var allCipherSuites []uint16
   172  	for _, cs := range tls.CipherSuites() {
   173  		allCipherSuites = append(allCipherSuites, cs.ID)
   174  	}
   175  	testCases := []struct {
   176  		name      string
   177  		serverTLS func() *tls.Config
   178  	}{
   179  		{
   180  			name: "base_case",
   181  			serverTLS: func() *tls.Config {
   182  				return &tls.Config{
   183  					MinVersion:   tls.VersionTLS10,
   184  					Certificates: []tls.Certificate{serverCert},
   185  					CipherSuites: allCipherSuites,
   186  				}
   187  			},
   188  		},
   189  		{
   190  			name: "fallback_to_base",
   191  			serverTLS: func() *tls.Config {
   192  				config := &tls.Config{
   193  					MinVersion:   tls.VersionTLS10,
   194  					Certificates: []tls.Certificate{serverCert},
   195  					CipherSuites: allCipherSuites,
   196  				}
   197  				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   198  					return nil, nil
   199  				}
   200  				return config
   201  			},
   202  		},
   203  		{
   204  			name: "dynamic_using_get_config_for_client",
   205  			serverTLS: func() *tls.Config {
   206  				return &tls.Config{
   207  					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   208  						return &tls.Config{
   209  							MinVersion:   tls.VersionTLS10,
   210  							Certificates: []tls.Certificate{serverCert},
   211  							CipherSuites: allCipherSuites,
   212  						}, nil
   213  					},
   214  				}
   215  			},
   216  		},
   217  	}
   218  
   219  	for _, tc := range testCases {
   220  		t.Run(tc.name, func(t *testing.T) {
   221  			// Create server creds that allow v1.0.
   222  			serverCreds := credentials.NewTLS(tc.serverTLS())
   223  			ss := stubserver.StubServer{
   224  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   225  					return &testpb.Empty{}, nil
   226  				},
   227  			}
   228  
   229  			// Create client creds that supports V1.0-V1.1.
   230  			clientCreds := credentials.NewTLS(&tls.Config{
   231  				ServerName:   serverName,
   232  				RootCAs:      certPool,
   233  				CipherSuites: allCipherSuites,
   234  				MinVersion:   tls.VersionTLS10,
   235  				MaxVersion:   tls.VersionTLS11,
   236  			})
   237  
   238  			if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
   239  				t.Fatalf("Error starting stub server: %v", err)
   240  			}
   241  			defer ss.Stop()
   242  
   243  			if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   244  				t.Fatalf("EmptyCall err = %v; want <nil>", err)
   245  			}
   246  		})
   247  	}
   248  }
   249  
   250  // Tests that CipherSuites is set to exclude HTTP/2 forbidden suites by default.
   251  func (s) TestTLS_CipherSuites(t *testing.T) {
   252  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   253  	defer cancel()
   254  	testCases := []struct {
   255  		name      string
   256  		serverTLS func() *tls.Config
   257  	}{
   258  		{
   259  			name: "base_case",
   260  			serverTLS: func() *tls.Config {
   261  				return &tls.Config{
   262  					Certificates: []tls.Certificate{serverCert},
   263  				}
   264  			},
   265  		},
   266  		{
   267  			name: "fallback_to_base",
   268  			serverTLS: func() *tls.Config {
   269  				config := &tls.Config{
   270  					Certificates: []tls.Certificate{serverCert},
   271  				}
   272  				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   273  					return nil, nil
   274  				}
   275  				return config
   276  			},
   277  		},
   278  		{
   279  			name: "dynamic_using_get_config_for_client",
   280  			serverTLS: func() *tls.Config {
   281  				return &tls.Config{
   282  					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   283  						return &tls.Config{
   284  							Certificates: []tls.Certificate{serverCert},
   285  						}, nil
   286  					},
   287  				}
   288  			},
   289  		},
   290  	}
   291  
   292  	for _, tc := range testCases {
   293  		t.Run(tc.name, func(t *testing.T) {
   294  			// Create server creds without cipher suites.
   295  			serverCreds := credentials.NewTLS(tc.serverTLS())
   296  			ss := stubserver.StubServer{
   297  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   298  					return &testpb.Empty{}, nil
   299  				},
   300  			}
   301  
   302  			// Create client creds that use a forbidden suite only.
   303  			clientCreds := credentials.NewTLS(&tls.Config{
   304  				ServerName:   serverName,
   305  				RootCAs:      certPool,
   306  				CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   307  				MaxVersion:   tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
   308  			})
   309  
   310  			// Start server and client separately, because Start() blocks on a
   311  			// successful connection, which we will not get.
   312  			if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
   313  				t.Fatalf("Error starting server: %v", err)
   314  			}
   315  			defer ss.Stop()
   316  
   317  			cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds))
   318  			if err != nil {
   319  				t.Fatalf("grpc.NewClient error: %v", err)
   320  			}
   321  			defer cc.Close()
   322  
   323  			client := testgrpc.NewTestServiceClient(cc)
   324  
   325  			const wantStr = "authentication handshake failed"
   326  			if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
   327  				t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
   328  			}
   329  		})
   330  	}
   331  }
   332  
   333  // Tests that CipherSuites is not overridden when it is set.
   334  func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
   335  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   336  	defer cancel()
   337  
   338  	testCases := []struct {
   339  		name      string
   340  		serverTLS func() *tls.Config
   341  	}{
   342  		{
   343  			name: "base_case",
   344  			serverTLS: func() *tls.Config {
   345  				return &tls.Config{
   346  					Certificates: []tls.Certificate{serverCert},
   347  					CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   348  				}
   349  			},
   350  		},
   351  		{
   352  			name: "fallback_to_base",
   353  			serverTLS: func() *tls.Config {
   354  				config := &tls.Config{
   355  					Certificates: []tls.Certificate{serverCert},
   356  					CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   357  				}
   358  				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   359  					return nil, nil
   360  				}
   361  				return config
   362  			},
   363  		},
   364  		{
   365  			name: "dynamic_using_get_config_for_client",
   366  			serverTLS: func() *tls.Config {
   367  				return &tls.Config{
   368  					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   369  						return &tls.Config{
   370  							Certificates: []tls.Certificate{serverCert},
   371  							CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   372  						}, nil
   373  					},
   374  				}
   375  			},
   376  		},
   377  	}
   378  
   379  	for _, tc := range testCases {
   380  		t.Run(tc.name, func(t *testing.T) {
   381  			// Create server that allows only a forbidden cipher suite.
   382  			serverCreds := credentials.NewTLS(tc.serverTLS())
   383  			ss := stubserver.StubServer{
   384  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   385  					return &testpb.Empty{}, nil
   386  				},
   387  			}
   388  
   389  			// Create server that allows only a forbidden cipher suite.
   390  			clientCreds := credentials.NewTLS(&tls.Config{
   391  				ServerName:   serverName,
   392  				RootCAs:      certPool,
   393  				CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
   394  				MaxVersion:   tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
   395  			})
   396  
   397  			if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
   398  				t.Fatalf("Error starting stub server: %v", err)
   399  			}
   400  			defer ss.Stop()
   401  
   402  			if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   403  				t.Fatalf("EmptyCall err = %v; want <nil>", err)
   404  			}
   405  		})
   406  	}
   407  }
   408  
   409  // TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configured
   410  // correctly for a server that doesn't specify the NextProtos field and uses
   411  // GetConfigForClient to provide the TLS config during the handshake.
   412  func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
   413  	initialVal := envconfig.EnforceALPNEnabled
   414  	defer func() {
   415  		envconfig.EnforceALPNEnabled = initialVal
   416  	}()
   417  	envconfig.EnforceALPNEnabled = true
   418  
   419  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   420  	defer cancel()
   421  
   422  	// Create a server that doesn't set the NextProtos field.
   423  	serverCreds := credentials.NewTLS(&tls.Config{
   424  		GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   425  			return &tls.Config{
   426  				Certificates: []tls.Certificate{serverCert},
   427  			}, nil
   428  		},
   429  	})
   430  
   431  	ss := stubserver.StubServer{
   432  		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   433  			return &testpb.Empty{}, nil
   434  		},
   435  	}
   436  
   437  	clientCreds := credentials.NewTLS(&tls.Config{
   438  		ServerName: serverName,
   439  		RootCAs:    certPool,
   440  	})
   441  
   442  	if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
   443  		t.Fatalf("Error starting stub server: %v", err)
   444  	}
   445  	defer ss.Stop()
   446  
   447  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   448  		t.Fatalf("EmptyCall err = %v; want <nil>", err)
   449  	}
   450  }
   451  
   452  // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
   453  // connecting to a server that doesn't support ALPN.
   454  func (s) TestTLS_DisabledALPNClient(t *testing.T) {
   455  	initialVal := envconfig.EnforceALPNEnabled
   456  	defer func() {
   457  		envconfig.EnforceALPNEnabled = initialVal
   458  	}()
   459  
   460  	tests := []struct {
   461  		name         string
   462  		alpnEnforced bool
   463  		wantErr      bool
   464  	}{
   465  		{
   466  			name:         "enforced",
   467  			alpnEnforced: true,
   468  			wantErr:      true,
   469  		},
   470  		{
   471  			name: "not_enforced",
   472  		},
   473  	}
   474  
   475  	for _, tc := range tests {
   476  		t.Run(tc.name, func(t *testing.T) {
   477  			envconfig.EnforceALPNEnabled = tc.alpnEnforced
   478  
   479  			listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
   480  				Certificates: []tls.Certificate{serverCert},
   481  				NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
   482  			})
   483  			if err != nil {
   484  				t.Fatalf("Error starting TLS server: %v", err)
   485  			}
   486  
   487  			errCh := make(chan error, 1)
   488  			go func() {
   489  				conn, err := listener.Accept()
   490  				if err != nil {
   491  					errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
   492  				} else {
   493  					// The first write to the TLS listener initiates the TLS handshake.
   494  					conn.Write([]byte("Hello, World!"))
   495  					conn.Close()
   496  				}
   497  				close(errCh)
   498  			}()
   499  
   500  			serverAddr := listener.Addr().String()
   501  			conn, err := net.Dial("tcp", serverAddr)
   502  			if err != nil {
   503  				t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
   504  			}
   505  			defer conn.Close()
   506  
   507  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   508  			defer cancel()
   509  
   510  			clientCfg := tls.Config{
   511  				ServerName: serverName,
   512  				RootCAs:    certPool,
   513  				NextProtos: []string{"h2"},
   514  			}
   515  			_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
   516  
   517  			if gotErr := (err != nil); gotErr != tc.wantErr {
   518  				t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
   519  			}
   520  
   521  			select {
   522  			case err := <-errCh:
   523  				if err != nil {
   524  					t.Fatalf("Unexpected error received from server: %v", err)
   525  				}
   526  			case <-ctx.Done():
   527  				t.Fatalf("Timeout waiting for error from server")
   528  			}
   529  		})
   530  	}
   531  }
   532  
   533  // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
   534  // accepting a request from a client that doesn't support ALPN.
   535  func (s) TestTLS_DisabledALPNServer(t *testing.T) {
   536  	initialVal := envconfig.EnforceALPNEnabled
   537  	defer func() {
   538  		envconfig.EnforceALPNEnabled = initialVal
   539  	}()
   540  
   541  	tests := []struct {
   542  		name         string
   543  		alpnEnforced bool
   544  		wantErr      bool
   545  	}{
   546  		{
   547  			name:         "enforced",
   548  			alpnEnforced: true,
   549  			wantErr:      true,
   550  		},
   551  		{
   552  			name: "not_enforced",
   553  		},
   554  	}
   555  
   556  	for _, tc := range tests {
   557  		t.Run(tc.name, func(t *testing.T) {
   558  			envconfig.EnforceALPNEnabled = tc.alpnEnforced
   559  
   560  			listener, err := net.Listen("tcp", "localhost:0")
   561  			if err != nil {
   562  				t.Fatalf("Error starting server: %v", err)
   563  			}
   564  
   565  			errCh := make(chan error, 1)
   566  			go func() {
   567  				conn, err := listener.Accept()
   568  				if err != nil {
   569  					errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
   570  					return
   571  				}
   572  				defer conn.Close()
   573  				serverCfg := tls.Config{
   574  					Certificates: []tls.Certificate{serverCert},
   575  					NextProtos:   []string{"h2"},
   576  				}
   577  				_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
   578  				if gotErr := (err != nil); gotErr != tc.wantErr {
   579  					t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
   580  				}
   581  				close(errCh)
   582  			}()
   583  
   584  			serverAddr := listener.Addr().String()
   585  			clientCfg := &tls.Config{
   586  				Certificates: []tls.Certificate{serverCert},
   587  				NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
   588  				RootCAs:      certPool,
   589  				ServerName:   serverName,
   590  			}
   591  			conn, err := tls.Dial("tcp", serverAddr, clientCfg)
   592  			if err != nil {
   593  				t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
   594  			}
   595  			defer conn.Close()
   596  
   597  			select {
   598  			case <-time.After(defaultTestTimeout):
   599  				t.Fatal("Timed out waiting for completion")
   600  			case err := <-errCh:
   601  				if err != nil {
   602  					t.Fatalf("Unexpected server error: %v", err)
   603  				}
   604  			}
   605  		})
   606  	}
   607  }