github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/httpserver/tls.go (about)

     1  // Copyright 2018 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package httpserver
     5  
     6  import (
     7  	"crypto/tls"
     8  
     9  	"github.com/juju/errors"
    10  	"github.com/juju/http/v2"
    11  	"golang.org/x/crypto/acme"
    12  	"golang.org/x/crypto/acme/autocert"
    13  
    14  	"github.com/juju/juju/state"
    15  )
    16  
    17  type SNIGetterFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error)
    18  
    19  func aggregateSNIGetter(getters ...SNIGetterFunc) SNIGetterFunc {
    20  	return SNIGetterFunc(func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
    21  		lastErr := errors.Errorf("unable to find certificate for %s",
    22  			hello.ServerName)
    23  		for _, getter := range getters {
    24  			cert, err := getter(hello)
    25  			if err != nil {
    26  				lastErr = err
    27  				continue
    28  			}
    29  			if cert != nil {
    30  				return cert, nil
    31  			}
    32  		}
    33  		return nil, lastErr
    34  	})
    35  }
    36  
    37  // NewTLSConfig returns the TLS configuration for the HTTP server to use
    38  // based on controller configuration stored in the state database.
    39  func NewTLSConfig(st *state.State, defaultSNI SNIGetterFunc, logger Logger) (*tls.Config, error) {
    40  	controllerConfig, err := st.ControllerConfig()
    41  	if err != nil {
    42  		return nil, errors.Trace(err)
    43  	}
    44  	return newTLSConfig(
    45  		controllerConfig.AutocertDNSName(),
    46  		controllerConfig.AutocertURL(),
    47  		st.AutocertCache(),
    48  		defaultSNI,
    49  		logger,
    50  	), nil
    51  }
    52  
    53  func newTLSConfig(
    54  	autocertDNSName, autocertURL string,
    55  	autocertCache autocert.Cache,
    56  	defaultSNI SNIGetterFunc,
    57  	logger Logger,
    58  ) *tls.Config {
    59  	tlsConfig := http.SecureTLSConfig()
    60  	if autocertDNSName == "" {
    61  		// No official DNS name, no certificate.
    62  		tlsConfig.GetCertificate = defaultSNI
    63  		return tlsConfig
    64  	}
    65  
    66  	m := autocert.Manager{
    67  		Prompt:     autocert.AcceptTOS,
    68  		Cache:      autocertCache,
    69  		HostPolicy: autocert.HostWhitelist(autocertDNSName),
    70  	}
    71  	if autocertURL != "" {
    72  		m.Client = &acme.Client{
    73  			DirectoryURL: autocertURL,
    74  		}
    75  	}
    76  	certLogger := SNIGetterFunc(func(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
    77  		logger.Infof("getting certificate for server name %q", h.ServerName)
    78  		return nil, nil
    79  	})
    80  
    81  	autoCertGetter := SNIGetterFunc(func(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
    82  		c, err := m.GetCertificate(h)
    83  		if err != nil {
    84  			logger.Errorf("cannot get autocert certificate for %q: %v",
    85  				h.ServerName, err)
    86  		}
    87  		return c, err
    88  	})
    89  
    90  	tlsConfig.GetCertificate = aggregateSNIGetter(
    91  		certLogger, autoCertGetter, defaultSNI)
    92  	tlsConfig.NextProtos = []string{
    93  		"h2", "http/1.1", // Enable HTTP/2.
    94  		acme.ALPNProto, // Enable TLS-ALPN ACME challenges.
    95  	}
    96  	return tlsConfig
    97  }