github.com/slackhq/nebula@v1.9.0/pki.go (about) 1 package nebula 2 3 import ( 4 "errors" 5 "fmt" 6 "os" 7 "strings" 8 "sync/atomic" 9 "time" 10 11 "github.com/sirupsen/logrus" 12 "github.com/slackhq/nebula/cert" 13 "github.com/slackhq/nebula/config" 14 "github.com/slackhq/nebula/util" 15 ) 16 17 type PKI struct { 18 cs atomic.Pointer[CertState] 19 caPool atomic.Pointer[cert.NebulaCAPool] 20 l *logrus.Logger 21 } 22 23 type CertState struct { 24 Certificate *cert.NebulaCertificate 25 RawCertificate []byte 26 RawCertificateNoKey []byte 27 PublicKey []byte 28 PrivateKey []byte 29 } 30 31 func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { 32 pki := &PKI{l: l} 33 err := pki.reload(c, true) 34 if err != nil { 35 return nil, err 36 } 37 38 c.RegisterReloadCallback(func(c *config.C) { 39 rErr := pki.reload(c, false) 40 if rErr != nil { 41 util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l) 42 } 43 }) 44 45 return pki, nil 46 } 47 48 func (p *PKI) GetCertState() *CertState { 49 return p.cs.Load() 50 } 51 52 func (p *PKI) GetCAPool() *cert.NebulaCAPool { 53 return p.caPool.Load() 54 } 55 56 func (p *PKI) reload(c *config.C, initial bool) error { 57 err := p.reloadCert(c, initial) 58 if err != nil { 59 if initial { 60 return err 61 } 62 err.Log(p.l) 63 } 64 65 err = p.reloadCAPool(c) 66 if err != nil { 67 if initial { 68 return err 69 } 70 err.Log(p.l) 71 } 72 73 return nil 74 } 75 76 func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { 77 cs, err := newCertStateFromConfig(c) 78 if err != nil { 79 return util.NewContextualError("Could not load client cert", nil, err) 80 } 81 82 if !initial { 83 // did IP in cert change? if so, don't set 84 currentCert := p.cs.Load().Certificate 85 oldIPs := currentCert.Details.Ips 86 newIPs := cs.Certificate.Details.Ips 87 if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { 88 return util.NewContextualError( 89 "IP in new cert was different from old", 90 m{"new_ip": newIPs[0], "old_ip": oldIPs[0]}, 91 nil, 92 ) 93 } 94 } 95 96 p.cs.Store(cs) 97 if initial { 98 p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") 99 } else { 100 p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") 101 } 102 return nil 103 } 104 105 func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { 106 caPool, err := loadCAPoolFromConfig(p.l, c) 107 if err != nil { 108 return util.NewContextualError("Failed to load ca from config", nil, err) 109 } 110 111 p.caPool.Store(caPool) 112 p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") 113 return nil 114 } 115 116 func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { 117 // Marshal the certificate to ensure it is valid 118 rawCertificate, err := certificate.Marshal() 119 if err != nil { 120 return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) 121 } 122 123 publicKey := certificate.Details.PublicKey 124 cs := &CertState{ 125 RawCertificate: rawCertificate, 126 Certificate: certificate, 127 PrivateKey: privateKey, 128 PublicKey: publicKey, 129 } 130 131 cs.Certificate.Details.PublicKey = nil 132 rawCertNoKey, err := cs.Certificate.Marshal() 133 if err != nil { 134 return nil, fmt.Errorf("error marshalling certificate no key: %s", err) 135 } 136 cs.RawCertificateNoKey = rawCertNoKey 137 // put public key back 138 cs.Certificate.Details.PublicKey = cs.PublicKey 139 return cs, nil 140 } 141 142 func newCertStateFromConfig(c *config.C) (*CertState, error) { 143 var pemPrivateKey []byte 144 var err error 145 146 privPathOrPEM := c.GetString("pki.key", "") 147 if privPathOrPEM == "" { 148 return nil, errors.New("no pki.key path or PEM data provided") 149 } 150 151 if strings.Contains(privPathOrPEM, "-----BEGIN") { 152 pemPrivateKey = []byte(privPathOrPEM) 153 privPathOrPEM = "<inline>" 154 155 } else { 156 pemPrivateKey, err = os.ReadFile(privPathOrPEM) 157 if err != nil { 158 return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) 159 } 160 } 161 162 rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) 163 if err != nil { 164 return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) 165 } 166 167 var rawCert []byte 168 169 pubPathOrPEM := c.GetString("pki.cert", "") 170 if pubPathOrPEM == "" { 171 return nil, errors.New("no pki.cert path or PEM data provided") 172 } 173 174 if strings.Contains(pubPathOrPEM, "-----BEGIN") { 175 rawCert = []byte(pubPathOrPEM) 176 pubPathOrPEM = "<inline>" 177 178 } else { 179 rawCert, err = os.ReadFile(pubPathOrPEM) 180 if err != nil { 181 return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) 182 } 183 } 184 185 nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) 186 if err != nil { 187 return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) 188 } 189 190 if nebulaCert.Expired(time.Now()) { 191 return nil, fmt.Errorf("nebula certificate for this host is expired") 192 } 193 194 if len(nebulaCert.Details.Ips) == 0 { 195 return nil, fmt.Errorf("no IPs encoded in certificate") 196 } 197 198 if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { 199 return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") 200 } 201 202 return newCertState(nebulaCert, rawKey) 203 } 204 205 func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { 206 var rawCA []byte 207 var err error 208 209 caPathOrPEM := c.GetString("pki.ca", "") 210 if caPathOrPEM == "" { 211 return nil, errors.New("no pki.ca path or PEM data provided") 212 } 213 214 if strings.Contains(caPathOrPEM, "-----BEGIN") { 215 rawCA = []byte(caPathOrPEM) 216 217 } else { 218 rawCA, err = os.ReadFile(caPathOrPEM) 219 if err != nil { 220 return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) 221 } 222 } 223 224 caPool, err := cert.NewCAPoolFromBytes(rawCA) 225 if errors.Is(err, cert.ErrExpired) { 226 var expired int 227 for _, crt := range caPool.CAs { 228 if crt.Expired(time.Now()) { 229 expired++ 230 l.WithField("cert", crt).Warn("expired certificate present in CA pool") 231 } 232 } 233 234 if expired >= len(caPool.CAs) { 235 return nil, errors.New("no valid CA certificates present") 236 } 237 238 } else if err != nil { 239 return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) 240 } 241 242 for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { 243 l.WithField("fingerprint", fp).Info("Blocklisting cert") 244 caPool.BlocklistFingerprint(fp) 245 } 246 247 return caPool, nil 248 }