github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/pkg/tlsclient/tlsclient.go (about)

     1  package tlsclient
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"encoding/hex"
     9  	"encoding/pem"
    10  	"fmt"
    11  	"net"
    12  	"net/http"
    13  	"net/url"
    14  	"os"
    15  	"strconv"
    16  	"time"
    17  
    18  	"github.com/cozy/cozy-stack/pkg/utils"
    19  )
    20  
    21  // HTTPEndpoint is a struct for specifying which parameters to use when
    22  // connecting to a HTTP(S) endpoint
    23  type HTTPEndpoint struct {
    24  	Host      string
    25  	Port      int
    26  	Timeout   time.Duration
    27  	EnvPrefix string
    28  
    29  	RootCAFile             string
    30  	ClientCertificateFiles ClientCertificateFilePair
    31  	PinnedKey              string
    32  	InsecureSkipValidation bool
    33  	MaxIdleConnsPerHost    int
    34  	DisableCompression     bool
    35  }
    36  
    37  // ClientCertificateFilePair is a struct with a certificate and a key pair
    38  type ClientCertificateFilePair struct {
    39  	KeyFile         string
    40  	CertificateFile string
    41  }
    42  
    43  type tlsConfig struct {
    44  	clientCertificates []tls.Certificate
    45  	rootCAs            []*x509.Certificate
    46  	pinnedKeys         [][]byte
    47  	skipVerification   bool
    48  }
    49  
    50  func generateURL(host string, port int) (*url.URL, error) {
    51  	u, err := url.Parse(host)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	if u.Scheme == "" {
    56  		u = &url.URL{
    57  			Scheme: "http",
    58  			Host:   net.JoinHostPort(host, strconv.Itoa(port)),
    59  		}
    60  	}
    61  	return u, nil
    62  }
    63  
    64  // NewHTTPClient creates a http.Client and an url.URL for the given HTTP endpoint
    65  func NewHTTPClient(opt HTTPEndpoint) (client *http.Client, u *url.URL, err error) {
    66  	if opt.Host != "" || opt.Port > 0 {
    67  		u, err = generateURL(opt.Host, opt.Port)
    68  		if err != nil {
    69  			return
    70  		}
    71  	}
    72  	c := &tlsConfig{}
    73  	if u != nil {
    74  		c, u, err = fromURL(c, u)
    75  		if err != nil {
    76  			return
    77  		}
    78  	}
    79  	if opt.EnvPrefix != "" {
    80  		c, err = fromEnv(c, opt.EnvPrefix)
    81  		if err != nil {
    82  			return
    83  		}
    84  	}
    85  	if opt.RootCAFile != "" {
    86  		if err = c.LoadRootCAFile(opt.RootCAFile); err != nil {
    87  			return
    88  		}
    89  	}
    90  	if opt.ClientCertificateFiles.CertificateFile != "" {
    91  		if err = c.LoadClientCertificateFile(
    92  			opt.ClientCertificateFiles.CertificateFile,
    93  			opt.ClientCertificateFiles.KeyFile,
    94  		); err != nil {
    95  			return
    96  		}
    97  	}
    98  	if opt.PinnedKey != "" {
    99  		if err = c.AddHexPinnedKey(opt.PinnedKey); err != nil {
   100  			return
   101  		}
   102  	}
   103  	if opt.InsecureSkipValidation {
   104  		c.SetInsecureSkipValidation()
   105  	}
   106  	transport := http.DefaultTransport.(*http.Transport).Clone()
   107  	transport.TLSClientConfig = c.Config()
   108  	if opt.MaxIdleConnsPerHost > 0 {
   109  		transport.MaxIdleConnsPerHost = opt.MaxIdleConnsPerHost
   110  	}
   111  	if opt.DisableCompression {
   112  		transport.DisableCompression = true
   113  	}
   114  	client = &http.Client{
   115  		Timeout:   opt.Timeout,
   116  		Transport: transport,
   117  	}
   118  	return
   119  }
   120  
   121  func fromURL(c *tlsConfig, u *url.URL) (conf *tlsConfig, uCopy *url.URL, err error) {
   122  	uCopy = utils.CloneURL(u)
   123  	q := uCopy.Query()
   124  	if u.Scheme == "https" {
   125  		if rootCAFile := q.Get("ca"); rootCAFile != "" {
   126  			if err = c.LoadRootCAFile(rootCAFile); err != nil {
   127  				return
   128  			}
   129  		}
   130  		if certFile := q.Get("cert"); certFile != "" {
   131  			if keyFile := q.Get("key"); keyFile != "" {
   132  				if err = c.LoadClientCertificateFile(certFile, keyFile); err != nil {
   133  					return
   134  				}
   135  			}
   136  		}
   137  		if hexPinnedKey := q.Get("fp"); hexPinnedKey != "" {
   138  			if err = c.AddHexPinnedKey(hexPinnedKey); err != nil {
   139  				return
   140  			}
   141  		}
   142  		if t := q.Get("validate"); t == "0" || t == "false" || t == "FALSE" {
   143  			c.SetInsecureSkipValidation()
   144  		}
   145  	}
   146  	q.Del("ca")
   147  	q.Del("cert")
   148  	q.Del("key")
   149  	q.Del("fp")
   150  	q.Del("validate")
   151  	uCopy.RawQuery = q.Encode()
   152  	return c, uCopy, nil
   153  }
   154  
   155  func fromEnv(c *tlsConfig, envPrefix string) (conf *tlsConfig, err error) {
   156  	if rootCAFile := os.Getenv(envPrefix + "_CA"); rootCAFile != "" {
   157  		if err = c.LoadRootCAFile(rootCAFile); err != nil {
   158  			return
   159  		}
   160  	}
   161  	if certFile := os.Getenv(envPrefix + "_CERT"); certFile != "" {
   162  		if keyFile := os.Getenv(envPrefix + "_KEY"); keyFile != "" {
   163  			if err = c.LoadClientCertificateFile(certFile, keyFile); err != nil {
   164  				return
   165  			}
   166  		}
   167  	}
   168  	if hexPinnedKey := os.Getenv(envPrefix + "_FINGERPRINT"); hexPinnedKey != "" {
   169  		if err = c.AddHexPinnedKey(hexPinnedKey); err != nil {
   170  			return
   171  		}
   172  	}
   173  	if t := os.Getenv(envPrefix + "_VALIDATE"); t == "0" || t == "false" || t == "FALSE" {
   174  		c.SetInsecureSkipValidation()
   175  	}
   176  	return c, nil
   177  }
   178  
   179  func (s *tlsConfig) LoadClientCertificateFile(certFile, keyFile string) error {
   180  	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
   181  	if err != nil {
   182  		return fmt.Errorf("tlsclient: could not load client certificate file: %s", err)
   183  	}
   184  	s.clientCertificates = append(s.clientCertificates, cert)
   185  	return nil
   186  }
   187  
   188  func (s *tlsConfig) LoadClientCertificate(certPEMBlock, keyPEMBlock []byte) error {
   189  	cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
   190  	if err != nil {
   191  		return fmt.Errorf("tlsclient: could not load client certificate file: %s", err)
   192  	}
   193  	s.clientCertificates = append(s.clientCertificates, cert)
   194  	return nil
   195  }
   196  
   197  func (s *tlsConfig) LoadRootCA(rootCA []byte) error {
   198  	cert, err := x509.ParseCertificate(rootCA)
   199  	if err != nil {
   200  		return err
   201  	}
   202  	s.rootCAs = append(s.rootCAs, cert)
   203  	return nil
   204  }
   205  
   206  func (s *tlsConfig) LoadRootCAFile(rootCAFile string) error {
   207  	pemCerts, err := os.ReadFile(rootCAFile)
   208  	if err != nil {
   209  		return fmt.Errorf("tlsclient: could not load root CA file %q: %s", rootCAFile, err)
   210  	}
   211  	ok := false
   212  	for len(pemCerts) > 0 {
   213  		var block *pem.Block
   214  		block, pemCerts = pem.Decode(pemCerts)
   215  		if block == nil {
   216  			break
   217  		}
   218  		if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   219  			continue
   220  		}
   221  		if err = s.LoadRootCA(block.Bytes); err != nil {
   222  			continue
   223  		}
   224  		ok = true
   225  	}
   226  	if !ok {
   227  		return fmt.Errorf("tlsclient: could not load any certificate from the given ROOTCA file: %q", rootCAFile)
   228  	}
   229  	return nil
   230  }
   231  
   232  func (s *tlsConfig) SetInsecureSkipValidation() {
   233  	s.skipVerification = true
   234  }
   235  
   236  func (s *tlsConfig) AddHexPinnedKey(hexPinnedKey string) error {
   237  	pinnedKey, err := hex.DecodeString(hexPinnedKey)
   238  	if err != nil {
   239  		return fmt.Errorf("tlsclient: invalid hexadecimal fingerprint: %s", err)
   240  	}
   241  	expected := sha256.Size
   242  	given := len(pinnedKey)
   243  	if given != expected {
   244  		return fmt.Errorf("tlsclient: invalid fingerprint size for %s, expected %d got %d", hexPinnedKey,
   245  			expected, given)
   246  	}
   247  	s.pinnedKeys = append(s.pinnedKeys, pinnedKey)
   248  	return nil
   249  }
   250  
   251  func (s *tlsConfig) Config() *tls.Config {
   252  	conf := &tls.Config{}
   253  	conf.InsecureSkipVerify = s.skipVerification
   254  
   255  	if len(s.rootCAs) > 0 {
   256  		rootCAs := x509.NewCertPool()
   257  		for _, cert := range s.rootCAs {
   258  			rootCAs.AddCert(cert)
   259  		}
   260  		conf.RootCAs = rootCAs
   261  	}
   262  
   263  	if len(s.clientCertificates) > 0 {
   264  		conf.Certificates = make([]tls.Certificate, len(s.clientCertificates))
   265  		copy(conf.Certificates, s.clientCertificates)
   266  	}
   267  
   268  	if len(s.pinnedKeys) > 0 {
   269  		conf.VerifyPeerCertificate = verifyCertificatePinnedKey(s.pinnedKeys)
   270  	}
   271  	return conf
   272  }
   273  
   274  func verifyCertificatePinnedKey(pinnedKeys [][]byte) func(certs [][]byte, verifiedChains [][]*x509.Certificate) error {
   275  	return func(certs [][]byte, verifiedChains [][]*x509.Certificate) error {
   276  		// Check for leaf pinning first
   277  		for _, asn1 := range certs {
   278  			cert, err := x509.ParseCertificate(asn1)
   279  			if err != nil {
   280  				return err
   281  			}
   282  			fingerPrint := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
   283  			for _, pinnedKey := range pinnedKeys {
   284  				if bytes.Equal(pinnedKey, fingerPrint[:]) {
   285  					return nil
   286  				}
   287  			}
   288  		}
   289  		// Then check for intermediate pinning
   290  		for _, verifiedChain := range verifiedChains {
   291  			if len(verifiedChain) > 0 {
   292  				verifiedCert := verifiedChain[0]
   293  				fingerPrint := sha256.Sum256(verifiedCert.RawSubjectPublicKeyInfo)
   294  				for _, pinnedKey := range pinnedKeys {
   295  					if bytes.Equal(pinnedKey, fingerPrint[:]) {
   296  						return nil
   297  					}
   298  				}
   299  			}
   300  		}
   301  		return fmt.Errorf("tlsclient: could not find the valid pinned key from proposed ones")
   302  	}
   303  }