github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/pkg/tlsclient/tlsclient.go (about) 1 package tlsclient 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "crypto/tls" 7 "crypto/x509" 8 "encoding/hex" 9 "encoding/pem" 10 "fmt" 11 "net" 12 "net/http" 13 "net/url" 14 "os" 15 "strconv" 16 "time" 17 18 "github.com/cozy/cozy-stack/pkg/utils" 19 ) 20 21 // HTTPEndpoint is a struct for specifying which parameters to use when 22 // connecting to a HTTP(S) endpoint 23 type HTTPEndpoint struct { 24 Host string 25 Port int 26 Timeout time.Duration 27 EnvPrefix string 28 29 RootCAFile string 30 ClientCertificateFiles ClientCertificateFilePair 31 PinnedKey string 32 InsecureSkipValidation bool 33 MaxIdleConnsPerHost int 34 DisableCompression bool 35 } 36 37 // ClientCertificateFilePair is a struct with a certificate and a key pair 38 type ClientCertificateFilePair struct { 39 KeyFile string 40 CertificateFile string 41 } 42 43 type tlsConfig struct { 44 clientCertificates []tls.Certificate 45 rootCAs []*x509.Certificate 46 pinnedKeys [][]byte 47 skipVerification bool 48 } 49 50 func generateURL(host string, port int) (*url.URL, error) { 51 u, err := url.Parse(host) 52 if err != nil { 53 return nil, err 54 } 55 if u.Scheme == "" { 56 u = &url.URL{ 57 Scheme: "http", 58 Host: net.JoinHostPort(host, strconv.Itoa(port)), 59 } 60 } 61 return u, nil 62 } 63 64 // NewHTTPClient creates a http.Client and an url.URL for the given HTTP endpoint 65 func NewHTTPClient(opt HTTPEndpoint) (client *http.Client, u *url.URL, err error) { 66 if opt.Host != "" || opt.Port > 0 { 67 u, err = generateURL(opt.Host, opt.Port) 68 if err != nil { 69 return 70 } 71 } 72 c := &tlsConfig{} 73 if u != nil { 74 c, u, err = fromURL(c, u) 75 if err != nil { 76 return 77 } 78 } 79 if opt.EnvPrefix != "" { 80 c, err = fromEnv(c, opt.EnvPrefix) 81 if err != nil { 82 return 83 } 84 } 85 if opt.RootCAFile != "" { 86 if err = c.LoadRootCAFile(opt.RootCAFile); err != nil { 87 return 88 } 89 } 90 if opt.ClientCertificateFiles.CertificateFile != "" { 91 if err = c.LoadClientCertificateFile( 92 opt.ClientCertificateFiles.CertificateFile, 93 opt.ClientCertificateFiles.KeyFile, 94 ); err != nil { 95 return 96 } 97 } 98 if opt.PinnedKey != "" { 99 if err = c.AddHexPinnedKey(opt.PinnedKey); err != nil { 100 return 101 } 102 } 103 if opt.InsecureSkipValidation { 104 c.SetInsecureSkipValidation() 105 } 106 transport := http.DefaultTransport.(*http.Transport).Clone() 107 transport.TLSClientConfig = c.Config() 108 if opt.MaxIdleConnsPerHost > 0 { 109 transport.MaxIdleConnsPerHost = opt.MaxIdleConnsPerHost 110 } 111 if opt.DisableCompression { 112 transport.DisableCompression = true 113 } 114 client = &http.Client{ 115 Timeout: opt.Timeout, 116 Transport: transport, 117 } 118 return 119 } 120 121 func fromURL(c *tlsConfig, u *url.URL) (conf *tlsConfig, uCopy *url.URL, err error) { 122 uCopy = utils.CloneURL(u) 123 q := uCopy.Query() 124 if u.Scheme == "https" { 125 if rootCAFile := q.Get("ca"); rootCAFile != "" { 126 if err = c.LoadRootCAFile(rootCAFile); err != nil { 127 return 128 } 129 } 130 if certFile := q.Get("cert"); certFile != "" { 131 if keyFile := q.Get("key"); keyFile != "" { 132 if err = c.LoadClientCertificateFile(certFile, keyFile); err != nil { 133 return 134 } 135 } 136 } 137 if hexPinnedKey := q.Get("fp"); hexPinnedKey != "" { 138 if err = c.AddHexPinnedKey(hexPinnedKey); err != nil { 139 return 140 } 141 } 142 if t := q.Get("validate"); t == "0" || t == "false" || t == "FALSE" { 143 c.SetInsecureSkipValidation() 144 } 145 } 146 q.Del("ca") 147 q.Del("cert") 148 q.Del("key") 149 q.Del("fp") 150 q.Del("validate") 151 uCopy.RawQuery = q.Encode() 152 return c, uCopy, nil 153 } 154 155 func fromEnv(c *tlsConfig, envPrefix string) (conf *tlsConfig, err error) { 156 if rootCAFile := os.Getenv(envPrefix + "_CA"); rootCAFile != "" { 157 if err = c.LoadRootCAFile(rootCAFile); err != nil { 158 return 159 } 160 } 161 if certFile := os.Getenv(envPrefix + "_CERT"); certFile != "" { 162 if keyFile := os.Getenv(envPrefix + "_KEY"); keyFile != "" { 163 if err = c.LoadClientCertificateFile(certFile, keyFile); err != nil { 164 return 165 } 166 } 167 } 168 if hexPinnedKey := os.Getenv(envPrefix + "_FINGERPRINT"); hexPinnedKey != "" { 169 if err = c.AddHexPinnedKey(hexPinnedKey); err != nil { 170 return 171 } 172 } 173 if t := os.Getenv(envPrefix + "_VALIDATE"); t == "0" || t == "false" || t == "FALSE" { 174 c.SetInsecureSkipValidation() 175 } 176 return c, nil 177 } 178 179 func (s *tlsConfig) LoadClientCertificateFile(certFile, keyFile string) error { 180 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 181 if err != nil { 182 return fmt.Errorf("tlsclient: could not load client certificate file: %s", err) 183 } 184 s.clientCertificates = append(s.clientCertificates, cert) 185 return nil 186 } 187 188 func (s *tlsConfig) LoadClientCertificate(certPEMBlock, keyPEMBlock []byte) error { 189 cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) 190 if err != nil { 191 return fmt.Errorf("tlsclient: could not load client certificate file: %s", err) 192 } 193 s.clientCertificates = append(s.clientCertificates, cert) 194 return nil 195 } 196 197 func (s *tlsConfig) LoadRootCA(rootCA []byte) error { 198 cert, err := x509.ParseCertificate(rootCA) 199 if err != nil { 200 return err 201 } 202 s.rootCAs = append(s.rootCAs, cert) 203 return nil 204 } 205 206 func (s *tlsConfig) LoadRootCAFile(rootCAFile string) error { 207 pemCerts, err := os.ReadFile(rootCAFile) 208 if err != nil { 209 return fmt.Errorf("tlsclient: could not load root CA file %q: %s", rootCAFile, err) 210 } 211 ok := false 212 for len(pemCerts) > 0 { 213 var block *pem.Block 214 block, pemCerts = pem.Decode(pemCerts) 215 if block == nil { 216 break 217 } 218 if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { 219 continue 220 } 221 if err = s.LoadRootCA(block.Bytes); err != nil { 222 continue 223 } 224 ok = true 225 } 226 if !ok { 227 return fmt.Errorf("tlsclient: could not load any certificate from the given ROOTCA file: %q", rootCAFile) 228 } 229 return nil 230 } 231 232 func (s *tlsConfig) SetInsecureSkipValidation() { 233 s.skipVerification = true 234 } 235 236 func (s *tlsConfig) AddHexPinnedKey(hexPinnedKey string) error { 237 pinnedKey, err := hex.DecodeString(hexPinnedKey) 238 if err != nil { 239 return fmt.Errorf("tlsclient: invalid hexadecimal fingerprint: %s", err) 240 } 241 expected := sha256.Size 242 given := len(pinnedKey) 243 if given != expected { 244 return fmt.Errorf("tlsclient: invalid fingerprint size for %s, expected %d got %d", hexPinnedKey, 245 expected, given) 246 } 247 s.pinnedKeys = append(s.pinnedKeys, pinnedKey) 248 return nil 249 } 250 251 func (s *tlsConfig) Config() *tls.Config { 252 conf := &tls.Config{} 253 conf.InsecureSkipVerify = s.skipVerification 254 255 if len(s.rootCAs) > 0 { 256 rootCAs := x509.NewCertPool() 257 for _, cert := range s.rootCAs { 258 rootCAs.AddCert(cert) 259 } 260 conf.RootCAs = rootCAs 261 } 262 263 if len(s.clientCertificates) > 0 { 264 conf.Certificates = make([]tls.Certificate, len(s.clientCertificates)) 265 copy(conf.Certificates, s.clientCertificates) 266 } 267 268 if len(s.pinnedKeys) > 0 { 269 conf.VerifyPeerCertificate = verifyCertificatePinnedKey(s.pinnedKeys) 270 } 271 return conf 272 } 273 274 func verifyCertificatePinnedKey(pinnedKeys [][]byte) func(certs [][]byte, verifiedChains [][]*x509.Certificate) error { 275 return func(certs [][]byte, verifiedChains [][]*x509.Certificate) error { 276 // Check for leaf pinning first 277 for _, asn1 := range certs { 278 cert, err := x509.ParseCertificate(asn1) 279 if err != nil { 280 return err 281 } 282 fingerPrint := sha256.Sum256(cert.RawSubjectPublicKeyInfo) 283 for _, pinnedKey := range pinnedKeys { 284 if bytes.Equal(pinnedKey, fingerPrint[:]) { 285 return nil 286 } 287 } 288 } 289 // Then check for intermediate pinning 290 for _, verifiedChain := range verifiedChains { 291 if len(verifiedChain) > 0 { 292 verifiedCert := verifiedChain[0] 293 fingerPrint := sha256.Sum256(verifiedCert.RawSubjectPublicKeyInfo) 294 for _, pinnedKey := range pinnedKeys { 295 if bytes.Equal(pinnedKey, fingerPrint[:]) { 296 return nil 297 } 298 } 299 } 300 } 301 return fmt.Errorf("tlsclient: could not find the valid pinned key from proposed ones") 302 } 303 }