github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/security/certs_test.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package security_test
    12  
    13  import (
    14  	"context"
    15  	gosql "database/sql"
    16  	"fmt"
    17  	"io/ioutil"
    18  	"net/http"
    19  	"os"
    20  	"path/filepath"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/cockroachdb/cockroach/pkg/base"
    25  	"github.com/cockroachdb/cockroach/pkg/security"
    26  	"github.com/cockroachdb/cockroach/pkg/testutils"
    27  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    28  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    29  	"github.com/cockroachdb/errors"
    30  )
    31  
    32  const testKeySize = 1024
    33  
    34  func TestGenerateCACert(t *testing.T) {
    35  	defer leaktest.AfterTest(t)()
    36  	// Do not mock cert access for this test.
    37  	security.ResetAssetLoader()
    38  	defer ResetTest()
    39  
    40  	certsDir, err := ioutil.TempDir("", "certs_test")
    41  	if err != nil {
    42  		t.Fatal(err)
    43  	}
    44  	defer func() {
    45  		if err := os.RemoveAll(certsDir); err != nil {
    46  			t.Fatal(err)
    47  		}
    48  	}()
    49  
    50  	cm, err := security.NewCertificateManager(certsDir)
    51  	if err != nil {
    52  		t.Fatalf("unexpected error: %v", err)
    53  	}
    54  
    55  	keyPath := filepath.Join(certsDir, "ca.key")
    56  
    57  	testCases := []struct {
    58  		certsDir, caKey       string
    59  		allowReuse, overwrite bool
    60  		errStr                string // error string for CreateCAPair, empty for nil.
    61  		numCerts              int    // number of certificates found in ca.crt
    62  	}{
    63  		{"", "ca.key", false, false, "the path to the certs directory is required", 0},
    64  		{certsDir, "", false, false, "the path to the CA key is required", 0},
    65  		// New CA key/cert.
    66  		{certsDir, keyPath, false, false, "", 1},
    67  		// Files exist, but reuse is disabled.
    68  		{certsDir, keyPath, false, false, "exists, but key reuse is disabled", 2},
    69  		// Files exist, but overwrite is off.
    70  		{certsDir, keyPath, true, false, "file exists", 2},
    71  		// Files exist and reuse/overwrite is enabled.
    72  		{certsDir, keyPath, true, true, "", 2},
    73  		// Cert exists and overwrite is enabled.
    74  		{certsDir, keyPath + "2", false, true, "", 3}, // Using a new key still keeps the ca.crt
    75  	}
    76  
    77  	for i, tc := range testCases {
    78  		err := security.CreateCAPair(tc.certsDir, tc.caKey, testKeySize,
    79  			time.Hour*48, tc.allowReuse, tc.overwrite)
    80  		if !testutils.IsError(err, tc.errStr) {
    81  			t.Errorf("#%d: expected error %s but got %+v", i, tc.errStr, err)
    82  			continue
    83  		}
    84  
    85  		if err != nil {
    86  			continue
    87  		}
    88  
    89  		// No failures on CreateCAPair, we expect a valid CA cert.
    90  		err = cm.LoadCertificates()
    91  		if err != nil {
    92  			t.Fatalf("#%d: unexpected failure: %v", i, err)
    93  		}
    94  
    95  		ci := cm.CACert()
    96  		if ci == nil {
    97  			t.Fatalf("#%d: no CA cert found", i)
    98  		}
    99  
   100  		certs, err := security.PEMToCertificates(ci.FileContents)
   101  		if err != nil {
   102  			t.Fatalf("#%d: unexpected parsing error for %+v: %v", i, ci, err)
   103  		}
   104  
   105  		if actual := len(certs); actual != tc.numCerts {
   106  			t.Errorf("#%d: expected %d certificates, found %d", i, tc.numCerts, actual)
   107  		}
   108  	}
   109  }
   110  
   111  func TestGenerateNodeCerts(t *testing.T) {
   112  	defer leaktest.AfterTest(t)()
   113  	// Do not mock cert access for this test.
   114  	security.ResetAssetLoader()
   115  	defer ResetTest()
   116  
   117  	certsDir, err := ioutil.TempDir("", "certs_test")
   118  	if err != nil {
   119  		t.Fatal(err)
   120  	}
   121  	defer func() {
   122  		if err := os.RemoveAll(certsDir); err != nil {
   123  			t.Fatal(err)
   124  		}
   125  	}()
   126  
   127  	// Try generating node certs without CA certs present.
   128  	if err := security.CreateNodePair(
   129  		certsDir, filepath.Join(certsDir, security.EmbeddedCAKey),
   130  		testKeySize, time.Hour*48, false, []string{"localhost"},
   131  	); err == nil {
   132  		t.Fatalf("Expected error, but got none")
   133  	}
   134  
   135  	// Now try in the proper order.
   136  	if err := security.CreateCAPair(
   137  		certsDir, filepath.Join(certsDir, security.EmbeddedCAKey), testKeySize, time.Hour*96, false, false,
   138  	); err != nil {
   139  		t.Fatalf("Expected success, got %v", err)
   140  	}
   141  
   142  	if err := security.CreateNodePair(
   143  		certsDir, filepath.Join(certsDir, security.EmbeddedCAKey),
   144  		testKeySize, time.Hour*48, false, []string{"localhost"},
   145  	); err != nil {
   146  		t.Fatalf("Expected success, got %v", err)
   147  	}
   148  }
   149  
   150  // Generate basic certs:
   151  // ca.crt: CA certificate
   152  // node.crt: dual-purpose node certificate
   153  // client.root.crt: client certificate for the root user.
   154  func generateBaseCerts(certsDir string) error {
   155  	if err := security.CreateCAPair(
   156  		certsDir, filepath.Join(certsDir, security.EmbeddedCAKey),
   157  		testKeySize, time.Hour*96, true, true,
   158  	); err != nil {
   159  		return errors.Errorf("could not generate CA pair: %v", err)
   160  	}
   161  
   162  	if err := security.CreateNodePair(
   163  		certsDir, filepath.Join(certsDir, security.EmbeddedCAKey),
   164  		testKeySize, time.Hour*48, true, []string{"127.0.0.1"},
   165  	); err != nil {
   166  		return errors.Errorf("could not generate Node pair: %v", err)
   167  	}
   168  
   169  	if err := security.CreateClientPair(
   170  		certsDir, filepath.Join(certsDir, security.EmbeddedCAKey),
   171  		testKeySize, time.Hour*48, true, security.RootUser, false,
   172  	); err != nil {
   173  		return errors.Errorf("could not generate Client pair: %v", err)
   174  	}
   175  
   176  	return nil
   177  }
   178  
   179  // Generate certificates with separate CAs:
   180  // ca.crt: CA certificate
   181  // ca-client.crt: CA certificate to verify client certs
   182  // node.crt: node server cert: signed by ca.crt
   183  // client.node.crt: node client cert: signed by ca-client.crt
   184  // client.root.crt: root client cert: signed by ca-client.crt
   185  func generateSplitCACerts(certsDir string) error {
   186  	if err := security.CreateCAPair(
   187  		certsDir, filepath.Join(certsDir, security.EmbeddedCAKey),
   188  		testKeySize, time.Hour*96, true, true,
   189  	); err != nil {
   190  		return errors.Errorf("could not generate CA pair: %v", err)
   191  	}
   192  
   193  	if err := security.CreateNodePair(
   194  		certsDir, filepath.Join(certsDir, security.EmbeddedCAKey),
   195  		testKeySize, time.Hour*48, true, []string{"127.0.0.1"},
   196  	); err != nil {
   197  		return errors.Errorf("could not generate Node pair: %v", err)
   198  	}
   199  
   200  	if err := security.CreateClientCAPair(
   201  		certsDir, filepath.Join(certsDir, security.EmbeddedClientCAKey),
   202  		testKeySize, time.Hour*96, true, true,
   203  	); err != nil {
   204  		return errors.Errorf("could not generate client CA pair: %v", err)
   205  	}
   206  
   207  	if err := security.CreateClientPair(
   208  		certsDir, filepath.Join(certsDir, security.EmbeddedClientCAKey),
   209  		testKeySize, time.Hour*48, true, security.NodeUser, false,
   210  	); err != nil {
   211  		return errors.Errorf("could not generate Client pair: %v", err)
   212  	}
   213  
   214  	if err := security.CreateClientPair(
   215  		certsDir, filepath.Join(certsDir, security.EmbeddedClientCAKey),
   216  		testKeySize, time.Hour*48, true, security.RootUser, false,
   217  	); err != nil {
   218  		return errors.Errorf("could not generate Client pair: %v", err)
   219  	}
   220  
   221  	if err := security.CreateUICAPair(
   222  		certsDir, filepath.Join(certsDir, security.EmbeddedUICAKey),
   223  		testKeySize, time.Hour*96, true, true,
   224  	); err != nil {
   225  		return errors.Errorf("could not generate UI CA pair: %v", err)
   226  	}
   227  
   228  	if err := security.CreateUIPair(
   229  		certsDir, filepath.Join(certsDir, security.EmbeddedUICAKey),
   230  		testKeySize, time.Hour*48, true, []string{"127.0.0.1"},
   231  	); err != nil {
   232  		return errors.Errorf("could not generate UI pair: %v", err)
   233  	}
   234  
   235  	return nil
   236  }
   237  
   238  // This is a fairly high-level test of CA and node certificates.
   239  // We construct SSL server and clients and use the generated certs.
   240  func TestUseCerts(t *testing.T) {
   241  	defer leaktest.AfterTest(t)()
   242  	// Do not mock cert access for this test.
   243  	security.ResetAssetLoader()
   244  	defer ResetTest()
   245  	certsDir, err := ioutil.TempDir("", "certs_test")
   246  	if err != nil {
   247  		t.Fatal(err)
   248  	}
   249  	defer func() {
   250  		if err := os.RemoveAll(certsDir); err != nil {
   251  			t.Fatal(err)
   252  		}
   253  	}()
   254  
   255  	if err := generateBaseCerts(certsDir); err != nil {
   256  		t.Fatal(err)
   257  	}
   258  
   259  	// Load TLS Configs. This is what TestServer and HTTPClient do internally.
   260  	if _, err := security.LoadServerTLSConfig(
   261  		filepath.Join(certsDir, security.EmbeddedCACert),
   262  		filepath.Join(certsDir, security.EmbeddedCACert),
   263  		filepath.Join(certsDir, security.EmbeddedNodeCert),
   264  		filepath.Join(certsDir, security.EmbeddedNodeKey),
   265  	); err != nil {
   266  		t.Fatalf("Expected success, got %v", err)
   267  	}
   268  	if _, err := security.LoadClientTLSConfig(
   269  		filepath.Join(certsDir, security.EmbeddedCACert),
   270  		filepath.Join(certsDir, security.EmbeddedNodeCert),
   271  		filepath.Join(certsDir, security.EmbeddedNodeKey),
   272  	); err != nil {
   273  		t.Fatalf("Expected success, got %v", err)
   274  	}
   275  
   276  	// Start a test server and override certs.
   277  	// We use a real context since we want generated certs.
   278  	// Web session authentication is disabled in order to avoid the need to
   279  	// authenticate the individual clients being instantiated (session auth has
   280  	// no effect on what is being tested here).
   281  	params := base.TestServerArgs{
   282  		SSLCertsDir:                     certsDir,
   283  		DisableWebSessionAuthentication: true,
   284  	}
   285  	s, _, db := serverutils.StartServer(t, params)
   286  	defer s.Stopper().Stop(context.Background())
   287  
   288  	// Insecure mode.
   289  	clientContext := testutils.NewNodeTestBaseContext()
   290  	clientContext.Insecure = true
   291  	httpClient, err := clientContext.GetHTTPClient()
   292  	if err != nil {
   293  		t.Fatal(err)
   294  	}
   295  	req, err := http.NewRequest("GET", s.AdminURL()+"/_status/metrics/local", nil)
   296  	if err != nil {
   297  		t.Fatalf("could not create request: %v", err)
   298  	}
   299  	resp, err := httpClient.Do(req)
   300  	if err == nil {
   301  		defer resp.Body.Close()
   302  		body, _ := ioutil.ReadAll(resp.Body)
   303  		t.Fatalf("Expected SSL error, got success: %s", body)
   304  	}
   305  
   306  	// New client. With certs this time.
   307  	clientContext = testutils.NewNodeTestBaseContext()
   308  	clientContext.SSLCertsDir = certsDir
   309  	httpClient, err = clientContext.GetHTTPClient()
   310  	if err != nil {
   311  		t.Fatalf("Expected success, got %v", err)
   312  	}
   313  	req, err = http.NewRequest("GET", s.AdminURL()+"/_status/metrics/local", nil)
   314  	if err != nil {
   315  		t.Fatalf("could not create request: %v", err)
   316  	}
   317  	resp, err = httpClient.Do(req)
   318  	if err != nil {
   319  		t.Fatalf("Expected success, got %v", err)
   320  	}
   321  	defer resp.Body.Close()
   322  	if resp.StatusCode != http.StatusOK {
   323  		body, _ := ioutil.ReadAll(resp.Body)
   324  		t.Fatalf("Expected OK, got %q with body: %s", resp.Status, body)
   325  	}
   326  
   327  	// Check KV connection.
   328  	if err := db.Put(context.Background(), "foo", "bar"); err != nil {
   329  		t.Error(err)
   330  	}
   331  }
   332  
   333  func makeSecurePGUrl(addr, user, certsDir, caName, certName, keyName string) string {
   334  	return fmt.Sprintf("postgresql://%s@%s/?sslmode=verify-full&sslrootcert=%s&sslcert=%s&sslkey=%s",
   335  		user, addr,
   336  		filepath.Join(certsDir, caName),
   337  		filepath.Join(certsDir, certName),
   338  		filepath.Join(certsDir, keyName))
   339  }
   340  
   341  // This is a fairly high-level test of CA and node certificates.
   342  // We construct SSL server and clients and use the generated certs.
   343  func TestUseSplitCACerts(t *testing.T) {
   344  	defer leaktest.AfterTest(t)()
   345  	// Do not mock cert access for this test.
   346  	security.ResetAssetLoader()
   347  	defer ResetTest()
   348  	certsDir, err := ioutil.TempDir("", "certs_test")
   349  	if err != nil {
   350  		t.Fatal(err)
   351  	}
   352  	defer func() {
   353  		if err := os.RemoveAll(certsDir); err != nil {
   354  			t.Fatal(err)
   355  		}
   356  	}()
   357  
   358  	if err := generateSplitCACerts(certsDir); err != nil {
   359  		t.Fatal(err)
   360  	}
   361  
   362  	// Start a test server and override certs.
   363  	// We use a real context since we want generated certs.
   364  	// Web session authentication is disabled in order to avoid the need to
   365  	// authenticate the individual clients being instantiated (session auth has
   366  	// no effect on what is being tested here).
   367  	params := base.TestServerArgs{
   368  		SSLCertsDir:                     certsDir,
   369  		DisableWebSessionAuthentication: true,
   370  	}
   371  	s, _, db := serverutils.StartServer(t, params)
   372  	defer s.Stopper().Stop(context.Background())
   373  
   374  	// Insecure mode.
   375  	clientContext := testutils.NewNodeTestBaseContext()
   376  	clientContext.Insecure = true
   377  	httpClient, err := clientContext.GetHTTPClient()
   378  	if err != nil {
   379  		t.Fatal(err)
   380  	}
   381  	req, err := http.NewRequest("GET", s.AdminURL()+"/_status/metrics/local", nil)
   382  	if err != nil {
   383  		t.Fatalf("could not create request: %v", err)
   384  	}
   385  	resp, err := httpClient.Do(req)
   386  	if err == nil {
   387  		defer resp.Body.Close()
   388  		body, _ := ioutil.ReadAll(resp.Body)
   389  		t.Fatalf("Expected SSL error, got success: %s", body)
   390  	}
   391  
   392  	// New client. With certs this time.
   393  	clientContext = testutils.NewNodeTestBaseContext()
   394  	clientContext.SSLCertsDir = certsDir
   395  	httpClient, err = clientContext.GetHTTPClient()
   396  	if err != nil {
   397  		t.Fatalf("Expected success, got %v", err)
   398  	}
   399  	req, err = http.NewRequest("GET", s.AdminURL()+"/_status/metrics/local", nil)
   400  	if err != nil {
   401  		t.Fatalf("could not create request: %v", err)
   402  	}
   403  	resp, err = httpClient.Do(req)
   404  	if err != nil {
   405  		t.Fatalf("Expected success, got %v", err)
   406  	}
   407  	defer resp.Body.Close()
   408  	if resp.StatusCode != http.StatusOK {
   409  		body, _ := ioutil.ReadAll(resp.Body)
   410  		t.Fatalf("Expected OK, got %q with body: %s", resp.Status, body)
   411  	}
   412  
   413  	// Check KV connection.
   414  	if err := db.Put(context.Background(), "foo", "bar"); err != nil {
   415  		t.Error(err)
   416  	}
   417  
   418  	// Test a SQL client with various certificates.
   419  	testCases := []struct {
   420  		user, caName, certPrefix string
   421  		expectedError            string
   422  	}{
   423  		// Success, but "node" is not a sql user.
   424  		{"node", security.EmbeddedCACert, "client.node", "pq: password authentication failed for user node"},
   425  		// Success!
   426  		{"root", security.EmbeddedCACert, "client.root", ""},
   427  		// Bad server CA: can't verify server certificate.
   428  		{"root", security.EmbeddedClientCACert, "client.root", "certificate signed by unknown authority"},
   429  		// Bad client cert: we're using the node cert but it's not signed by the client CA.
   430  		{"node", security.EmbeddedCACert, "node", "tls: bad certificate"},
   431  		// We can't verify the node certificate using the UI cert.
   432  		{"node", security.EmbeddedUICACert, "node", "certificate signed by unknown authority"},
   433  		// And the SQL server doesn't know what the ui.crt is.
   434  		{"node", security.EmbeddedCACert, "ui", "tls: bad certificate"},
   435  	}
   436  
   437  	for i, tc := range testCases {
   438  		pgUrl := makeSecurePGUrl(s.ServingSQLAddr(), tc.user, certsDir, tc.caName, tc.certPrefix+".crt", tc.certPrefix+".key")
   439  		goDB, err := gosql.Open("postgres", pgUrl)
   440  		if err != nil {
   441  			t.Fatal(err)
   442  		}
   443  		defer goDB.Close()
   444  
   445  		_, err = goDB.Exec("SELECT 1")
   446  		if !testutils.IsError(err, tc.expectedError) {
   447  			t.Errorf("#%d: expected error %v, got %v", i, tc.expectedError, err)
   448  		}
   449  	}
   450  }
   451  
   452  // This is a fairly high-level test of CA and node certificates.
   453  // We construct SSL server and clients and use the generated certs.
   454  func TestUseWrongSplitCACerts(t *testing.T) {
   455  	defer leaktest.AfterTest(t)()
   456  	// Do not mock cert access for this test.
   457  	security.ResetAssetLoader()
   458  	defer ResetTest()
   459  	certsDir, err := ioutil.TempDir("", "certs_test")
   460  	if err != nil {
   461  		t.Fatal(err)
   462  	}
   463  	defer func() {
   464  		if err := os.RemoveAll(certsDir); err != nil {
   465  			t.Fatal(err)
   466  		}
   467  	}()
   468  
   469  	if err := generateSplitCACerts(certsDir); err != nil {
   470  		t.Fatal(err)
   471  	}
   472  
   473  	// Delete ca-client.crt and ca-ui.crt before starting the node.
   474  	// This will make the server fall back on using ca.crt.
   475  	if err := os.Remove(filepath.Join(certsDir, "ca-client.crt")); err != nil {
   476  		t.Fatal(err)
   477  	}
   478  	if err := os.Remove(filepath.Join(certsDir, "ca-ui.crt")); err != nil {
   479  		t.Fatal(err)
   480  	}
   481  
   482  	// Start a test server and override certs.
   483  	// We use a real context since we want generated certs.
   484  	// Web session authentication is disabled in order to avoid the need to
   485  	// authenticate the individual clients being instantiated (session auth has
   486  	// no effect on what is being tested here).
   487  	params := base.TestServerArgs{
   488  		SSLCertsDir:                     certsDir,
   489  		DisableWebSessionAuthentication: true,
   490  	}
   491  	s, _, db := serverutils.StartServer(t, params)
   492  	defer s.Stopper().Stop(context.Background())
   493  
   494  	// Insecure mode.
   495  	clientContext := testutils.NewNodeTestBaseContext()
   496  	clientContext.Insecure = true
   497  	httpClient, err := clientContext.GetHTTPClient()
   498  	if err != nil {
   499  		t.Fatal(err)
   500  	}
   501  	req, err := http.NewRequest("GET", s.AdminURL()+"/_status/metrics/local", nil)
   502  	if err != nil {
   503  		t.Fatalf("could not create request: %v", err)
   504  	}
   505  	resp, err := httpClient.Do(req)
   506  	if err == nil {
   507  		defer resp.Body.Close()
   508  		body, _ := ioutil.ReadAll(resp.Body)
   509  		t.Fatalf("Expected SSL error, got success: %s", body)
   510  	}
   511  
   512  	// New client with certs, but the UI CA is gone, we have no way to verify the Admin UI cert.
   513  	clientContext = testutils.NewNodeTestBaseContext()
   514  	clientContext.SSLCertsDir = certsDir
   515  	httpClient, err = clientContext.GetHTTPClient()
   516  	if err != nil {
   517  		t.Fatalf("Expected success, got %v", err)
   518  	}
   519  	req, err = http.NewRequest("GET", s.AdminURL()+"/_status/metrics/local", nil)
   520  	if err != nil {
   521  		t.Fatalf("could not create request: %v", err)
   522  	}
   523  
   524  	_, err = httpClient.Do(req)
   525  	if expected := "certificate signed by unknown authority"; !testutils.IsError(err, expected) {
   526  		t.Fatalf("Expected error %q, got %v", expected, err)
   527  	}
   528  
   529  	// Check KV connection.
   530  	if err := db.Put(context.Background(), "foo", "bar"); err != nil {
   531  		t.Error(err)
   532  	}
   533  
   534  	// Try with various certificates.
   535  	testCases := []struct {
   536  		user, caName, certPrefix string
   537  		expectedError            string
   538  	}{
   539  		// Certificate signed by wrong client CA.
   540  		{"root", security.EmbeddedCACert, "client.root", "tls: bad certificate"},
   541  		// Success! The node certificate still contains "CN=node" and is signed by ca.crt.
   542  		{"node", security.EmbeddedCACert, "node", "pq: password authentication failed for user node"},
   543  	}
   544  
   545  	for i, tc := range testCases {
   546  		pgUrl := makeSecurePGUrl(s.ServingSQLAddr(), tc.user, certsDir, tc.caName, tc.certPrefix+".crt", tc.certPrefix+".key")
   547  		goDB, err := gosql.Open("postgres", pgUrl)
   548  		if err != nil {
   549  			t.Fatal(err)
   550  		}
   551  		defer goDB.Close()
   552  
   553  		_, err = goDB.Exec("SELECT 1")
   554  		if !testutils.IsError(err, tc.expectedError) {
   555  			t.Errorf("#%d: expected error %v, got %v", i, tc.expectedError, err)
   556  		}
   557  	}
   558  }