github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/httpserver/tls.go (about) 1 // Copyright 2018 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package httpserver 5 6 import ( 7 "crypto/tls" 8 9 "github.com/juju/errors" 10 "github.com/juju/http/v2" 11 "golang.org/x/crypto/acme" 12 "golang.org/x/crypto/acme/autocert" 13 14 "github.com/juju/juju/state" 15 ) 16 17 type SNIGetterFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) 18 19 func aggregateSNIGetter(getters ...SNIGetterFunc) SNIGetterFunc { 20 return SNIGetterFunc(func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 21 lastErr := errors.Errorf("unable to find certificate for %s", 22 hello.ServerName) 23 for _, getter := range getters { 24 cert, err := getter(hello) 25 if err != nil { 26 lastErr = err 27 continue 28 } 29 if cert != nil { 30 return cert, nil 31 } 32 } 33 return nil, lastErr 34 }) 35 } 36 37 // NewTLSConfig returns the TLS configuration for the HTTP server to use 38 // based on controller configuration stored in the state database. 39 func NewTLSConfig(st *state.State, defaultSNI SNIGetterFunc, logger Logger) (*tls.Config, error) { 40 controllerConfig, err := st.ControllerConfig() 41 if err != nil { 42 return nil, errors.Trace(err) 43 } 44 return newTLSConfig( 45 controllerConfig.AutocertDNSName(), 46 controllerConfig.AutocertURL(), 47 st.AutocertCache(), 48 defaultSNI, 49 logger, 50 ), nil 51 } 52 53 func newTLSConfig( 54 autocertDNSName, autocertURL string, 55 autocertCache autocert.Cache, 56 defaultSNI SNIGetterFunc, 57 logger Logger, 58 ) *tls.Config { 59 tlsConfig := http.SecureTLSConfig() 60 if autocertDNSName == "" { 61 // No official DNS name, no certificate. 62 tlsConfig.GetCertificate = defaultSNI 63 return tlsConfig 64 } 65 66 m := autocert.Manager{ 67 Prompt: autocert.AcceptTOS, 68 Cache: autocertCache, 69 HostPolicy: autocert.HostWhitelist(autocertDNSName), 70 } 71 if autocertURL != "" { 72 m.Client = &acme.Client{ 73 DirectoryURL: autocertURL, 74 } 75 } 76 certLogger := SNIGetterFunc(func(h *tls.ClientHelloInfo) (*tls.Certificate, error) { 77 logger.Infof("getting certificate for server name %q", h.ServerName) 78 return nil, nil 79 }) 80 81 autoCertGetter := SNIGetterFunc(func(h *tls.ClientHelloInfo) (*tls.Certificate, error) { 82 c, err := m.GetCertificate(h) 83 if err != nil { 84 logger.Errorf("cannot get autocert certificate for %q: %v", 85 h.ServerName, err) 86 } 87 return c, err 88 }) 89 90 tlsConfig.GetCertificate = aggregateSNIGetter( 91 certLogger, autoCertGetter, defaultSNI) 92 tlsConfig.NextProtos = []string{ 93 "h2", "http/1.1", // Enable HTTP/2. 94 acme.ALPNProto, // Enable TLS-ALPN ACME challenges. 95 } 96 return tlsConfig 97 }