istio.io/istio@v0.0.0-20240520182934-d79c90f27776/security/pkg/pki/util/keycertbundle.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Provides utility methods to generate X.509 certificates with different
    16  // options. This implementation is Largely inspired from
    17  // https://golang.org/src/crypto/tls/generate_cert.go.
    18  
    19  package util
    20  
    21  import (
    22  	"crypto"
    23  	"crypto/ecdsa"
    24  	"crypto/rsa"
    25  	"crypto/tls"
    26  	"crypto/x509"
    27  	"errors"
    28  	"fmt"
    29  	"os"
    30  	"sync"
    31  	"time"
    32  )
    33  
    34  // KeyCertBundle stores the cert, private key, cert chain and root cert for an entity. It is thread safe.
    35  // The cert and privKey should be a public/private key pair.
    36  // The cert should be verifiable from the rootCert through the certChain.
    37  // cert and priveKey are pointers to the cert/key parsed from certBytes/privKeyBytes.
    38  type KeyCertBundle struct {
    39  	certBytes      []byte
    40  	cert           *x509.Certificate
    41  	privKeyBytes   []byte
    42  	privKey        *crypto.PrivateKey
    43  	certChainBytes []byte
    44  	rootCertBytes  []byte
    45  	// mutex protects the R/W to all keys and certs.
    46  	mutex sync.RWMutex
    47  }
    48  
    49  // NewKeyCertBundleFromPem returns a new KeyCertBundle, regardless of whether or not the key can be correctly parsed.
    50  func NewKeyCertBundleFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) *KeyCertBundle {
    51  	bundle := &KeyCertBundle{}
    52  	bundle.setAllFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes)
    53  	return bundle
    54  }
    55  
    56  // NewVerifiedKeyCertBundleFromPem returns a new KeyCertBundle, or error if the provided certs failed the
    57  // verification.
    58  func NewVerifiedKeyCertBundleFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) (
    59  	*KeyCertBundle, error,
    60  ) {
    61  	bundle := &KeyCertBundle{}
    62  	if err := bundle.VerifyAndSetAll(certBytes, privKeyBytes, certChainBytes, rootCertBytes); err != nil {
    63  		return nil, err
    64  	}
    65  	return bundle, nil
    66  }
    67  
    68  // NewVerifiedKeyCertBundleFromFile returns a new KeyCertBundle, or error if the provided certs failed the
    69  // verification.
    70  func NewVerifiedKeyCertBundleFromFile(certFile string, privKeyFile string, certChainFiles []string, rootCertFile string) (
    71  	*KeyCertBundle, error,
    72  ) {
    73  	certBytes, err := os.ReadFile(certFile)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  	privKeyBytes, err := os.ReadFile(privKeyFile)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	var certChainBytes []byte
    82  	if len(certChainFiles) > 0 {
    83  		for _, f := range certChainFiles {
    84  			var b []byte
    85  
    86  			if b, err = os.ReadFile(f); err != nil {
    87  				return nil, err
    88  			}
    89  
    90  			certChainBytes = append(certChainBytes, b...)
    91  		}
    92  	}
    93  	rootCertBytes, err := os.ReadFile(rootCertFile)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	return NewVerifiedKeyCertBundleFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes)
    98  }
    99  
   100  // NewKeyCertBundleWithRootCertFromFile returns a new KeyCertBundle with the root cert without verification.
   101  func NewKeyCertBundleWithRootCertFromFile(rootCertFile string) (*KeyCertBundle, error) {
   102  	var rootCertBytes []byte
   103  	var err error
   104  	if rootCertFile == "" {
   105  		rootCertBytes = []byte{}
   106  	} else {
   107  		rootCertBytes, err = os.ReadFile(rootCertFile)
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  	}
   112  	return &KeyCertBundle{
   113  		certBytes:      []byte{},
   114  		cert:           nil,
   115  		privKeyBytes:   []byte{},
   116  		privKey:        nil,
   117  		certChainBytes: []byte{},
   118  		rootCertBytes:  rootCertBytes,
   119  	}, nil
   120  }
   121  
   122  // GetAllPem returns all key/cert PEMs in KeyCertBundle together. Getting all values together avoids inconsistency.
   123  func (b *KeyCertBundle) GetAllPem() (certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) {
   124  	b.mutex.RLock()
   125  	certBytes = copyBytes(b.certBytes)
   126  	privKeyBytes = copyBytes(b.privKeyBytes)
   127  	certChainBytes = copyBytes(b.certChainBytes)
   128  	rootCertBytes = copyBytes(b.rootCertBytes)
   129  	b.mutex.RUnlock()
   130  	return
   131  }
   132  
   133  // GetAll returns all key/cert in KeyCertBundle together. Getting all values together avoids inconsistency.
   134  // NOTE: Callers should not modify the content of cert and privKey.
   135  func (b *KeyCertBundle) GetAll() (cert *x509.Certificate, privKey *crypto.PrivateKey, certChainBytes,
   136  	rootCertBytes []byte,
   137  ) {
   138  	b.mutex.RLock()
   139  	cert = b.cert
   140  	privKey = b.privKey
   141  	certChainBytes = copyBytes(b.certChainBytes)
   142  	rootCertBytes = copyBytes(b.rootCertBytes)
   143  	b.mutex.RUnlock()
   144  	return
   145  }
   146  
   147  // GetCertChainPem returns the certificate chain PEM.
   148  func (b *KeyCertBundle) GetCertChainPem() []byte {
   149  	b.mutex.RLock()
   150  	defer b.mutex.RUnlock()
   151  	return copyBytes(b.certChainBytes)
   152  }
   153  
   154  // GetRootCertPem returns the root certificate PEM.
   155  func (b *KeyCertBundle) GetRootCertPem() []byte {
   156  	b.mutex.RLock()
   157  	defer b.mutex.RUnlock()
   158  	return copyBytes(b.rootCertBytes)
   159  }
   160  
   161  // VerifyAndSetAll verifies the key/certs, and sets all key/certs in KeyCertBundle together.
   162  // Setting all values together avoids inconsistency.
   163  func (b *KeyCertBundle) VerifyAndSetAll(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) error {
   164  	if err := Verify(certBytes, privKeyBytes, certChainBytes, rootCertBytes); err != nil {
   165  		return err
   166  	}
   167  	b.setAllFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes)
   168  	return nil
   169  }
   170  
   171  // Setting all values together avoids inconsistency.
   172  func (b *KeyCertBundle) setAllFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) {
   173  	b.mutex.Lock()
   174  	b.certBytes = copyBytes(certBytes)
   175  	b.privKeyBytes = copyBytes(privKeyBytes)
   176  	b.certChainBytes = copyBytes(certChainBytes)
   177  	b.rootCertBytes = copyBytes(rootCertBytes)
   178  	// cert and privKey are always reset to point to new addresses. This avoids modifying the pointed structs that
   179  	// could be still used outside of the class.
   180  	b.cert, _ = ParsePemEncodedCertificate(certBytes)
   181  	privKey, _ := ParsePemEncodedKey(privKeyBytes)
   182  	b.privKey = &privKey
   183  	b.mutex.Unlock()
   184  }
   185  
   186  // CertOptions returns the certificate config based on currently stored cert.
   187  func (b *KeyCertBundle) CertOptions() (*CertOptions, error) {
   188  	b.mutex.RLock()
   189  	defer b.mutex.RUnlock()
   190  	ids, err := ExtractIDs(b.cert.Extensions)
   191  	if err != nil {
   192  		return nil, fmt.Errorf("failed to extract id %v", err)
   193  	}
   194  	if len(ids) != 1 {
   195  		return nil, fmt.Errorf("expect single id from the cert, found %v", ids)
   196  	}
   197  
   198  	opts := &CertOptions{
   199  		Host:      ids[0],
   200  		Org:       b.cert.Issuer.Organization[0],
   201  		IsCA:      b.cert.IsCA,
   202  		TTL:       b.cert.NotAfter.Sub(b.cert.NotBefore),
   203  		IsDualUse: ids[0] == b.cert.Subject.CommonName,
   204  	}
   205  
   206  	switch (*b.privKey).(type) {
   207  	case *rsa.PrivateKey:
   208  		size, err := GetRSAKeySize(*b.privKey)
   209  		if err != nil {
   210  			return nil, fmt.Errorf("failed to get RSA key size: %v", err)
   211  		}
   212  		opts.RSAKeySize = size
   213  	case *ecdsa.PrivateKey:
   214  		opts.ECSigAlg = EcdsaSigAlg
   215  	default:
   216  		return nil, errors.New("unknown private key type")
   217  	}
   218  
   219  	return opts, nil
   220  }
   221  
   222  // UpdateVerifiedKeyCertBundleFromFile Verifies and updates KeyCertBundle with new certs
   223  func (b *KeyCertBundle) UpdateVerifiedKeyCertBundleFromFile(certFile string, privKeyFile string, certChainFiles []string, rootCertFile string) error {
   224  	certBytes, err := os.ReadFile(certFile)
   225  	if err != nil {
   226  		return err
   227  	}
   228  	privKeyBytes, err := os.ReadFile(privKeyFile)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	certChainBytes := []byte{}
   233  	if len(certChainFiles) != 0 {
   234  		for _, f := range certChainFiles {
   235  			var b []byte
   236  			if b, err = os.ReadFile(f); err != nil {
   237  				return err
   238  			}
   239  
   240  			certChainBytes = append(certChainBytes, b...)
   241  		}
   242  	}
   243  	rootCertBytes, err := os.ReadFile(rootCertFile)
   244  	if err != nil {
   245  		return err
   246  	}
   247  
   248  	err = b.VerifyAndSetAll(certBytes, privKeyBytes, certChainBytes, rootCertBytes)
   249  	if err != nil {
   250  		return err
   251  	}
   252  
   253  	return nil
   254  }
   255  
   256  // ExtractRootCertExpiryTimestamp returns the unix timestamp when the root becomes expires.
   257  func (b *KeyCertBundle) ExtractRootCertExpiryTimestamp() (float64, error) {
   258  	return extractCertExpiryTimestamp("root cert", b.GetRootCertPem())
   259  }
   260  
   261  // ExtractCACertExpiryTimestamp returns the unix timestamp when the cert chain becomes expires.
   262  func (b *KeyCertBundle) ExtractCACertExpiryTimestamp() (float64, error) {
   263  	return extractCertExpiryTimestamp("CA cert", b.GetCertChainPem())
   264  }
   265  
   266  // TimeBeforeCertExpires returns the time duration before the cert gets expired.
   267  // It returns an error if it failed to extract the cert expiration timestamp.
   268  // The returned time duration could be a negative value indicating the cert has already been expired.
   269  func TimeBeforeCertExpires(certBytes []byte, now time.Time) (time.Duration, error) {
   270  	if len(certBytes) == 0 {
   271  		return 0, fmt.Errorf("no certificate found")
   272  	}
   273  
   274  	certExpiryTimestamp, err := extractCertExpiryTimestamp("cert", certBytes)
   275  	if err != nil {
   276  		return 0, fmt.Errorf("failed to extract cert expiration timestamp: %v", err)
   277  	}
   278  
   279  	certExpiry := time.Duration(certExpiryTimestamp-float64(now.Unix())) * time.Second
   280  	return certExpiry, nil
   281  }
   282  
   283  // Verify that the cert chain, root cert and key/cert match.
   284  func Verify(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) error {
   285  	// Verify the cert can be verified from the root cert through the cert chain.
   286  	rcp := x509.NewCertPool()
   287  	rcp.AppendCertsFromPEM(rootCertBytes)
   288  
   289  	icp := x509.NewCertPool()
   290  	icp.AppendCertsFromPEM(certChainBytes)
   291  
   292  	opts := x509.VerifyOptions{
   293  		Intermediates: icp,
   294  		Roots:         rcp,
   295  	}
   296  	cert, err := ParsePemEncodedCertificate(certBytes)
   297  	if err != nil {
   298  		return fmt.Errorf("failed to parse cert PEM: %v", err)
   299  	}
   300  	chains, err := cert.Verify(opts)
   301  
   302  	if len(chains) == 0 || err != nil {
   303  		return fmt.Errorf(
   304  			"cannot verify the cert with the provided root chain and cert "+
   305  				"pool with error: %v", err)
   306  	}
   307  
   308  	// Verify that the key can be correctly parsed.
   309  	if _, err = ParsePemEncodedKey(privKeyBytes); err != nil {
   310  		return fmt.Errorf("failed to parse private key PEM: %v", err)
   311  	}
   312  
   313  	// Verify the cert and key match.
   314  	if _, err := tls.X509KeyPair(certBytes, privKeyBytes); err != nil {
   315  		return fmt.Errorf("the cert does not match the key: %v", err)
   316  	}
   317  
   318  	return nil
   319  }
   320  
   321  func extractCertExpiryTimestamp(certType string, certPem []byte) (float64, error) {
   322  	cert, err := ParsePemEncodedCertificate(certPem)
   323  	if err != nil {
   324  		return -1, fmt.Errorf("failed to parse the %s: %v", certType, err)
   325  	}
   326  
   327  	end := cert.NotAfter
   328  	expiryTimestamp := float64(end.Unix())
   329  	if end.Before(time.Now()) {
   330  		return expiryTimestamp, fmt.Errorf("expired %s found, x509.NotAfter %v, please transit your %s", certType, end, certType)
   331  	}
   332  	return expiryTimestamp, nil
   333  }
   334  
   335  func copyBytes(src []byte) []byte {
   336  	bs := make([]byte, len(src))
   337  	copy(bs, src)
   338  	return bs
   339  }