github.com/slackhq/nebula@v1.9.0/pki.go (about)

     1  package nebula
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"os"
     7  	"strings"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"github.com/sirupsen/logrus"
    12  	"github.com/slackhq/nebula/cert"
    13  	"github.com/slackhq/nebula/config"
    14  	"github.com/slackhq/nebula/util"
    15  )
    16  
    17  type PKI struct {
    18  	cs     atomic.Pointer[CertState]
    19  	caPool atomic.Pointer[cert.NebulaCAPool]
    20  	l      *logrus.Logger
    21  }
    22  
    23  type CertState struct {
    24  	Certificate         *cert.NebulaCertificate
    25  	RawCertificate      []byte
    26  	RawCertificateNoKey []byte
    27  	PublicKey           []byte
    28  	PrivateKey          []byte
    29  }
    30  
    31  func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
    32  	pki := &PKI{l: l}
    33  	err := pki.reload(c, true)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	c.RegisterReloadCallback(func(c *config.C) {
    39  		rErr := pki.reload(c, false)
    40  		if rErr != nil {
    41  			util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
    42  		}
    43  	})
    44  
    45  	return pki, nil
    46  }
    47  
    48  func (p *PKI) GetCertState() *CertState {
    49  	return p.cs.Load()
    50  }
    51  
    52  func (p *PKI) GetCAPool() *cert.NebulaCAPool {
    53  	return p.caPool.Load()
    54  }
    55  
    56  func (p *PKI) reload(c *config.C, initial bool) error {
    57  	err := p.reloadCert(c, initial)
    58  	if err != nil {
    59  		if initial {
    60  			return err
    61  		}
    62  		err.Log(p.l)
    63  	}
    64  
    65  	err = p.reloadCAPool(c)
    66  	if err != nil {
    67  		if initial {
    68  			return err
    69  		}
    70  		err.Log(p.l)
    71  	}
    72  
    73  	return nil
    74  }
    75  
    76  func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
    77  	cs, err := newCertStateFromConfig(c)
    78  	if err != nil {
    79  		return util.NewContextualError("Could not load client cert", nil, err)
    80  	}
    81  
    82  	if !initial {
    83  		// did IP in cert change? if so, don't set
    84  		currentCert := p.cs.Load().Certificate
    85  		oldIPs := currentCert.Details.Ips
    86  		newIPs := cs.Certificate.Details.Ips
    87  		if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
    88  			return util.NewContextualError(
    89  				"IP in new cert was different from old",
    90  				m{"new_ip": newIPs[0], "old_ip": oldIPs[0]},
    91  				nil,
    92  			)
    93  		}
    94  	}
    95  
    96  	p.cs.Store(cs)
    97  	if initial {
    98  		p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
    99  	} else {
   100  		p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
   101  	}
   102  	return nil
   103  }
   104  
   105  func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
   106  	caPool, err := loadCAPoolFromConfig(p.l, c)
   107  	if err != nil {
   108  		return util.NewContextualError("Failed to load ca from config", nil, err)
   109  	}
   110  
   111  	p.caPool.Store(caPool)
   112  	p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
   113  	return nil
   114  }
   115  
   116  func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
   117  	// Marshal the certificate to ensure it is valid
   118  	rawCertificate, err := certificate.Marshal()
   119  	if err != nil {
   120  		return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
   121  	}
   122  
   123  	publicKey := certificate.Details.PublicKey
   124  	cs := &CertState{
   125  		RawCertificate: rawCertificate,
   126  		Certificate:    certificate,
   127  		PrivateKey:     privateKey,
   128  		PublicKey:      publicKey,
   129  	}
   130  
   131  	cs.Certificate.Details.PublicKey = nil
   132  	rawCertNoKey, err := cs.Certificate.Marshal()
   133  	if err != nil {
   134  		return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
   135  	}
   136  	cs.RawCertificateNoKey = rawCertNoKey
   137  	// put public key back
   138  	cs.Certificate.Details.PublicKey = cs.PublicKey
   139  	return cs, nil
   140  }
   141  
   142  func newCertStateFromConfig(c *config.C) (*CertState, error) {
   143  	var pemPrivateKey []byte
   144  	var err error
   145  
   146  	privPathOrPEM := c.GetString("pki.key", "")
   147  	if privPathOrPEM == "" {
   148  		return nil, errors.New("no pki.key path or PEM data provided")
   149  	}
   150  
   151  	if strings.Contains(privPathOrPEM, "-----BEGIN") {
   152  		pemPrivateKey = []byte(privPathOrPEM)
   153  		privPathOrPEM = "<inline>"
   154  
   155  	} else {
   156  		pemPrivateKey, err = os.ReadFile(privPathOrPEM)
   157  		if err != nil {
   158  			return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
   159  		}
   160  	}
   161  
   162  	rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
   163  	if err != nil {
   164  		return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
   165  	}
   166  
   167  	var rawCert []byte
   168  
   169  	pubPathOrPEM := c.GetString("pki.cert", "")
   170  	if pubPathOrPEM == "" {
   171  		return nil, errors.New("no pki.cert path or PEM data provided")
   172  	}
   173  
   174  	if strings.Contains(pubPathOrPEM, "-----BEGIN") {
   175  		rawCert = []byte(pubPathOrPEM)
   176  		pubPathOrPEM = "<inline>"
   177  
   178  	} else {
   179  		rawCert, err = os.ReadFile(pubPathOrPEM)
   180  		if err != nil {
   181  			return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
   182  		}
   183  	}
   184  
   185  	nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
   186  	if err != nil {
   187  		return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
   188  	}
   189  
   190  	if nebulaCert.Expired(time.Now()) {
   191  		return nil, fmt.Errorf("nebula certificate for this host is expired")
   192  	}
   193  
   194  	if len(nebulaCert.Details.Ips) == 0 {
   195  		return nil, fmt.Errorf("no IPs encoded in certificate")
   196  	}
   197  
   198  	if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
   199  		return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
   200  	}
   201  
   202  	return newCertState(nebulaCert, rawKey)
   203  }
   204  
   205  func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
   206  	var rawCA []byte
   207  	var err error
   208  
   209  	caPathOrPEM := c.GetString("pki.ca", "")
   210  	if caPathOrPEM == "" {
   211  		return nil, errors.New("no pki.ca path or PEM data provided")
   212  	}
   213  
   214  	if strings.Contains(caPathOrPEM, "-----BEGIN") {
   215  		rawCA = []byte(caPathOrPEM)
   216  
   217  	} else {
   218  		rawCA, err = os.ReadFile(caPathOrPEM)
   219  		if err != nil {
   220  			return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
   221  		}
   222  	}
   223  
   224  	caPool, err := cert.NewCAPoolFromBytes(rawCA)
   225  	if errors.Is(err, cert.ErrExpired) {
   226  		var expired int
   227  		for _, crt := range caPool.CAs {
   228  			if crt.Expired(time.Now()) {
   229  				expired++
   230  				l.WithField("cert", crt).Warn("expired certificate present in CA pool")
   231  			}
   232  		}
   233  
   234  		if expired >= len(caPool.CAs) {
   235  			return nil, errors.New("no valid CA certificates present")
   236  		}
   237  
   238  	} else if err != nil {
   239  		return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
   240  	}
   241  
   242  	for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
   243  		l.WithField("fingerprint", fp).Info("Blocklisting cert")
   244  		caPool.BlocklistFingerprint(fp)
   245  	}
   246  
   247  	return caPool, nil
   248  }