k8s.io/kubernetes@v1.31.0-alpha.0.0.20240520171757-56147500dadc/cmd/kubeadm/app/phases/certs/renewal/manager_test.go (about)

     1  /*
     2  Copyright 2019 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package renewal
    18  
    19  import (
    20  	"crypto"
    21  	"crypto/x509"
    22  	"crypto/x509/pkix"
    23  	"fmt"
    24  	"net"
    25  	"os"
    26  	"path/filepath"
    27  	"reflect"
    28  	"testing"
    29  	"time"
    30  
    31  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    32  	certutil "k8s.io/client-go/util/cert"
    33  	netutils "k8s.io/utils/net"
    34  
    35  	kubeadmapi "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm"
    36  	kubeadmconstants "k8s.io/kubernetes/cmd/kubeadm/app/constants"
    37  	kubeadmutil "k8s.io/kubernetes/cmd/kubeadm/app/util"
    38  	certtestutil "k8s.io/kubernetes/cmd/kubeadm/app/util/certs"
    39  	"k8s.io/kubernetes/cmd/kubeadm/app/util/pkiutil"
    40  	testutil "k8s.io/kubernetes/cmd/kubeadm/test"
    41  )
    42  
    43  var (
    44  	testCACertCfg = &pkiutil.CertConfig{
    45  		Config: certutil.Config{CommonName: "kubernetes"},
    46  	}
    47  
    48  	testCACert, testCAKey, _ = pkiutil.NewCertificateAuthority(testCACertCfg)
    49  
    50  	testCertOrganization = []string{"sig-cluster-lifecycle"}
    51  
    52  	testCertCfg = makeTestCertConfig(testCertOrganization, time.Time{}, time.Time{})
    53  )
    54  
    55  type fakecertificateReadWriter struct {
    56  	exist bool
    57  	cert  *x509.Certificate
    58  }
    59  
    60  func (cr fakecertificateReadWriter) Exists() (bool, error) {
    61  	return cr.exist, nil
    62  }
    63  
    64  func (cr fakecertificateReadWriter) Read() (*x509.Certificate, error) {
    65  	return cr.cert, nil
    66  }
    67  
    68  func (cr fakecertificateReadWriter) Write(*x509.Certificate, crypto.Signer) error {
    69  	return nil
    70  }
    71  
    72  func TestNewManager(t *testing.T) {
    73  	tests := []struct {
    74  		name                 string
    75  		cfg                  *kubeadmapi.ClusterConfiguration
    76  		expectedCertificates int
    77  	}{
    78  		{
    79  			name:                 "cluster with local etcd",
    80  			cfg:                  &kubeadmapi.ClusterConfiguration{},
    81  			expectedCertificates: 11, // [admin super-admin apiserver apiserver-etcd-client apiserver-kubelet-client controller-manager etcd/healthcheck-client etcd/peer etcd/server front-proxy-client scheduler]
    82  		},
    83  		{
    84  			name: "cluster with external etcd",
    85  			cfg: &kubeadmapi.ClusterConfiguration{
    86  				Etcd: kubeadmapi.Etcd{
    87  					External: &kubeadmapi.ExternalEtcd{},
    88  				},
    89  			},
    90  			expectedCertificates: 7, // [admin super-admin apiserver apiserver-kubelet-client controller-manager front-proxy-client scheduler]
    91  		},
    92  	}
    93  
    94  	for _, test := range tests {
    95  		t.Run(test.name, func(t *testing.T) {
    96  			rm, err := NewManager(test.cfg, "")
    97  			if err != nil {
    98  				t.Fatalf("Failed to create the certificate renewal manager: %v", err)
    99  			}
   100  
   101  			if len(rm.Certificates()) != test.expectedCertificates {
   102  				t.Errorf("Expected %d certificates, saw %d", test.expectedCertificates, len(rm.Certificates()))
   103  			}
   104  		})
   105  	}
   106  }
   107  
   108  func TestRenewUsingLocalCA(t *testing.T) {
   109  	dir := testutil.SetupTempDir(t)
   110  	defer os.RemoveAll(dir)
   111  
   112  	if err := pkiutil.WriteCertAndKey(dir, "ca", testCACert, testCAKey); err != nil {
   113  		t.Fatalf("couldn't write out CA certificate to %s", dir)
   114  	}
   115  
   116  	etcdDir := filepath.Join(dir, "etcd")
   117  	if err := pkiutil.WriteCertAndKey(etcdDir, "ca", testCACert, testCAKey); err != nil {
   118  		t.Fatalf("couldn't write out CA certificate to %s", etcdDir)
   119  	}
   120  
   121  	cfg := &kubeadmapi.ClusterConfiguration{
   122  		CertificatesDir: dir,
   123  		CertificateValidityPeriod: &metav1.Duration{
   124  			Duration: time.Hour * 10,
   125  		},
   126  	}
   127  	rm, err := NewManager(cfg, dir)
   128  	if err != nil {
   129  		t.Fatalf("Failed to create the certificate renewal manager: %v", err)
   130  	}
   131  
   132  	// Prepare test certs with a past validity.
   133  	startTime := kubeadmutil.StartTimeUTC()
   134  
   135  	fmt.Println("START TIME TEST", startTime)
   136  
   137  	notBefore := startTime.Add(-rm.cfg.CertificateValidityPeriod.Duration * 2)
   138  	notAfter := startTime.Add(-rm.cfg.CertificateValidityPeriod.Duration)
   139  
   140  	tests := []struct {
   141  		name                 string
   142  		certName             string
   143  		createCertFunc       func() *x509.Certificate
   144  		expectedOrganization []string
   145  	}{
   146  		{
   147  			name:     "Certificate renewal for a PKI certificate",
   148  			certName: "apiserver",
   149  			createCertFunc: func() *x509.Certificate {
   150  				return writeTestCertificate(t, dir, "apiserver", testCACert, testCAKey, testCertOrganization, notBefore, notAfter)
   151  			},
   152  			expectedOrganization: testCertOrganization,
   153  		},
   154  		{
   155  			name:     "Certificate renewal for a certificate embedded in a kubeconfig file",
   156  			certName: "admin.conf",
   157  			createCertFunc: func() *x509.Certificate {
   158  				return writeTestKubeconfig(t, dir, "admin.conf", testCACert, testCAKey, notBefore, notAfter)
   159  			},
   160  			expectedOrganization: testCertOrganization,
   161  		},
   162  	}
   163  
   164  	for _, test := range tests {
   165  		t.Run(test.name, func(t *testing.T) {
   166  			cert := test.createCertFunc()
   167  
   168  			notBefore := startTime.Add(-kubeadmconstants.CertificateBackdate)
   169  			notAfter := startTime.Add(rm.cfg.CertificateValidityPeriod.Duration)
   170  			testCertCfg := makeTestCertConfig(testCertOrganization, notBefore, notAfter)
   171  
   172  			_, err := rm.RenewUsingLocalCA(test.certName)
   173  			if err != nil {
   174  				t.Fatalf("error renewing certificate: %v", err)
   175  			}
   176  
   177  			newCert, err := rm.certificates[test.certName].readwriter.Read()
   178  			if err != nil {
   179  				t.Fatalf("error reading renewed certificate: %v", err)
   180  			}
   181  
   182  			if newCert.SerialNumber.Cmp(cert.SerialNumber) == 0 {
   183  				t.Fatal("expected new certificate, but renewed certificate has same serial number")
   184  			}
   185  
   186  			if !newCert.NotAfter.After(cert.NotAfter) {
   187  				t.Fatalf("expected new certificate with updated expiration, but renewed certificate has same NotAfter value: saw %s, expected greather than %s", newCert.NotAfter, cert.NotAfter)
   188  			}
   189  
   190  			certtestutil.AssertCertificateIsSignedByCa(t, newCert, testCACert)
   191  			certtestutil.AssertCertificateHasClientAuthUsage(t, newCert)
   192  			certtestutil.AssertCertificateHasOrganizations(t, newCert, test.expectedOrganization...)
   193  			certtestutil.AssertCertificateHasCommonName(t, newCert, testCertCfg.CommonName)
   194  			certtestutil.AssertCertificateHasDNSNames(t, newCert, testCertCfg.AltNames.DNSNames...)
   195  			certtestutil.AssertCertificateHasIPAddresses(t, newCert, testCertCfg.AltNames.IPs...)
   196  			certtestutil.AssertCertificateHasNotBefore(t, newCert, testCertCfg.NotBefore)
   197  			certtestutil.AssertCertificateHasNotAfter(t, newCert, testCertCfg.NotAfter)
   198  		})
   199  	}
   200  }
   201  
   202  func TestCreateRenewCSR(t *testing.T) {
   203  	dir := testutil.SetupTempDir(t)
   204  	defer os.RemoveAll(dir)
   205  
   206  	outdir := filepath.Join(dir, "out")
   207  
   208  	if err := os.MkdirAll(outdir, 0755); err != nil {
   209  		t.Fatalf("couldn't create %s", outdir)
   210  	}
   211  
   212  	if err := pkiutil.WriteCertAndKey(dir, "ca", testCACert, testCAKey); err != nil {
   213  		t.Fatalf("couldn't write out CA certificate to %s", dir)
   214  	}
   215  
   216  	cfg := &kubeadmapi.ClusterConfiguration{
   217  		CertificatesDir: dir,
   218  	}
   219  	rm, err := NewManager(cfg, dir)
   220  	if err != nil {
   221  		t.Fatalf("Failed to create the certificate renewal manager: %v", err)
   222  	}
   223  
   224  	tests := []struct {
   225  		name           string
   226  		certName       string
   227  		createCertFunc func() *x509.Certificate
   228  	}{
   229  		{
   230  			name:     "Creation of a CSR request for renewal of a PKI certificate",
   231  			certName: "apiserver",
   232  			createCertFunc: func() *x509.Certificate {
   233  				return writeTestCertificate(t, dir, "apiserver", testCACert, testCAKey, testCertOrganization, time.Time{}, time.Time{})
   234  			},
   235  		},
   236  		{
   237  			name:     "Creation of a CSR request for renewal of a certificate embedded in a kubeconfig file",
   238  			certName: "admin.conf",
   239  			createCertFunc: func() *x509.Certificate {
   240  				return writeTestKubeconfig(t, dir, "admin.conf", testCACert, testCAKey, time.Time{}, time.Time{})
   241  			},
   242  		},
   243  	}
   244  
   245  	for _, test := range tests {
   246  		t.Run(test.name, func(t *testing.T) {
   247  			test.createCertFunc()
   248  
   249  			time.Sleep(1 * time.Second)
   250  
   251  			err := rm.CreateRenewCSR(test.certName, outdir)
   252  			if err != nil {
   253  				t.Fatalf("error renewing certificate: %v", err)
   254  			}
   255  
   256  			file := fmt.Sprintf("%s.key", test.certName)
   257  			if _, err := os.Stat(filepath.Join(outdir, file)); os.IsNotExist(err) {
   258  				t.Errorf("Expected file %s does not exist", file)
   259  			}
   260  
   261  			file = fmt.Sprintf("%s.csr", test.certName)
   262  			if _, err := os.Stat(filepath.Join(outdir, file)); os.IsNotExist(err) {
   263  				t.Errorf("Expected file %s does not exist", file)
   264  			}
   265  		})
   266  	}
   267  
   268  }
   269  
   270  func TestCertToConfig(t *testing.T) {
   271  	expectedConfig := &certutil.Config{
   272  		CommonName:   "test-common-name",
   273  		Organization: testCertOrganization,
   274  		AltNames: certutil.AltNames{
   275  			IPs:      []net.IP{netutils.ParseIPSloppy("10.100.0.1")},
   276  			DNSNames: []string{"test-domain.space"},
   277  		},
   278  		Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
   279  	}
   280  
   281  	cert := &x509.Certificate{
   282  		Subject: pkix.Name{
   283  			CommonName:   "test-common-name",
   284  			Organization: testCertOrganization,
   285  		},
   286  		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
   287  		DNSNames:    []string{"test-domain.space"},
   288  		IPAddresses: []net.IP{netutils.ParseIPSloppy("10.100.0.1")},
   289  	}
   290  
   291  	cfg := certToConfig(cert)
   292  
   293  	if cfg.CommonName != expectedConfig.CommonName {
   294  		t.Errorf("expected common name %q, got %q", expectedConfig.CommonName, cfg.CommonName)
   295  	}
   296  
   297  	if len(cfg.Organization) != 1 || cfg.Organization[0] != expectedConfig.Organization[0] {
   298  		t.Errorf("expected organization %v, got %v", expectedConfig.Organization, cfg.Organization)
   299  
   300  	}
   301  
   302  	if len(cfg.Usages) != 1 || cfg.Usages[0] != expectedConfig.Usages[0] {
   303  		t.Errorf("expected ext key usage %v, got %v", expectedConfig.Usages, cfg.Usages)
   304  	}
   305  
   306  	if len(cfg.AltNames.IPs) != 1 || cfg.AltNames.IPs[0].String() != expectedConfig.AltNames.IPs[0].String() {
   307  		t.Errorf("expected SAN IPs %v, got %v", expectedConfig.AltNames.IPs, cfg.AltNames.IPs)
   308  	}
   309  
   310  	if len(cfg.AltNames.DNSNames) != 1 || cfg.AltNames.DNSNames[0] != expectedConfig.AltNames.DNSNames[0] {
   311  		t.Errorf("expected SAN DNSNames %v, got %v", expectedConfig.AltNames.DNSNames, cfg.AltNames.DNSNames)
   312  	}
   313  }
   314  
   315  func makeTestCertConfig(organization []string, notBefore, notAfter time.Time) *pkiutil.CertConfig {
   316  	return &pkiutil.CertConfig{
   317  		Config: certutil.Config{
   318  			CommonName:   "test-common-name",
   319  			Organization: organization,
   320  			AltNames: certutil.AltNames{
   321  				IPs:      []net.IP{netutils.ParseIPSloppy("10.100.0.1")},
   322  				DNSNames: []string{"test-domain.space"},
   323  			},
   324  			Usages:    []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
   325  			NotBefore: notBefore,
   326  		},
   327  		NotAfter: notAfter,
   328  	}
   329  }
   330  
   331  func TestManagerCAs(t *testing.T) {
   332  	tests := []struct {
   333  		name string
   334  		cas  map[string]*CAExpirationHandler
   335  		want []*CAExpirationHandler
   336  	}{
   337  		{
   338  			name: "CAExpirationHandler is sequential",
   339  			cas: map[string]*CAExpirationHandler{
   340  				"foo": {
   341  					Name: "1",
   342  				},
   343  				"bar": {
   344  					Name: "2",
   345  				},
   346  			},
   347  			want: []*CAExpirationHandler{
   348  				{
   349  					Name: "1",
   350  				},
   351  				{
   352  					Name: "2",
   353  				},
   354  			},
   355  		},
   356  		{
   357  			name: "CAExpirationHandler is in reverse order",
   358  			cas: map[string]*CAExpirationHandler{
   359  				"foo": {
   360  					Name: "2",
   361  				},
   362  				"bar": {
   363  					Name: "1",
   364  				},
   365  			},
   366  			want: []*CAExpirationHandler{
   367  				{
   368  					Name: "1",
   369  				},
   370  				{
   371  					Name: "2",
   372  				},
   373  			},
   374  		},
   375  	}
   376  	for _, tt := range tests {
   377  		t.Run(tt.name, func(t *testing.T) {
   378  			rm := &Manager{
   379  				cas: tt.cas,
   380  			}
   381  			if got := rm.CAs(); !reflect.DeepEqual(got, tt.want) {
   382  				t.Errorf("Manager.CAs() = %v, want %v", got, tt.want)
   383  			}
   384  		})
   385  	}
   386  }
   387  
   388  func TestManagerCAExists(t *testing.T) {
   389  	certificateReadWriterExist := fakecertificateReadWriter{
   390  		exist: true,
   391  	}
   392  	certificateReadWriterMissing := fakecertificateReadWriter{
   393  		exist: false,
   394  	}
   395  	tests := []struct {
   396  		name    string
   397  		cas     map[string]*CAExpirationHandler
   398  		caName  string
   399  		want    bool
   400  		wantErr bool
   401  	}{
   402  		{
   403  			name:    "caName does not exist in cas list",
   404  			cas:     map[string]*CAExpirationHandler{},
   405  			caName:  "foo",
   406  			want:    false,
   407  			wantErr: true,
   408  		},
   409  		{
   410  			name: "ca exists",
   411  			cas: map[string]*CAExpirationHandler{
   412  				"foo": {
   413  					Name:       "foo",
   414  					FileName:   "test",
   415  					readwriter: certificateReadWriterExist,
   416  				},
   417  			},
   418  			caName:  "foo",
   419  			want:    true,
   420  			wantErr: false,
   421  		},
   422  		{
   423  			name: "ca does not exist",
   424  			cas: map[string]*CAExpirationHandler{
   425  				"foo": {
   426  					Name:       "foo",
   427  					FileName:   "test",
   428  					readwriter: certificateReadWriterMissing,
   429  				},
   430  			},
   431  			caName:  "foo",
   432  			want:    false,
   433  			wantErr: false,
   434  		},
   435  	}
   436  	for _, tt := range tests {
   437  		t.Run(tt.name, func(t *testing.T) {
   438  			rm := &Manager{
   439  				cas: tt.cas,
   440  			}
   441  			got, err := rm.CAExists(tt.caName)
   442  			if (err != nil) != tt.wantErr {
   443  				t.Errorf("Manager.CAExists() error = %v, wantErr %v", err, tt.wantErr)
   444  				return
   445  			}
   446  			if got != tt.want {
   447  				t.Errorf("Manager.CAExists() = %v, want %v", got, tt.want)
   448  			}
   449  		})
   450  	}
   451  }
   452  
   453  func TestManagerCertificateExists(t *testing.T) {
   454  	certificateReadWriterExist := fakecertificateReadWriter{
   455  		exist: true,
   456  	}
   457  	certificateReadWriterMissing := fakecertificateReadWriter{
   458  		exist: false,
   459  	}
   460  	tests := []struct {
   461  		name         string
   462  		certificates map[string]*CertificateRenewHandler
   463  		certName     string
   464  		want         bool
   465  		wantErr      bool
   466  	}{
   467  		{
   468  			name:         "certName does not exist in certificate list",
   469  			certificates: map[string]*CertificateRenewHandler{},
   470  			certName:     "foo",
   471  			want:         false,
   472  			wantErr:      true,
   473  		},
   474  		{
   475  			name: "certificate exists",
   476  			certificates: map[string]*CertificateRenewHandler{
   477  				"foo": {
   478  					Name:       "foo",
   479  					readwriter: certificateReadWriterExist,
   480  				},
   481  			},
   482  			certName: "foo",
   483  			want:     true,
   484  			wantErr:  false,
   485  		},
   486  		{
   487  			name: "certificate does not exist",
   488  			certificates: map[string]*CertificateRenewHandler{
   489  				"foo": {
   490  					Name:       "foo",
   491  					readwriter: certificateReadWriterMissing,
   492  				},
   493  			},
   494  			certName: "foo",
   495  			want:     false,
   496  			wantErr:  false,
   497  		},
   498  	}
   499  	for _, tt := range tests {
   500  		t.Run(tt.name, func(t *testing.T) {
   501  			rm := &Manager{
   502  				certificates: tt.certificates,
   503  			}
   504  			got, err := rm.CertificateExists(tt.certName)
   505  			if (err != nil) != tt.wantErr {
   506  				t.Errorf("Manager.CertificateExists() error = %v, wantErr %v", err, tt.wantErr)
   507  				return
   508  			}
   509  			if got != tt.want {
   510  				t.Errorf("Manager.CertificateExists() = %v, want %v", got, tt.want)
   511  			}
   512  		})
   513  	}
   514  }