github.com/slackhq/nebula@v1.9.0/cert/cert.go (about)

     1  package cert
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/ecdh"
     6  	"crypto/ecdsa"
     7  	"crypto/ed25519"
     8  	"crypto/elliptic"
     9  	"crypto/rand"
    10  	"crypto/sha256"
    11  	"encoding/binary"
    12  	"encoding/hex"
    13  	"encoding/json"
    14  	"encoding/pem"
    15  	"errors"
    16  	"fmt"
    17  	"math"
    18  	"math/big"
    19  	"net"
    20  	"sync/atomic"
    21  	"time"
    22  
    23  	"golang.org/x/crypto/curve25519"
    24  	"google.golang.org/protobuf/proto"
    25  )
    26  
    27  const publicKeyLen = 32
    28  
    29  const (
    30  	CertBanner                       = "NEBULA CERTIFICATE"
    31  	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
    32  	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
    33  	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
    34  	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
    35  	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
    36  
    37  	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
    38  	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
    39  	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
    40  	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
    41  )
    42  
    43  type NebulaCertificate struct {
    44  	Details   NebulaCertificateDetails
    45  	Signature []byte
    46  
    47  	// the cached hex string of the calculated sha256sum
    48  	// for VerifyWithCache
    49  	sha256sum atomic.Pointer[string]
    50  
    51  	// the cached public key bytes if they were verified as the signer
    52  	// for VerifyWithCache
    53  	signatureVerified atomic.Pointer[[]byte]
    54  }
    55  
    56  type NebulaCertificateDetails struct {
    57  	Name      string
    58  	Ips       []*net.IPNet
    59  	Subnets   []*net.IPNet
    60  	Groups    []string
    61  	NotBefore time.Time
    62  	NotAfter  time.Time
    63  	PublicKey []byte
    64  	IsCA      bool
    65  	Issuer    string
    66  
    67  	// Map of groups for faster lookup
    68  	InvertedGroups map[string]struct{}
    69  
    70  	Curve Curve
    71  }
    72  
    73  type NebulaEncryptedData struct {
    74  	EncryptionMetadata NebulaEncryptionMetadata
    75  	Ciphertext         []byte
    76  }
    77  
    78  type NebulaEncryptionMetadata struct {
    79  	EncryptionAlgorithm string
    80  	Argon2Parameters    Argon2Parameters
    81  }
    82  
    83  type m map[string]interface{}
    84  
    85  // Returned if we try to unmarshal an encrypted private key without a passphrase
    86  var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
    87  
    88  // UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert
    89  func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
    90  	if len(b) == 0 {
    91  		return nil, fmt.Errorf("nil byte array")
    92  	}
    93  	var rc RawNebulaCertificate
    94  	err := proto.Unmarshal(b, &rc)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	if rc.Details == nil {
   100  		return nil, fmt.Errorf("encoded Details was nil")
   101  	}
   102  
   103  	if len(rc.Details.Ips)%2 != 0 {
   104  		return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
   105  	}
   106  
   107  	if len(rc.Details.Subnets)%2 != 0 {
   108  		return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
   109  	}
   110  
   111  	nc := NebulaCertificate{
   112  		Details: NebulaCertificateDetails{
   113  			Name:           rc.Details.Name,
   114  			Groups:         make([]string, len(rc.Details.Groups)),
   115  			Ips:            make([]*net.IPNet, len(rc.Details.Ips)/2),
   116  			Subnets:        make([]*net.IPNet, len(rc.Details.Subnets)/2),
   117  			NotBefore:      time.Unix(rc.Details.NotBefore, 0),
   118  			NotAfter:       time.Unix(rc.Details.NotAfter, 0),
   119  			PublicKey:      make([]byte, len(rc.Details.PublicKey)),
   120  			IsCA:           rc.Details.IsCA,
   121  			InvertedGroups: make(map[string]struct{}),
   122  			Curve:          rc.Details.Curve,
   123  		},
   124  		Signature: make([]byte, len(rc.Signature)),
   125  	}
   126  
   127  	copy(nc.Signature, rc.Signature)
   128  	copy(nc.Details.Groups, rc.Details.Groups)
   129  	nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer)
   130  
   131  	if len(rc.Details.PublicKey) < publicKeyLen {
   132  		return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
   133  	}
   134  	copy(nc.Details.PublicKey, rc.Details.PublicKey)
   135  
   136  	for i, rawIp := range rc.Details.Ips {
   137  		if i%2 == 0 {
   138  			nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)}
   139  		} else {
   140  			nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp))
   141  		}
   142  	}
   143  
   144  	for i, rawIp := range rc.Details.Subnets {
   145  		if i%2 == 0 {
   146  			nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)}
   147  		} else {
   148  			nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp))
   149  		}
   150  	}
   151  
   152  	for _, g := range rc.Details.Groups {
   153  		nc.Details.InvertedGroups[g] = struct{}{}
   154  	}
   155  
   156  	return &nc, nil
   157  }
   158  
   159  // UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data
   160  // or an error on failure
   161  func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
   162  	p, r := pem.Decode(b)
   163  	if p == nil {
   164  		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
   165  	}
   166  	if p.Type != CertBanner {
   167  		return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner")
   168  	}
   169  	nc, err := UnmarshalNebulaCertificate(p.Bytes)
   170  	return nc, r, err
   171  }
   172  
   173  func MarshalPrivateKey(curve Curve, b []byte) []byte {
   174  	switch curve {
   175  	case Curve_CURVE25519:
   176  		return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
   177  	case Curve_P256:
   178  		return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
   179  	default:
   180  		return nil
   181  	}
   182  }
   183  
   184  func MarshalSigningPrivateKey(curve Curve, b []byte) []byte {
   185  	switch curve {
   186  	case Curve_CURVE25519:
   187  		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
   188  	case Curve_P256:
   189  		return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
   190  	default:
   191  		return nil
   192  	}
   193  }
   194  
   195  // MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key
   196  func MarshalX25519PrivateKey(b []byte) []byte {
   197  	return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
   198  }
   199  
   200  // MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key
   201  func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte {
   202  	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key})
   203  }
   204  
   205  func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
   206  	k, r := pem.Decode(b)
   207  	if k == nil {
   208  		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
   209  	}
   210  	var expectedLen int
   211  	var curve Curve
   212  	switch k.Type {
   213  	case X25519PrivateKeyBanner:
   214  		expectedLen = 32
   215  		curve = Curve_CURVE25519
   216  	case P256PrivateKeyBanner:
   217  		expectedLen = 32
   218  		curve = Curve_P256
   219  	default:
   220  		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner")
   221  	}
   222  	if len(k.Bytes) != expectedLen {
   223  		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
   224  	}
   225  	return k.Bytes, r, curve, nil
   226  }
   227  
   228  func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
   229  	k, r := pem.Decode(b)
   230  	if k == nil {
   231  		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
   232  	}
   233  	var curve Curve
   234  	switch k.Type {
   235  	case EncryptedEd25519PrivateKeyBanner:
   236  		return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
   237  	case EncryptedECDSAP256PrivateKeyBanner:
   238  		return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
   239  	case Ed25519PrivateKeyBanner:
   240  		curve = Curve_CURVE25519
   241  		if len(k.Bytes) != ed25519.PrivateKeySize {
   242  			return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
   243  		}
   244  	case ECDSAP256PrivateKeyBanner:
   245  		curve = Curve_P256
   246  		if len(k.Bytes) != 32 {
   247  			return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
   248  		}
   249  	default:
   250  		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
   251  	}
   252  	return k.Bytes, r, curve, nil
   253  }
   254  
   255  // EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
   256  func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
   257  	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
   258  	if err != nil {
   259  		return nil, err
   260  	}
   261  
   262  	b, err = proto.Marshal(&RawNebulaEncryptedData{
   263  		EncryptionMetadata: &RawNebulaEncryptionMetadata{
   264  			EncryptionAlgorithm: "AES-256-GCM",
   265  			Argon2Parameters: &RawNebulaArgon2Parameters{
   266  				Version:     kdfParams.version,
   267  				Memory:      kdfParams.Memory,
   268  				Parallelism: uint32(kdfParams.Parallelism),
   269  				Iterations:  kdfParams.Iterations,
   270  				Salt:        kdfParams.salt,
   271  			},
   272  		},
   273  		Ciphertext: ciphertext,
   274  	})
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	switch curve {
   280  	case Curve_CURVE25519:
   281  		return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
   282  	case Curve_P256:
   283  		return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
   284  	default:
   285  		return nil, fmt.Errorf("invalid curve: %v", curve)
   286  	}
   287  }
   288  
   289  // UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b
   290  // or an error on failure
   291  func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) {
   292  	k, r := pem.Decode(b)
   293  	if k == nil {
   294  		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
   295  	}
   296  	if k.Type != X25519PrivateKeyBanner {
   297  		return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner")
   298  	}
   299  	if len(k.Bytes) != publicKeyLen {
   300  		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key")
   301  	}
   302  
   303  	return k.Bytes, r, nil
   304  }
   305  
   306  // UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b
   307  // or an error on failure
   308  func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) {
   309  	k, r := pem.Decode(b)
   310  	if k == nil {
   311  		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
   312  	}
   313  
   314  	if k.Type == EncryptedEd25519PrivateKeyBanner {
   315  		return nil, r, ErrPrivateKeyEncrypted
   316  	} else if k.Type != Ed25519PrivateKeyBanner {
   317  		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner")
   318  	}
   319  
   320  	if len(k.Bytes) != ed25519.PrivateKeySize {
   321  		return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
   322  	}
   323  
   324  	return k.Bytes, r, nil
   325  }
   326  
   327  // UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
   328  // protobuf-generated struct.
   329  func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
   330  	if len(b) == 0 {
   331  		return nil, fmt.Errorf("nil byte array")
   332  	}
   333  	var rned RawNebulaEncryptedData
   334  	err := proto.Unmarshal(b, &rned)
   335  	if err != nil {
   336  		return nil, err
   337  	}
   338  
   339  	if rned.EncryptionMetadata == nil {
   340  		return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
   341  	}
   342  
   343  	if rned.EncryptionMetadata.Argon2Parameters == nil {
   344  		return nil, fmt.Errorf("encoded Argon2Parameters was nil")
   345  	}
   346  
   347  	params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
   348  	if err != nil {
   349  		return nil, err
   350  	}
   351  
   352  	ned := NebulaEncryptedData{
   353  		EncryptionMetadata: NebulaEncryptionMetadata{
   354  			EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
   355  			Argon2Parameters:    *params,
   356  		},
   357  		Ciphertext: rned.Ciphertext,
   358  	}
   359  
   360  	return &ned, nil
   361  }
   362  
   363  func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
   364  	if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
   365  		return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
   366  	}
   367  	if params.Memory <= 0 || params.Memory > math.MaxUint32 {
   368  		return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
   369  	}
   370  	if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
   371  		return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
   372  	}
   373  	if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
   374  		return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
   375  	}
   376  
   377  	return &Argon2Parameters{
   378  		version:     rune(params.Version),
   379  		Memory:      uint32(params.Memory),
   380  		Parallelism: uint8(params.Parallelism),
   381  		Iterations:  uint32(params.Iterations),
   382  		salt:        params.Salt,
   383  	}, nil
   384  
   385  }
   386  
   387  // DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
   388  // the given passphrase, returning any other bytes b or an error on failure
   389  func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
   390  	var curve Curve
   391  
   392  	k, r := pem.Decode(b)
   393  	if k == nil {
   394  		return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
   395  	}
   396  
   397  	switch k.Type {
   398  	case EncryptedEd25519PrivateKeyBanner:
   399  		curve = Curve_CURVE25519
   400  	case EncryptedECDSAP256PrivateKeyBanner:
   401  		curve = Curve_P256
   402  	default:
   403  		return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
   404  	}
   405  
   406  	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
   407  	if err != nil {
   408  		return curve, nil, r, err
   409  	}
   410  
   411  	var bytes []byte
   412  	switch ned.EncryptionMetadata.EncryptionAlgorithm {
   413  	case "AES-256-GCM":
   414  		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
   415  		if err != nil {
   416  			return curve, nil, r, err
   417  		}
   418  	default:
   419  		return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
   420  	}
   421  
   422  	switch curve {
   423  	case Curve_CURVE25519:
   424  		if len(bytes) != ed25519.PrivateKeySize {
   425  			return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
   426  		}
   427  	case Curve_P256:
   428  		if len(bytes) != 32 {
   429  			return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
   430  		}
   431  	}
   432  
   433  	return curve, bytes, r, nil
   434  }
   435  
   436  func MarshalPublicKey(curve Curve, b []byte) []byte {
   437  	switch curve {
   438  	case Curve_CURVE25519:
   439  		return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
   440  	case Curve_P256:
   441  		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
   442  	default:
   443  		return nil
   444  	}
   445  }
   446  
   447  // MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key
   448  func MarshalX25519PublicKey(b []byte) []byte {
   449  	return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
   450  }
   451  
   452  // MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key
   453  func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte {
   454  	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key})
   455  }
   456  
   457  func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) {
   458  	k, r := pem.Decode(b)
   459  	if k == nil {
   460  		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
   461  	}
   462  	var expectedLen int
   463  	var curve Curve
   464  	switch k.Type {
   465  	case X25519PublicKeyBanner:
   466  		expectedLen = 32
   467  		curve = Curve_CURVE25519
   468  	case P256PublicKeyBanner:
   469  		// Uncompressed
   470  		expectedLen = 65
   471  		curve = Curve_P256
   472  	default:
   473  		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner")
   474  	}
   475  	if len(k.Bytes) != expectedLen {
   476  		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
   477  	}
   478  	return k.Bytes, r, curve, nil
   479  }
   480  
   481  // UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b
   482  // or an error on failure
   483  func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) {
   484  	k, r := pem.Decode(b)
   485  	if k == nil {
   486  		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
   487  	}
   488  	if k.Type != X25519PublicKeyBanner {
   489  		return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner")
   490  	}
   491  	if len(k.Bytes) != publicKeyLen {
   492  		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key")
   493  	}
   494  
   495  	return k.Bytes, r, nil
   496  }
   497  
   498  // UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b
   499  // or an error on failure
   500  func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) {
   501  	k, r := pem.Decode(b)
   502  	if k == nil {
   503  		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
   504  	}
   505  	if k.Type != Ed25519PublicKeyBanner {
   506  		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner")
   507  	}
   508  	if len(k.Bytes) != ed25519.PublicKeySize {
   509  		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key")
   510  	}
   511  
   512  	return k.Bytes, r, nil
   513  }
   514  
   515  // Sign signs a nebula cert with the provided private key
   516  func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error {
   517  	if curve != nc.Details.Curve {
   518  		return fmt.Errorf("curve in cert and private key supplied don't match")
   519  	}
   520  
   521  	b, err := proto.Marshal(nc.getRawDetails())
   522  	if err != nil {
   523  		return err
   524  	}
   525  
   526  	var sig []byte
   527  
   528  	switch curve {
   529  	case Curve_CURVE25519:
   530  		signer := ed25519.PrivateKey(key)
   531  		sig = ed25519.Sign(signer, b)
   532  	case Curve_P256:
   533  		signer := &ecdsa.PrivateKey{
   534  			PublicKey: ecdsa.PublicKey{
   535  				Curve: elliptic.P256(),
   536  			},
   537  			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
   538  			D: new(big.Int).SetBytes(key),
   539  		}
   540  		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
   541  		signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
   542  
   543  		// We need to hash first for ECDSA
   544  		// - https://pkg.go.dev/crypto/ecdsa#SignASN1
   545  		hashed := sha256.Sum256(b)
   546  		sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
   547  		if err != nil {
   548  			return err
   549  		}
   550  	default:
   551  		return fmt.Errorf("invalid curve: %s", nc.Details.Curve)
   552  	}
   553  
   554  	nc.Signature = sig
   555  	return nil
   556  }
   557  
   558  // CheckSignature verifies the signature against the provided public key
   559  func (nc *NebulaCertificate) CheckSignature(key []byte) bool {
   560  	b, err := proto.Marshal(nc.getRawDetails())
   561  	if err != nil {
   562  		return false
   563  	}
   564  	switch nc.Details.Curve {
   565  	case Curve_CURVE25519:
   566  		return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature)
   567  	case Curve_P256:
   568  		x, y := elliptic.Unmarshal(elliptic.P256(), key)
   569  		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
   570  		hashed := sha256.Sum256(b)
   571  		return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature)
   572  	default:
   573  		return false
   574  	}
   575  }
   576  
   577  // NOTE: This uses an internal cache that will not be invalidated automatically
   578  // if you manually change any fields in the NebulaCertificate.
   579  func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool {
   580  	if !useCache {
   581  		return nc.CheckSignature(key)
   582  	}
   583  
   584  	if v := nc.signatureVerified.Load(); v != nil {
   585  		return bytes.Equal(*v, key)
   586  	}
   587  
   588  	verified := nc.CheckSignature(key)
   589  	if verified {
   590  		keyCopy := make([]byte, len(key))
   591  		copy(keyCopy, key)
   592  		nc.signatureVerified.Store(&keyCopy)
   593  	}
   594  
   595  	return verified
   596  }
   597  
   598  // Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false
   599  func (nc *NebulaCertificate) Expired(t time.Time) bool {
   600  	return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
   601  }
   602  
   603  // Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
   604  func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
   605  	return nc.verify(t, ncp, false)
   606  }
   607  
   608  // VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
   609  //
   610  // NOTE: This uses an internal cache that will not be invalidated automatically
   611  // if you manually change any fields in the NebulaCertificate.
   612  func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *NebulaCAPool) (bool, error) {
   613  	return nc.verify(t, ncp, true)
   614  }
   615  
   616  // ResetCache resets the cache used by VerifyWithCache.
   617  func (nc *NebulaCertificate) ResetCache() {
   618  	nc.sha256sum.Store(nil)
   619  	nc.signatureVerified.Store(nil)
   620  }
   621  
   622  // Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
   623  func (nc *NebulaCertificate) verify(t time.Time, ncp *NebulaCAPool, useCache bool) (bool, error) {
   624  	if ncp.isBlocklistedWithCache(nc, useCache) {
   625  		return false, ErrBlockListed
   626  	}
   627  
   628  	signer, err := ncp.GetCAForCert(nc)
   629  	if err != nil {
   630  		return false, err
   631  	}
   632  
   633  	if signer.Expired(t) {
   634  		return false, ErrRootExpired
   635  	}
   636  
   637  	if nc.Expired(t) {
   638  		return false, ErrExpired
   639  	}
   640  
   641  	if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) {
   642  		return false, ErrSignatureMismatch
   643  	}
   644  
   645  	if err := nc.CheckRootConstrains(signer); err != nil {
   646  		return false, err
   647  	}
   648  
   649  	return true, nil
   650  }
   651  
   652  // CheckRootConstrains returns an error if the certificate violates constraints set on the root (groups, ips, subnets)
   653  func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) error {
   654  	// Make sure this cert wasn't valid before the root
   655  	if signer.Details.NotAfter.Before(nc.Details.NotAfter) {
   656  		return fmt.Errorf("certificate expires after signing certificate")
   657  	}
   658  
   659  	// Make sure this cert isn't valid after the root
   660  	if signer.Details.NotBefore.After(nc.Details.NotBefore) {
   661  		return fmt.Errorf("certificate is valid before the signing certificate")
   662  	}
   663  
   664  	// If the signer has a limited set of groups make sure the cert only contains a subset
   665  	if len(signer.Details.InvertedGroups) > 0 {
   666  		for _, g := range nc.Details.Groups {
   667  			if _, ok := signer.Details.InvertedGroups[g]; !ok {
   668  				return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
   669  			}
   670  		}
   671  	}
   672  
   673  	// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
   674  	if len(signer.Details.Ips) > 0 {
   675  		for _, ip := range nc.Details.Ips {
   676  			if !netMatch(ip, signer.Details.Ips) {
   677  				return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", ip.String())
   678  			}
   679  		}
   680  	}
   681  
   682  	// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
   683  	if len(signer.Details.Subnets) > 0 {
   684  		for _, subnet := range nc.Details.Subnets {
   685  			if !netMatch(subnet, signer.Details.Subnets) {
   686  				return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subnet)
   687  			}
   688  		}
   689  	}
   690  
   691  	return nil
   692  }
   693  
   694  // VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match
   695  func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error {
   696  	if curve != nc.Details.Curve {
   697  		return fmt.Errorf("curve in cert and private key supplied don't match")
   698  	}
   699  	if nc.Details.IsCA {
   700  		switch curve {
   701  		case Curve_CURVE25519:
   702  			// the call to PublicKey below will panic slice bounds out of range otherwise
   703  			if len(key) != ed25519.PrivateKeySize {
   704  				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
   705  			}
   706  
   707  			if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
   708  				return fmt.Errorf("public key in cert and private key supplied don't match")
   709  			}
   710  		case Curve_P256:
   711  			privkey, err := ecdh.P256().NewPrivateKey(key)
   712  			if err != nil {
   713  				return fmt.Errorf("cannot parse private key as P256")
   714  			}
   715  			pub := privkey.PublicKey().Bytes()
   716  			if !bytes.Equal(pub, nc.Details.PublicKey) {
   717  				return fmt.Errorf("public key in cert and private key supplied don't match")
   718  			}
   719  		default:
   720  			return fmt.Errorf("invalid curve: %s", curve)
   721  		}
   722  		return nil
   723  	}
   724  
   725  	var pub []byte
   726  	switch curve {
   727  	case Curve_CURVE25519:
   728  		var err error
   729  		pub, err = curve25519.X25519(key, curve25519.Basepoint)
   730  		if err != nil {
   731  			return err
   732  		}
   733  	case Curve_P256:
   734  		privkey, err := ecdh.P256().NewPrivateKey(key)
   735  		if err != nil {
   736  			return err
   737  		}
   738  		pub = privkey.PublicKey().Bytes()
   739  	default:
   740  		return fmt.Errorf("invalid curve: %s", curve)
   741  	}
   742  	if !bytes.Equal(pub, nc.Details.PublicKey) {
   743  		return fmt.Errorf("public key in cert and private key supplied don't match")
   744  	}
   745  
   746  	return nil
   747  }
   748  
   749  // String will return a pretty printed representation of a nebula cert
   750  func (nc *NebulaCertificate) String() string {
   751  	if nc == nil {
   752  		return "NebulaCertificate {}\n"
   753  	}
   754  
   755  	s := "NebulaCertificate {\n"
   756  	s += "\tDetails {\n"
   757  	s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name)
   758  
   759  	if len(nc.Details.Ips) > 0 {
   760  		s += "\t\tIps: [\n"
   761  		for _, ip := range nc.Details.Ips {
   762  			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
   763  		}
   764  		s += "\t\t]\n"
   765  	} else {
   766  		s += "\t\tIps: []\n"
   767  	}
   768  
   769  	if len(nc.Details.Subnets) > 0 {
   770  		s += "\t\tSubnets: [\n"
   771  		for _, ip := range nc.Details.Subnets {
   772  			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
   773  		}
   774  		s += "\t\t]\n"
   775  	} else {
   776  		s += "\t\tSubnets: []\n"
   777  	}
   778  
   779  	if len(nc.Details.Groups) > 0 {
   780  		s += "\t\tGroups: [\n"
   781  		for _, g := range nc.Details.Groups {
   782  			s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
   783  		}
   784  		s += "\t\t]\n"
   785  	} else {
   786  		s += "\t\tGroups: []\n"
   787  	}
   788  
   789  	s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore)
   790  	s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter)
   791  	s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA)
   792  	s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer)
   793  	s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey)
   794  	s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve)
   795  	s += "\t}\n"
   796  	fp, err := nc.Sha256Sum()
   797  	if err == nil {
   798  		s += fmt.Sprintf("\tFingerprint: %s\n", fp)
   799  	}
   800  	s += fmt.Sprintf("\tSignature: %x\n", nc.Signature)
   801  	s += "}"
   802  
   803  	return s
   804  }
   805  
   806  // getRawDetails marshals the raw details into protobuf ready struct
   807  func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails {
   808  	rd := &RawNebulaCertificateDetails{
   809  		Name:      nc.Details.Name,
   810  		Groups:    nc.Details.Groups,
   811  		NotBefore: nc.Details.NotBefore.Unix(),
   812  		NotAfter:  nc.Details.NotAfter.Unix(),
   813  		PublicKey: make([]byte, len(nc.Details.PublicKey)),
   814  		IsCA:      nc.Details.IsCA,
   815  		Curve:     nc.Details.Curve,
   816  	}
   817  
   818  	for _, ipNet := range nc.Details.Ips {
   819  		rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask))
   820  	}
   821  
   822  	for _, ipNet := range nc.Details.Subnets {
   823  		rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask))
   824  	}
   825  
   826  	copy(rd.PublicKey, nc.Details.PublicKey[:])
   827  
   828  	// I know, this is terrible
   829  	rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer)
   830  
   831  	return rd
   832  }
   833  
   834  // Marshal will marshal a nebula cert into a protobuf byte array
   835  func (nc *NebulaCertificate) Marshal() ([]byte, error) {
   836  	rc := RawNebulaCertificate{
   837  		Details:   nc.getRawDetails(),
   838  		Signature: nc.Signature,
   839  	}
   840  
   841  	return proto.Marshal(&rc)
   842  }
   843  
   844  // MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result
   845  func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) {
   846  	b, err := nc.Marshal()
   847  	if err != nil {
   848  		return nil, err
   849  	}
   850  	return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil
   851  }
   852  
   853  // Sha256Sum calculates a sha-256 sum of the marshaled certificate
   854  func (nc *NebulaCertificate) Sha256Sum() (string, error) {
   855  	b, err := nc.Marshal()
   856  	if err != nil {
   857  		return "", err
   858  	}
   859  
   860  	sum := sha256.Sum256(b)
   861  	return hex.EncodeToString(sum[:]), nil
   862  }
   863  
   864  // NOTE: This uses an internal cache that will not be invalidated automatically
   865  // if you manually change any fields in the NebulaCertificate.
   866  func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) {
   867  	if !useCache {
   868  		return nc.Sha256Sum()
   869  	}
   870  
   871  	if s := nc.sha256sum.Load(); s != nil {
   872  		return *s, nil
   873  	}
   874  	s, err := nc.Sha256Sum()
   875  	if err != nil {
   876  		return s, err
   877  	}
   878  
   879  	nc.sha256sum.Store(&s)
   880  	return s, nil
   881  }
   882  
   883  func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
   884  	toString := func(ips []*net.IPNet) []string {
   885  		s := []string{}
   886  		for _, ip := range ips {
   887  			s = append(s, ip.String())
   888  		}
   889  		return s
   890  	}
   891  
   892  	fp, _ := nc.Sha256Sum()
   893  	jc := m{
   894  		"details": m{
   895  			"name":      nc.Details.Name,
   896  			"ips":       toString(nc.Details.Ips),
   897  			"subnets":   toString(nc.Details.Subnets),
   898  			"groups":    nc.Details.Groups,
   899  			"notBefore": nc.Details.NotBefore,
   900  			"notAfter":  nc.Details.NotAfter,
   901  			"publicKey": fmt.Sprintf("%x", nc.Details.PublicKey),
   902  			"isCa":      nc.Details.IsCA,
   903  			"issuer":    nc.Details.Issuer,
   904  			"curve":     nc.Details.Curve.String(),
   905  		},
   906  		"fingerprint": fp,
   907  		"signature":   fmt.Sprintf("%x", nc.Signature),
   908  	}
   909  	return json.Marshal(jc)
   910  }
   911  
   912  //func (nc *NebulaCertificate) Copy() *NebulaCertificate {
   913  //	r, err := nc.Marshal()
   914  //	if err != nil {
   915  //		//TODO
   916  //		return nil
   917  //	}
   918  //
   919  //	c, err := UnmarshalNebulaCertificate(r)
   920  //	return c
   921  //}
   922  
   923  func (nc *NebulaCertificate) Copy() *NebulaCertificate {
   924  	c := &NebulaCertificate{
   925  		Details: NebulaCertificateDetails{
   926  			Name:           nc.Details.Name,
   927  			Groups:         make([]string, len(nc.Details.Groups)),
   928  			Ips:            make([]*net.IPNet, len(nc.Details.Ips)),
   929  			Subnets:        make([]*net.IPNet, len(nc.Details.Subnets)),
   930  			NotBefore:      nc.Details.NotBefore,
   931  			NotAfter:       nc.Details.NotAfter,
   932  			PublicKey:      make([]byte, len(nc.Details.PublicKey)),
   933  			IsCA:           nc.Details.IsCA,
   934  			Issuer:         nc.Details.Issuer,
   935  			InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
   936  		},
   937  		Signature: make([]byte, len(nc.Signature)),
   938  	}
   939  
   940  	copy(c.Signature, nc.Signature)
   941  	copy(c.Details.Groups, nc.Details.Groups)
   942  	copy(c.Details.PublicKey, nc.Details.PublicKey)
   943  
   944  	for i, p := range nc.Details.Ips {
   945  		c.Details.Ips[i] = &net.IPNet{
   946  			IP:   make(net.IP, len(p.IP)),
   947  			Mask: make(net.IPMask, len(p.Mask)),
   948  		}
   949  		copy(c.Details.Ips[i].IP, p.IP)
   950  		copy(c.Details.Ips[i].Mask, p.Mask)
   951  	}
   952  
   953  	for i, p := range nc.Details.Subnets {
   954  		c.Details.Subnets[i] = &net.IPNet{
   955  			IP:   make(net.IP, len(p.IP)),
   956  			Mask: make(net.IPMask, len(p.Mask)),
   957  		}
   958  		copy(c.Details.Subnets[i].IP, p.IP)
   959  		copy(c.Details.Subnets[i].Mask, p.Mask)
   960  	}
   961  
   962  	for g := range nc.Details.InvertedGroups {
   963  		c.Details.InvertedGroups[g] = struct{}{}
   964  	}
   965  
   966  	return c
   967  }
   968  
   969  func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
   970  	for _, net := range rootIps {
   971  		if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
   972  			return true
   973  		}
   974  	}
   975  
   976  	return false
   977  }
   978  
   979  func maskContains(caMask, certMask net.IPMask) bool {
   980  	caM := maskTo4(caMask)
   981  	cM := maskTo4(certMask)
   982  	// Make sure forcing to ipv4 didn't nuke us
   983  	if caM == nil || cM == nil {
   984  		return false
   985  	}
   986  
   987  	// Make sure the cert mask is not greater than the ca mask
   988  	for i := 0; i < len(caMask); i++ {
   989  		if caM[i] > cM[i] {
   990  			return false
   991  		}
   992  	}
   993  
   994  	return true
   995  }
   996  
   997  func maskTo4(ip net.IPMask) net.IPMask {
   998  	if len(ip) == net.IPv4len {
   999  		return ip
  1000  	}
  1001  
  1002  	if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
  1003  		return ip[12:16]
  1004  	}
  1005  
  1006  	return nil
  1007  }
  1008  
  1009  func isZeros(b []byte) bool {
  1010  	for i := 0; i < len(b); i++ {
  1011  		if b[i] != 0 {
  1012  			return false
  1013  		}
  1014  	}
  1015  	return true
  1016  }
  1017  
  1018  func ip2int(ip []byte) uint32 {
  1019  	if len(ip) == 16 {
  1020  		return binary.BigEndian.Uint32(ip[12:16])
  1021  	}
  1022  	return binary.BigEndian.Uint32(ip)
  1023  }
  1024  
  1025  func int2ip(nn uint32) net.IP {
  1026  	ip := make(net.IP, net.IPv4len)
  1027  	binary.BigEndian.PutUint32(ip, nn)
  1028  	return ip
  1029  }