github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/security/auth_test.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package security_test
    12  
    13  import (
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"crypto/x509/pkix"
    17  	"strings"
    18  	"testing"
    19  
    20  	"github.com/cockroachdb/cockroach/pkg/security"
    21  	"github.com/cockroachdb/cockroach/pkg/testutils"
    22  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  // Construct a fake tls.ConnectionState object. The spec is a semicolon
    27  // separated list if peer certificate specifications. Each peer certificate
    28  // specification is a comma separated list of names where the first name is the
    29  // CommonName and the remaining names are SubjectAlternateNames. For example,
    30  // "foo" creates a single peer certificate with the CommonName "foo". The spec
    31  // "foo,bar" creates a single peer certificate with the CommonName "foo" and a
    32  // single SubjectAlternateName "bar". Contrast that with "foo;bar" which
    33  // creates two peer certificates with the CommonNames "foo" and "bar"
    34  // respectively.
    35  func makeFakeTLSState(spec string) *tls.ConnectionState {
    36  	tls := &tls.ConnectionState{}
    37  	if spec != "" {
    38  		for _, peerSpec := range strings.Split(spec, ";") {
    39  			names := strings.Split(peerSpec, ",")
    40  			if len(names) == 0 {
    41  				continue
    42  			}
    43  			peerCert := &x509.Certificate{}
    44  			peerCert.Subject = pkix.Name{CommonName: names[0]}
    45  			peerCert.DNSNames = names[1:]
    46  			tls.PeerCertificates = append(tls.PeerCertificates, peerCert)
    47  		}
    48  	}
    49  	return tls
    50  }
    51  
    52  func TestGetCertificateUsers(t *testing.T) {
    53  	defer leaktest.AfterTest(t)()
    54  	// Nil TLS state.
    55  	if _, err := security.GetCertificateUsers(nil); err == nil {
    56  		t.Error("unexpected success")
    57  	}
    58  
    59  	// No certificates.
    60  	if _, err := security.GetCertificateUsers(makeFakeTLSState("")); err == nil {
    61  		t.Error("unexpected success")
    62  	}
    63  
    64  	// Good request: single certificate.
    65  	if names, err := security.GetCertificateUsers(makeFakeTLSState("foo")); err != nil {
    66  		t.Error(err)
    67  	} else {
    68  		require.EqualValues(t, names, []string{"foo"})
    69  	}
    70  
    71  	// Request with multiple certs, but only one chain (eg: origin certs are client and CA).
    72  	if names, err := security.GetCertificateUsers(makeFakeTLSState("foo;CA")); err != nil {
    73  		t.Error(err)
    74  	} else {
    75  		require.EqualValues(t, names, []string{"foo"})
    76  	}
    77  
    78  	// Always use the first certificate.
    79  	if names, err := security.GetCertificateUsers(makeFakeTLSState("foo;bar")); err != nil {
    80  		t.Error(err)
    81  	} else {
    82  		require.EqualValues(t, names, []string{"foo"})
    83  	}
    84  
    85  	// Extract all of the principals from the first certificate.
    86  	if names, err := security.GetCertificateUsers(makeFakeTLSState("foo,bar,blah;CA")); err != nil {
    87  		t.Error(err)
    88  	} else {
    89  		require.EqualValues(t, names, []string{"foo", "bar", "blah"})
    90  	}
    91  }
    92  
    93  func TestSetCertPrincipalMap(t *testing.T) {
    94  	defer leaktest.AfterTest(t)()
    95  	defer func() { _ = security.SetCertPrincipalMap(nil) }()
    96  
    97  	testCases := []struct {
    98  		vals     []string
    99  		expected string
   100  	}{
   101  		{[]string{}, ""},
   102  		{[]string{"foo"}, "invalid <cert-principal>:<db-principal> mapping:"},
   103  		{[]string{"foo:bar"}, ""},
   104  		{[]string{"foo:bar", "blah:blah"}, ""},
   105  	}
   106  	for _, c := range testCases {
   107  		t.Run("", func(t *testing.T) {
   108  			err := security.SetCertPrincipalMap(c.vals)
   109  			if !testutils.IsError(err, c.expected) {
   110  				t.Fatalf("expected %q, but found %v", c.expected, err)
   111  			}
   112  		})
   113  	}
   114  }
   115  
   116  func TestGetCertificateUsersMapped(t *testing.T) {
   117  	defer leaktest.AfterTest(t)()
   118  	defer func() { _ = security.SetCertPrincipalMap(nil) }()
   119  
   120  	testCases := []struct {
   121  		spec     string
   122  		val      string
   123  		expected string
   124  	}{
   125  		// No mapping present.
   126  		{"foo", "", "foo"},
   127  		// The basic mapping case.
   128  		{"foo", "foo:bar", "bar"},
   129  		// Identity mapping.
   130  		{"foo", "foo:foo", "foo"},
   131  		// Mapping does not apply to cert principals.
   132  		{"foo", "bar:bar", "foo"},
   133  		// The last mapping for a principal takes precedence.
   134  		{"foo", "foo:bar,foo:blah", "blah"},
   135  		// First principal mapped, second principal unmapped.
   136  		{"foo,bar", "foo:blah", "blah,bar"},
   137  		// First principal unmapped, second principal mapped.
   138  		{"bar,foo", "foo:blah", "bar,blah"},
   139  		// Both principals mapped.
   140  		{"foo,bar", "foo:bar,bar:foo", "bar,foo"},
   141  	}
   142  	for _, c := range testCases {
   143  		t.Run("", func(t *testing.T) {
   144  			vals := strings.Split(c.val, ",")
   145  			if err := security.SetCertPrincipalMap(vals); err != nil {
   146  				t.Fatal(err)
   147  			}
   148  			names, err := security.GetCertificateUsers(makeFakeTLSState(c.spec))
   149  			if err != nil {
   150  				t.Fatal(err)
   151  			}
   152  			require.EqualValues(t, strings.Join(names, ","), c.expected)
   153  		})
   154  	}
   155  }
   156  
   157  func TestAuthenticationHook(t *testing.T) {
   158  	defer leaktest.AfterTest(t)()
   159  	defer func() { _ = security.SetCertPrincipalMap(nil) }()
   160  
   161  	testCases := []struct {
   162  		insecure           bool
   163  		tlsSpec            string
   164  		username           string
   165  		principalMap       string
   166  		buildHookSuccess   bool
   167  		publicHookSuccess  bool
   168  		privateHookSuccess bool
   169  	}{
   170  		// Insecure mode, empty username.
   171  		{true, "", "", "", true, false, false},
   172  		// Insecure mode, non-empty username.
   173  		{true, "", "foo", "", true, true, false},
   174  		// Secure mode, no TLS state.
   175  		{false, "", "", "", false, false, false},
   176  		// Secure mode, bad user.
   177  		{false, "foo", "node", "", true, false, false},
   178  		// Secure mode, node user.
   179  		{false, security.NodeUser, "node", "", true, true, true},
   180  		// Secure mode, root user.
   181  		{false, security.RootUser, "node", "", true, false, false},
   182  		// Secure mode, multiple cert principals.
   183  		{false, "foo,bar", "foo", "", true, true, false},
   184  		{false, "foo,bar", "bar", "", true, true, false},
   185  		// Secure mode, principal map.
   186  		{false, "foo,bar", "blah", "foo:blah", true, true, false},
   187  		{false, "foo,bar", "blah", "bar:blah", true, true, false},
   188  	}
   189  
   190  	for _, tc := range testCases {
   191  		t.Run("", func(t *testing.T) {
   192  			err := security.SetCertPrincipalMap(strings.Split(tc.principalMap, ","))
   193  			if err != nil {
   194  				t.Fatal(err)
   195  			}
   196  			hook, err := security.UserAuthCertHook(tc.insecure, makeFakeTLSState(tc.tlsSpec))
   197  			if (err == nil) != tc.buildHookSuccess {
   198  				t.Fatalf("expected success=%t, got err=%v", tc.buildHookSuccess, err)
   199  			}
   200  			if err != nil {
   201  				return
   202  			}
   203  			_, err = hook(tc.username, true /* clientConnection */)
   204  			if (err == nil) != tc.publicHookSuccess {
   205  				t.Fatalf("expected success=%t, got err=%v", tc.publicHookSuccess, err)
   206  			}
   207  			_, err = hook(tc.username, false /* clientConnection */)
   208  			if (err == nil) != tc.privateHookSuccess {
   209  				t.Fatalf("expected success=%t, got err=%v", tc.privateHookSuccess, err)
   210  			}
   211  		})
   212  	}
   213  }