github.com/slackhq/nebula@v1.9.0/cmd/nebula-cert/sign.go (about)

     1  package main
     2  
     3  import (
     4  	"crypto/ecdh"
     5  	"crypto/rand"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/skip2/go-qrcode"
    15  	"github.com/slackhq/nebula/cert"
    16  	"golang.org/x/crypto/curve25519"
    17  )
    18  
    19  type signFlags struct {
    20  	set         *flag.FlagSet
    21  	caKeyPath   *string
    22  	caCertPath  *string
    23  	name        *string
    24  	ip          *string
    25  	duration    *time.Duration
    26  	inPubPath   *string
    27  	outKeyPath  *string
    28  	outCertPath *string
    29  	outQRPath   *string
    30  	groups      *string
    31  	subnets     *string
    32  }
    33  
    34  func newSignFlags() *signFlags {
    35  	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
    36  	sf.set.Usage = func() {}
    37  	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
    38  	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
    39  	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
    40  	sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
    41  	sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
    42  	sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
    43  	sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
    44  	sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
    45  	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
    46  	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
    47  	sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
    48  	return &sf
    49  
    50  }
    51  
    52  func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
    53  	sf := newSignFlags()
    54  	err := sf.set.Parse(args)
    55  	if err != nil {
    56  		return err
    57  	}
    58  
    59  	if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
    60  		return err
    61  	}
    62  	if err := mustFlagString("ca-crt", sf.caCertPath); err != nil {
    63  		return err
    64  	}
    65  	if err := mustFlagString("name", sf.name); err != nil {
    66  		return err
    67  	}
    68  	if err := mustFlagString("ip", sf.ip); err != nil {
    69  		return err
    70  	}
    71  	if *sf.inPubPath != "" && *sf.outKeyPath != "" {
    72  		return newHelpErrorf("cannot set both -in-pub and -out-key")
    73  	}
    74  
    75  	rawCAKey, err := os.ReadFile(*sf.caKeyPath)
    76  	if err != nil {
    77  		return fmt.Errorf("error while reading ca-key: %s", err)
    78  	}
    79  
    80  	var curve cert.Curve
    81  	var caKey []byte
    82  
    83  	// naively attempt to decode the private key as though it is not encrypted
    84  	caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
    85  	if err == cert.ErrPrivateKeyEncrypted {
    86  		// ask for a passphrase until we get one
    87  		var passphrase []byte
    88  		for i := 0; i < 5; i++ {
    89  			out.Write([]byte("Enter passphrase: "))
    90  			passphrase, err = pr.ReadPassword()
    91  
    92  			if err == ErrNoTerminal {
    93  				return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
    94  			} else if err != nil {
    95  				return fmt.Errorf("error reading password: %s", err)
    96  			}
    97  
    98  			if len(passphrase) > 0 {
    99  				break
   100  			}
   101  		}
   102  		if len(passphrase) == 0 {
   103  			return fmt.Errorf("cannot open encrypted ca-key without passphrase")
   104  		}
   105  
   106  		curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
   107  		if err != nil {
   108  			return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
   109  		}
   110  	} else if err != nil {
   111  		return fmt.Errorf("error while parsing ca-key: %s", err)
   112  	}
   113  
   114  	rawCACert, err := os.ReadFile(*sf.caCertPath)
   115  	if err != nil {
   116  		return fmt.Errorf("error while reading ca-crt: %s", err)
   117  	}
   118  
   119  	caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert)
   120  	if err != nil {
   121  		return fmt.Errorf("error while parsing ca-crt: %s", err)
   122  	}
   123  
   124  	if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
   125  		return fmt.Errorf("refusing to sign, root certificate does not match private key")
   126  	}
   127  
   128  	issuer, err := caCert.Sha256Sum()
   129  	if err != nil {
   130  		return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err)
   131  	}
   132  
   133  	if caCert.Expired(time.Now()) {
   134  		return fmt.Errorf("ca certificate is expired")
   135  	}
   136  
   137  	// if no duration is given, expire one second before the root expires
   138  	if *sf.duration <= 0 {
   139  		*sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1
   140  	}
   141  
   142  	ip, ipNet, err := net.ParseCIDR(*sf.ip)
   143  	if err != nil {
   144  		return newHelpErrorf("invalid ip definition: %s", err)
   145  	}
   146  	if ip.To4() == nil {
   147  		return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
   148  	}
   149  	ipNet.IP = ip
   150  
   151  	groups := []string{}
   152  	if *sf.groups != "" {
   153  		for _, rg := range strings.Split(*sf.groups, ",") {
   154  			g := strings.TrimSpace(rg)
   155  			if g != "" {
   156  				groups = append(groups, g)
   157  			}
   158  		}
   159  	}
   160  
   161  	subnets := []*net.IPNet{}
   162  	if *sf.subnets != "" {
   163  		for _, rs := range strings.Split(*sf.subnets, ",") {
   164  			rs := strings.Trim(rs, " ")
   165  			if rs != "" {
   166  				_, s, err := net.ParseCIDR(rs)
   167  				if err != nil {
   168  					return newHelpErrorf("invalid subnet definition: %s", err)
   169  				}
   170  				if s.IP.To4() == nil {
   171  					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
   172  				}
   173  				subnets = append(subnets, s)
   174  			}
   175  		}
   176  	}
   177  
   178  	var pub, rawPriv []byte
   179  	if *sf.inPubPath != "" {
   180  		rawPub, err := os.ReadFile(*sf.inPubPath)
   181  		if err != nil {
   182  			return fmt.Errorf("error while reading in-pub: %s", err)
   183  		}
   184  		var pubCurve cert.Curve
   185  		pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub)
   186  		if err != nil {
   187  			return fmt.Errorf("error while parsing in-pub: %s", err)
   188  		}
   189  		if pubCurve != curve {
   190  			return fmt.Errorf("curve of in-pub does not match ca")
   191  		}
   192  	} else {
   193  		pub, rawPriv = newKeypair(curve)
   194  	}
   195  
   196  	nc := cert.NebulaCertificate{
   197  		Details: cert.NebulaCertificateDetails{
   198  			Name:      *sf.name,
   199  			Ips:       []*net.IPNet{ipNet},
   200  			Groups:    groups,
   201  			Subnets:   subnets,
   202  			NotBefore: time.Now(),
   203  			NotAfter:  time.Now().Add(*sf.duration),
   204  			PublicKey: pub,
   205  			IsCA:      false,
   206  			Issuer:    issuer,
   207  			Curve:     curve,
   208  		},
   209  	}
   210  
   211  	if err := nc.CheckRootConstrains(caCert); err != nil {
   212  		return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err)
   213  	}
   214  
   215  	if *sf.outKeyPath == "" {
   216  		*sf.outKeyPath = *sf.name + ".key"
   217  	}
   218  
   219  	if *sf.outCertPath == "" {
   220  		*sf.outCertPath = *sf.name + ".crt"
   221  	}
   222  
   223  	if _, err := os.Stat(*sf.outCertPath); err == nil {
   224  		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
   225  	}
   226  
   227  	err = nc.Sign(curve, caKey)
   228  	if err != nil {
   229  		return fmt.Errorf("error while signing: %s", err)
   230  	}
   231  
   232  	if *sf.inPubPath == "" {
   233  		if _, err := os.Stat(*sf.outKeyPath); err == nil {
   234  			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
   235  		}
   236  
   237  		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
   238  		if err != nil {
   239  			return fmt.Errorf("error while writing out-key: %s", err)
   240  		}
   241  	}
   242  
   243  	b, err := nc.MarshalToPEM()
   244  	if err != nil {
   245  		return fmt.Errorf("error while marshalling certificate: %s", err)
   246  	}
   247  
   248  	err = os.WriteFile(*sf.outCertPath, b, 0600)
   249  	if err != nil {
   250  		return fmt.Errorf("error while writing out-crt: %s", err)
   251  	}
   252  
   253  	if *sf.outQRPath != "" {
   254  		b, err = qrcode.Encode(string(b), qrcode.Medium, -5)
   255  		if err != nil {
   256  			return fmt.Errorf("error while generating qr code: %s", err)
   257  		}
   258  
   259  		err = os.WriteFile(*sf.outQRPath, b, 0600)
   260  		if err != nil {
   261  			return fmt.Errorf("error while writing out-qr: %s", err)
   262  		}
   263  	}
   264  
   265  	return nil
   266  }
   267  
   268  func newKeypair(curve cert.Curve) ([]byte, []byte) {
   269  	switch curve {
   270  	case cert.Curve_CURVE25519:
   271  		return x25519Keypair()
   272  	case cert.Curve_P256:
   273  		return p256Keypair()
   274  	default:
   275  		return nil, nil
   276  	}
   277  }
   278  
   279  func x25519Keypair() ([]byte, []byte) {
   280  	privkey := make([]byte, 32)
   281  	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
   282  		panic(err)
   283  	}
   284  
   285  	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
   286  	if err != nil {
   287  		panic(err)
   288  	}
   289  
   290  	return pubkey, privkey
   291  }
   292  
   293  func p256Keypair() ([]byte, []byte) {
   294  	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
   295  	if err != nil {
   296  		panic(err)
   297  	}
   298  	pubkey := privkey.PublicKey()
   299  	return pubkey.Bytes(), privkey.Bytes()
   300  }
   301  
   302  func signSummary() string {
   303  	return "sign <flags>: create and sign a certificate"
   304  }
   305  
   306  func signHelp(out io.Writer) {
   307  	sf := newSignFlags()
   308  	out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n"))
   309  	sf.set.SetOutput(out)
   310  	sf.set.PrintDefaults()
   311  }