vitess.io/vitess@v0.16.2/go/vt/tlstest/tlstest_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package tlstest
    18  
    19  import (
    20  	"crypto/tls"
    21  	"crypto/x509"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"strings"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/stretchr/testify/assert"
    31  
    32  	"vitess.io/vitess/go/vt/vttls"
    33  )
    34  
    35  func TestClientServerWithoutCombineCerts(t *testing.T) {
    36  	testClientServer(t, false)
    37  }
    38  
    39  func TestClientServerWithCombineCerts(t *testing.T) {
    40  	testClientServer(t, true)
    41  }
    42  
    43  // testClientServer generates:
    44  // - a root CA
    45  // - a server intermediate CA, with a server.
    46  // - a client intermediate CA, with a client.
    47  // And then performs a few tests on them.
    48  func testClientServer(t *testing.T, combineCerts bool) {
    49  	// Our test root.
    50  	root := t.TempDir()
    51  
    52  	clientServerKeyPairs := CreateClientServerCertPairs(root)
    53  	serverCA := ""
    54  
    55  	if combineCerts {
    56  		serverCA = clientServerKeyPairs.ServerCA
    57  	}
    58  
    59  	serverConfig, err := vttls.ServerConfig(
    60  		clientServerKeyPairs.ServerCert,
    61  		clientServerKeyPairs.ServerKey,
    62  		clientServerKeyPairs.ClientCA,
    63  		clientServerKeyPairs.ClientCRL,
    64  		serverCA,
    65  		tls.VersionTLS12)
    66  	if err != nil {
    67  		t.Fatalf("TLSServerConfig failed: %v", err)
    68  	}
    69  	clientConfig, err := vttls.ClientConfig(
    70  		vttls.VerifyIdentity,
    71  		clientServerKeyPairs.ClientCert,
    72  		clientServerKeyPairs.ClientKey,
    73  		clientServerKeyPairs.ServerCA,
    74  		clientServerKeyPairs.ServerCRL,
    75  		clientServerKeyPairs.ServerName,
    76  		tls.VersionTLS12)
    77  	if err != nil {
    78  		t.Fatalf("TLSClientConfig failed: %v", err)
    79  	}
    80  
    81  	// Create a TLS server listener.
    82  	listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig)
    83  	if err != nil {
    84  		t.Fatalf("Listen failed: %v", err)
    85  	}
    86  	addr := listener.Addr().String()
    87  	defer listener.Close()
    88  	// create a dialer with timeout
    89  	dialer := new(net.Dialer)
    90  	dialer.Timeout = 10 * time.Second
    91  
    92  	wg := sync.WaitGroup{}
    93  
    94  	//
    95  	// Positive case: accept on server side, connect a client, send data.
    96  	//
    97  	var clientErr error
    98  	wg.Add(1)
    99  	go func() {
   100  		defer wg.Done()
   101  		clientConn, clientErr := tls.DialWithDialer(dialer, "tcp", addr, clientConfig)
   102  		if clientErr == nil {
   103  			_, _ = clientConn.Write([]byte{42})
   104  			clientConn.Close()
   105  		}
   106  	}()
   107  
   108  	serverConn, err := listener.Accept()
   109  	if err != nil {
   110  		t.Fatalf("Accept failed: %v", err)
   111  	}
   112  
   113  	result := make([]byte, 1)
   114  	if n, err := serverConn.Read(result); (err != nil && err != io.EOF) || n != 1 {
   115  		t.Fatalf("Read failed: %v %v", n, err)
   116  	}
   117  	if result[0] != 42 {
   118  		t.Fatalf("Read returned wrong result: %v", result)
   119  	}
   120  	serverConn.Close()
   121  
   122  	wg.Wait()
   123  
   124  	if clientErr != nil {
   125  		t.Fatalf("Dial failed: %v", clientErr)
   126  	}
   127  
   128  	//
   129  	// Negative case: connect a client with wrong cert (using the
   130  	// server cert on the client side).
   131  	//
   132  
   133  	badClientConfig, err := vttls.ClientConfig(
   134  		vttls.VerifyIdentity,
   135  		clientServerKeyPairs.ServerCert,
   136  		clientServerKeyPairs.ServerKey,
   137  		clientServerKeyPairs.ServerCA,
   138  		clientServerKeyPairs.ServerCRL,
   139  		clientServerKeyPairs.ServerName,
   140  		tls.VersionTLS12)
   141  	if err != nil {
   142  		t.Fatalf("TLSClientConfig failed: %v", err)
   143  	}
   144  
   145  	var serverErr error
   146  	wg.Add(1)
   147  	go func() {
   148  		// We expect the Accept to work, but the first read to fail.
   149  		defer wg.Done()
   150  		serverConn, serverErr := listener.Accept()
   151  		// This will fail.
   152  		if serverErr == nil {
   153  			result := make([]byte, 1)
   154  			if n, err := serverConn.Read(result); err == nil {
   155  				fmt.Printf("Was able to read from server: %v\n", n)
   156  			}
   157  			serverConn.Close()
   158  		}
   159  	}()
   160  
   161  	// When using TLS 1.2, the Dial will fail.
   162  	// With TLS 1.3, the Dial will succeed and the first Read will fail.
   163  	clientConn, err := tls.DialWithDialer(dialer, "tcp", addr, badClientConfig)
   164  	if err != nil {
   165  		if !strings.Contains(err.Error(), "bad certificate") {
   166  			t.Errorf("Wrong error returned: %v", err)
   167  		}
   168  		return
   169  	}
   170  	wg.Wait()
   171  	if serverErr != nil {
   172  		t.Fatalf("Connection failed: %v", serverErr)
   173  	}
   174  
   175  	data := make([]byte, 1)
   176  	_, err = clientConn.Read(data)
   177  	if err == nil {
   178  		t.Fatalf("Dial or first Read was expected to fail")
   179  	}
   180  	if !strings.Contains(err.Error(), "bad certificate") {
   181  		t.Errorf("Wrong error returned: %v", err)
   182  	}
   183  }
   184  
   185  func getServerConfigWithoutCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Config, error) {
   186  	return vttls.ServerConfig(
   187  		keypairs.ServerCert,
   188  		keypairs.ServerKey,
   189  		keypairs.ClientCA,
   190  		keypairs.ClientCRL,
   191  		"",
   192  		tls.VersionTLS12)
   193  }
   194  
   195  func getServerConfigWithCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Config, error) {
   196  	return vttls.ServerConfig(
   197  		keypairs.ServerCert,
   198  		keypairs.ServerKey,
   199  		keypairs.ClientCA,
   200  		keypairs.ClientCRL,
   201  		keypairs.ServerCA,
   202  		tls.VersionTLS12)
   203  }
   204  
   205  func getClientConfig(keypairs ClientServerKeyPairs) (*tls.Config, error) {
   206  	return vttls.ClientConfig(
   207  		vttls.VerifyIdentity,
   208  		keypairs.ClientCert,
   209  		keypairs.ClientKey,
   210  		keypairs.ServerCA,
   211  		keypairs.ServerCRL,
   212  		keypairs.ServerName,
   213  		tls.VersionTLS12)
   214  }
   215  
   216  func testServerTLSConfigCaching(t *testing.T, getServerConfig func(ClientServerKeyPairs) (*tls.Config, error)) {
   217  	testConfigGeneration(t, "servertlstest", getServerConfig, func(config *tls.Config) *x509.CertPool {
   218  		return config.ClientCAs
   219  	})
   220  }
   221  
   222  func TestServerTLSConfigCachingWithoutCombinedCerts(t *testing.T) {
   223  	testServerTLSConfigCaching(t, getServerConfigWithoutCombinedCerts)
   224  }
   225  
   226  func TestServerTLSConfigCachingWithCombinedCerts(t *testing.T) {
   227  	testServerTLSConfigCaching(t, getServerConfigWithCombinedCerts)
   228  }
   229  
   230  func TestClientTLSConfigCaching(t *testing.T) {
   231  	testConfigGeneration(t, "clienttlstest", getClientConfig, func(config *tls.Config) *x509.CertPool {
   232  		return config.RootCAs
   233  	})
   234  }
   235  
   236  func testConfigGeneration(t *testing.T, rootPrefix string, generateConfig func(ClientServerKeyPairs) (*tls.Config, error), getCertPool func(tlsConfig *tls.Config) *x509.CertPool) {
   237  	// Our test root.
   238  	root := t.TempDir()
   239  
   240  	const configsToGenerate = 1
   241  
   242  	firstClientServerKeyPairs := CreateClientServerCertPairs(root)
   243  	secondClientServerKeyPairs := CreateClientServerCertPairs(root)
   244  
   245  	firstExpectedConfig, _ := generateConfig(firstClientServerKeyPairs)
   246  	secondExpectedConfig, _ := generateConfig(secondClientServerKeyPairs)
   247  	firstConfigChannel := make(chan *tls.Config, configsToGenerate)
   248  	secondConfigChannel := make(chan *tls.Config, configsToGenerate)
   249  
   250  	var configCounter = 0
   251  
   252  	for i := 1; i <= configsToGenerate; i++ {
   253  		go func() {
   254  			firstConfig, _ := generateConfig(firstClientServerKeyPairs)
   255  			firstConfigChannel <- firstConfig
   256  			secondConfig, _ := generateConfig(secondClientServerKeyPairs)
   257  			secondConfigChannel <- secondConfig
   258  		}()
   259  	}
   260  
   261  	for {
   262  		select {
   263  		case firstConfig := <-firstConfigChannel:
   264  			assert.Equal(t, &firstExpectedConfig.Certificates, &firstConfig.Certificates)
   265  			assert.Equal(t, getCertPool(firstExpectedConfig), getCertPool(firstConfig))
   266  		case secondConfig := <-secondConfigChannel:
   267  			assert.Equal(t, &secondExpectedConfig.Certificates, &secondConfig.Certificates)
   268  			assert.Equal(t, getCertPool(secondExpectedConfig), getCertPool(secondConfig))
   269  		}
   270  		configCounter = configCounter + 1
   271  
   272  		if configCounter >= 2*configsToGenerate {
   273  			break
   274  		}
   275  	}
   276  
   277  }
   278  
   279  func testNumberOfCertsWithOrWithoutCombining(t *testing.T, numCertsExpected int, combine bool) {
   280  	// Our test root.
   281  	root := t.TempDir()
   282  
   283  	clientServerKeyPairs := CreateClientServerCertPairs(root)
   284  	serverCA := ""
   285  	if combine {
   286  		serverCA = clientServerKeyPairs.ServerCA
   287  	}
   288  
   289  	serverConfig, err := vttls.ServerConfig(
   290  		clientServerKeyPairs.ServerCert,
   291  		clientServerKeyPairs.ServerKey,
   292  		clientServerKeyPairs.ClientCA,
   293  		clientServerKeyPairs.ClientCRL,
   294  		serverCA,
   295  		tls.VersionTLS12)
   296  
   297  	if err != nil {
   298  		t.Fatalf("TLSServerConfig failed: %v", err)
   299  	}
   300  	assert.Equal(t, numCertsExpected, len(serverConfig.Certificates[0].Certificate))
   301  }
   302  
   303  func TestNumberOfCertsWithoutCombining(t *testing.T) {
   304  	testNumberOfCertsWithOrWithoutCombining(t, 1, false)
   305  }
   306  
   307  func TestNumberOfCertsWithCombining(t *testing.T) {
   308  	testNumberOfCertsWithOrWithoutCombining(t, 2, true)
   309  }
   310  
   311  func assertTLSHandshakeFails(t *testing.T, serverConfig, clientConfig *tls.Config) {
   312  	// Create a TLS server listener.
   313  	listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig)
   314  	if err != nil {
   315  		t.Fatalf("Listen failed: %v", err)
   316  	}
   317  	addr := listener.Addr().String()
   318  	defer listener.Close()
   319  	// create a dialer with timeout
   320  	dialer := new(net.Dialer)
   321  	dialer.Timeout = 10 * time.Second
   322  
   323  	wg := sync.WaitGroup{}
   324  
   325  	var clientErr error
   326  	wg.Add(1)
   327  	go func() {
   328  		defer wg.Done()
   329  		var clientConn *tls.Conn
   330  		clientConn, clientErr = tls.DialWithDialer(dialer, "tcp", addr, clientConfig)
   331  		if clientErr == nil {
   332  			clientConn.Close()
   333  		}
   334  	}()
   335  
   336  	serverConn, err := listener.Accept()
   337  	if err != nil {
   338  		// We should always be able to accept on the socket
   339  		t.Fatalf("Accept failed: %v", err)
   340  	}
   341  
   342  	err = serverConn.(*tls.Conn).Handshake()
   343  	if err != nil {
   344  		if !(strings.Contains(err.Error(), "Certificate revoked: CommonName=") ||
   345  			strings.Contains(err.Error(), "remote error: tls: bad certificate")) {
   346  			t.Fatalf("Wrong error returned: %v", err)
   347  		}
   348  	} else {
   349  		t.Fatal("Server should have failed the TLS handshake but it did not")
   350  	}
   351  	serverConn.Close()
   352  	wg.Wait()
   353  }
   354  
   355  func TestClientServerWithRevokedServerCert(t *testing.T) {
   356  	root := t.TempDir()
   357  
   358  	clientServerKeyPairs := CreateClientServerCertPairs(root)
   359  
   360  	serverConfig, err := vttls.ServerConfig(
   361  		clientServerKeyPairs.RevokedServerCert,
   362  		clientServerKeyPairs.RevokedServerKey,
   363  		clientServerKeyPairs.ClientCA,
   364  		clientServerKeyPairs.ClientCRL,
   365  		"",
   366  		tls.VersionTLS12)
   367  	if err != nil {
   368  		t.Fatalf("TLSServerConfig failed: %v", err)
   369  	}
   370  
   371  	clientConfig, err := vttls.ClientConfig(
   372  		vttls.VerifyIdentity,
   373  		clientServerKeyPairs.ClientCert,
   374  		clientServerKeyPairs.ClientKey,
   375  		clientServerKeyPairs.ServerCA,
   376  		clientServerKeyPairs.ServerCRL,
   377  		clientServerKeyPairs.RevokedServerName,
   378  		tls.VersionTLS12)
   379  	if err != nil {
   380  		t.Fatalf("TLSClientConfig failed: %v", err)
   381  	}
   382  
   383  	assertTLSHandshakeFails(t, serverConfig, clientConfig)
   384  
   385  	serverConfig, err = vttls.ServerConfig(
   386  		clientServerKeyPairs.RevokedServerCert,
   387  		clientServerKeyPairs.RevokedServerKey,
   388  		clientServerKeyPairs.ClientCA,
   389  		clientServerKeyPairs.CombinedCRL,
   390  		"",
   391  		tls.VersionTLS12)
   392  	if err != nil {
   393  		t.Fatalf("TLSServerConfig failed: %v", err)
   394  	}
   395  
   396  	clientConfig, err = vttls.ClientConfig(
   397  		vttls.VerifyIdentity,
   398  		clientServerKeyPairs.ClientCert,
   399  		clientServerKeyPairs.ClientKey,
   400  		clientServerKeyPairs.ServerCA,
   401  		clientServerKeyPairs.CombinedCRL,
   402  		clientServerKeyPairs.RevokedServerName,
   403  		tls.VersionTLS12)
   404  	if err != nil {
   405  		t.Fatalf("TLSClientConfig failed: %v", err)
   406  	}
   407  
   408  	assertTLSHandshakeFails(t, serverConfig, clientConfig)
   409  }
   410  
   411  func TestClientServerWithRevokedClientCert(t *testing.T) {
   412  	root := t.TempDir()
   413  
   414  	clientServerKeyPairs := CreateClientServerCertPairs(root)
   415  
   416  	// Single CRL
   417  
   418  	serverConfig, err := vttls.ServerConfig(
   419  		clientServerKeyPairs.ServerCert,
   420  		clientServerKeyPairs.ServerKey,
   421  		clientServerKeyPairs.ClientCA,
   422  		clientServerKeyPairs.ClientCRL,
   423  		"",
   424  		tls.VersionTLS12)
   425  	if err != nil {
   426  		t.Fatalf("TLSServerConfig failed: %v", err)
   427  	}
   428  
   429  	clientConfig, err := vttls.ClientConfig(
   430  		vttls.VerifyIdentity,
   431  		clientServerKeyPairs.RevokedClientCert,
   432  		clientServerKeyPairs.RevokedClientKey,
   433  		clientServerKeyPairs.ServerCA,
   434  		clientServerKeyPairs.ServerCRL,
   435  		clientServerKeyPairs.ServerName,
   436  		tls.VersionTLS12)
   437  	if err != nil {
   438  		t.Fatalf("TLSClientConfig failed: %v", err)
   439  	}
   440  
   441  	assertTLSHandshakeFails(t, serverConfig, clientConfig)
   442  
   443  	// CombinedCRL
   444  
   445  	serverConfig, err = vttls.ServerConfig(
   446  		clientServerKeyPairs.ServerCert,
   447  		clientServerKeyPairs.ServerKey,
   448  		clientServerKeyPairs.ClientCA,
   449  		clientServerKeyPairs.CombinedCRL,
   450  		"",
   451  		tls.VersionTLS12)
   452  	if err != nil {
   453  		t.Fatalf("TLSServerConfig failed: %v", err)
   454  	}
   455  
   456  	clientConfig, err = vttls.ClientConfig(
   457  		vttls.VerifyIdentity,
   458  		clientServerKeyPairs.RevokedClientCert,
   459  		clientServerKeyPairs.RevokedClientKey,
   460  		clientServerKeyPairs.ServerCA,
   461  		clientServerKeyPairs.CombinedCRL,
   462  		clientServerKeyPairs.ServerName,
   463  		tls.VersionTLS12)
   464  	if err != nil {
   465  		t.Fatalf("TLSClientConfig failed: %v", err)
   466  	}
   467  
   468  	assertTLSHandshakeFails(t, serverConfig, clientConfig)
   469  }