github.com/thanos-io/thanos@v0.32.5/pkg/tls/options.go (about)

     1  // Copyright (c) The Thanos Authors.
     2  // Licensed under the Apache License 2.0.
     3  
     4  package tls
     5  
     6  import (
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"os"
    10  	"path/filepath"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/go-kit/log"
    15  	"github.com/go-kit/log/level"
    16  	"github.com/pkg/errors"
    17  )
    18  
    19  // NewServerConfig provides new server TLS configuration.
    20  func NewServerConfig(logger log.Logger, certPath, keyPath, clientCA string) (*tls.Config, error) {
    21  	if keyPath == "" && certPath == "" {
    22  		if clientCA != "" {
    23  			return nil, errors.New("when a client CA is used a server key and certificate must also be provided")
    24  		}
    25  
    26  		level.Info(logger).Log("msg", "disabled TLS, key and cert must be set to enable")
    27  		return nil, nil
    28  	}
    29  
    30  	level.Info(logger).Log("msg", "enabling server side TLS")
    31  
    32  	if keyPath == "" || certPath == "" {
    33  		return nil, errors.New("both server key and certificate must be provided")
    34  	}
    35  
    36  	tlsCfg := &tls.Config{
    37  		MinVersion: tls.VersionTLS13,
    38  	}
    39  	// Certificate is loaded during server startup to check for any errors.
    40  	certificate, err := tls.LoadX509KeyPair(certPath, keyPath)
    41  	if err != nil {
    42  		return nil, errors.Wrap(err, "server credentials")
    43  	}
    44  
    45  	mngr := &serverTLSManager{
    46  		srvCertPath: certPath,
    47  		srvKeyPath:  keyPath,
    48  		srvCert:     &certificate,
    49  	}
    50  
    51  	tlsCfg.GetCertificate = mngr.getCertificate
    52  
    53  	if clientCA != "" {
    54  		caPEM, err := os.ReadFile(filepath.Clean(clientCA))
    55  		if err != nil {
    56  			return nil, errors.Wrap(err, "reading client CA")
    57  		}
    58  
    59  		certPool := x509.NewCertPool()
    60  		if !certPool.AppendCertsFromPEM(caPEM) {
    61  			return nil, errors.Wrap(err, "building client CA")
    62  		}
    63  		tlsCfg.ClientCAs = certPool
    64  		tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert
    65  
    66  		level.Info(logger).Log("msg", "server TLS client verification enabled")
    67  	}
    68  
    69  	return tlsCfg, nil
    70  }
    71  
    72  type serverTLSManager struct {
    73  	srvCertPath string
    74  	srvKeyPath  string
    75  
    76  	mtx            sync.Mutex
    77  	srvCert        *tls.Certificate
    78  	srvCertModTime time.Time
    79  	srvKeyModTime  time.Time
    80  }
    81  
    82  func (m *serverTLSManager) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
    83  	m.mtx.Lock()
    84  	defer m.mtx.Unlock()
    85  
    86  	statCert, err := os.Stat(m.srvCertPath)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  	statKey, err := os.Stat(m.srvKeyPath)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	if m.srvCert == nil || !statCert.ModTime().Equal(m.srvCertModTime) || !statKey.ModTime().Equal(m.srvKeyModTime) {
    96  		cert, err := tls.LoadX509KeyPair(m.srvCertPath, m.srvKeyPath)
    97  		if err != nil {
    98  			return nil, errors.Wrap(err, "server credentials")
    99  		}
   100  		m.srvCertModTime = statCert.ModTime()
   101  		m.srvKeyModTime = statKey.ModTime()
   102  		m.srvCert = &cert
   103  	}
   104  	return m.srvCert, nil
   105  }
   106  
   107  // NewClientConfig provides new client TLS configuration.
   108  func NewClientConfig(logger log.Logger, cert, key, caCert, serverName string, skipVerify bool) (*tls.Config, error) {
   109  	var certPool *x509.CertPool
   110  	if caCert != "" {
   111  		caPEM, err := os.ReadFile(filepath.Clean(caCert))
   112  		if err != nil {
   113  			return nil, errors.Wrap(err, "reading client CA")
   114  		}
   115  
   116  		certPool = x509.NewCertPool()
   117  		if !certPool.AppendCertsFromPEM(caPEM) {
   118  			return nil, errors.Wrap(err, "building client CA")
   119  		}
   120  		level.Info(logger).Log("msg", "TLS client using provided certificate pool")
   121  	} else {
   122  		var err error
   123  		certPool, err = x509.SystemCertPool()
   124  		if err != nil {
   125  			return nil, errors.Wrap(err, "reading system certificate pool")
   126  		}
   127  		level.Info(logger).Log("msg", "TLS client using system certificate pool")
   128  	}
   129  
   130  	tlsCfg := &tls.Config{
   131  		RootCAs: certPool,
   132  	}
   133  
   134  	if serverName != "" {
   135  		tlsCfg.ServerName = serverName
   136  	}
   137  
   138  	if skipVerify {
   139  		tlsCfg.InsecureSkipVerify = true
   140  	}
   141  
   142  	if (key != "") != (cert != "") {
   143  		return nil, errors.New("both client key and certificate must be provided")
   144  	}
   145  
   146  	if cert != "" {
   147  		mngr := &clientTLSManager{
   148  			certPath: cert,
   149  			keyPath:  key,
   150  		}
   151  		tlsCfg.GetClientCertificate = mngr.getClientCertificate
   152  
   153  		level.Info(logger).Log("msg", "TLS client authentication enabled")
   154  	}
   155  	return tlsCfg, nil
   156  }
   157  
   158  type clientTLSManager struct {
   159  	certPath string
   160  	keyPath  string
   161  
   162  	mtx         sync.Mutex
   163  	cert        *tls.Certificate
   164  	certModTime time.Time
   165  	keyModTime  time.Time
   166  }
   167  
   168  func (m *clientTLSManager) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   169  	m.mtx.Lock()
   170  	defer m.mtx.Unlock()
   171  
   172  	statCert, err := os.Stat(m.certPath)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	statKey, err := os.Stat(m.keyPath)
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	if m.cert == nil || !statCert.ModTime().Equal(m.certModTime) || !statKey.ModTime().Equal(m.keyModTime) {
   182  		cert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath)
   183  		if err != nil {
   184  			return nil, errors.Wrap(err, "client credentials")
   185  		}
   186  		m.certModTime = statCert.ModTime()
   187  		m.keyModTime = statKey.ModTime()
   188  		m.cert = &cert
   189  	}
   190  
   191  	return m.cert, nil
   192  }