github.com/StackExchange/DNSControl@v0.2.8/pkg/acme/acme.go (about)

     1  // Package acme provides a means of performing Let's Encrypt DNS challenges via a DNSConfig
     2  package acme
     3  
     4  import (
     5  	"crypto/x509"
     6  	"encoding/pem"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"log"
    10  	"net/url"
    11  	"sort"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/StackExchange/dnscontrol/models"
    16  	"github.com/StackExchange/dnscontrol/pkg/nameservers"
    17  	"github.com/xenolf/lego/acme"
    18  	acmelog "github.com/xenolf/lego/log"
    19  )
    20  
    21  type CertConfig struct {
    22  	CertName string   `json:"cert_name"`
    23  	Names    []string `json:"names"`
    24  	UseECC   bool     `json:"use_ecc"`
    25  }
    26  
    27  type Client interface {
    28  	IssueOrRenewCert(config *CertConfig, renewUnder int, verbose bool) (bool, error)
    29  }
    30  
    31  type certManager struct {
    32  	email         string
    33  	acmeDirectory string
    34  	acmeHost      string
    35  
    36  	storage         Storage
    37  	cfg             *models.DNSConfig
    38  	domains         map[string]*models.DomainConfig
    39  	originalDomains []*models.DomainConfig
    40  
    41  	account    *Account
    42  	waitedOnce bool
    43  }
    44  
    45  const (
    46  	LetsEncryptLive  = "https://acme-v02.api.letsencrypt.org/directory"
    47  	LetsEncryptStage = "https://acme-staging-v02.api.letsencrypt.org/directory"
    48  )
    49  
    50  func New(cfg *models.DNSConfig, directory string, email string, server string) (Client, error) {
    51  	return commonNew(cfg, directoryStorage(directory), email, server)
    52  }
    53  
    54  func commonNew(cfg *models.DNSConfig, storage Storage, email string, server string) (Client, error) {
    55  	u, err := url.Parse(server)
    56  	if err != nil || u.Host == "" {
    57  		return nil, fmt.Errorf("ACME directory '%s' is not a valid URL", server)
    58  	}
    59  	c := &certManager{
    60  		storage:       storage,
    61  		email:         email,
    62  		acmeDirectory: server,
    63  		acmeHost:      u.Host,
    64  		cfg:           cfg,
    65  		domains:       map[string]*models.DomainConfig{},
    66  	}
    67  
    68  	acct, err := c.getOrCreateAccount()
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	c.account = acct
    73  	return c, nil
    74  }
    75  
    76  func NewVault(cfg *models.DNSConfig, vaultPath string, email string, server string) (Client, error) {
    77  	storage, err := makeVaultStorage(vaultPath)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	return commonNew(cfg, storage, email, server)
    82  }
    83  
    84  // IssueOrRenewCert will obtain a certificate with the given name if it does not exist,
    85  // or renew it if it is close enough to the expiration date.
    86  // It will return true if it issued or updated the certificate.
    87  func (c *certManager) IssueOrRenewCert(cfg *CertConfig, renewUnder int, verbose bool) (bool, error) {
    88  	if !verbose {
    89  		acmelog.Logger = log.New(ioutil.Discard, "", 0)
    90  	}
    91  	defer c.finalCleanUp()
    92  
    93  	log.Printf("Checking certificate [%s]", cfg.CertName)
    94  	existing, err := c.storage.GetCertificate(cfg.CertName)
    95  	if err != nil {
    96  		return false, err
    97  	}
    98  
    99  	var client *acme.Client
   100  
   101  	var action = func() (*acme.CertificateResource, error) {
   102  		return client.ObtainCertificate(cfg.Names, true, nil, true)
   103  	}
   104  
   105  	if existing == nil {
   106  		log.Println("No existing cert found. Issuing new...")
   107  	} else {
   108  		names, daysLeft, err := getCertInfo(existing.Certificate)
   109  		if err != nil {
   110  			return false, err
   111  		}
   112  		log.Printf("Found existing cert. %0.2f days remaining.", daysLeft)
   113  		namesOK := dnsNamesEqual(cfg.Names, names)
   114  		if daysLeft >= float64(renewUnder) && namesOK {
   115  			log.Println("Nothing to do")
   116  			//nothing to do
   117  			return false, nil
   118  		}
   119  		if !namesOK {
   120  			log.Println("DNS Names don't match expected set. Reissuing.")
   121  		} else {
   122  			log.Println("Renewing cert")
   123  			action = func() (*acme.CertificateResource, error) {
   124  				return client.RenewCertificate(*existing, true, true)
   125  			}
   126  		}
   127  	}
   128  
   129  	kt := acme.RSA2048
   130  	if cfg.UseECC {
   131  		kt = acme.EC256
   132  	}
   133  	client, err = acme.NewClient(c.acmeDirectory, c.account, kt)
   134  	if err != nil {
   135  		return false, err
   136  	}
   137  	client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSALPN01})
   138  	client.SetChallengeProvider(acme.DNS01, c)
   139  
   140  	acme.PreCheckDNS = c.preCheckDNS
   141  	defer func() { acme.PreCheckDNS = acmePreCheck }()
   142  
   143  	certResource, err := action()
   144  	if err != nil {
   145  		return false, err
   146  	}
   147  	fmt.Printf("Obtained certificate for %s\n", cfg.CertName)
   148  	if err = c.storage.StoreCertificate(cfg.CertName, certResource); err != nil {
   149  		return true, err
   150  	}
   151  
   152  	return true, nil
   153  }
   154  
   155  func getCertInfo(pemBytes []byte) (names []string, remaining float64, err error) {
   156  	block, _ := pem.Decode(pemBytes)
   157  	if block == nil {
   158  		return nil, 0, fmt.Errorf("Invalid certificate pem data")
   159  	}
   160  	cert, err := x509.ParseCertificate(block.Bytes)
   161  	if err != nil {
   162  		return nil, 0, err
   163  	}
   164  	var daysLeft = float64(cert.NotAfter.Sub(time.Now())) / float64(time.Hour*24)
   165  	return cert.DNSNames, daysLeft, nil
   166  }
   167  
   168  // checks two lists of sans to make sure they have all the same names in them.
   169  func dnsNamesEqual(a []string, b []string) bool {
   170  	if len(a) != len(b) {
   171  		return false
   172  	}
   173  	sort.Strings(a)
   174  	sort.Strings(b)
   175  	for i, s := range a {
   176  		if b[i] != s {
   177  			return false
   178  		}
   179  	}
   180  	return true
   181  }
   182  
   183  func (c *certManager) Present(domain, token, keyAuth string) (e error) {
   184  	d := c.cfg.DomainContainingFQDN(domain)
   185  	name := d.Name
   186  	if seen := c.domains[name]; seen != nil {
   187  		// we've already pre-processed this domain, just need to add to it.
   188  		d = seen
   189  	} else {
   190  		// one-time tasks to get this domain ready.
   191  		// if multiple validations on a single domain, we don't need to rebuild all this.
   192  
   193  		// fix NS records for this domain's DNS providers
   194  		nsList, err := nameservers.DetermineNameservers(d)
   195  		if err != nil {
   196  			return err
   197  		}
   198  		d.Nameservers = nsList
   199  		nameservers.AddNSRecords(d)
   200  
   201  		// make sure we have the latest config before we change anything.
   202  		// alternately, we could avoid a lot of this trouble if we really really trusted no-purge in all cases
   203  		if err := c.ensureNoPendingCorrections(d); err != nil {
   204  			return err
   205  		}
   206  
   207  		// copy domain and work from copy from now on. That way original config can be used to "restore" when we are all done.
   208  		copy, err := d.Copy()
   209  		if err != nil {
   210  			return err
   211  		}
   212  		c.originalDomains = append(c.originalDomains, d)
   213  		c.domains[name] = copy
   214  		d = copy
   215  	}
   216  
   217  	fqdn, val, _ := acme.DNS01Record(domain, keyAuth)
   218  	txt := &models.RecordConfig{Type: "TXT"}
   219  	txt.SetTargetTXT(val)
   220  	txt.SetLabelFromFQDN(fqdn, d.Name)
   221  	d.Records = append(d.Records, txt)
   222  	return getAndRunCorrections(d)
   223  }
   224  
   225  func (c *certManager) ensureNoPendingCorrections(d *models.DomainConfig) error {
   226  	corrections, err := getCorrections(d)
   227  	if err != nil {
   228  		return err
   229  	}
   230  	if len(corrections) != 0 {
   231  		// TODO: maybe allow forcing through this check.
   232  		for _, c := range corrections {
   233  			fmt.Println(c.Msg)
   234  		}
   235  		return fmt.Errorf("Found %d pending corrections for %s. Not going to proceed issuing certificates", len(corrections), d.Name)
   236  	}
   237  	return nil
   238  }
   239  
   240  // IgnoredProviders is a lit of provider names that should not be used to fill challenges.
   241  var IgnoredProviders = map[string]bool{}
   242  
   243  func getCorrections(d *models.DomainConfig) ([]*models.Correction, error) {
   244  	cs := []*models.Correction{}
   245  	for _, p := range d.DNSProviderInstances {
   246  		if IgnoredProviders[p.Name] {
   247  			continue
   248  		}
   249  		dc, err := d.Copy()
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  		corrections, err := p.Driver.GetDomainCorrections(dc)
   254  		if err != nil {
   255  			return nil, err
   256  		}
   257  		for _, c := range corrections {
   258  			c.Msg = fmt.Sprintf("[%s] %s", p.Name, strings.TrimSpace(c.Msg))
   259  		}
   260  		cs = append(cs, corrections...)
   261  	}
   262  	return cs, nil
   263  }
   264  
   265  func getAndRunCorrections(d *models.DomainConfig) error {
   266  	cs, err := getCorrections(d)
   267  	if err != nil {
   268  		return err
   269  	}
   270  	fmt.Printf("%d corrections\n", len(cs))
   271  	for _, c := range cs {
   272  		fmt.Printf("Running [%s]\n", c.Msg)
   273  		err = c.F()
   274  		if err != nil {
   275  			return err
   276  		}
   277  	}
   278  	return nil
   279  }
   280  
   281  func (c *certManager) CleanUp(domain, token, keyAuth string) error {
   282  	// do nothing for now. We will do a final clean up step at the very end.
   283  	return nil
   284  }
   285  
   286  func (c *certManager) finalCleanUp() error {
   287  	log.Println("Cleaning up all records we made")
   288  	var lastError error
   289  	for _, d := range c.originalDomains {
   290  		if err := getAndRunCorrections(d); err != nil {
   291  			log.Printf("ERROR cleaning up: %s", err)
   292  			lastError = err
   293  		}
   294  	}
   295  	return lastError
   296  }