github.com/IBM-Blockchain/fabric-operator@v1.0.4/pkg/certificate/certificate_test.go (about)

     1  /*
     2   * Copyright contributors to the Hyperledger Fabric Operator project
     3   *
     4   * SPDX-License-Identifier: Apache-2.0
     5   *
     6   * Licensed under the Apache License, Version 2.0 (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at:
     9   *
    10   * 	  http://www.apache.org/licenses/LICENSE-2.0
    11   *
    12   * Unless required by applicable law or agreed to in writing, software
    13   * distributed under the License is distributed on an "AS IS" BASIS,
    14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15   * See the License for the specific language governing permissions and
    16   * limitations under the License.
    17   */
    18  
    19  package certificate_test
    20  
    21  import (
    22  	"context"
    23  	"crypto/ecdsa"
    24  	"crypto/elliptic"
    25  	"crypto/rand"
    26  	"crypto/x509"
    27  	"encoding/pem"
    28  	"errors"
    29  	"math/big"
    30  	"time"
    31  
    32  	current "github.com/IBM-Blockchain/fabric-operator/api/v1beta1"
    33  	controllermocks "github.com/IBM-Blockchain/fabric-operator/controllers/mocks"
    34  	"github.com/IBM-Blockchain/fabric-operator/pkg/certificate"
    35  	"github.com/IBM-Blockchain/fabric-operator/pkg/certificate/mocks"
    36  	"github.com/IBM-Blockchain/fabric-operator/pkg/initializer/common"
    37  	"github.com/IBM-Blockchain/fabric-operator/pkg/initializer/common/config"
    38  	. "github.com/onsi/ginkgo/v2"
    39  	. "github.com/onsi/gomega"
    40  	corev1 "k8s.io/api/core/v1"
    41  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    42  	v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    43  	"k8s.io/apimachinery/pkg/runtime"
    44  	"k8s.io/apimachinery/pkg/types"
    45  	"sigs.k8s.io/controller-runtime/pkg/client"
    46  )
    47  
    48  var _ = Describe("Certificate", func() {
    49  	var (
    50  		certificateManager *certificate.CertificateManager
    51  		mockClient         *controllermocks.Client
    52  		mockEnroller       *mocks.Reenroller
    53  		instance           v1.Object
    54  
    55  		certBytes []byte
    56  	)
    57  
    58  	BeforeEach(func() {
    59  		mockClient = &controllermocks.Client{}
    60  		mockEnroller = &mocks.Reenroller{}
    61  
    62  		certificateManager = certificate.New(mockClient, &runtime.Scheme{})
    63  
    64  		instance = &current.IBPPeer{
    65  			ObjectMeta: metav1.ObjectMeta{
    66  				Name:      "peer-1",
    67  				Namespace: "peer-namespace",
    68  				Labels:    map[string]string{},
    69  			},
    70  		}
    71  
    72  		certBytes = createCert(time.Now().Add(time.Hour * 24 * 30)) // expires in 30 days
    73  
    74  		reenrollResponse := &config.Response{
    75  			SignCert: []byte("cert"),
    76  			Keystore: []byte("key"),
    77  		}
    78  
    79  		mockEnroller.ReenrollReturns(reenrollResponse, nil)
    80  		mockClient.UpdateReturns(nil)
    81  
    82  		mockClient.GetStub = func(ctx context.Context, types types.NamespacedName, obj client.Object) error {
    83  			o := obj.(*corev1.Secret)
    84  			switch types.Name {
    85  			case "tls-" + instance.GetName() + "-signcert":
    86  				o.Name = "tls-" + instance.GetName() + "-signcert"
    87  				o.Namespace = instance.GetNamespace()
    88  				o.Data = map[string][]byte{"cert.pem": certBytes}
    89  			case "tls-" + instance.GetName() + "-keystore":
    90  				o.Name = "tls-" + instance.GetName() + "-keystore"
    91  				o.Namespace = instance.GetNamespace()
    92  				o.Data = map[string][]byte{"key.pem": []byte("key")}
    93  			case "ecert-" + instance.GetName() + "-signcert":
    94  				o.Name = "ecert-" + instance.GetName() + "-signcert"
    95  				o.Namespace = instance.GetNamespace()
    96  				o.Data = map[string][]byte{"cert.pem": certBytes}
    97  			case "ecert-" + instance.GetName() + "-keystore":
    98  				o.Name = "ecert-" + instance.GetName() + "-keystore"
    99  				o.Namespace = instance.GetNamespace()
   100  				o.Data = map[string][]byte{"key.pem": []byte("key")}
   101  			}
   102  			return nil
   103  		}
   104  	})
   105  
   106  	Context("get expire date", func() {
   107  		It("returns error if fails to read certificate", func() {
   108  			certbytes := []byte("invalid")
   109  			_, err := certificateManager.GetExpireDate(certbytes)
   110  			Expect(err).To(HaveOccurred())
   111  			Expect(err.Error()).To(ContainSubstring("failed to get certificate from bytes"))
   112  		})
   113  
   114  		It("returns expire date of certificate", func() {
   115  			expectedtime := time.Now().Add(time.Hour * 24 * 30).UTC()
   116  			expireDate, err := certificateManager.GetExpireDate(certBytes)
   117  			Expect(err).NotTo(HaveOccurred())
   118  			Expect(expireDate.Month()).To(Equal(expectedtime.Month()))
   119  			Expect(expireDate.Day()).To(Equal(expectedtime.Day()))
   120  			Expect(expireDate.Year()).To(Equal(expectedtime.Year()))
   121  		})
   122  	})
   123  
   124  	Context("get duration to next renewal", func() {
   125  		It("returns error if fails to get expire date", func() {
   126  			mockClient.GetStub = func(ctx context.Context, types types.NamespacedName, obj client.Object) error {
   127  				o := obj.(*corev1.Secret)
   128  				o.Name = "tls-" + instance.GetName() + "-signcert"
   129  				o.Namespace = instance.GetNamespace()
   130  				o.Data = map[string][]byte{"cert.pem": []byte("invalid")}
   131  				return nil
   132  			}
   133  			thirtyDaysToSeconds := int64(30 * 24 * 60 * 60)
   134  			_, err := certificateManager.GetDurationToNextRenewal(common.TLS, instance, thirtyDaysToSeconds)
   135  			Expect(err).To(HaveOccurred())
   136  			Expect(err.Error()).To(ContainSubstring("failed to get certificate from bytes"))
   137  		})
   138  
   139  		It("gets duration until next renewal 10 days before expire", func() {
   140  			tenDaysToSeconds := int64(10 * 24 * 60 * 60)
   141  			duration, err := certificateManager.GetDurationToNextRenewal(common.TLS, instance, tenDaysToSeconds)
   142  			Expect(err).NotTo(HaveOccurred())
   143  			Expect(duration.Round(time.Hour)).To(Equal(time.Hour * 24 * 20)) // 10 days before cert that expires in 30 days = 20 days until next renewal
   144  		})
   145  
   146  		It("gets duration until next renewal 31 days before expire", func() {
   147  			thiryOneDaysToSeconds := int64(31 * 24 * 60 * 60)
   148  			duration, err := certificateManager.GetDurationToNextRenewal(common.TLS, instance, thiryOneDaysToSeconds)
   149  			Expect(err).NotTo(HaveOccurred())
   150  			Expect(duration.Round(time.Hour)).To(Equal(time.Duration(0))) // 31 days before cert that expires in 30 days = -1 days until next renewal, so should return 0
   151  		})
   152  	})
   153  
   154  	Context("certificate expiring", func() {
   155  		It("returns false if not expiring", func() {
   156  			tenDaysToSeconds := int64(10 * 24 * 60 * 60)
   157  			expiring, _, err := certificateManager.CertificateExpiring(common.TLS, instance, tenDaysToSeconds)
   158  			Expect(err).NotTo(HaveOccurred())
   159  			Expect(expiring).To(Equal(false))
   160  		})
   161  
   162  		It("returns true if expiring", func() {
   163  			thirtyDaysToSeconds := int64(30 * 24 * 60 * 60)
   164  			expiring, _, err := certificateManager.CertificateExpiring(common.TLS, instance, thirtyDaysToSeconds)
   165  			Expect(err).NotTo(HaveOccurred())
   166  			Expect(expiring).To(Equal(true))
   167  		})
   168  	})
   169  
   170  	Context("check certificates for expire", func() {
   171  		var (
   172  			expiredCert []byte
   173  		)
   174  		BeforeEach(func() {
   175  			expiredCert = createCert(time.Now().Add(-30 * time.Second)) // expired 30 seconds ago
   176  		})
   177  
   178  		It("returns error if fails to get tls signcert expiry info", func() {
   179  			mockClient.GetReturns(errors.New("fake error"))
   180  			_, _, err := certificateManager.CheckCertificatesForExpire(instance, 0)
   181  			Expect(err).To(HaveOccurred())
   182  			Expect(err.Error()).To(ContainSubstring("failed to get tls signcert expiry info"))
   183  		})
   184  
   185  		It("returns deployed status if neither tls nor ecert signcerts are expiring", func() {
   186  			tenDaysToSeconds := int64(10 * 24 * 60 * 60)
   187  			status, message, err := certificateManager.CheckCertificatesForExpire(instance, tenDaysToSeconds)
   188  			Expect(err).NotTo(HaveOccurred())
   189  			Expect(status).To(Equal(current.Deployed))
   190  			Expect(message).To(Equal(""))
   191  		})
   192  
   193  		It("returns warning status if either tls or ecert signcert is expiring", func() {
   194  			thirtyDaysToSeconds := int64(30 * 24 * 60 * 60)
   195  			status, message, err := certificateManager.CheckCertificatesForExpire(instance, thirtyDaysToSeconds)
   196  			Expect(err).NotTo(HaveOccurred())
   197  			Expect(status).To(Equal(current.Warning))
   198  			Expect(message).To(ContainSubstring("tls-peer-1-signcert expires on"))
   199  			Expect(message).To(ContainSubstring("ecert-peer-1-signcert expires on"))
   200  		})
   201  
   202  		It("returns error status if either tls or ecert signcert has expired", func() {
   203  			mockClient.GetStub = func(ctx context.Context, types types.NamespacedName, obj client.Object) error {
   204  				o := obj.(*corev1.Secret)
   205  				switch types.Name {
   206  				case "tls-" + instance.GetName() + "-signcert":
   207  					o.Name = "tls-" + instance.GetName() + "-signcert"
   208  					o.Namespace = instance.GetNamespace()
   209  					o.Data = map[string][]byte{"cert.pem": expiredCert}
   210  				case "ecert-" + instance.GetName() + "-signcert":
   211  					o.Name = "ecert-" + instance.GetName() + "-signcert"
   212  					o.Namespace = instance.GetNamespace()
   213  					o.Data = map[string][]byte{"cert.pem": certBytes}
   214  				}
   215  				return nil
   216  			}
   217  			thirtyDaysToSeconds := int64(30 * 24 * 60 * 60)
   218  			status, message, err := certificateManager.CheckCertificatesForExpire(instance, thirtyDaysToSeconds)
   219  			Expect(err).NotTo(HaveOccurred())
   220  			Expect(status).To(Equal(current.Error))
   221  			Expect(message).To(ContainSubstring("tls-peer-1-signcert has expired"))
   222  			Expect(message).To(ContainSubstring("ecert-peer-1-signcert expires on"))
   223  		})
   224  	})
   225  
   226  	Context("reenroll cert", func() {
   227  		When("not using HSM", func() {
   228  			It("returns error if enroller not passed", func() {
   229  				err := certificateManager.ReenrollCert("tls", nil, instance, false)
   230  				Expect(err).To(HaveOccurred())
   231  				Expect(err.Error()).To(Equal("reenroller not passed"))
   232  			})
   233  
   234  			It("returns error if reenroll returns error", func() {
   235  				mockEnroller.ReenrollReturns(nil, errors.New("fake error"))
   236  				err := certificateManager.ReenrollCert("tls", mockEnroller, instance, false)
   237  				Expect(err).To(HaveOccurred())
   238  				Expect(err.Error()).To(Equal("failed to renew tls certificate for instance 'peer-1': fake error"))
   239  			})
   240  
   241  			It("returns error if failed to update signcert secret", func() {
   242  				mockClient.UpdateReturns(errors.New("fake error"))
   243  				err := certificateManager.ReenrollCert("tls", mockEnroller, instance, false)
   244  				Expect(err).To(HaveOccurred())
   245  				Expect(err.Error()).To(Equal("failed to update signcert secret for instance 'peer-1': fake error"))
   246  			})
   247  
   248  			It("returns error if failed to update keystore secret", func() {
   249  				mockClient.UpdateReturnsOnCall(1, errors.New("fake error"))
   250  				err := certificateManager.ReenrollCert("tls", mockEnroller, instance, false)
   251  				Expect(err).To(HaveOccurred())
   252  				Expect(err.Error()).To(Equal("failed to update keystore secret for instance 'peer-1': fake error"))
   253  			})
   254  
   255  			It("renews certificate", func() {
   256  				err := certificateManager.ReenrollCert("tls", mockEnroller, instance, false)
   257  				Expect(err).NotTo(HaveOccurred())
   258  
   259  				By("updating cert and key secret", func() {
   260  					Expect(mockClient.UpdateCallCount()).To(Equal(2))
   261  				})
   262  			})
   263  		})
   264  
   265  		When("using HSM", func() {
   266  			It("only updates cert secret", func() {
   267  				err := certificateManager.ReenrollCert("tls", mockEnroller, instance, true)
   268  				Expect(err).NotTo(HaveOccurred())
   269  				Expect(mockClient.UpdateCallCount()).To(Equal(1))
   270  			})
   271  		})
   272  	})
   273  
   274  	Context("update signcert", func() {
   275  		It("returns error if client fails to update secret", func() {
   276  			mockClient.UpdateReturns(errors.New("fake error"))
   277  			err := certificateManager.UpdateSignCert("secret-name", []byte("cert"), instance)
   278  			Expect(err).To(HaveOccurred())
   279  			Expect(err.Error()).To(Equal("fake error"))
   280  		})
   281  
   282  		It("updates signcert secret", func() {
   283  			err := certificateManager.UpdateSignCert("secret-name", []byte("cert"), instance)
   284  			Expect(err).NotTo(HaveOccurred())
   285  		})
   286  	})
   287  
   288  	Context("update key", func() {
   289  		It("returns error if client fails to update secret", func() {
   290  			mockClient.UpdateReturns(errors.New("fake error"))
   291  			err := certificateManager.UpdateKey("secret-name", []byte("cert"), instance)
   292  			Expect(err).To(HaveOccurred())
   293  			Expect(err.Error()).To(Equal("fake error"))
   294  		})
   295  
   296  		It("updates keystore secret", func() {
   297  			err := certificateManager.UpdateKey("secret-name", []byte("cert"), instance)
   298  			Expect(err).NotTo(HaveOccurred())
   299  		})
   300  	})
   301  
   302  	Context("update secret", func() {
   303  		It("returns error if client call for update fails", func() {
   304  			mockClient.UpdateReturns(errors.New("fake error"))
   305  			err := certificateManager.UpdateSecret(instance, "secret-name", map[string][]byte{})
   306  			Expect(err).To(HaveOccurred())
   307  			Expect(err.Error()).To(Equal("fake error"))
   308  		})
   309  
   310  		It("updates secret", func() {
   311  			err := certificateManager.UpdateSecret(instance, "secret-name", map[string][]byte{})
   312  			Expect(err).NotTo(HaveOccurred())
   313  		})
   314  	})
   315  
   316  	Context("get signcert and key", func() {
   317  		When("not using HSM", func() {
   318  			It("returns an error if fails to get secret", func() {
   319  				mockClient.GetReturns(errors.New("fake error"))
   320  				_, _, err := certificateManager.GetSignCertAndKey("tls", instance, false)
   321  				Expect(err).To(HaveOccurred())
   322  				Expect(err.Error()).To(Equal("fake error"))
   323  			})
   324  
   325  			It("gets signcert and key", func() {
   326  				cert, key, err := certificateManager.GetSignCertAndKey("tls", instance, false)
   327  				Expect(err).NotTo(HaveOccurred())
   328  				Expect(cert).NotTo(BeNil())
   329  				Expect(key).NotTo(BeNil())
   330  			})
   331  		})
   332  
   333  		When("using HSM", func() {
   334  			It("gets signcert and empty key", func() {
   335  				cert, key, err := certificateManager.GetSignCertAndKey("tls", instance, true)
   336  				Expect(err).NotTo(HaveOccurred())
   337  				Expect(cert).NotTo(BeNil())
   338  				Expect(len(key)).To(Equal(0))
   339  			})
   340  		})
   341  	})
   342  })
   343  
   344  func createCert(expireDate time.Time) []byte {
   345  	certtemplate := x509.Certificate{
   346  		SerialNumber: big.NewInt(1),
   347  		NotAfter:     expireDate,
   348  	}
   349  
   350  	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   351  	Expect(err).NotTo(HaveOccurred())
   352  
   353  	cert, err := x509.CreateCertificate(rand.Reader, &certtemplate, &certtemplate, &priv.PublicKey, priv)
   354  	Expect(err).NotTo(HaveOccurred())
   355  
   356  	block := &pem.Block{
   357  		Type:  "CERTIFICATE",
   358  		Bytes: cert,
   359  	}
   360  
   361  	return pem.EncodeToMemory(block)
   362  }