github.com/defang-io/defang/src@v0.0.0-20240505002154-bdf411911834/pkg/cli/cert.go (about)

     1  package cli
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	cliClient "github.com/defang-io/defang/src/pkg/cli/client"
    13  	"github.com/defang-io/defang/src/pkg/spinner"
    14  	"github.com/defang-io/defang/src/pkg/term"
    15  )
    16  
    17  var resolver = net.Resolver{}
    18  var httpClient = http.Client{}
    19  
    20  func GenerateLetsEncryptCert(ctx context.Context, client cliClient.Client) error {
    21  	services, err := client.GetServices(ctx)
    22  	if err != nil {
    23  		return err
    24  	}
    25  
    26  	cnt := 0
    27  	for _, service := range services.Services {
    28  		if service.Service != nil && service.Service.Domainname != "" && service.ZoneId == "" {
    29  			cnt++
    30  			generateCert(ctx, service.Service.Domainname, service.LbDns)
    31  		}
    32  	}
    33  	if cnt == 0 {
    34  		term.Infof("No services found need to generate Let's Encrypt cert")
    35  	}
    36  
    37  	return nil
    38  }
    39  
    40  func generateCert(ctx context.Context, domain, albDns string) {
    41  	term.Infof("Triggering Let's Encrypt cert generation for %v", domain)
    42  	if err := waitForCNAME(ctx, domain, albDns); err != nil {
    43  		term.Errorf("Error waiting for CNAME: %v", err)
    44  		return
    45  	}
    46  
    47  	term.Infof("%v DNS is properly configured!", domain)
    48  	if err := checkTLSCert(ctx, domain); err == nil {
    49  		term.Infof("TLS cert for %v is already ready", domain)
    50  		return
    51  	}
    52  	term.Infof("Triggering cert generation for %v", domain)
    53  	triggerCertGeneration(ctx, domain)
    54  
    55  	term.Infof("Waiting for TLS cert to be online for %v", domain)
    56  	if err := waitForTLS(ctx, domain); err != nil {
    57  		term.Errorf("Error waiting for TLS to be online: %v", err)
    58  		// FIXME: The message below is only valid for BYOC, need to update when playground ACME cert support is added
    59  		term.Errorf("Please check for error messages from `/aws/lambda/acme-lambda` log group in cloudwatch for more details")
    60  		return
    61  	}
    62  
    63  	term.Infof("TLS cert for %v is ready", domain)
    64  }
    65  
    66  func triggerCertGeneration(ctx context.Context, domain string) {
    67  	doSpinner := term.CanColor && term.IsTerminal
    68  	if doSpinner {
    69  		spinCtx, cancel := context.WithCancel(ctx)
    70  		defer cancel()
    71  		go func() {
    72  			term.Stdout.HideCursor()
    73  			defer term.Stdout.ShowCursor()
    74  			ticker := time.NewTicker(1 * time.Second)
    75  			defer ticker.Stop()
    76  			spin := spinner.New()
    77  			for {
    78  				select {
    79  				case <-spinCtx.Done():
    80  					return
    81  				case <-ticker.C:
    82  					fmt.Print(spin.Next())
    83  				}
    84  			}
    85  		}()
    86  	}
    87  	if err := getWithRetries(ctx, fmt.Sprintf("http://%v", domain), 3); err != nil { // Retry incase of DNS error
    88  		// Ignore possible tls error as cert attachment may take time
    89  		term.Debugf("Error triggering cert generation: %v", err)
    90  	}
    91  }
    92  
    93  func waitForTLS(ctx context.Context, domain string) error {
    94  	ticker := time.NewTicker(1 * time.Second)
    95  	defer ticker.Stop()
    96  	timeout, cancel := context.WithTimeout(ctx, 1*time.Minute)
    97  	defer cancel()
    98  
    99  	doSpinner := term.CanColor && term.IsTerminal
   100  	if doSpinner {
   101  		term.Stdout.HideCursor()
   102  		defer term.Stdout.ShowCursor()
   103  	}
   104  	spin := spinner.New()
   105  	for {
   106  		select {
   107  		case <-timeout.Done():
   108  			return timeout.Err()
   109  		case <-ticker.C:
   110  			if err := checkTLSCert(timeout, domain); err == nil {
   111  				return nil
   112  			}
   113  			if doSpinner {
   114  				fmt.Print(spin.Next())
   115  			}
   116  		}
   117  	}
   118  }
   119  
   120  func waitForCNAME(ctx context.Context, domain, albDns string) error {
   121  	ticker := time.NewTicker(1 * time.Second)
   122  	defer ticker.Stop()
   123  
   124  	albDns = strings.TrimSuffix(albDns, ".")
   125  	msgShown := false
   126  	doSpinner := term.CanColor && term.IsTerminal
   127  	if doSpinner {
   128  		term.Stdout.HideCursor()
   129  		defer term.Stdout.ShowCursor()
   130  	}
   131  	spin := spinner.New()
   132  	for {
   133  		select {
   134  		case <-ctx.Done():
   135  			return ctx.Err()
   136  		case <-ticker.C:
   137  			cname, err := resolver.LookupCNAME(ctx, domain)
   138  			cname = strings.TrimSuffix(cname, ".")
   139  			if err != nil || strings.ToLower(cname) != strings.ToLower(albDns) {
   140  				if !msgShown {
   141  					term.Infof("Please setup CNAME record for %v to point to ALB %v, waiting for CNAME record setup and DNS propagation", domain, strings.ToLower(albDns))
   142  					term.Infof("Note: DNS propagation may take a while, we will proceed as soon as the CNAME record is ready, checking...")
   143  					msgShown = true
   144  				}
   145  				if doSpinner {
   146  					fmt.Print(spin.Next())
   147  				}
   148  			} else {
   149  				return nil
   150  			}
   151  		}
   152  	}
   153  }
   154  
   155  func checkTLSCert(ctx context.Context, domain string) error {
   156  	return getWithRetries(ctx, fmt.Sprintf("https://%v", domain), 3)
   157  }
   158  
   159  func getWithRetries(ctx context.Context, url string, tries int) error {
   160  	var errs []error
   161  	for i := 0; i < tries; i++ {
   162  
   163  		req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
   164  		if err != nil {
   165  			return err // No point retrying if we can't even create the request
   166  		}
   167  		if _, err := httpClient.Do(req); err != nil {
   168  			errs = append(errs, err)
   169  		}
   170  
   171  		delay := (100 * time.Millisecond) >> i // Simple exponential backoff
   172  		select {
   173  		case <-time.After(delay):
   174  			continue
   175  		case <-ctx.Done():
   176  			return ctx.Err()
   177  		}
   178  	}
   179  	return errors.Join(errs...)
   180  }