github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/ca/server_test.go (about)

     1  package ca_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/cloudflare/cfssl/helpers"
    17  	"github.com/docker/swarmkit/api"
    18  	"github.com/docker/swarmkit/api/equality"
    19  	"github.com/docker/swarmkit/ca"
    20  	cautils "github.com/docker/swarmkit/ca/testutils"
    21  	"github.com/docker/swarmkit/log"
    22  	"github.com/docker/swarmkit/manager/state/store"
    23  	"github.com/docker/swarmkit/testutils"
    24  	"github.com/opencontainers/go-digest"
    25  	"github.com/pkg/errors"
    26  	"github.com/sirupsen/logrus"
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"google.golang.org/grpc/codes"
    30  )
    31  
    32  var _ api.CAServer = &ca.Server{}
    33  var _ api.NodeCAServer = &ca.Server{}
    34  
    35  func TestGetRootCACertificate(t *testing.T) {
    36  	tc := cautils.NewTestCA(t)
    37  	defer tc.Stop()
    38  
    39  	resp, err := tc.CAClients[0].GetRootCACertificate(tc.Context, &api.GetRootCACertificateRequest{})
    40  	assert.NoError(t, err)
    41  	assert.NotEmpty(t, resp.Certificate)
    42  }
    43  
    44  func TestRestartRootCA(t *testing.T) {
    45  	tc := cautils.NewTestCA(t)
    46  	defer tc.Stop()
    47  
    48  	_, err := tc.NodeCAClients[0].NodeCertificateStatus(tc.Context, &api.NodeCertificateStatusRequest{NodeID: "foo"})
    49  	assert.Error(t, err)
    50  	assert.Equal(t, codes.NotFound, testutils.ErrorCode(err))
    51  
    52  	tc.CAServer.Stop()
    53  	go tc.CAServer.Run(tc.Context)
    54  
    55  	<-tc.CAServer.Ready()
    56  
    57  	_, err = tc.NodeCAClients[0].NodeCertificateStatus(tc.Context, &api.NodeCertificateStatusRequest{NodeID: "foo"})
    58  	assert.Error(t, err)
    59  	assert.Equal(t, codes.NotFound, testutils.ErrorCode(err))
    60  }
    61  
    62  func TestIssueNodeCertificate(t *testing.T) {
    63  	tc := cautils.NewTestCA(t)
    64  	defer tc.Stop()
    65  
    66  	csr, _, err := ca.GenerateNewCSR()
    67  	assert.NoError(t, err)
    68  
    69  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr, Token: tc.WorkerToken}
    70  	issueResponse, err := tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
    71  	assert.NoError(t, err)
    72  	assert.NotNil(t, issueResponse.NodeID)
    73  	assert.Equal(t, api.NodeMembershipAccepted, issueResponse.NodeMembership)
    74  
    75  	statusRequest := &api.NodeCertificateStatusRequest{NodeID: issueResponse.NodeID}
    76  	statusResponse, err := tc.NodeCAClients[0].NodeCertificateStatus(tc.Context, statusRequest)
    77  	require.NoError(t, err)
    78  	assert.Equal(t, api.IssuanceStateIssued, statusResponse.Status.State)
    79  	assert.NotNil(t, statusResponse.Certificate.Certificate)
    80  	assert.Equal(t, api.NodeRoleWorker, statusResponse.Certificate.Role)
    81  }
    82  
    83  func TestForceRotationIsNoop(t *testing.T) {
    84  	tc := cautils.NewTestCA(t)
    85  	defer tc.Stop()
    86  
    87  	// Get a new Certificate issued
    88  	csr, _, err := ca.GenerateNewCSR()
    89  	assert.NoError(t, err)
    90  
    91  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr, Token: tc.WorkerToken}
    92  	issueResponse, err := tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
    93  	assert.NoError(t, err)
    94  	assert.NotNil(t, issueResponse.NodeID)
    95  	assert.Equal(t, api.NodeMembershipAccepted, issueResponse.NodeMembership)
    96  
    97  	// Check that the Certificate is successfully issued
    98  	statusRequest := &api.NodeCertificateStatusRequest{NodeID: issueResponse.NodeID}
    99  	statusResponse, err := tc.NodeCAClients[0].NodeCertificateStatus(tc.Context, statusRequest)
   100  	require.NoError(t, err)
   101  	assert.Equal(t, api.IssuanceStateIssued, statusResponse.Status.State)
   102  	assert.NotNil(t, statusResponse.Certificate.Certificate)
   103  	assert.Equal(t, api.NodeRoleWorker, statusResponse.Certificate.Role)
   104  
   105  	// Update the certificate status to IssuanceStateRotate which should be a server-side noop
   106  	err = tc.MemoryStore.Update(func(tx store.Tx) error {
   107  		// Attempt to retrieve the node with nodeID
   108  		node := store.GetNode(tx, issueResponse.NodeID)
   109  		assert.NotNil(t, node)
   110  
   111  		node.Certificate.Status.State = api.IssuanceStateRotate
   112  		return store.UpdateNode(tx, node)
   113  	})
   114  	assert.NoError(t, err)
   115  
   116  	// Wait a bit and check that the certificate hasn't changed/been reissued
   117  	time.Sleep(250 * time.Millisecond)
   118  
   119  	statusNewResponse, err := tc.NodeCAClients[0].NodeCertificateStatus(tc.Context, statusRequest)
   120  	require.NoError(t, err)
   121  	assert.Equal(t, statusResponse.Certificate.Certificate, statusNewResponse.Certificate.Certificate)
   122  	assert.Equal(t, api.IssuanceStateRotate, statusNewResponse.Certificate.Status.State)
   123  	assert.Equal(t, api.NodeRoleWorker, statusNewResponse.Certificate.Role)
   124  }
   125  
   126  func TestIssueNodeCertificateBrokenCA(t *testing.T) {
   127  	if !cautils.External {
   128  		t.Skip("test only applicable for external CA configuration")
   129  	}
   130  
   131  	tc := cautils.NewTestCA(t)
   132  	defer tc.Stop()
   133  
   134  	csr, _, err := ca.GenerateNewCSR()
   135  	assert.NoError(t, err)
   136  
   137  	tc.ExternalSigningServer.Flake()
   138  
   139  	go func() {
   140  		time.Sleep(250 * time.Millisecond)
   141  		tc.ExternalSigningServer.Deflake()
   142  	}()
   143  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr, Token: tc.WorkerToken}
   144  	issueResponse, err := tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   145  	assert.NoError(t, err)
   146  	assert.NotNil(t, issueResponse.NodeID)
   147  	assert.Equal(t, api.NodeMembershipAccepted, issueResponse.NodeMembership)
   148  
   149  	statusRequest := &api.NodeCertificateStatusRequest{NodeID: issueResponse.NodeID}
   150  	statusResponse, err := tc.NodeCAClients[0].NodeCertificateStatus(tc.Context, statusRequest)
   151  	require.NoError(t, err)
   152  	assert.Equal(t, api.IssuanceStateIssued, statusResponse.Status.State)
   153  	assert.NotNil(t, statusResponse.Certificate.Certificate)
   154  	assert.Equal(t, api.NodeRoleWorker, statusResponse.Certificate.Role)
   155  
   156  }
   157  
   158  func TestIssueNodeCertificateWithInvalidCSR(t *testing.T) {
   159  	tc := cautils.NewTestCA(t)
   160  	defer tc.Stop()
   161  
   162  	issueRequest := &api.IssueNodeCertificateRequest{CSR: []byte("random garbage"), Token: tc.WorkerToken}
   163  	issueResponse, err := tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   164  	assert.NoError(t, err)
   165  	assert.NotNil(t, issueResponse.NodeID)
   166  	assert.Equal(t, api.NodeMembershipAccepted, issueResponse.NodeMembership)
   167  
   168  	statusRequest := &api.NodeCertificateStatusRequest{NodeID: issueResponse.NodeID}
   169  	statusResponse, err := tc.NodeCAClients[0].NodeCertificateStatus(tc.Context, statusRequest)
   170  	require.NoError(t, err)
   171  	assert.Equal(t, api.IssuanceStateFailed, statusResponse.Status.State)
   172  	assert.Contains(t, statusResponse.Status.Err, "CSR Decode failed")
   173  	assert.Nil(t, statusResponse.Certificate.Certificate)
   174  }
   175  
   176  func TestIssueNodeCertificateWorkerRenewal(t *testing.T) {
   177  	tc := cautils.NewTestCA(t)
   178  	defer tc.Stop()
   179  
   180  	csr, _, err := ca.GenerateNewCSR()
   181  	assert.NoError(t, err)
   182  
   183  	role := api.NodeRoleWorker
   184  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr, Role: role}
   185  	issueResponse, err := tc.NodeCAClients[1].IssueNodeCertificate(tc.Context, issueRequest)
   186  	assert.NoError(t, err)
   187  	assert.NotNil(t, issueResponse.NodeID)
   188  	assert.Equal(t, api.NodeMembershipAccepted, issueResponse.NodeMembership)
   189  
   190  	statusRequest := &api.NodeCertificateStatusRequest{NodeID: issueResponse.NodeID}
   191  	statusResponse, err := tc.NodeCAClients[1].NodeCertificateStatus(tc.Context, statusRequest)
   192  	require.NoError(t, err)
   193  	assert.Equal(t, api.IssuanceStateIssued, statusResponse.Status.State)
   194  	assert.NotNil(t, statusResponse.Certificate.Certificate)
   195  	assert.Equal(t, role, statusResponse.Certificate.Role)
   196  }
   197  
   198  func TestIssueNodeCertificateManagerRenewal(t *testing.T) {
   199  	tc := cautils.NewTestCA(t)
   200  	defer tc.Stop()
   201  
   202  	csr, _, err := ca.GenerateNewCSR()
   203  	assert.NoError(t, err)
   204  	assert.NotNil(t, csr)
   205  
   206  	role := api.NodeRoleManager
   207  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr, Role: role}
   208  	issueResponse, err := tc.NodeCAClients[2].IssueNodeCertificate(tc.Context, issueRequest)
   209  	require.NoError(t, err)
   210  	assert.NotNil(t, issueResponse.NodeID)
   211  	assert.Equal(t, api.NodeMembershipAccepted, issueResponse.NodeMembership)
   212  
   213  	statusRequest := &api.NodeCertificateStatusRequest{NodeID: issueResponse.NodeID}
   214  	statusResponse, err := tc.NodeCAClients[2].NodeCertificateStatus(tc.Context, statusRequest)
   215  	require.NoError(t, err)
   216  	assert.Equal(t, api.IssuanceStateIssued, statusResponse.Status.State)
   217  	assert.NotNil(t, statusResponse.Certificate.Certificate)
   218  	assert.Equal(t, role, statusResponse.Certificate.Role)
   219  }
   220  
   221  func TestIssueNodeCertificateWorkerFromDifferentOrgRenewal(t *testing.T) {
   222  	tc := cautils.NewTestCA(t)
   223  	defer tc.Stop()
   224  
   225  	csr, _, err := ca.GenerateNewCSR()
   226  	assert.NoError(t, err)
   227  
   228  	// Since we're using a client that has a different Organization, this request will be treated
   229  	// as a new certificate request, not allowing auto-renewal. Therefore, the request will fail.
   230  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr}
   231  	_, err = tc.NodeCAClients[3].IssueNodeCertificate(tc.Context, issueRequest)
   232  	assert.Error(t, err)
   233  }
   234  
   235  func TestNodeCertificateRenewalsDoNotRequireToken(t *testing.T) {
   236  	tc := cautils.NewTestCA(t)
   237  	defer tc.Stop()
   238  
   239  	csr, _, err := ca.GenerateNewCSR()
   240  	assert.NoError(t, err)
   241  
   242  	role := api.NodeRoleManager
   243  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr, Role: role}
   244  	issueResponse, err := tc.NodeCAClients[2].IssueNodeCertificate(tc.Context, issueRequest)
   245  	assert.NoError(t, err)
   246  	assert.NotNil(t, issueResponse.NodeID)
   247  	assert.Equal(t, api.NodeMembershipAccepted, issueResponse.NodeMembership)
   248  
   249  	statusRequest := &api.NodeCertificateStatusRequest{NodeID: issueResponse.NodeID}
   250  	statusResponse, err := tc.NodeCAClients[2].NodeCertificateStatus(tc.Context, statusRequest)
   251  	assert.NoError(t, err)
   252  	assert.Equal(t, api.IssuanceStateIssued, statusResponse.Status.State)
   253  	assert.NotNil(t, statusResponse.Certificate.Certificate)
   254  	assert.Equal(t, role, statusResponse.Certificate.Role)
   255  
   256  	role = api.NodeRoleWorker
   257  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role}
   258  	issueResponse, err = tc.NodeCAClients[1].IssueNodeCertificate(tc.Context, issueRequest)
   259  	require.NoError(t, err)
   260  	assert.NotNil(t, issueResponse.NodeID)
   261  	assert.Equal(t, api.NodeMembershipAccepted, issueResponse.NodeMembership)
   262  
   263  	statusRequest = &api.NodeCertificateStatusRequest{NodeID: issueResponse.NodeID}
   264  	statusResponse, err = tc.NodeCAClients[2].NodeCertificateStatus(tc.Context, statusRequest)
   265  	require.NoError(t, err)
   266  	assert.Equal(t, api.IssuanceStateIssued, statusResponse.Status.State)
   267  	assert.NotNil(t, statusResponse.Certificate.Certificate)
   268  	assert.Equal(t, role, statusResponse.Certificate.Role)
   269  }
   270  
   271  func TestNewNodeCertificateRequiresToken(t *testing.T) {
   272  	t.Parallel()
   273  
   274  	tc := cautils.NewTestCA(t)
   275  	defer tc.Stop()
   276  
   277  	csr, _, err := ca.GenerateNewCSR()
   278  	assert.NoError(t, err)
   279  
   280  	// Issuance fails if no secret is provided
   281  	role := api.NodeRoleManager
   282  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr, Role: role}
   283  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   284  	assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = A valid join token is necessary to join this cluster")
   285  
   286  	role = api.NodeRoleWorker
   287  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role}
   288  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   289  	assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = A valid join token is necessary to join this cluster")
   290  
   291  	// Issuance fails if wrong secret is provided
   292  	role = api.NodeRoleManager
   293  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: "invalid-secret"}
   294  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   295  	assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = A valid join token is necessary to join this cluster")
   296  
   297  	role = api.NodeRoleWorker
   298  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: "invalid-secret"}
   299  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   300  	assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = A valid join token is necessary to join this cluster")
   301  
   302  	// Issuance succeeds if correct token is provided
   303  	role = api.NodeRoleManager
   304  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: tc.ManagerToken}
   305  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   306  	assert.NoError(t, err)
   307  
   308  	role = api.NodeRoleWorker
   309  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: tc.WorkerToken}
   310  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   311  	assert.NoError(t, err)
   312  
   313  	// Rotate manager and worker tokens
   314  	var (
   315  		newManagerToken string
   316  		newWorkerToken  string
   317  	)
   318  	assert.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error {
   319  		clusters, _ := store.FindClusters(tx, store.ByName(store.DefaultClusterName))
   320  		newWorkerToken = ca.GenerateJoinToken(&tc.RootCA, false)
   321  		clusters[0].RootCA.JoinTokens.Worker = newWorkerToken
   322  		newManagerToken = ca.GenerateJoinToken(&tc.RootCA, false)
   323  		clusters[0].RootCA.JoinTokens.Manager = newManagerToken
   324  		return store.UpdateCluster(tx, clusters[0])
   325  	}))
   326  
   327  	// updating the join token may take a little bit in order to register on the CA server, so poll
   328  	assert.NoError(t, testutils.PollFunc(nil, func() error {
   329  		// Old token should fail
   330  		role = api.NodeRoleManager
   331  		issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: tc.ManagerToken}
   332  		_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   333  		if err == nil {
   334  			return fmt.Errorf("join token not updated yet")
   335  		}
   336  		return nil
   337  	}))
   338  
   339  	// Old token should fail
   340  	assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = A valid join token is necessary to join this cluster")
   341  
   342  	role = api.NodeRoleWorker
   343  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: tc.WorkerToken}
   344  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   345  	assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = A valid join token is necessary to join this cluster")
   346  
   347  	// New token should succeed
   348  	role = api.NodeRoleManager
   349  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: newManagerToken}
   350  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   351  	assert.NoError(t, err)
   352  
   353  	role = api.NodeRoleWorker
   354  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: newWorkerToken}
   355  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   356  	assert.NoError(t, err)
   357  }
   358  
   359  func TestNewNodeCertificateBadToken(t *testing.T) {
   360  	tc := cautils.NewTestCA(t)
   361  	defer tc.Stop()
   362  
   363  	csr, _, err := ca.GenerateNewCSR()
   364  	assert.NoError(t, err)
   365  
   366  	// Issuance fails if wrong secret is provided
   367  	role := api.NodeRoleManager
   368  	issueRequest := &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: "invalid-secret"}
   369  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   370  	assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = A valid join token is necessary to join this cluster")
   371  
   372  	role = api.NodeRoleWorker
   373  	issueRequest = &api.IssueNodeCertificateRequest{CSR: csr, Role: role, Token: "invalid-secret"}
   374  	_, err = tc.NodeCAClients[0].IssueNodeCertificate(tc.Context, issueRequest)
   375  	assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = A valid join token is necessary to join this cluster")
   376  }
   377  
   378  func TestGetUnlockKey(t *testing.T) {
   379  	t.Parallel()
   380  
   381  	tc := cautils.NewTestCA(t)
   382  	defer tc.Stop()
   383  
   384  	var cluster *api.Cluster
   385  	tc.MemoryStore.View(func(tx store.ReadTx) {
   386  		clusters, err := store.FindClusters(tx, store.ByName(store.DefaultClusterName))
   387  		require.NoError(t, err)
   388  		cluster = clusters[0]
   389  	})
   390  
   391  	resp, err := tc.CAClients[0].GetUnlockKey(tc.Context, &api.GetUnlockKeyRequest{})
   392  	require.NoError(t, err)
   393  	require.Nil(t, resp.UnlockKey)
   394  	require.Equal(t, cluster.Meta.Version, resp.Version)
   395  
   396  	// Update the unlock key
   397  	require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error {
   398  		cluster = store.GetCluster(tx, cluster.ID)
   399  		cluster.Spec.EncryptionConfig.AutoLockManagers = true
   400  		cluster.UnlockKeys = []*api.EncryptionKey{{
   401  			Subsystem: ca.ManagerRole,
   402  			Key:       []byte("secret"),
   403  		}}
   404  		return store.UpdateCluster(tx, cluster)
   405  	}))
   406  
   407  	tc.MemoryStore.View(func(tx store.ReadTx) {
   408  		cluster = store.GetCluster(tx, cluster.ID)
   409  	})
   410  
   411  	require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error {
   412  		resp, err = tc.CAClients[0].GetUnlockKey(tc.Context, &api.GetUnlockKeyRequest{})
   413  		if err != nil {
   414  			return fmt.Errorf("get unlock key: %v", err)
   415  		}
   416  		if !bytes.Equal(resp.UnlockKey, []byte("secret")) {
   417  			return fmt.Errorf("secret hasn't rotated yet")
   418  		}
   419  		if cluster.Meta.Version.Index > resp.Version.Index {
   420  			return fmt.Errorf("hasn't updated to the right version yet")
   421  		}
   422  		return nil
   423  	}, 250*time.Millisecond))
   424  }
   425  
   426  type clusterObjToUpdate struct {
   427  	clusterObj           *api.Cluster
   428  	rootCARoots          []byte
   429  	rootCASigningCert    []byte
   430  	rootCASigningKey     []byte
   431  	rootCAIntermediates  []byte
   432  	externalCertSignedBy []byte
   433  }
   434  
   435  // When the SecurityConfig is updated with a new TLS keypair, the server automatically uses that keypair to contact
   436  // the external CA
   437  func TestServerExternalCAGetsTLSKeypairUpdates(t *testing.T) {
   438  	t.Parallel()
   439  
   440  	// this one needs the external CA server for testing
   441  	if !cautils.External {
   442  		return
   443  	}
   444  
   445  	tc := cautils.NewTestCA(t)
   446  	defer tc.Stop()
   447  
   448  	// show that we can connect to the external CA using our original creds
   449  	csr, _, err := ca.GenerateNewCSR()
   450  	require.NoError(t, err)
   451  	req := ca.PrepareCSR(csr, "cn", ca.ManagerRole, tc.Organization)
   452  
   453  	externalCA := tc.CAServer.ExternalCA()
   454  	extSignedCert, err := externalCA.Sign(tc.Context, req)
   455  	require.NoError(t, err)
   456  	require.NotNil(t, extSignedCert)
   457  
   458  	// get a new cert and make it expired
   459  	_, issuerInfo, err := tc.RootCA.IssueAndSaveNewCertificates(
   460  		tc.KeyReadWriter, tc.ServingSecurityConfig.ClientTLSCreds.NodeID(), ca.ManagerRole, tc.Organization)
   461  	require.NoError(t, err)
   462  	cert, key, err := tc.KeyReadWriter.Read()
   463  	require.NoError(t, err)
   464  
   465  	s, err := tc.RootCA.Signer()
   466  	require.NoError(t, err)
   467  	cert = cautils.ReDateCert(t, cert, s.Cert, s.Key, time.Now().Add(-5*time.Hour), time.Now().Add(-3*time.Hour))
   468  
   469  	// we have to create the keypair and update the security config manually, because all the renew functions check for
   470  	// expiry
   471  	tlsKeyPair, err := tls.X509KeyPair(cert, key)
   472  	require.NoError(t, err)
   473  	require.NoError(t, tc.ServingSecurityConfig.UpdateTLSCredentials(&tlsKeyPair, issuerInfo))
   474  
   475  	// show that we now cannot connect to the external CA using our original creds
   476  	require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error {
   477  		externalCA := tc.CAServer.ExternalCA()
   478  		// wait for the credentials for the external CA to update
   479  		if _, err = externalCA.Sign(tc.Context, req); err == nil {
   480  			return errors.New("external CA creds haven't updated yet to be invalid")
   481  		}
   482  		return nil
   483  	}, 2*time.Second))
   484  	require.Contains(t, errors.Cause(err).Error(), "remote error: tls: bad certificate")
   485  }
   486  
   487  func TestCAServerUpdateRootCA(t *testing.T) {
   488  	// this one needs both external CA servers for testing
   489  	if !cautils.External {
   490  		return
   491  	}
   492  
   493  	fakeClusterSpec := func(rootCerts, key []byte, rotation *api.RootRotation, externalCAs []*api.ExternalCA) *api.Cluster {
   494  		return &api.Cluster{
   495  			RootCA: api.RootCA{
   496  				CACert:     rootCerts,
   497  				CAKey:      key,
   498  				CACertHash: "hash",
   499  				JoinTokens: api.JoinTokens{
   500  					Worker:  "SWMTKN-1-worker",
   501  					Manager: "SWMTKN-1-manager",
   502  				},
   503  				RootRotation: rotation,
   504  			},
   505  			Spec: api.ClusterSpec{
   506  				CAConfig: api.CAConfig{
   507  					ExternalCAs: externalCAs,
   508  				},
   509  			},
   510  		}
   511  	}
   512  
   513  	tc := cautils.NewTestCA(t)
   514  	require.NoError(t, tc.CAServer.Stop())
   515  	defer tc.Stop()
   516  
   517  	cert, key, err := cautils.CreateRootCertAndKey("new root to rotate to")
   518  	require.NoError(t, err)
   519  	newRootCA, err := ca.NewRootCA(append(tc.RootCA.Certs, cert...), cert, key, ca.DefaultNodeCertExpiration, nil)
   520  	require.NoError(t, err)
   521  	externalServer, err := cautils.NewExternalSigningServer(newRootCA, tc.TempDir)
   522  	require.NoError(t, err)
   523  	defer externalServer.Stop()
   524  	crossSigned, err := tc.RootCA.CrossSignCACertificate(cert)
   525  	require.NoError(t, err)
   526  
   527  	for i, testCase := range []clusterObjToUpdate{
   528  		{
   529  			clusterObj: fakeClusterSpec(tc.RootCA.Certs, nil, nil, []*api.ExternalCA{{
   530  				Protocol: api.ExternalCA_CAProtocolCFSSL,
   531  				URL:      tc.ExternalSigningServer.URL,
   532  				// without a CA cert, the URL gets successfully added, and there should be no error connecting to it
   533  			}}),
   534  			rootCARoots:          tc.RootCA.Certs,
   535  			externalCertSignedBy: tc.RootCA.Certs,
   536  		},
   537  		{
   538  			clusterObj: fakeClusterSpec(tc.RootCA.Certs, nil, &api.RootRotation{
   539  				CACert:            cert,
   540  				CAKey:             key,
   541  				CrossSignedCACert: crossSigned,
   542  			}, []*api.ExternalCA{
   543  				{
   544  					Protocol: api.ExternalCA_CAProtocolCFSSL,
   545  					URL:      tc.ExternalSigningServer.URL,
   546  					// without a CA cert, we count this as the old tc.RootCA.Certs, and this should be ignored because we want the new root
   547  				},
   548  			}),
   549  			rootCARoots:         tc.RootCA.Certs,
   550  			rootCASigningCert:   crossSigned,
   551  			rootCASigningKey:    key,
   552  			rootCAIntermediates: crossSigned,
   553  		},
   554  		{
   555  			clusterObj: fakeClusterSpec(tc.RootCA.Certs, nil, &api.RootRotation{
   556  				CACert:            cert,
   557  				CrossSignedCACert: crossSigned,
   558  			}, []*api.ExternalCA{
   559  				{
   560  					Protocol: api.ExternalCA_CAProtocolCFSSL,
   561  					URL:      tc.ExternalSigningServer.URL,
   562  					// without a CA cert, we count this as the old tc.RootCA.Certs
   563  				},
   564  				{
   565  					Protocol: api.ExternalCA_CAProtocolCFSSL,
   566  					URL:      externalServer.URL,
   567  					CACert:   append(cert, '\n'),
   568  				},
   569  			}),
   570  			rootCARoots:          tc.RootCA.Certs,
   571  			rootCAIntermediates:  crossSigned,
   572  			externalCertSignedBy: cert,
   573  		},
   574  	} {
   575  		require.NoError(t, tc.CAServer.UpdateRootCA(tc.Context, testCase.clusterObj, nil))
   576  
   577  		rootCA := tc.CAServer.RootCA()
   578  		require.Equal(t, testCase.rootCARoots, rootCA.Certs)
   579  		var signingCert, signingKey []byte
   580  		if s, err := rootCA.Signer(); err == nil {
   581  			signingCert, signingKey = s.Cert, s.Key
   582  		}
   583  		require.Equal(t, testCase.rootCARoots, rootCA.Certs)
   584  		require.Equal(t, testCase.rootCASigningCert, signingCert, "%d", i)
   585  		require.Equal(t, testCase.rootCASigningKey, signingKey, "%d", i)
   586  		require.Equal(t, testCase.rootCAIntermediates, rootCA.Intermediates)
   587  
   588  		externalCA := tc.CAServer.ExternalCA()
   589  		csr, _, err := ca.GenerateNewCSR()
   590  		require.NoError(t, err)
   591  		signedCert, err := externalCA.Sign(tc.Context, ca.PrepareCSR(csr, "cn", ca.ManagerRole, tc.Organization))
   592  
   593  		if testCase.externalCertSignedBy != nil {
   594  			require.NoError(t, err)
   595  			parsed, err := helpers.ParseCertificatesPEM(signedCert)
   596  			require.NoError(t, err)
   597  			rootPool := x509.NewCertPool()
   598  			rootPool.AppendCertsFromPEM(testCase.externalCertSignedBy)
   599  			var intermediatePool *x509.CertPool
   600  			if len(parsed) > 1 {
   601  				intermediatePool = x509.NewCertPool()
   602  				for _, cert := range parsed[1:] {
   603  					intermediatePool.AddCert(cert)
   604  				}
   605  			}
   606  			_, err = parsed[0].Verify(x509.VerifyOptions{Roots: rootPool, Intermediates: intermediatePool})
   607  			require.NoError(t, err)
   608  		} else {
   609  			require.Equal(t, ca.ErrNoExternalCAURLs, err)
   610  		}
   611  	}
   612  }
   613  
   614  type rootRotationTester struct {
   615  	tc *cautils.TestCA
   616  	t  *testing.T
   617  }
   618  
   619  // go through all the nodes and update/create the ones we want, and delete the ones
   620  // we don't
   621  func (r *rootRotationTester) convergeWantedNodes(wantNodes map[string]*api.Node, descr string) {
   622  	// update existing and create new nodes first before deleting nodes, else a root rotation
   623  	// may finish early if all the nodes get deleted when the root rotation happens
   624  	require.NoError(r.t, r.tc.MemoryStore.Update(func(tx store.Tx) error {
   625  		for nodeID, wanted := range wantNodes {
   626  			node := store.GetNode(tx, nodeID)
   627  			if node == nil {
   628  				if err := store.CreateNode(tx, wanted); err != nil {
   629  					return err
   630  				}
   631  				continue
   632  			}
   633  			node.Description = wanted.Description
   634  			node.Certificate = wanted.Certificate
   635  			if err := store.UpdateNode(tx, node); err != nil {
   636  				return err
   637  			}
   638  		}
   639  		nodes, err := store.FindNodes(tx, store.All)
   640  		if err != nil {
   641  			return err
   642  		}
   643  		for _, node := range nodes {
   644  			if _, inWanted := wantNodes[node.ID]; !inWanted {
   645  				if err := store.DeleteNode(tx, node.ID); err != nil {
   646  					return err
   647  				}
   648  			}
   649  		}
   650  		return nil
   651  	}), descr)
   652  }
   653  
   654  func (r *rootRotationTester) convergeRootCA(wantRootCA *api.RootCA, descr string) {
   655  	require.NoError(r.t, r.tc.MemoryStore.Update(func(tx store.Tx) error {
   656  		clusters, err := store.FindClusters(tx, store.All)
   657  		if err != nil || len(clusters) != 1 {
   658  			return errors.Wrap(err, "unable to find cluster")
   659  		}
   660  		clusters[0].RootCA = *wantRootCA
   661  		return store.UpdateCluster(tx, clusters[0])
   662  	}), descr)
   663  }
   664  
   665  func getFakeAPINode(t *testing.T, id string, state api.IssuanceStatus_State, tlsInfo *api.NodeTLSInfo, member bool) *api.Node {
   666  	node := &api.Node{
   667  		ID: id,
   668  		Certificate: api.Certificate{
   669  			Status: api.IssuanceStatus{
   670  				State: state,
   671  			},
   672  		},
   673  		Spec: api.NodeSpec{
   674  			Membership: api.NodeMembershipAccepted,
   675  		},
   676  	}
   677  	if !member {
   678  		node.Spec.Membership = api.NodeMembershipPending
   679  	}
   680  	// the CA server will immediately pick these up, so generate CSRs for the CA server to sign
   681  	if state == api.IssuanceStateRenew || state == api.IssuanceStatePending {
   682  		csr, _, err := ca.GenerateNewCSR()
   683  		require.NoError(t, err)
   684  		node.Certificate.CSR = csr
   685  	}
   686  	if tlsInfo != nil {
   687  		node.Description = &api.NodeDescription{TLSInfo: tlsInfo}
   688  	}
   689  	return node
   690  }
   691  
   692  func startCAServer(ctx context.Context, caServer *ca.Server) {
   693  	alreadyRunning := make(chan struct{})
   694  	go func() {
   695  		if err := caServer.Run(ctx); err != nil {
   696  			close(alreadyRunning)
   697  		}
   698  	}()
   699  	select {
   700  	case <-caServer.Ready():
   701  	case <-alreadyRunning:
   702  	}
   703  }
   704  
   705  func getRotationInfo(t *testing.T, rotationCert []byte, rootCA *ca.RootCA) ([]byte, *api.NodeTLSInfo) {
   706  	parsedNewRoot, err := helpers.ParseCertificatePEM(rotationCert)
   707  	require.NoError(t, err)
   708  	crossSigned, err := rootCA.CrossSignCACertificate(rotationCert)
   709  	require.NoError(t, err)
   710  	return crossSigned, &api.NodeTLSInfo{
   711  		TrustRoot:           rootCA.Certs,
   712  		CertIssuerPublicKey: parsedNewRoot.RawSubjectPublicKeyInfo,
   713  		CertIssuerSubject:   parsedNewRoot.RawSubject,
   714  	}
   715  }
   716  
   717  // These are the root rotation test cases where we expect there to be a change in the FindNodes
   718  // or root CA values after converging.
   719  func TestRootRotationReconciliationWithChanges(t *testing.T) {
   720  	t.Parallel()
   721  	if cautils.External {
   722  		// the external CA functionality is unrelated to testing the reconciliation loop
   723  		return
   724  	}
   725  
   726  	tc := cautils.NewTestCA(t)
   727  	defer tc.Stop()
   728  	rt := rootRotationTester{
   729  		tc: tc,
   730  		t:  t,
   731  	}
   732  
   733  	rotationCerts := [][]byte{cautils.ECDSA256SHA256Cert, cautils.ECDSACertChain[2]}
   734  	rotationKeys := [][]byte{cautils.ECDSA256Key, cautils.ECDSACertChainKeys[2]}
   735  	var (
   736  		rotationCrossSigned [][]byte
   737  		rotationTLSInfo     []*api.NodeTLSInfo
   738  	)
   739  	for _, cert := range rotationCerts {
   740  		cross, info := getRotationInfo(t, cert, &tc.RootCA)
   741  		rotationCrossSigned = append(rotationCrossSigned, cross)
   742  		rotationTLSInfo = append(rotationTLSInfo, info)
   743  	}
   744  
   745  	oldNodeTLSInfo := &api.NodeTLSInfo{
   746  		TrustRoot:           tc.RootCA.Certs,
   747  		CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey,
   748  		CertIssuerSubject:   tc.ServingSecurityConfig.IssuerInfo().Subject,
   749  	}
   750  
   751  	var startCluster *api.Cluster
   752  	tc.MemoryStore.View(func(tx store.ReadTx) {
   753  		startCluster = store.GetCluster(tx, tc.Organization)
   754  	})
   755  	require.NotNil(t, startCluster)
   756  
   757  	testcases := []struct {
   758  		nodes           map[string]*api.Node // what nodes we should start with
   759  		rootCA          *api.RootCA          // what root CA we should start with
   760  		expectedNodes   map[string]*api.Node // what nodes we expect in the end, if nil, then unchanged from the start
   761  		expectedRootCA  *api.RootCA          // what root CA we expect in the end, if nil, then unchanged from the start
   762  		caServerRestart bool                 // whether to stop the CA server before making the node and root changes and restart after
   763  		descr           string
   764  	}{
   765  		{
   766  			descr: ("If there is no TLS info, the reconciliation cycle tells the nodes to rotate if they're not already getting " +
   767  				"a new cert.  Any renew/pending nodes will have certs issued, but because the TLS info is nil, they will " +
   768  				`go "rotate" state`),
   769  			nodes: map[string]*api.Node{
   770  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   771  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, nil, true),
   772  				"2": getFakeAPINode(t, "2", api.IssuanceStateRenew, nil, true),
   773  				"3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true),
   774  				"4": getFakeAPINode(t, "4", api.IssuanceStatePending, nil, true),
   775  				"5": getFakeAPINode(t, "5", api.IssuanceStateFailed, nil, true),
   776  				"6": getFakeAPINode(t, "6", api.IssuanceStateIssued, nil, false),
   777  			},
   778  			rootCA: &api.RootCA{
   779  				CACert:     startCluster.RootCA.CACert,
   780  				CAKey:      startCluster.RootCA.CAKey,
   781  				CACertHash: startCluster.RootCA.CACertHash,
   782  				RootRotation: &api.RootRotation{
   783  					CACert:            rotationCerts[0],
   784  					CAKey:             rotationKeys[0],
   785  					CrossSignedCACert: rotationCrossSigned[0],
   786  				},
   787  			},
   788  			expectedNodes: map[string]*api.Node{
   789  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   790  				"1": getFakeAPINode(t, "1", api.IssuanceStateRotate, nil, true),
   791  				"2": getFakeAPINode(t, "2", api.IssuanceStateRotate, nil, true),
   792  				"3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true),
   793  				"4": getFakeAPINode(t, "4", api.IssuanceStateRotate, nil, true),
   794  				"5": getFakeAPINode(t, "5", api.IssuanceStateRotate, nil, true),
   795  				"6": getFakeAPINode(t, "6", api.IssuanceStateRotate, nil, false),
   796  			},
   797  		},
   798  		{
   799  			descr: ("Assume all of the nodes have gotten certs, but some of them are the wrong cert " +
   800  				"(going by the TLS info), which shouldn't really happen.  the rotation reconciliation " +
   801  				"will tell the wrong ones to rotate a second time"),
   802  			nodes: map[string]*api.Node{
   803  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   804  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   805  				"2": getFakeAPINode(t, "2", api.IssuanceStateIssued, oldNodeTLSInfo, true),
   806  				"3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   807  				"4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   808  				"5": getFakeAPINode(t, "5", api.IssuanceStateIssued, oldNodeTLSInfo, true),
   809  				"6": getFakeAPINode(t, "6", api.IssuanceStateIssued, oldNodeTLSInfo, false),
   810  			},
   811  			rootCA: &api.RootCA{ // no change in root CA from previous
   812  				CACert:     startCluster.RootCA.CACert,
   813  				CAKey:      startCluster.RootCA.CAKey,
   814  				CACertHash: startCluster.RootCA.CACertHash,
   815  				RootRotation: &api.RootRotation{
   816  					CACert:            rotationCerts[0],
   817  					CAKey:             rotationKeys[0],
   818  					CrossSignedCACert: rotationCrossSigned[0],
   819  				},
   820  			},
   821  			expectedNodes: map[string]*api.Node{
   822  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   823  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   824  				"2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true),
   825  				"3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   826  				"4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   827  				"5": getFakeAPINode(t, "5", api.IssuanceStateRotate, oldNodeTLSInfo, true),
   828  				"6": getFakeAPINode(t, "6", api.IssuanceStateRotate, oldNodeTLSInfo, false),
   829  			},
   830  		},
   831  		{
   832  			descr: ("New nodes that are added will also be picked up and told to rotate"),
   833  			nodes: map[string]*api.Node{
   834  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   835  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   836  				"3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   837  				"4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   838  				"5": getFakeAPINode(t, "5", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   839  				"6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[0], false),
   840  				"7": getFakeAPINode(t, "7", api.IssuanceStateRenew, nil, true),
   841  			},
   842  			rootCA: &api.RootCA{ // no change in root CA from previous
   843  				CACert:     startCluster.RootCA.CACert,
   844  				CAKey:      startCluster.RootCA.CAKey,
   845  				CACertHash: startCluster.RootCA.CACertHash,
   846  				RootRotation: &api.RootRotation{
   847  					CACert:            rotationCerts[0],
   848  					CAKey:             rotationKeys[0],
   849  					CrossSignedCACert: rotationCrossSigned[0],
   850  				},
   851  			},
   852  			expectedNodes: map[string]*api.Node{
   853  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   854  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   855  				"3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   856  				"4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   857  				"5": getFakeAPINode(t, "5", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   858  				"6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[0], false),
   859  				"7": getFakeAPINode(t, "7", api.IssuanceStateRotate, nil, true),
   860  			},
   861  		},
   862  		{
   863  			descr: ("Even if root rotation isn't finished, if the root changes again to a " +
   864  				"different cert, all the nodes with the old root rotation cert will be told " +
   865  				"to rotate again."),
   866  			nodes: map[string]*api.Node{
   867  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   868  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   869  				"3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   870  				"4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   871  				"5": getFakeAPINode(t, "5", api.IssuanceStateIssued, oldNodeTLSInfo, true),
   872  				"6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[0], true),
   873  				"7": getFakeAPINode(t, "7", api.IssuanceStateIssued, rotationTLSInfo[0], false),
   874  			},
   875  			rootCA: &api.RootCA{ // new root rotation
   876  				CACert:     startCluster.RootCA.CACert,
   877  				CAKey:      startCluster.RootCA.CAKey,
   878  				CACertHash: startCluster.RootCA.CACertHash,
   879  				RootRotation: &api.RootRotation{
   880  					CACert:            rotationCerts[1],
   881  					CAKey:             rotationKeys[1],
   882  					CrossSignedCACert: rotationCrossSigned[1],
   883  				},
   884  			},
   885  			expectedNodes: map[string]*api.Node{
   886  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   887  				"1": getFakeAPINode(t, "1", api.IssuanceStateRotate, rotationTLSInfo[0], true),
   888  				"3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   889  				"4": getFakeAPINode(t, "4", api.IssuanceStateRotate, rotationTLSInfo[0], true),
   890  				"5": getFakeAPINode(t, "5", api.IssuanceStateRotate, oldNodeTLSInfo, true),
   891  				"6": getFakeAPINode(t, "6", api.IssuanceStateRotate, rotationTLSInfo[0], true),
   892  				"7": getFakeAPINode(t, "7", api.IssuanceStateRotate, rotationTLSInfo[0], false),
   893  			},
   894  		},
   895  		{
   896  			descr: ("Once all nodes have rotated to their desired TLS info (even if it's because " +
   897  				"a node with the wrong TLS info has been removed, the root rotation is completed."),
   898  			nodes: map[string]*api.Node{
   899  				"0": getFakeAPINode(t, "0", api.IssuanceStateIssued, rotationTLSInfo[1], false),
   900  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   901  				"3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   902  				"4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   903  				"6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   904  			},
   905  			rootCA: &api.RootCA{
   906  				// no change in root CA from previous - even if root rotation gets completed after
   907  				// the nodes are first set, and we just add the root rotation again because of this
   908  				// test order, because the TLS info is correct for all nodes it will be completed again
   909  				// anyway)
   910  				CACert:     startCluster.RootCA.CACert,
   911  				CAKey:      startCluster.RootCA.CAKey,
   912  				CACertHash: startCluster.RootCA.CACertHash,
   913  				RootRotation: &api.RootRotation{
   914  					CACert:            rotationCerts[1],
   915  					CAKey:             rotationKeys[1],
   916  					CrossSignedCACert: rotationCrossSigned[1],
   917  				},
   918  			},
   919  			expectedRootCA: &api.RootCA{
   920  				CACert:     rotationCerts[1],
   921  				CAKey:      rotationKeys[1],
   922  				CACertHash: digest.FromBytes(rotationCerts[1]).String(),
   923  				// ignore the join tokens - we aren't comparing them
   924  			},
   925  		},
   926  		{
   927  			descr: ("If a root rotation happens when the CA server is down, so long as it saw the change " +
   928  				"it will start reconciling the nodes as soon as it's started up again"),
   929  			caServerRestart: true,
   930  			nodes: map[string]*api.Node{
   931  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   932  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   933  				"3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   934  				"4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   935  				"6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[1], true),
   936  				"7": getFakeAPINode(t, "7", api.IssuanceStateIssued, rotationTLSInfo[1], false),
   937  			},
   938  			rootCA: &api.RootCA{
   939  				CACert:     startCluster.RootCA.CACert,
   940  				CAKey:      startCluster.RootCA.CAKey,
   941  				CACertHash: startCluster.RootCA.CACertHash,
   942  				RootRotation: &api.RootRotation{
   943  					CACert:            rotationCerts[0],
   944  					CAKey:             rotationKeys[0],
   945  					CrossSignedCACert: rotationCrossSigned[0],
   946  				},
   947  			},
   948  			expectedNodes: map[string]*api.Node{
   949  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
   950  				"1": getFakeAPINode(t, "1", api.IssuanceStateRotate, rotationTLSInfo[1], true),
   951  				"3": getFakeAPINode(t, "3", api.IssuanceStateRotate, rotationTLSInfo[1], true),
   952  				"4": getFakeAPINode(t, "4", api.IssuanceStateRotate, rotationTLSInfo[1], true),
   953  				"6": getFakeAPINode(t, "6", api.IssuanceStateRotate, rotationTLSInfo[1], true),
   954  				"7": getFakeAPINode(t, "7", api.IssuanceStateRotate, rotationTLSInfo[1], false),
   955  			},
   956  		},
   957  	}
   958  
   959  	for _, testcase := range testcases {
   960  		// stop the CA server, get the cluster to the state we want (correct root CA, correct nodes, etc.)
   961  		rt.tc.CAServer.Stop()
   962  		rt.convergeWantedNodes(testcase.nodes, testcase.descr)
   963  
   964  		if testcase.caServerRestart {
   965  			// if we want to simulate restarting the CA server with a root rotation already done, set the rootCA to
   966  			// have a root rotation, then start the CA
   967  			rt.convergeRootCA(testcase.rootCA, testcase.descr)
   968  			startCAServer(rt.tc.Context, rt.tc.CAServer)
   969  		} else {
   970  			// otherwise, start the CA in the state where there is no root rotation, and start a root rotation
   971  			rt.convergeRootCA(&startCluster.RootCA, testcase.descr) // no root rotation
   972  			startCAServer(rt.tc.Context, rt.tc.CAServer)
   973  			rt.convergeRootCA(testcase.rootCA, testcase.descr)
   974  		}
   975  
   976  		if testcase.expectedNodes == nil {
   977  			testcase.expectedNodes = testcase.nodes
   978  		}
   979  		if testcase.expectedRootCA == nil {
   980  			testcase.expectedRootCA = testcase.rootCA
   981  		}
   982  
   983  		require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error {
   984  			var (
   985  				nodes   []*api.Node
   986  				cluster *api.Cluster
   987  				err     error
   988  			)
   989  			tc.MemoryStore.View(func(tx store.ReadTx) {
   990  				nodes, err = store.FindNodes(tx, store.All)
   991  				cluster = store.GetCluster(tx, tc.Organization)
   992  			})
   993  			if err != nil {
   994  				return err
   995  			}
   996  			if cluster == nil {
   997  				return errors.New("no cluster found")
   998  			}
   999  
  1000  			if !equality.RootCAEqualStable(&cluster.RootCA, testcase.expectedRootCA) {
  1001  				return fmt.Errorf("root CAs not equal:\n\texpected: %v\n\tactual: %v", *testcase.expectedRootCA, cluster.RootCA)
  1002  			}
  1003  			if len(nodes) != len(testcase.expectedNodes) {
  1004  				return fmt.Errorf("number of expected nodes (%d) does not equal number of actual nodes (%d)",
  1005  					len(testcase.expectedNodes), len(nodes))
  1006  			}
  1007  			for _, node := range nodes {
  1008  				expected, ok := testcase.expectedNodes[node.ID]
  1009  				if !ok {
  1010  					return fmt.Errorf("node %s is present and was unexpected", node.ID)
  1011  				}
  1012  				if !reflect.DeepEqual(expected.Description, node.Description) {
  1013  					return fmt.Errorf("the node description of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID,
  1014  						expected.Description, node.Description)
  1015  				}
  1016  				if !reflect.DeepEqual(expected.Certificate.Status, node.Certificate.Status) {
  1017  					return fmt.Errorf("the certificate status of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID,
  1018  						expected.Certificate, node.Certificate)
  1019  				}
  1020  
  1021  				// ensure that the security config's root CA object has the same expected key
  1022  				expectedKey := testcase.expectedRootCA.CAKey
  1023  				if testcase.expectedRootCA.RootRotation != nil {
  1024  					expectedKey = testcase.expectedRootCA.RootRotation.CAKey
  1025  				}
  1026  				s, err := rt.tc.CAServer.RootCA().Signer()
  1027  				if err != nil {
  1028  					return err
  1029  				}
  1030  				if !bytes.Equal(s.Key, expectedKey) {
  1031  					return fmt.Errorf("the CA Server's root CA has not been updated correctly")
  1032  				}
  1033  			}
  1034  			return nil
  1035  		}, 5*time.Second), testcase.descr)
  1036  	}
  1037  }
  1038  
  1039  // These are the root rotation test cases where we expect there to be no changes made to either
  1040  // the nodes or the root CA object, although the server's signing root CA may change.
  1041  func TestRootRotationReconciliationNoChanges(t *testing.T) {
  1042  	t.Parallel()
  1043  	if cautils.External {
  1044  		// the external CA functionality is unrelated to testing the reconciliation loop
  1045  		return
  1046  	}
  1047  
  1048  	tc := cautils.NewTestCA(t)
  1049  	defer tc.Stop()
  1050  	rt := rootRotationTester{
  1051  		tc: tc,
  1052  		t:  t,
  1053  	}
  1054  
  1055  	rotationCert := cautils.ECDSA256SHA256Cert
  1056  	rotationKey := cautils.ECDSA256Key
  1057  	rotationCrossSigned, rotationTLSInfo := getRotationInfo(t, rotationCert, &tc.RootCA)
  1058  
  1059  	oldNodeTLSInfo := &api.NodeTLSInfo{
  1060  		TrustRoot:           tc.RootCA.Certs,
  1061  		CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey,
  1062  		CertIssuerSubject:   tc.ServingSecurityConfig.IssuerInfo().Subject,
  1063  	}
  1064  
  1065  	var startCluster *api.Cluster
  1066  	tc.MemoryStore.View(func(tx store.ReadTx) {
  1067  		startCluster = store.GetCluster(tx, tc.Organization)
  1068  	})
  1069  	require.NotNil(t, startCluster)
  1070  
  1071  	testcases := []struct {
  1072  		nodes  map[string]*api.Node // what nodes we should start with
  1073  		rootCA *api.RootCA          // what root CA we should start with
  1074  		descr  string
  1075  	}{
  1076  		{
  1077  			descr: ("If all nodes have the right TLS info or are already rotated, rotating, or pending, " +
  1078  				"there will be no changes needed"),
  1079  			nodes: map[string]*api.Node{
  1080  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
  1081  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo, true),
  1082  				"2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true),
  1083  				"3": getFakeAPINode(t, "3", api.IssuanceStateRotate, rotationTLSInfo, false),
  1084  			},
  1085  			rootCA: &api.RootCA{ // no change in root CA from previous
  1086  				CACert:     startCluster.RootCA.CACert,
  1087  				CAKey:      startCluster.RootCA.CAKey,
  1088  				CACertHash: startCluster.RootCA.CACertHash,
  1089  				RootRotation: &api.RootRotation{
  1090  					CACert:            rotationCert,
  1091  					CAKey:             rotationKey,
  1092  					CrossSignedCACert: rotationCrossSigned,
  1093  				},
  1094  			},
  1095  		},
  1096  		{
  1097  			descr: ("Nodes already in rotate state, even if they currently have the correct TLS issuer, will be " +
  1098  				"left in the rotate state even if root rotation is aborted because we don't know if they're already " +
  1099  				"in the process of getting a new cert.  Even if they're issued by a different issuer, they will be " +
  1100  				"left alone because they'll have an interemdiate that chains up to the old issuer."),
  1101  			nodes: map[string]*api.Node{
  1102  				"0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false),
  1103  				"1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo, true),
  1104  				"2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true),
  1105  				"3": getFakeAPINode(t, "3", api.IssuanceStateRotate, oldNodeTLSInfo, false),
  1106  			},
  1107  			rootCA: &api.RootCA{ // no change in root CA from previous
  1108  				CACert:     startCluster.RootCA.CACert,
  1109  				CAKey:      startCluster.RootCA.CAKey,
  1110  				CACertHash: startCluster.RootCA.CACertHash,
  1111  			},
  1112  		},
  1113  	}
  1114  
  1115  	for _, testcase := range testcases {
  1116  		// stop the CA server, get the cluster to the state we want (correct root CA, correct nodes, etc.)
  1117  		rt.tc.CAServer.Stop()
  1118  		rt.convergeWantedNodes(testcase.nodes, testcase.descr)
  1119  		rt.convergeRootCA(&startCluster.RootCA, testcase.descr) // no root rotation
  1120  		startCAServer(rt.tc.Context, rt.tc.CAServer)
  1121  		rt.convergeRootCA(testcase.rootCA, testcase.descr)
  1122  
  1123  		time.Sleep(500 * time.Millisecond)
  1124  
  1125  		var (
  1126  			nodes   []*api.Node
  1127  			cluster *api.Cluster
  1128  			err     error
  1129  		)
  1130  
  1131  		tc.MemoryStore.View(func(tx store.ReadTx) {
  1132  			nodes, err = store.FindNodes(tx, store.All)
  1133  			cluster = store.GetCluster(tx, tc.Organization)
  1134  		})
  1135  		require.NoError(t, err)
  1136  		require.NotNil(t, cluster)
  1137  		require.Equal(t, cluster.RootCA, *testcase.rootCA, testcase.descr)
  1138  		require.Len(t, nodes, len(testcase.nodes), testcase.descr)
  1139  		for _, node := range nodes {
  1140  			expected, ok := testcase.nodes[node.ID]
  1141  			require.True(t, ok, "node %s: %s", node.ID, testcase.descr)
  1142  			require.Equal(t, expected.Description, node.Description, "node %s: %s", node.ID, testcase.descr)
  1143  			require.Equal(t, expected.Certificate.Status, node.Certificate.Status, "node %s: %s", node.ID, testcase.descr)
  1144  		}
  1145  
  1146  		// ensure that the server's root CA object has the same expected key
  1147  		expectedKey := testcase.rootCA.CAKey
  1148  		if testcase.rootCA.RootRotation != nil {
  1149  			expectedKey = testcase.rootCA.RootRotation.CAKey
  1150  		}
  1151  		s, err := rt.tc.CAServer.RootCA().Signer()
  1152  		require.NoError(t, err, testcase.descr)
  1153  		require.Equal(t, s.Key, expectedKey, testcase.descr)
  1154  	}
  1155  }
  1156  
  1157  // Tests if the root rotation changes while the reconciliation loop is going, eventually the root rotation will finish
  1158  // successfully (even if there's a competing reconciliation loop, for instance if there's a bug during leadership handoff).
  1159  func TestRootRotationReconciliationRace(t *testing.T) {
  1160  	t.Parallel()
  1161  	if cautils.External {
  1162  		// the external CA functionality is unrelated to testing the reconciliation loop
  1163  		return
  1164  	}
  1165  
  1166  	tc := cautils.NewTestCA(t)
  1167  	defer tc.Stop()
  1168  	tc.CAServer.Stop() // we can't use the testCA's CA server because we need to inject extra behavior into the control loop
  1169  	rt := rootRotationTester{
  1170  		tc: tc,
  1171  		t:  t,
  1172  	}
  1173  
  1174  	tempDir, err := ioutil.TempDir("", "competing-ca-server")
  1175  	require.NoError(t, err)
  1176  	defer os.RemoveAll(tempDir)
  1177  
  1178  	var (
  1179  		otherServers   = make([]*ca.Server, 5)
  1180  		serverContexts = make([]context.Context, 5)
  1181  		paths          = make([]*ca.SecurityConfigPaths, 5)
  1182  	)
  1183  
  1184  	for i := 0; i < 5; i++ { // to make sure we get some collision
  1185  		// start a competing CA server
  1186  		paths[i] = ca.NewConfigPaths(filepath.Join(tempDir, fmt.Sprintf("%d", i)))
  1187  
  1188  		// the sec config is only used to get the organization, the initial root CA copy, and any updates to
  1189  		// TLS certificates, so all the servers can share the same one
  1190  		otherServers[i] = ca.NewServer(tc.MemoryStore, tc.ServingSecurityConfig)
  1191  
  1192  		// offset each server's reconciliation interval somewhat so that some will
  1193  		// pre-empt others
  1194  		otherServers[i].SetRootReconciliationInterval(time.Millisecond * time.Duration((i+1)*10))
  1195  		serverContexts[i] = log.WithLogger(tc.Context, log.G(tc.Context).WithFields(logrus.Fields{
  1196  			"otherCAServer": i,
  1197  		}))
  1198  		startCAServer(serverContexts[i], otherServers[i])
  1199  		defer otherServers[i].Stop()
  1200  	}
  1201  
  1202  	oldNodeTLSInfo := &api.NodeTLSInfo{
  1203  		TrustRoot:           tc.RootCA.Certs,
  1204  		CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey,
  1205  		CertIssuerSubject:   tc.ServingSecurityConfig.IssuerInfo().Subject,
  1206  	}
  1207  
  1208  	nodes := make(map[string]*api.Node)
  1209  	for i := 0; i < 5; i++ {
  1210  		nodeID := fmt.Sprintf("%d", i)
  1211  		nodes[nodeID] = getFakeAPINode(t, nodeID, api.IssuanceStateIssued, oldNodeTLSInfo, true)
  1212  	}
  1213  	rt.convergeWantedNodes(nodes, "setting up nodes for root rotation race condition test")
  1214  
  1215  	var rotationCert, rotationKey []byte
  1216  	for i := 0; i < 10; i++ {
  1217  		var (
  1218  			rotationCrossSigned []byte
  1219  			rotationTLSInfo     *api.NodeTLSInfo
  1220  			caRootCA            ca.RootCA
  1221  		)
  1222  		rotationCert, rotationKey, err = cautils.CreateRootCertAndKey(fmt.Sprintf("root cn %d", i))
  1223  		require.NoError(t, err)
  1224  		require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error {
  1225  			cluster := store.GetCluster(tx, tc.Organization)
  1226  			if cluster == nil {
  1227  				return errors.New("cluster has disappeared")
  1228  			}
  1229  			rootCA := cluster.RootCA.Copy()
  1230  			caRootCA, err = ca.NewRootCA(rootCA.CACert, rootCA.CACert, rootCA.CAKey, ca.DefaultNodeCertExpiration, nil)
  1231  			if err != nil {
  1232  				return err
  1233  			}
  1234  			rotationCrossSigned, rotationTLSInfo = getRotationInfo(t, rotationCert, &caRootCA)
  1235  			rootCA.RootRotation = &api.RootRotation{
  1236  				CACert:            rotationCert,
  1237  				CAKey:             rotationKey,
  1238  				CrossSignedCACert: rotationCrossSigned,
  1239  			}
  1240  			cluster.RootCA = *rootCA
  1241  			return store.UpdateCluster(tx, cluster)
  1242  		}))
  1243  		for _, node := range nodes {
  1244  			node.Description.TLSInfo = rotationTLSInfo
  1245  		}
  1246  		rt.convergeWantedNodes(nodes, fmt.Sprintf("iteration %d", i))
  1247  	}
  1248  
  1249  	require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error {
  1250  		var cluster *api.Cluster
  1251  		tc.MemoryStore.View(func(tx store.ReadTx) {
  1252  			cluster = store.GetCluster(tx, tc.Organization)
  1253  		})
  1254  		if cluster == nil {
  1255  			return errors.New("cluster has disappeared")
  1256  		}
  1257  		if cluster.RootCA.RootRotation != nil {
  1258  			return errors.New("root rotation is still present")
  1259  		}
  1260  		if !bytes.Equal(cluster.RootCA.CACert, rotationCert) {
  1261  			return errors.New("expected root cert is wrong")
  1262  		}
  1263  		if !bytes.Equal(cluster.RootCA.CAKey, rotationKey) {
  1264  			return errors.New("expected root key is wrong")
  1265  		}
  1266  		for i, server := range otherServers {
  1267  			s, err := server.RootCA().Signer()
  1268  			if err != nil {
  1269  				return err
  1270  			}
  1271  			if !bytes.Equal(s.Key, rotationKey) {
  1272  				return errors.Errorf("server %d's root CAs hasn't been updated yet", i)
  1273  			}
  1274  		}
  1275  		return nil
  1276  	}, 5*time.Second))
  1277  
  1278  	// all of the ca servers have the appropriate cert and key
  1279  }
  1280  
  1281  // If there are a lot of nodes, we only update a small number of them at once.
  1282  func TestRootRotationReconciliationThrottled(t *testing.T) {
  1283  	t.Parallel()
  1284  	if cautils.External {
  1285  		// the external CA functionality is unrelated to testing the reconciliation loop
  1286  		return
  1287  	}
  1288  
  1289  	tc := cautils.NewTestCA(t)
  1290  	defer tc.Stop()
  1291  	// immediately stop the CA server - we want to run our own
  1292  	tc.CAServer.Stop()
  1293  
  1294  	caServer := ca.NewServer(tc.MemoryStore, tc.ServingSecurityConfig)
  1295  	// set the reconciliation interval to something ridiculous, so we can make sure the first
  1296  	// batch does update all of them
  1297  	caServer.SetRootReconciliationInterval(time.Hour)
  1298  	startCAServer(tc.Context, caServer)
  1299  	defer caServer.Stop()
  1300  
  1301  	var (
  1302  		nodes []*api.Node
  1303  		err   error
  1304  	)
  1305  	tc.MemoryStore.View(func(tx store.ReadTx) {
  1306  		nodes, err = store.FindNodes(tx, store.All)
  1307  	})
  1308  	require.NoError(t, err)
  1309  
  1310  	// create twice the batch size of nodes
  1311  	err = tc.MemoryStore.Batch(func(batch *store.Batch) error {
  1312  		for i := len(nodes); i < ca.IssuanceStateRotateMaxBatchSize*2; i++ {
  1313  			nodeID := fmt.Sprintf("%d", i)
  1314  			err := batch.Update(func(tx store.Tx) error {
  1315  				return store.CreateNode(tx, getFakeAPINode(t, nodeID, api.IssuanceStateIssued, nil, true))
  1316  			})
  1317  			if err != nil {
  1318  				return err
  1319  			}
  1320  		}
  1321  		return nil
  1322  	})
  1323  	require.NoError(t, err)
  1324  
  1325  	rotationCert := cautils.ECDSA256SHA256Cert
  1326  	rotationKey := cautils.ECDSA256Key
  1327  	rotationCrossSigned, _ := getRotationInfo(t, rotationCert, &tc.RootCA)
  1328  
  1329  	require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error {
  1330  		cluster := store.GetCluster(tx, tc.Organization)
  1331  		if cluster == nil {
  1332  			return errors.New("cluster has disappeared")
  1333  		}
  1334  		rootCA := cluster.RootCA.Copy()
  1335  		rootCA.RootRotation = &api.RootRotation{
  1336  			CACert:            rotationCert,
  1337  			CAKey:             rotationKey,
  1338  			CrossSignedCACert: rotationCrossSigned,
  1339  		}
  1340  		cluster.RootCA = *rootCA
  1341  		return store.UpdateCluster(tx, cluster)
  1342  	}))
  1343  
  1344  	checkRotationNumber := func() error {
  1345  		tc.MemoryStore.View(func(tx store.ReadTx) {
  1346  			nodes, err = store.FindNodes(tx, store.All)
  1347  		})
  1348  		var issuanceRotate int
  1349  		for _, n := range nodes {
  1350  			if n.Certificate.Status.State == api.IssuanceStateRotate {
  1351  				issuanceRotate += 1
  1352  			}
  1353  		}
  1354  		if issuanceRotate != ca.IssuanceStateRotateMaxBatchSize {
  1355  			return fmt.Errorf("expected %d, got %d", ca.IssuanceStateRotateMaxBatchSize, issuanceRotate)
  1356  		}
  1357  		return nil
  1358  	}
  1359  
  1360  	require.NoError(t, testutils.PollFuncWithTimeout(nil, checkRotationNumber, 5*time.Second))
  1361  	// prove that it's not just because the updates haven't finished
  1362  	time.Sleep(time.Second)
  1363  	require.NoError(t, checkRotationNumber())
  1364  }