
     1  package tlsclient
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"encoding/hex"
     9  	"encoding/pem"
    10  	"fmt"
    11  	"net"
    12  	"net/http"
    13  	"net/url"
    14  	"os"
    15  	"strconv"
    16  	"time"
    18  	""
    19  )
    21  // HTTPEndpoint is a struct for specifying which parameters to use when
    22  // connecting to a HTTP(S) endpoint
    23  type HTTPEndpoint struct {
    24  	Host      string
    25  	Port      int
    26  	Timeout   time.Duration
    27  	EnvPrefix string
    29  	RootCAFile             string
    30  	ClientCertificateFiles ClientCertificateFilePair
    31  	PinnedKey              string
    32  	InsecureSkipValidation bool
    33  	MaxIdleConnsPerHost    int
    34  	DisableCompression     bool
    35  }
    37  // ClientCertificateFilePair is a struct with a certificate and a key pair
    38  type ClientCertificateFilePair struct {
    39  	KeyFile         string
    40  	CertificateFile string
    41  }
    43  type tlsConfig struct {
    44  	clientCertificates []tls.Certificate
    45  	rootCAs            []*x509.Certificate
    46  	pinnedKeys         [][]byte
    47  	skipVerification   bool
    48  }
    50  func generateURL(host string, port int) (*url.URL, error) {
    51  	u, err := url.Parse(host)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	if u.Scheme == "" {
    56  		u = &url.URL{
    57  			Scheme: "http",
    58  			Host:   net.JoinHostPort(host, strconv.Itoa(port)),
    59  		}
    60  	}
    61  	return u, nil
    62  }
    64  // NewHTTPClient creates a http.Client and an url.URL for the given HTTP endpoint
    65  func NewHTTPClient(opt HTTPEndpoint) (client *http.Client, u *url.URL, err error) {
    66  	if opt.Host != "" || opt.Port > 0 {
    67  		u, err = generateURL(opt.Host, opt.Port)
    68  		if err != nil {
    69  			return
    70  		}
    71  	}
    72  	c := &tlsConfig{}
    73  	if u != nil {
    74  		c, u, err = fromURL(c, u)
    75  		if err != nil {
    76  			return
    77  		}
    78  	}
    79  	if opt.EnvPrefix != "" {
    80  		c, err = fromEnv(c, opt.EnvPrefix)
    81  		if err != nil {
    82  			return
    83  		}
    84  	}
    85  	if opt.RootCAFile != "" {
    86  		if err = c.LoadRootCAFile(opt.RootCAFile); err != nil {
    87  			return
    88  		}
    89  	}
    90  	if opt.ClientCertificateFiles.CertificateFile != "" {
    91  		if err = c.LoadClientCertificateFile(
    92  			opt.ClientCertificateFiles.CertificateFile,
    93  			opt.ClientCertificateFiles.KeyFile,
    94  		); err != nil {
    95  			return
    96  		}
    97  	}
    98  	if opt.PinnedKey != "" {
    99  		if err = c.AddHexPinnedKey(opt.PinnedKey); err != nil {
   100  			return
   101  		}
   102  	}
   103  	if opt.InsecureSkipValidation {
   104  		c.SetInsecureSkipValidation()
   105  	}
   106  	transport := http.DefaultTransport.(*http.Transport).Clone()
   107  	transport.TLSClientConfig = c.Config()
   108  	if opt.MaxIdleConnsPerHost > 0 {
   109  		transport.MaxIdleConnsPerHost = opt.MaxIdleConnsPerHost
   110  	}
   111  	if opt.DisableCompression {
   112  		transport.DisableCompression = true
   113  	}
   114  	client = &http.Client{
   115  		Timeout:   opt.Timeout,
   116  		Transport: transport,
   117  	}
   118  	return
   119  }
   121  func fromURL(c *tlsConfig, u *url.URL) (conf *tlsConfig, uCopy *url.URL, err error) {
   122  	uCopy = utils.CloneURL(u)
   123  	q := uCopy.Query()
   124  	if u.Scheme == "https" {
   125  		if rootCAFile := q.Get("ca"); rootCAFile != "" {
   126  			if err = c.LoadRootCAFile(rootCAFile); err != nil {
   127  				return
   128  			}
   129  		}
   130  		if certFile := q.Get("cert"); certFile != "" {
   131  			if keyFile := q.Get("key"); keyFile != "" {
   132  				if err = c.LoadClientCertificateFile(certFile, keyFile); err != nil {
   133  					return
   134  				}
   135  			}
   136  		}
   137  		if hexPinnedKey := q.Get("fp"); hexPinnedKey != "" {
   138  			if err = c.AddHexPinnedKey(hexPinnedKey); err != nil {
   139  				return
   140  			}
   141  		}
   142  		if t := q.Get("validate"); t == "0" || t == "false" || t == "FALSE" {
   143  			c.SetInsecureSkipValidation()
   144  		}
   145  	}
   146  	q.Del("ca")
   147  	q.Del("cert")
   148  	q.Del("key")
   149  	q.Del("fp")
   150  	q.Del("validate")
   151  	uCopy.RawQuery = q.Encode()
   152  	return c, uCopy, nil
   153  }
   155  func fromEnv(c *tlsConfig, envPrefix string) (conf *tlsConfig, err error) {
   156  	if rootCAFile := os.Getenv(envPrefix + "_CA"); rootCAFile != "" {
   157  		if err = c.LoadRootCAFile(rootCAFile); err != nil {
   158  			return
   159  		}
   160  	}
   161  	if certFile := os.Getenv(envPrefix + "_CERT"); certFile != "" {
   162  		if keyFile := os.Getenv(envPrefix + "_KEY"); keyFile != "" {
   163  			if err = c.LoadClientCertificateFile(certFile, keyFile); err != nil {
   164  				return
   165  			}
   166  		}
   167  	}
   168  	if hexPinnedKey := os.Getenv(envPrefix + "_FINGERPRINT"); hexPinnedKey != "" {
   169  		if err = c.AddHexPinnedKey(hexPinnedKey); err != nil {
   170  			return
   171  		}
   172  	}
   173  	if t := os.Getenv(envPrefix + "_VALIDATE"); t == "0" || t == "false" || t == "FALSE" {
   174  		c.SetInsecureSkipValidation()
   175  	}
   176  	return c, nil
   177  }
   179  func (s *tlsConfig) LoadClientCertificateFile(certFile, keyFile string) error {
   180  	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
   181  	if err != nil {
   182  		return fmt.Errorf("tlsclient: could not load client certificate file: %s", err)
   183  	}
   184  	s.clientCertificates = append(s.clientCertificates, cert)
   185  	return nil
   186  }
   188  func (s *tlsConfig) LoadClientCertificate(certPEMBlock, keyPEMBlock []byte) error {
   189  	cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
   190  	if err != nil {
   191  		return fmt.Errorf("tlsclient: could not load client certificate file: %s", err)
   192  	}
   193  	s.clientCertificates = append(s.clientCertificates, cert)
   194  	return nil
   195  }
   197  func (s *tlsConfig) LoadRootCA(rootCA []byte) error {
   198  	cert, err := x509.ParseCertificate(rootCA)
   199  	if err != nil {
   200  		return err
   201  	}
   202  	s.rootCAs = append(s.rootCAs, cert)
   203  	return nil
   204  }
   206  func (s *tlsConfig) LoadRootCAFile(rootCAFile string) error {
   207  	pemCerts, err := os.ReadFile(rootCAFile)
   208  	if err != nil {
   209  		return fmt.Errorf("tlsclient: could not load root CA file %q: %s", rootCAFile, err)
   210  	}
   211  	ok := false
   212  	for len(pemCerts) > 0 {
   213  		var block *pem.Block
   214  		block, pemCerts = pem.Decode(pemCerts)
   215  		if block == nil {
   216  			break
   217  		}
   218  		if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   219  			continue
   220  		}
   221  		if err = s.LoadRootCA(block.Bytes); err != nil {
   222  			continue
   223  		}
   224  		ok = true
   225  	}
   226  	if !ok {
   227  		return fmt.Errorf("tlsclient: could not load any certificate from the given ROOTCA file: %q", rootCAFile)
   228  	}
   229  	return nil
   230  }
   232  func (s *tlsConfig) SetInsecureSkipValidation() {
   233  	s.skipVerification = true
   234  }
   236  func (s *tlsConfig) AddHexPinnedKey(hexPinnedKey string) error {
   237  	pinnedKey, err := hex.DecodeString(hexPinnedKey)
   238  	if err != nil {
   239  		return fmt.Errorf("tlsclient: invalid hexadecimal fingerprint: %s", err)
   240  	}
   241  	expected := sha256.Size
   242  	given := len(pinnedKey)
   243  	if given != expected {
   244  		return fmt.Errorf("tlsclient: invalid fingerprint size for %s, expected %d got %d", hexPinnedKey,
   245  			expected, given)
   246  	}
   247  	s.pinnedKeys = append(s.pinnedKeys, pinnedKey)
   248  	return nil
   249  }
   251  func (s *tlsConfig) Config() *tls.Config {
   252  	conf := &tls.Config{}
   253  	conf.InsecureSkipVerify = s.skipVerification
   255  	if len(s.rootCAs) > 0 {
   256  		rootCAs := x509.NewCertPool()
   257  		for _, cert := range s.rootCAs {
   258  			rootCAs.AddCert(cert)
   259  		}
   260  		conf.RootCAs = rootCAs
   261  	}
   263  	if len(s.clientCertificates) > 0 {
   264  		conf.Certificates = make([]tls.Certificate, len(s.clientCertificates))
   265  		copy(conf.Certificates, s.clientCertificates)
   266  	}
   268  	if len(s.pinnedKeys) > 0 {
   269  		conf.VerifyPeerCertificate = verifyCertificatePinnedKey(s.pinnedKeys)
   270  	}
   271  	return conf
   272  }
   274  func verifyCertificatePinnedKey(pinnedKeys [][]byte) func(certs [][]byte, verifiedChains [][]*x509.Certificate) error {
   275  	return func(certs [][]byte, verifiedChains [][]*x509.Certificate) error {
   276  		// Check for leaf pinning first
   277  		for _, asn1 := range certs {
   278  			cert, err := x509.ParseCertificate(asn1)
   279  			if err != nil {
   280  				return err
   281  			}
   282  			fingerPrint := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
   283  			for _, pinnedKey := range pinnedKeys {
   284  				if bytes.Equal(pinnedKey, fingerPrint[:]) {
   285  					return nil
   286  				}
   287  			}
   288  		}
   289  		// Then check for intermediate pinning
   290  		for _, verifiedChain := range verifiedChains {
   291  			if len(verifiedChain) > 0 {
   292  				verifiedCert := verifiedChain[0]
   293  				fingerPrint := sha256.Sum256(verifiedCert.RawSubjectPublicKeyInfo)
   294  				for _, pinnedKey := range pinnedKeys {
   295  					if bytes.Equal(pinnedKey, fingerPrint[:]) {
   296  						return nil
   297  					}
   298  				}
   299  			}
   300  		}
   301  		return fmt.Errorf("tlsclient: could not find the valid pinned key from proposed ones")
   302  	}
   303  }