github.com/slackhq/nebula@v1.9.0/cert/ca.go (about)

     1  package cert
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"strings"
     7  	"time"
     8  )
     9  
    10  type NebulaCAPool struct {
    11  	CAs           map[string]*NebulaCertificate
    12  	certBlocklist map[string]struct{}
    13  }
    14  
    15  // NewCAPool creates a CAPool
    16  func NewCAPool() *NebulaCAPool {
    17  	ca := NebulaCAPool{
    18  		CAs:           make(map[string]*NebulaCertificate),
    19  		certBlocklist: make(map[string]struct{}),
    20  	}
    21  
    22  	return &ca
    23  }
    24  
    25  // NewCAPoolFromBytes will create a new CA pool from the provided
    26  // input bytes, which must be a PEM-encoded set of nebula certificates.
    27  // If the pool contains any expired certificates, an ErrExpired will be
    28  // returned along with the pool. The caller must handle any such errors.
    29  func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
    30  	pool := NewCAPool()
    31  	var err error
    32  	var expired bool
    33  	for {
    34  		caPEMs, err = pool.AddCACertificate(caPEMs)
    35  		if errors.Is(err, ErrExpired) {
    36  			expired = true
    37  			err = nil
    38  		}
    39  		if err != nil {
    40  			return nil, err
    41  		}
    42  		if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
    43  			break
    44  		}
    45  	}
    46  
    47  	if expired {
    48  		return pool, ErrExpired
    49  	}
    50  
    51  	return pool, nil
    52  }
    53  
    54  // AddCACertificate verifies a Nebula CA certificate and adds it to the pool
    55  // Only the first pem encoded object will be consumed, any remaining bytes are returned.
    56  // Parsed certificates will be verified and must be a CA
    57  func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
    58  	c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes)
    59  	if err != nil {
    60  		return pemBytes, err
    61  	}
    62  
    63  	if !c.Details.IsCA {
    64  		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA)
    65  	}
    66  
    67  	if !c.CheckSignature(c.Details.PublicKey) {
    68  		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned)
    69  	}
    70  
    71  	sum, err := c.Sha256Sum()
    72  	if err != nil {
    73  		return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name)
    74  	}
    75  
    76  	ncp.CAs[sum] = c
    77  	if c.Expired(time.Now()) {
    78  		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired)
    79  	}
    80  
    81  	return pemBytes, nil
    82  }
    83  
    84  // BlocklistFingerprint adds a cert fingerprint to the blocklist
    85  func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
    86  	ncp.certBlocklist[f] = struct{}{}
    87  }
    88  
    89  // ResetCertBlocklist removes all previously blocklisted cert fingerprints
    90  func (ncp *NebulaCAPool) ResetCertBlocklist() {
    91  	ncp.certBlocklist = make(map[string]struct{})
    92  }
    93  
    94  // NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated
    95  // automatically if you manually change any fields in the NebulaCertificate.
    96  func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
    97  	return ncp.isBlocklistedWithCache(c, false)
    98  }
    99  
   100  // IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
   101  func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool {
   102  	h, err := c.sha256SumWithCache(useCache)
   103  	if err != nil {
   104  		return true
   105  	}
   106  
   107  	if _, ok := ncp.certBlocklist[h]; ok {
   108  		return true
   109  	}
   110  
   111  	return false
   112  }
   113  
   114  // GetCAForCert attempts to return the signing certificate for the provided certificate.
   115  // No signature validation is performed
   116  func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) {
   117  	if c.Details.Issuer == "" {
   118  		return nil, fmt.Errorf("no issuer in certificate")
   119  	}
   120  
   121  	signer, ok := ncp.CAs[c.Details.Issuer]
   122  	if ok {
   123  		return signer, nil
   124  	}
   125  
   126  	return nil, fmt.Errorf("could not find ca for the certificate")
   127  }
   128  
   129  // GetFingerprints returns an array of trusted CA fingerprints
   130  func (ncp *NebulaCAPool) GetFingerprints() []string {
   131  	fp := make([]string, len(ncp.CAs))
   132  
   133  	i := 0
   134  	for k := range ncp.CAs {
   135  		fp[i] = k
   136  		i++
   137  	}
   138  
   139  	return fp
   140  }