github.com/ztalab/ZACA@v0.0.1/pkg/caclient/transport.go (about)

     1  package caclient
     2  
     3  import (
     4  	"crypto/tls"
     5  	"time"
     6  
     7  	"github.com/pkg/errors"
     8  	"go.uber.org/zap"
     9  
    10  	"github.com/cloudflare/backoff"
    11  	"github.com/ztalab/cfssl/csr"
    12  	"github.com/ztalab/cfssl/transport/ca"
    13  	"github.com/ztalab/cfssl/transport/core"
    14  	"github.com/ztalab/cfssl/transport/kp"
    15  	"github.com/ztalab/cfssl/transport/roots"
    16  )
    17  
    18  // A Transport is capable of providing transport-layer security using
    19  // TLS.
    20  type Transport struct {
    21  	CertRefreshDurationRate int
    22  
    23  	// Provider contains a key management provider.
    24  	Provider kp.KeyProvider
    25  
    26  	// CA contains a mechanism for obtaining signed certificates.
    27  	CA ca.CertificateAuthority
    28  
    29  	// TrustStore contains the certificates trusted by this
    30  	// transport.
    31  	TrustStore *roots.TrustStore
    32  
    33  	// ClientTrustStore contains the certificate authorities to
    34  	// use in verifying client authentication certificates.
    35  	ClientTrustStore *roots.TrustStore
    36  
    37  	// Identity contains information about the entity that will be
    38  	// used to construct certificates.
    39  	Identity *core.Identity
    40  
    41  	// Backoff is used to control the behaviour of a Transport
    42  	// when it is attempting to automatically update a certificate
    43  	// as part of AutoUpdate.
    44  	Backoff *backoff.Backoff
    45  
    46  	// RevokeSoftFail, if true, will cause a failure to check
    47  	// revocation (such that the revocation status of a
    48  	// certificate cannot be checked) to not be treated as an
    49  	// error.
    50  	RevokeSoftFail bool
    51  
    52  	manualRevoke bool
    53  
    54  	logger *zap.SugaredLogger
    55  }
    56  
    57  // TLSClientAuthClientConfig Client TLS configuration, changing certificate dynamically
    58  func (tr *Transport) TLSClientAuthClientConfig(host string) (*tls.Config, error) {
    59  	return &tls.Config{
    60  		GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
    61  			cert, err := tr.GetCertificate()
    62  			if err != nil {
    63  				tr.logger.Errorf("Client certificate acquisition error: %v", err)
    64  				return nil, err
    65  			}
    66  			return cert, nil
    67  		},
    68  		RootCAs:      tr.TrustStore.Pool(),
    69  		ServerName:   host,
    70  		CipherSuites: core.CipherSuites,
    71  		MinVersion:   tls.VersionTLS12,
    72  	}, nil
    73  }
    74  
    75  // TLSClientAuthServerConfig The server TLS configuration needs to be changed dynamically
    76  func (tr *Transport) TLSClientAuthServerConfig() (*tls.Config, error) {
    77  	return &tls.Config{
    78  		// Get configuration dynamically
    79  		GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
    80  			tlsConfig := &tls.Config{
    81  				GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
    82  					cert, err := tr.GetCertificate()
    83  					if err != nil {
    84  						tr.logger.Errorf("Server certificate acquisition error: %v", err)
    85  						return nil, err
    86  					}
    87  					return cert, nil
    88  				},
    89  				RootCAs:   tr.TrustStore.Pool(),
    90  				ClientCAs: tr.ClientTrustStore.Pool(),
    91  			}
    92  			return tlsConfig, nil
    93  		},
    94  		ClientAuth:   tls.RequireAndVerifyClientCert,
    95  		CipherSuites: core.CipherSuites,
    96  		MinVersion:   tls.VersionTLS12,
    97  	}, nil
    98  }
    99  
   100  // TLSServerConfig is a general server configuration that should be
   101  // used for non-client authentication purposes, such as HTTPS.
   102  func (tr *Transport) TLSServerConfig() (*tls.Config, error) {
   103  	return &tls.Config{
   104  		GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   105  			cert, err := tr.GetCertificate()
   106  			if err != nil {
   107  				tr.logger.Errorf("Server certificate acquisition error: %v", err)
   108  				return nil, err
   109  			}
   110  			return cert, nil
   111  		},
   112  		RootCAs:      tr.TrustStore.Pool(),
   113  		ClientCAs:    tr.ClientTrustStore.Pool(),
   114  		CipherSuites: core.CipherSuites,
   115  		MinVersion:   tls.VersionTLS12,
   116  		ClientAuth:   tls.VerifyClientCertIfGiven,
   117  	}, nil
   118  }
   119  
   120  // Lifespan Returns the remaining replacement time of a certificate. If it is less than or equal to 0, the certificate must be replaced
   121  // remain Total remaining time of certificate, ava update time
   122  func (tr *Transport) Lifespan() (remain time.Duration, ava time.Duration) {
   123  	cert := tr.Provider.Certificate()
   124  	if cert == nil {
   125  		return 0, 0
   126  	}
   127  
   128  	now := time.Now()
   129  	if now.After(cert.NotAfter) {
   130  		return 0, 0
   131  	}
   132  	remain = cert.NotAfter.Sub(now)
   133  
   134  	certLong := cert.NotAfter.Sub(cert.NotBefore)
   135  	ava = certLong / time.Duration(tr.CertRefreshDurationRate)
   136  
   137  	if tr.manualRevoke {
   138  		tr.manualRevoke = false
   139  		return 0, 0
   140  	}
   141  
   142  	return remain, ava
   143  }
   144  
   145  // ManualRevoke ...
   146  func (tr *Transport) ManualRevoke() {
   147  	tr.manualRevoke = true
   148  }
   149  
   150  // RefreshKeys will make sure the Transport has loaded keys and has a
   151  // valid certificate. It will handle any persistence, check that the
   152  // certificate is valid (i.e. that its expiry date is within the
   153  // Before date), and handle certificate reissuance as needed.
   154  func (tr *Transport) RefreshKeys() (err error) {
   155  	ch := make(chan error, 1)
   156  	go func(tr *Transport) {
   157  		ch <- tr.AsyncRefreshKeys()
   158  	}(tr)
   159  	select {
   160  	case err := <-ch:
   161  		return err
   162  	case <-time.After(5 * time.Second): // 5 seconds timeout
   163  		return errors.New("RefreshKeys timeout")
   164  	}
   165  
   166  }
   167  
   168  // AsyncRefreshKeys timeout handler
   169  func (tr *Transport) AsyncRefreshKeys() error {
   170  	if !tr.Provider.Ready() {
   171  		tr.logger.Debug("key and certificate aren't ready, loading")
   172  		err := tr.Provider.Load()
   173  		if err != nil && !errors.Is(err, kp.ErrCertificateUnavailable) {
   174  			tr.logger.Debugf("failed to load keypair: %v", err)
   175  			kr := tr.Identity.Request.KeyRequest
   176  			if kr == nil {
   177  				kr = csr.NewKeyRequest()
   178  			}
   179  
   180  			// Create a new private key
   181  			tr.logger.Debug("Create a new private key")
   182  			err = tr.Provider.Generate(kr.Algo(), kr.Size())
   183  			if err != nil {
   184  				tr.logger.Debugf("failed to generate key: %v", err)
   185  				return err
   186  			}
   187  			tr.logger.Debug("Created successfully")
   188  		}
   189  	}
   190  
   191  	// Certificate validity
   192  	remain, lifespan := tr.Lifespan()
   193  	if remain < lifespan || lifespan <= 0 {
   194  		// Read the CSR configuration from the filled in request structure
   195  		tr.logger.Debug("Create csr")
   196  		req, err := tr.Provider.CertificateRequest(tr.Identity.Request)
   197  		if err != nil {
   198  			tr.logger.Debugf("couldn't get a CSR: %v", err)
   199  			if tr.Provider.SignalFailure(err) {
   200  				return tr.RefreshKeys()
   201  			}
   202  			return err
   203  		}
   204  		tr.logger.Debug("Create CSR complete")
   205  
   206  		tr.logger.Debug("requesting certificate from CA")
   207  		cert, err := tr.CA.SignCSR(req)
   208  		if err != nil {
   209  			if tr.Provider.SignalFailure(err) {
   210  				return tr.RefreshKeys()
   211  			}
   212  			tr.logger.Debugf("failed to get the certificate signed: %v", err)
   213  			return err
   214  		}
   215  
   216  		tr.logger.Debug("giving the certificate to the provider")
   217  		err = tr.Provider.SetCertificatePEM(cert)
   218  		if err != nil {
   219  			tr.logger.Debugf("failed to set the provider's certificate: %v", err)
   220  			if tr.Provider.SignalFailure(err) {
   221  				return tr.RefreshKeys()
   222  			}
   223  			return err
   224  		}
   225  
   226  		if tr.Provider.Persistent() {
   227  			tr.logger.Debug("storing the certificate")
   228  			err = tr.Provider.Store()
   229  
   230  			if err != nil {
   231  				tr.logger.Debugf("the provider failed to store the certificate: %v", err)
   232  				if tr.Provider.SignalFailure(err) {
   233  					return tr.RefreshKeys()
   234  				}
   235  				return err
   236  			}
   237  		}
   238  	}
   239  	return nil
   240  }
   241  
   242  // GetCertificate ...
   243  func (tr *Transport) GetCertificate() (*tls.Certificate, error) {
   244  	tr.logger.Debug("keygen")
   245  	if !tr.Provider.Ready() {
   246  		tr.logger.Debug("transport isn't ready; attempting to refresh keypair")
   247  		err := tr.RefreshKeys()
   248  		if err != nil {
   249  			tr.logger.Debugf("transport couldn't get a certificate: %v", err)
   250  			return nil, err
   251  		}
   252  	}
   253  
   254  	tr.logger.Debug("keypair")
   255  	cert, err := tr.Provider.X509KeyPair()
   256  	if err != nil {
   257  		tr.logger.Debugf("couldn't generate an X.509 keypair: %v", err)
   258  	}
   259  
   260  	return &cert, nil
   261  }
   262  
   263  // AutoUpdate will automatically update the listener. If a non-nil
   264  // certUpdates chan is provided, it will receive timestamps for
   265  // reissued certificates. If errChan is non-nil, any errors that occur
   266  // in the updater will be passed along.
   267  func (tr *Transport) AutoUpdate() error {
   268  	defer func() {
   269  		if r := recover(); r != nil {
   270  			tr.logger.Errorf("AutoUpdate certificates: %v", r)
   271  		}
   272  	}()
   273  	remain, nextUpdateAt := tr.Lifespan()
   274  	tr.logger.Debugf("attempting to refresh keypair")
   275  	if remain > nextUpdateAt { // Failure to arrive at the rotation time: the rotation time is the certificate validity period of 1/2
   276  		tr.logger.Debugf("Rotation time not reached %v %v", remain, nextUpdateAt)
   277  		return nil
   278  	}
   279  	err := tr.RefreshKeys()
   280  	if err != nil {
   281  		retry := tr.Backoff.Duration()
   282  		tr.logger.Debugf("failed to update certificate, will try again in %s", retry)
   283  		return err
   284  	}
   285  	tr.logger.Debugf("certificate updated")
   286  	tr.Backoff.Reset()
   287  	return nil
   288  }