github.com/microsoft/moc@v0.17.1/pkg/certs/certs_test.go (about)

     1  // Copyright (c) Microsoft Corporation. All rights reserved.
     2  // Licensed under the Apache v2.0 license.
     3  package certs
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"crypto/rand"
     9  	"crypto/rsa"
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"crypto/x509/pkix"
    13  	"encoding/asn1"
    14  	"fmt"
    15  	"log"
    16  	"math"
    17  	"math/big"
    18  	"net"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/microsoft/moc/pkg/errors"
    23  
    24  	gomock "github.com/golang/mock/gomock"
    25  	mock "github.com/microsoft/moc/pkg/certs/mock"
    26  	"github.com/microsoft/moc/rpc/testagent"
    27  	"github.com/stretchr/testify/assert"
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/codes"
    30  	"google.golang.org/grpc/credentials"
    31  	"google.golang.org/grpc/status"
    32  )
    33  
    34  func IsTransportUnavailable(err error) bool {
    35  	if e, ok := status.FromError(err); ok && e.Code() == codes.Unavailable {
    36  		return true
    37  	}
    38  	return false
    39  }
    40  
    41  type TestTlsServer struct {
    42  }
    43  
    44  func (s *TestTlsServer) PingHello(ctx context.Context, in *testagent.Hello) (*testagent.Hello, error) {
    45  	return &testagent.Hello{Name: "Hello From the Server!" + in.Name}, nil
    46  }
    47  
    48  func startHelloServer(grpcServer *grpc.Server, address string) {
    49  	lis, err := net.Listen("tcp", address)
    50  	if err != nil {
    51  		log.Fatalf("failed to listen: %v", err)
    52  	}
    53  	tlsServer := TestTlsServer{}
    54  	testagent.RegisterHelloAgentServer(grpcServer, &tlsServer)
    55  	if err := grpcServer.Serve(lis); err != nil {
    56  		log.Fatalf("failed to serve: %s", err)
    57  	}
    58  }
    59  
    60  type CertAuthority struct {
    61  	ca *CertificateAuthority
    62  }
    63  
    64  func (auth *CertAuthority) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
    65  	return auth.ca.VerifyClientCertificate(rawCerts)
    66  }
    67  
    68  func getTlsCreds(t *testing.T, tlsCert tls.Certificate, certAuth *CertAuthority) credentials.TransportCredentials {
    69  
    70  	return credentials.NewTLS(&tls.Config{
    71  		CipherSuites: []uint16{
    72  			tls.TLS_AES_128_GCM_SHA256,
    73  			tls.TLS_AES_256_GCM_SHA384,
    74  			tls.TLS_CHACHA20_POLY1305_SHA256,
    75  			tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
    76  			tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
    77  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
    78  		},
    79  		MinVersion:               tls.VersionTLS12,
    80  		PreferServerCipherSuites: true,
    81  		ClientAuth:               tls.RequestClientCert,
    82  		Certificates:             []tls.Certificate{tlsCert},
    83  		VerifyPeerCertificate:    certAuth.VerifyPeerCertificate,
    84  	})
    85  }
    86  
    87  func getGrpcServer(t *testing.T, creds credentials.TransportCredentials) *grpc.Server {
    88  	var opts []grpc.ServerOption
    89  	opts = append(opts, grpc.Creds(creds))
    90  	grpcServer := grpc.NewServer(opts...)
    91  	return grpcServer
    92  }
    93  
    94  func makeTlsCall(t *testing.T, address string, provider credentials.TransportCredentials) (*testagent.Hello, error) {
    95  	var conn *grpc.ClientConn
    96  	var err error
    97  	if provider != nil {
    98  		conn, err = grpc.Dial(address, grpc.WithTransportCredentials(provider))
    99  	} else {
   100  		conn, err = grpc.Dial(address, grpc.WithInsecure())
   101  	}
   102  	assert.NoErrorf(t, err, "Failed to dial", err)
   103  	defer conn.Close()
   104  	c := testagent.NewHelloAgentClient(conn)
   105  	return c.PingHello(context.Background(), &testagent.Hello{Name: "TLSServer"})
   106  }
   107  
   108  func createTestCertificate(before, after time.Time) (string, error) {
   109  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   110  	if err != nil {
   111  		return "", err
   112  	}
   113  
   114  	serial, err := rand.Int(rand.Reader, new(big.Int).SetInt64(math.MaxInt64))
   115  	if err != nil {
   116  		return "", err
   117  	}
   118  
   119  	tmpl := x509.Certificate{
   120  		SerialNumber: serial,
   121  		Subject: pkix.Name{
   122  			CommonName:   "test",
   123  			Organization: []string{"microsoft"},
   124  		},
   125  		NotBefore:             before,
   126  		NotAfter:              after,
   127  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   128  		MaxPathLenZero:        true,
   129  		BasicConstraintsValid: true,
   130  		MaxPathLen:            0,
   131  		IsCA:                  true,
   132  	}
   133  
   134  	b, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, key.Public(), key)
   135  	if err != nil {
   136  		return "", err
   137  	}
   138  
   139  	x509Cert, err := x509.ParseCertificate(b)
   140  	if err != nil {
   141  		return "", err
   142  	}
   143  
   144  	pemCert := EncodeCertPEM(x509Cert)
   145  	return string(pemCert), nil
   146  }
   147  
   148  func NewTransportCredentialFromAuthFromPem(serverName string, tlsCert tls.Certificate, caCertPem []byte) (credentials.TransportCredentials, error) {
   149  	certPool := x509.NewCertPool()
   150  	// Append the client certificates from the CA
   151  	if ok := certPool.AppendCertsFromPEM(caCertPem); !ok {
   152  		return nil, fmt.Errorf("could not append the server certificate")
   153  	}
   154  	creds := &tls.Config{
   155  		ServerName:   serverName,
   156  		RootCAs:      certPool,
   157  		Certificates: []tls.Certificate{tlsCert},
   158  	}
   159  	return credentials.NewTLS(creds), nil
   160  }
   161  
   162  func Test_TLSServer(t *testing.T) {
   163  	server := "localhost"
   164  	port := "9000"
   165  	address := server + ":" + port
   166  	ca, key, err := GenerateClientCertificate("test CA")
   167  	assert.NoErrorf(t, err, "Error creation in CA certificate failed: %v", err)
   168  
   169  	rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key))
   170  	assert.NoErrorf(t, err, "Failed to load root key pair: %v", err)
   171  
   172  	caConfig := CAConfig{
   173  		RootSigner: &rootSigner,
   174  	}
   175  
   176  	caAuth, err := NewCertificateAuthority(&caConfig)
   177  	assert.NoErrorf(t, err, "Error creation CA Auth: %v", err)
   178  
   179  	certPem := EncodeCertPEM(ca)
   180  	keyPem := EncodePrivateKeyPEM(key)
   181  	tlsCert, err := tls.X509KeyPair(certPem, keyPem)
   182  	assert.NoErrorf(t, err, "Failed to get tls cert", err)
   183  
   184  	creds := getTlsCreds(t, tlsCert, &CertAuthority{caAuth})
   185  	grpcServer := getGrpcServer(t, creds)
   186  	go startHelloServer(grpcServer, address)
   187  	defer grpcServer.Stop()
   188  	time.Sleep((time.Second * 3))
   189  	conf := Config{
   190  		CommonName:   "Test Cert",
   191  		Organization: []string{"microsoft"},
   192  	}
   193  	conf.AltNames.DNSNames = []string{"Test Cert"}
   194  	csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil)
   195  	assert.NoErrorf(t, err, "Error creation in CSR: %v", err)
   196  
   197  	signConf := SignConfig{Offset: time.Second * 5}
   198  	clientCertPem, err := caAuth.SignRequest(csr, nil, &signConf)
   199  	assert.NoErrorf(t, err, "Error signing CSR: %v", err)
   200  	tlsClientCert, err := tls.X509KeyPair(clientCertPem, keyClientPem)
   201  	assert.NoErrorf(t, err, "Failed to get tls cert", err)
   202  
   203  	provider, err := NewTransportCredentialFromAuthFromPem(server, tlsClientCert, EncodeCertPEM(ca))
   204  	assert.NoErrorf(t, err, "Failed to create TLS Credentials", err)
   205  	// Making the certificate invalid
   206  	time.Sleep((time.Second * 10))
   207  	_, err = makeTlsCall(t, address, provider)
   208  	assert.True(t, IsTransportUnavailable(err))
   209  }
   210  
   211  func Test_CACerts(t *testing.T) {
   212  	ca, key, err := GenerateClientCertificate("test CA")
   213  	assert.NoErrorf(t, err, "Error creation in CA certificate failed: %v", err)
   214  
   215  	rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key))
   216  	assert.NoErrorf(t, err, "Failed to load root key pair: %v", err)
   217  
   218  	caConfig := CAConfig{
   219  		RootSigner: &rootSigner,
   220  	}
   221  	caAuth, err := NewCertificateAuthority(&caConfig)
   222  	assert.NoErrorf(t, err, "Error creation CA Auth: %v", err)
   223  
   224  	conf := Config{
   225  		CommonName:   "Test Cert",
   226  		Organization: []string{"microsoft"},
   227  	}
   228  	conf.AltNames.DNSNames = []string{"Test Cert"}
   229  	csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil)
   230  	assert.NoErrorf(t, err, "Error creation in CSR: %v", err)
   231  	keyClient, err := DecodePrivateKeyPEM(keyClientPem)
   232  	assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err)
   233  	clientCertPem, err := caAuth.SignRequest(csr, nil, nil)
   234  	assert.NoErrorf(t, err, "Error signing CSR: %v", err)
   235  	clientCert, err := DecodeCertPEM(clientCertPem)
   236  	assert.NoErrorf(t, err, "Failed Decoding cert: %v", err)
   237  	if (clientCert.NotAfter.Sub(clientCert.NotBefore)) != (time.Hour * 24 * 365) {
   238  		t.Errorf("Invalid certificate expiry")
   239  	}
   240  
   241  	foundCertDER := false
   242  	foundRenewCount := false
   243  	for _, ext := range clientCert.Extensions {
   244  		if ext.Id.Equal(oidOriginalCertificate) {
   245  			foundCertDER = true
   246  		} else if ext.Id.Equal(oidRenewCount) {
   247  			foundRenewCount = true
   248  		}
   249  	}
   250  
   251  	if foundRenewCount || foundCertDER {
   252  		t.Errorf("Found certDER or renewCount Extensions")
   253  	}
   254  
   255  	roots := x509.NewCertPool()
   256  	roots.AddCert(ca)
   257  
   258  	opts := x509.VerifyOptions{
   259  		Roots:     roots,
   260  		DNSName:   "Test Cert",
   261  		KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
   262  	}
   263  
   264  	if _, err := clientCert.Verify(opts); err != nil {
   265  		panic("failed to verify certificate: " + err.Error())
   266  	}
   267  	if _, err = tls.X509KeyPair(EncodeCertPEM(clientCert), EncodePrivateKeyPEM(keyClient)); err != nil {
   268  		t.Errorf("Error Verifying key and cert: %s", err.Error())
   269  	}
   270  }
   271  
   272  func Test_CACertsVerify(t *testing.T) {
   273  	ca, key, err := GenerateClientCertificate("test CA")
   274  	assert.NoErrorf(t, err, "Error creation in CA certificate failed: %v", err)
   275  
   276  	rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key))
   277  	assert.NoErrorf(t, err, "Failed to load root key pair: %v", err)
   278  
   279  	caConfig := CAConfig{
   280  		RootSigner: &rootSigner,
   281  	}
   282  
   283  	caAuth, err := NewCertificateAuthority(&caConfig)
   284  	assert.NoErrorf(t, err, "Error creation CA Auth: %v", err)
   285  
   286  	conf := Config{
   287  		CommonName:   "Test Cert",
   288  		Organization: []string{"microsoft"},
   289  	}
   290  	conf.AltNames.DNSNames = []string{"Test Cert"}
   291  	csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil)
   292  	assert.NoErrorf(t, err, "Error creation in CSR: %v", err)
   293  	keyClient, err := DecodePrivateKeyPEM(keyClientPem)
   294  	assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err)
   295  
   296  	signConf := SignConfig{Offset: time.Second * 5}
   297  	clientCertPem, err := caAuth.SignRequest(csr, nil, &signConf)
   298  	assert.NoErrorf(t, err, "Error signing CSR: %v", err)
   299  
   300  	clientCert, err := DecodeCertPEM(clientCertPem)
   301  	assert.NoErrorf(t, err, "Failed Decoding cert: %v", err)
   302  
   303  	if (clientCert.NotAfter.Sub(clientCert.NotBefore)) != signConf.Offset {
   304  		t.Errorf("Invalid certificate expiry")
   305  	}
   306  
   307  	foundCertDER := false
   308  	foundRenewCount := false
   309  	for _, ext := range clientCert.Extensions {
   310  		if ext.Id.Equal(oidOriginalCertificate) {
   311  			foundCertDER = true
   312  		} else if ext.Id.Equal(oidRenewCount) {
   313  			foundRenewCount = true
   314  		}
   315  	}
   316  
   317  	if foundRenewCount || foundCertDER {
   318  		t.Errorf("Found certDER or renewCount Extensions")
   319  	}
   320  
   321  	clientCerts := [][]byte{clientCert.Raw}
   322  
   323  	err = caAuth.VerifyClientCertificate(clientCerts)
   324  	assert.NoErrorf(t, err, "failed to verify certificate: %v", err)
   325  
   326  	time.Sleep(time.Second * 6)
   327  	err = caAuth.VerifyClientCertificate(clientCerts)
   328  	assert.Errorf(t, err, "failed to verify certificate after Expiry")
   329  
   330  	_, err = tls.X509KeyPair(EncodeCertPEM(clientCert), EncodePrivateKeyPEM(keyClient))
   331  	assert.NoErrorf(t, err, "Error Verifying key and cert: %v", err)
   332  }
   333  
   334  func Test_CACertsRenewVerify(t *testing.T) {
   335  	ca, key, err := GenerateClientCertificate("test CA")
   336  	assert.NoErrorf(t, err, "Error creation in CA certificate failed: %v", err)
   337  
   338  	rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key))
   339  	assert.NoErrorf(t, err, "Failed to load root key pair: %v", err)
   340  
   341  	caConfig := CAConfig{
   342  		RootSigner: &rootSigner,
   343  	}
   344  	caAuth, err := NewCertificateAuthority(&caConfig)
   345  	assert.NoErrorf(t, err, "Error creation CA Auth: %v", err)
   346  
   347  	conf := Config{
   348  		CommonName:   "Test Cert",
   349  		Organization: []string{"microsoft"},
   350  	}
   351  	conf.AltNames.DNSNames = []string{"Test Cert"}
   352  	csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil)
   353  	assert.NoErrorf(t, err, "Error creation in CSR: %v", err)
   354  
   355  	keyClient, err := DecodePrivateKeyPEM(keyClientPem)
   356  	assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err)
   357  
   358  	signConf := SignConfig{Offset: time.Second * 5}
   359  	clientCertPem, err := caAuth.SignRequest(csr, nil, &signConf)
   360  	assert.NoErrorf(t, err, "Error signing CSR: %v", err)
   361  
   362  	clientCert, err := DecodeCertPEM(clientCertPem)
   363  	assert.NoErrorf(t, err, "Failed Decoding cert: %v", err)
   364  
   365  	// Test certificate duration
   366  	if (clientCert.NotAfter.Sub(clientCert.NotBefore)) != signConf.Offset {
   367  		t.Errorf("Invalid certificate expiry")
   368  	}
   369  
   370  	clientCerts := [][]byte{clientCert.Raw}
   371  
   372  	err = caAuth.VerifyClientCertificate(clientCerts)
   373  	assert.NoErrorf(t, err, "Failed to verify certificate: %v", err)
   374  
   375  	oldcert, err := tls.X509KeyPair(EncodeCertPEM(clientCert), EncodePrivateKeyPEM(keyClient))
   376  	assert.NoErrorf(t, err, "Error creating X509 keypair: %v", err)
   377  
   378  	// ================= Renew 1 ========================
   379  	csr1, keyClient1Pem, err := GenerateCertificateRenewRequest(&oldcert)
   380  	assert.NoErrorf(t, err, "Error creating renew CSR: %v", err)
   381  
   382  	keyClient1, err := DecodePrivateKeyPEM(keyClient1Pem)
   383  	assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err)
   384  
   385  	signConf = SignConfig{Offset: time.Second * 20}
   386  	certClient1Pem, err := caAuth.SignRequest(csr1, clientCert.Raw, &signConf)
   387  	assert.NoErrorf(t, err, "Error signing CSR: %v", err)
   388  
   389  	certClient1, err := DecodeCertPEM(certClient1Pem)
   390  	assert.NoErrorf(t, err, "Failed Decoding cert: %v", err)
   391  
   392  	// Test certificate duration
   393  	if (certClient1.NotAfter.Sub(certClient1.NotBefore)) != (time.Second * 5) {
   394  		t.Errorf("Invalid certificate expiry")
   395  	}
   396  
   397  	foundCertDER := false
   398  	foundRenewCount := false
   399  	var origCertDER []byte
   400  	var renewCount int64 = 0
   401  	for _, ext := range certClient1.Extensions {
   402  		if ext.Id.Equal(oidOriginalCertificate) {
   403  			origCertDER = ext.Value
   404  			foundCertDER = true
   405  		} else if ext.Id.Equal(oidRenewCount) {
   406  			asn1.Unmarshal(ext.Value, &renewCount)
   407  			foundRenewCount = true
   408  		}
   409  	}
   410  
   411  	if !(foundRenewCount && foundCertDER) {
   412  		t.Errorf("Not found certDER or renewCount Extensions")
   413  	}
   414  
   415  	if !bytes.Equal(origCertDER, clientCert.Raw) {
   416  		t.Errorf("Extension not Matching old cert")
   417  	}
   418  
   419  	if renewCount != 1 {
   420  		t.Errorf("Extension renew count is wrong")
   421  	}
   422  
   423  	clientCerts = [][]byte{certClient1.Raw}
   424  	err = caAuth.VerifyClientCertificate(clientCerts)
   425  	assert.NoErrorf(t, err, "failed to verify certificate: %v", err)
   426  	_, err = tls.X509KeyPair(EncodeCertPEM(certClient1), EncodePrivateKeyPEM(keyClient1))
   427  	assert.NoErrorf(t, err, "Error Verifying key and cert: %v", err)
   428  
   429  	// ================= Renew 2 ========================
   430  	oldcert, err = tls.X509KeyPair(EncodeCertPEM(certClient1), EncodePrivateKeyPEM(keyClient1))
   431  	assert.NoErrorf(t, err, "Error creating X509 keypair: %v", err)
   432  
   433  	csr2, keyClient2Pem, err := GenerateCertificateRenewRequest(&oldcert)
   434  	assert.NoErrorf(t, err, "Error creating renew CSR: %v", err)
   435  
   436  	keyClient2, err := DecodePrivateKeyPEM(keyClient2Pem)
   437  	assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err)
   438  
   439  	certClient2Pem, err := caAuth.SignRequest(csr2, certClient1.Raw, nil)
   440  	assert.NoErrorf(t, err, "Error signing CSR: %v", err)
   441  	certClient2, err := DecodeCertPEM(certClient2Pem)
   442  	assert.NoErrorf(t, err, "Failed Decoding cert: %v", err)
   443  
   444  	// Test certificate duration
   445  	if (certClient2.NotAfter.Sub(certClient2.NotBefore)) != (time.Second * 5) {
   446  		t.Errorf("Invalid certificate expiry")
   447  	}
   448  
   449  	foundCertDER = false
   450  	foundRenewCount = false
   451  	for _, ext := range certClient2.Extensions {
   452  		if ext.Id.Equal(oidOriginalCertificate) {
   453  			origCertDER = ext.Value
   454  			foundCertDER = true
   455  		} else if ext.Id.Equal(oidRenewCount) {
   456  			asn1.Unmarshal(ext.Value, &renewCount)
   457  			foundRenewCount = true
   458  		}
   459  	}
   460  
   461  	if !(foundRenewCount && foundCertDER) {
   462  		t.Errorf("Not found certDER or renewCount Extensions")
   463  	}
   464  
   465  	// The origCertDER should point to the first cert
   466  	if !bytes.Equal(origCertDER, clientCert.Raw) {
   467  		t.Errorf("Extension not Matching old cert")
   468  	}
   469  
   470  	if renewCount != 2 {
   471  		t.Errorf("Extension renew count is wrong")
   472  	}
   473  
   474  	clientCerts = [][]byte{certClient2.Raw}
   475  	err = caAuth.VerifyClientCertificate(clientCerts)
   476  	assert.NoErrorf(t, err, "failed to verify certificate: %v", err)
   477  	_, err = tls.X509KeyPair(EncodeCertPEM(certClient2), EncodePrivateKeyPEM(keyClient2))
   478  	assert.NoErrorf(t, err, "Error Verifying key and cert: %v", err)
   479  }
   480  
   481  func Test_CACertsRenewVerifySameKey(t *testing.T) {
   482  	ca, key, err := GenerateClientCertificate("test CA")
   483  	if err != nil {
   484  		t.Errorf("Error creation in CA certificate failed: %s", err.Error())
   485  	}
   486  
   487  	rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key))
   488  	if err != nil {
   489  		t.Errorf("Failed to load root key pair: %v", err)
   490  		return
   491  	}
   492  
   493  	caConfig := CAConfig{
   494  		RootSigner: &rootSigner,
   495  	}
   496  
   497  	caAuth, err := NewCertificateAuthority(&caConfig)
   498  	if err != nil {
   499  		t.Errorf("Error creation CA Auth: %s", err.Error())
   500  	}
   501  
   502  	conf := Config{
   503  		CommonName:   "Test Cert",
   504  		Organization: []string{"microsoft"},
   505  	}
   506  	conf.AltNames.DNSNames = []string{"Test Cert"}
   507  	csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil)
   508  	if err != nil {
   509  		t.Errorf("Error creation in CSR: %s", err.Error())
   510  	}
   511  	keyClient, err := DecodePrivateKeyPEM(keyClientPem)
   512  	if err != nil {
   513  		t.Errorf("Failed Decoding privatekey: %s", err.Error())
   514  	}
   515  	signConf := SignConfig{Offset: time.Second * 5}
   516  	clientCertPem, err := caAuth.SignRequest(csr, nil, &signConf)
   517  	if err != nil {
   518  		t.Errorf("Error signing CSR: %s", err.Error())
   519  	}
   520  	clientCert, err := DecodeCertPEM(clientCertPem)
   521  	if err != nil {
   522  		t.Errorf("Failed Decoding cert: %s", err.Error())
   523  	}
   524  	// Test certificate duration
   525  	if (clientCert.NotAfter.Sub(clientCert.NotBefore)) != signConf.Offset {
   526  		t.Errorf("Invalid certificate expiry")
   527  	}
   528  
   529  	clientCerts := [][]byte{clientCert.Raw}
   530  
   531  	if err := caAuth.VerifyClientCertificate(clientCerts); err != nil {
   532  		panic("failed to verify certificate: " + err.Error())
   533  	}
   534  	oldcert, err := tls.X509KeyPair(EncodeCertPEM(clientCert), EncodePrivateKeyPEM(keyClient))
   535  	if err != nil {
   536  		t.Errorf("Error creating X509 keypair: %s", err.Error())
   537  	}
   538  
   539  	// ================= Renew 1 ========================
   540  	csr1, err := GenerateCertificateRenewRequestSameKey(&oldcert)
   541  	if err != nil {
   542  		t.Errorf("Error creating renew CSR: %s", err.Error())
   543  	}
   544  	certClient1Pem, err := caAuth.SignRequest(csr1, clientCert.Raw, nil)
   545  	if err != nil {
   546  		t.Errorf("Error signing CSR: %s", err.Error())
   547  	}
   548  
   549  	certClient1, err := DecodeCertPEM(certClient1Pem)
   550  	if err != nil {
   551  		t.Errorf("Failed Decoding cert: %s", err.Error())
   552  	}
   553  
   554  	// Test certificate duration
   555  	if (certClient1.NotAfter.Sub(certClient1.NotBefore)) != signConf.Offset {
   556  		t.Errorf("Invalid certificate expiry")
   557  	}
   558  
   559  	foundCertDER := false
   560  	foundRenewCount := false
   561  	var origCertDER []byte
   562  	var renewCount int64 = 0
   563  	for _, ext := range certClient1.Extensions {
   564  		if ext.Id.Equal(oidOriginalCertificate) {
   565  			origCertDER = ext.Value
   566  			foundCertDER = true
   567  		} else if ext.Id.Equal(oidRenewCount) {
   568  			asn1.Unmarshal(ext.Value, &renewCount)
   569  			foundRenewCount = true
   570  		}
   571  	}
   572  
   573  	if !(foundRenewCount && foundCertDER) {
   574  		t.Errorf("Not found certDER or renewCount Extensions")
   575  	}
   576  
   577  	if !bytes.Equal(origCertDER, clientCert.Raw) {
   578  		t.Errorf("Extension not Matching old cert")
   579  	}
   580  
   581  	if renewCount != 1 {
   582  		t.Errorf("Extension renew count is wrong")
   583  	}
   584  
   585  	clientCerts = [][]byte{certClient1.Raw}
   586  	if err := caAuth.VerifyClientCertificate(clientCerts); err != nil {
   587  		t.Errorf("failed to verify certificate: " + err.Error())
   588  	}
   589  	if _, err = tls.X509KeyPair(EncodeCertPEM(certClient1), EncodePrivateKeyPEM(keyClient)); err != nil {
   590  		t.Errorf("Error Verifying key and cert: %s", err.Error())
   591  	}
   592  
   593  	// ================= Renew 2 ========================
   594  	oldcert, err = tls.X509KeyPair(EncodeCertPEM(certClient1), EncodePrivateKeyPEM(keyClient))
   595  	if err != nil {
   596  		t.Errorf("Error creating X509 keypair: %s", err.Error())
   597  	}
   598  	csr2, err := GenerateCertificateRenewRequestSameKey(&oldcert)
   599  	if err != nil {
   600  		t.Errorf("Error creating renew CSR: %s", err.Error())
   601  	}
   602  	certClient2Pem, err := caAuth.SignRequest(csr2, certClient1.Raw, nil)
   603  	if err != nil {
   604  		t.Errorf("Error signing CSR: %s", err.Error())
   605  	}
   606  
   607  	certClient2, err := DecodeCertPEM(certClient2Pem)
   608  	if err != nil {
   609  		t.Errorf("Failed Decoding cert: %s", err.Error())
   610  	}
   611  	// Test certificate duration
   612  	if (certClient2.NotAfter.Sub(certClient2.NotBefore)) != signConf.Offset {
   613  		t.Errorf("Invalid certificate expiry")
   614  	}
   615  
   616  	foundCertDER = false
   617  	foundRenewCount = false
   618  	for _, ext := range certClient2.Extensions {
   619  		if ext.Id.Equal(oidOriginalCertificate) {
   620  			origCertDER = ext.Value
   621  			foundCertDER = true
   622  		} else if ext.Id.Equal(oidRenewCount) {
   623  			asn1.Unmarshal(ext.Value, &renewCount)
   624  			foundRenewCount = true
   625  		}
   626  	}
   627  
   628  	if !(foundRenewCount && foundCertDER) {
   629  		t.Errorf("Not found certDER or renewCount Extensions")
   630  	}
   631  
   632  	// The origCertDER should point to the first cert
   633  	if !bytes.Equal(origCertDER, clientCert.Raw) {
   634  		t.Errorf("Extension not Matching old cert")
   635  	}
   636  
   637  	if renewCount != 2 {
   638  		t.Errorf("Extension renew count is wrong")
   639  	}
   640  
   641  	clientCerts = [][]byte{certClient2.Raw}
   642  	if err := caAuth.VerifyClientCertificate(clientCerts); err != nil {
   643  		t.Errorf("failed to verify certificate: " + err.Error())
   644  	}
   645  	if _, err = tls.X509KeyPair(EncodeCertPEM(certClient2), EncodePrivateKeyPEM(keyClient)); err != nil {
   646  		t.Errorf("Error Verifying key and cert: %s", err.Error())
   647  	}
   648  }
   649  
   650  func Test_BackoffFactor(t *testing.T) {
   651  	_, err := NewBackOffFactor(-1.0, 5)
   652  	if err == nil || !errors.IsInvalidInput(err) {
   653  		t.Errorf("Expected Error InvalidInput")
   654  	}
   655  	_, err = NewBackOffFactor(1.0, -5.0)
   656  	if err == nil || !errors.IsInvalidInput(err) {
   657  		t.Errorf("Expected Error InvalidInput")
   658  	}
   659  	_, err = NewBackOffFactor(-1.0, -5.0)
   660  	if err == nil || !errors.IsInvalidInput(err) {
   661  		t.Errorf("Expected Error InvalidInput")
   662  	}
   663  }
   664  
   665  func Test_BackoffFactor1(t *testing.T) {
   666  	factor, err := NewBackOffFactor(1.0, 5)
   667  	if err != nil {
   668  		t.Errorf("Error creating Factor: %s", err.Error())
   669  	}
   670  	if factor.errorBackoffFactor != 5 || factor.renewBackoffFactor != 1 {
   671  		t.Errorf("renewBackoffFactor Expected:1.0 Actual:%f \n errorBackoffFactor Expected:5.0 Actual:%f", factor.renewBackoffFactor, factor.errorBackoffFactor)
   672  	}
   673  }
   674  
   675  func Test_CalculateTime(t *testing.T) {
   676  	factor, err := NewBackOffFactor(0.3, 0.02)
   677  	if err != nil {
   678  		t.Errorf("Error creating Factor: %s", err.Error())
   679  	}
   680  	now := time.Now()
   681  	before := now.Add(time.Duration(time.Second * -10))
   682  	after := now.Add(time.Duration(time.Second * 10))
   683  	duration := calculateTime(before, after, now, factor)
   684  	if duration.RenewBackoffDuration != time.Duration(time.Second*4) {
   685  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*4), duration.RenewBackoffDuration)
   686  	}
   687  	if duration.RenewBackoffDuration < time.Duration(0) {
   688  		t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration)
   689  	}
   690  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*400) {
   691  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   692  	}
   693  }
   694  
   695  func Test_CalculateTime1(t *testing.T) {
   696  	factor, err := NewBackOffFactor(0.1, 0.002)
   697  	if err != nil {
   698  		t.Errorf("Error creating Factor: %s", err.Error())
   699  	}
   700  	now := time.Now()
   701  	before := now.Add(time.Duration(time.Second * -30))
   702  	after := now.Add(time.Duration(time.Second * 10))
   703  	duration := calculateTime(before, after, now, factor)
   704  	if duration.RenewBackoffDuration != time.Duration(time.Second*6) {
   705  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*6), duration.RenewBackoffDuration)
   706  	}
   707  	if duration.RenewBackoffDuration < time.Duration(0) {
   708  		t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration)
   709  	}
   710  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*80) {
   711  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   712  	}
   713  }
   714  
   715  func Test_CalculateTime2(t *testing.T) {
   716  	factor, err := NewBackOffFactor(0.5, 0.002)
   717  	if err != nil {
   718  		t.Errorf("Error creating Factor: %s", err.Error())
   719  	}
   720  	now := time.Now()
   721  	before := now.Add(time.Duration(time.Second * -30))
   722  	after := now.Add(time.Duration(time.Second * 10))
   723  	duration := calculateTime(before, after, now, factor)
   724  	if duration.RenewBackoffDuration != time.Duration(time.Second*-10) {
   725  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*-10), duration.RenewBackoffDuration)
   726  	}
   727  	if duration.RenewBackoffDuration > time.Duration(0) {
   728  		t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration)
   729  	}
   730  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*80) {
   731  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   732  	}
   733  }
   734  
   735  func Test_CalculateTime3(t *testing.T) {
   736  	factor, err := NewBackOffFactor(30.0/100.0, 0.02)
   737  	if err != nil {
   738  		t.Errorf("Error creating Factor: %s", err.Error())
   739  	}
   740  	now := time.Now()
   741  	before := now.Add(time.Minute * -5)
   742  	after := now.Add(time.Duration(time.Minute*10 + time.Second*30))
   743  	duration := calculateTime(before, after, now, factor)
   744  	if duration.RenewBackoffDuration != time.Duration(time.Minute*5+time.Second*51) {
   745  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Minute*5+time.Second*51), duration.RenewBackoffDuration)
   746  	}
   747  	if duration.RenewBackoffDuration < time.Duration(0) {
   748  		t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration)
   749  	}
   750  	if duration.ErrorBackoffDuration != time.Duration(time.Second*18+time.Millisecond*600) {
   751  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   752  	}
   753  }
   754  
   755  func Test_CalculateTimeNegative(t *testing.T) {
   756  	factor, err := NewBackOffFactor(0.3, 0.02)
   757  	if err != nil {
   758  		t.Errorf("Error creating Factor: %s", err.Error())
   759  	}
   760  	now := time.Now()
   761  	before := now.Add(time.Duration(time.Second * -20))
   762  	after := now.Add(time.Duration(time.Second * -10))
   763  	duration := calculateTime(before, after, now, factor)
   764  	if duration.RenewBackoffDuration != time.Duration(time.Second*-13) {
   765  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*-13), duration.RenewBackoffDuration)
   766  	}
   767  	if duration.RenewBackoffDuration > time.Duration(0) {
   768  		t.Errorf("Wrong wait time returned Expected less than zero %s", duration.RenewBackoffDuration)
   769  	}
   770  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*200) {
   771  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*200), duration.ErrorBackoffDuration)
   772  	}
   773  }
   774  
   775  func Test_CalculateTimeAfter(t *testing.T) {
   776  	factor, err := NewBackOffFactor(0.3, 0.02)
   777  	if err != nil {
   778  		t.Errorf("Error creating Factor: %s", err.Error())
   779  	}
   780  	now := time.Now()
   781  	before := now.Add(time.Duration(time.Second * 10))
   782  	after := now.Add(time.Duration(time.Second * 30))
   783  	duration := calculateTime(before, after, now, factor)
   784  	if duration.RenewBackoffDuration != time.Duration(time.Second*24) {
   785  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*24), duration.RenewBackoffDuration)
   786  	}
   787  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*400) {
   788  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   789  	}
   790  }
   791  
   792  func Test_CalculateRenewTime(t *testing.T) {
   793  	factor, err := NewBackOffFactor(0.3, 0.02)
   794  	if err != nil {
   795  		t.Errorf("Error creating Factor: %s", err.Error())
   796  	}
   797  	now := time.Now()
   798  	before := now.Add(time.Duration(time.Second * -10))
   799  	after := now.Add(time.Duration(time.Second * 10))
   800  	cert, err := createTestCertificate(before, after)
   801  	if err != nil {
   802  		t.Errorf("Failed creating certificate: %s", err.Error())
   803  	}
   804  	duration, err := CalculateRenewTime(cert, factor)
   805  	if err != nil {
   806  		t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error())
   807  	}
   808  	if duration.RenewBackoffDuration > time.Duration(time.Second*4) || duration.RenewBackoffDuration < time.Duration(time.Second*1) {
   809  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*4), duration.RenewBackoffDuration)
   810  	}
   811  	if duration.RenewBackoffDuration < time.Duration(0) {
   812  		t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration)
   813  	}
   814  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*400) {
   815  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   816  	}
   817  }
   818  
   819  func Test_CalculateRenewTime1(t *testing.T) {
   820  	factor, err := NewBackOffFactor(0.1, 0.002)
   821  	if err != nil {
   822  		t.Errorf("Error creating Factor: %s", err.Error())
   823  	}
   824  	now := time.Now()
   825  	before := now.Add(time.Duration(time.Second * -30))
   826  	after := now.Add(time.Duration(time.Second * 10))
   827  	cert, err := createTestCertificate(before, after)
   828  	if err != nil {
   829  		t.Errorf("Failed creating certificate: %s", err.Error())
   830  	}
   831  	duration, err := CalculateRenewTime(cert, factor)
   832  	if err != nil {
   833  		t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error())
   834  	}
   835  	if duration.RenewBackoffDuration > time.Duration(time.Second*6) || duration.RenewBackoffDuration < time.Duration(time.Second*3) {
   836  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*6), duration.RenewBackoffDuration)
   837  	}
   838  	if duration.RenewBackoffDuration < time.Duration(0) {
   839  		t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration)
   840  	}
   841  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*80) {
   842  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   843  	}
   844  }
   845  
   846  func Test_CalculateRenewTime2(t *testing.T) {
   847  	factor, err := NewBackOffFactor(0.5, 0.002)
   848  	if err != nil {
   849  		t.Errorf("Error creating Factor: %s", err.Error())
   850  	}
   851  	now := time.Now()
   852  	before := now.Add(time.Duration(time.Second * -30))
   853  	after := now.Add(time.Duration(time.Second * 10))
   854  	cert, err := createTestCertificate(before, after)
   855  	if err != nil {
   856  		t.Errorf("Failed creating certificate: %s", err.Error())
   857  	}
   858  	duration, err := CalculateRenewTime(cert, factor)
   859  	if err != nil {
   860  		t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error())
   861  	}
   862  	if duration.RenewBackoffDuration > time.Duration(time.Second*-10) || duration.RenewBackoffDuration < time.Duration(time.Second*-13) {
   863  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*-10), duration.RenewBackoffDuration)
   864  	}
   865  	if duration.RenewBackoffDuration > time.Duration(0) {
   866  		t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration)
   867  	}
   868  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*80) {
   869  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   870  	}
   871  }
   872  
   873  func Test_CalculateRenewTimeNegative(t *testing.T) {
   874  	factor, err := NewBackOffFactor(0.3, 0.02)
   875  	if err != nil {
   876  		t.Errorf("Error creating Factor: %s", err.Error())
   877  	}
   878  	now := time.Now()
   879  	before := now.Add(time.Duration(time.Second * -20))
   880  	after := now.Add(time.Duration(time.Second * -10))
   881  	cert, err := createTestCertificate(before, after)
   882  	if err != nil {
   883  		t.Errorf("Failed creating certificate: %s", err.Error())
   884  	}
   885  	duration, err := CalculateRenewTime(cert, factor)
   886  	if err != nil {
   887  		t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error())
   888  	}
   889  	if duration.RenewBackoffDuration > time.Duration(time.Second*-13) || duration.RenewBackoffDuration < time.Duration(time.Second*-16) {
   890  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*-13), duration.RenewBackoffDuration)
   891  	}
   892  	if duration.RenewBackoffDuration > time.Duration(0) {
   893  		t.Errorf("Wrong wait time returned Expected less than zero %s", duration.RenewBackoffDuration)
   894  	}
   895  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*200) {
   896  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*200), duration.ErrorBackoffDuration)
   897  	}
   898  }
   899  
   900  func Test_CalculateRenewTimeAfter(t *testing.T) {
   901  	factor, err := NewBackOffFactor(0.3, 0.02)
   902  	if err != nil {
   903  		t.Errorf("Error creating Factor: %s", err.Error())
   904  	}
   905  	now := time.Now()
   906  	before := now.Add(time.Duration(time.Second * 10))
   907  	after := now.Add(time.Duration(time.Second * 30))
   908  	cert, err := createTestCertificate(before, after)
   909  	if err != nil {
   910  		t.Errorf("Failed creating certificate: %s", err.Error())
   911  	}
   912  	duration, err := CalculateRenewTime(cert, factor)
   913  	if err != nil {
   914  		t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error())
   915  	}
   916  	if duration.RenewBackoffDuration < time.Duration(time.Second*22) || duration.RenewBackoffDuration > time.Duration(time.Second*24) {
   917  		t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*24), duration.RenewBackoffDuration)
   918  	}
   919  	if duration.RenewBackoffDuration < time.Duration(0) {
   920  		t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration)
   921  	}
   922  	if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*400) {
   923  		t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration)
   924  	}
   925  }
   926  
   927  func Test_CalculateCertExpiry(t *testing.T) {
   928  	now := time.Now()
   929  	before := now.Add(time.Duration(time.Second * -30))
   930  	after := now.Add(time.Duration(time.Second * 10))
   931  	cert, err := createTestCertificate(before, after)
   932  	if err != nil {
   933  		t.Errorf("Failed creating certificate: %s", err.Error())
   934  	}
   935  	expired, err := IsCertificateExpired(cert)
   936  	if err != nil {
   937  		t.Errorf("Failed finding certificate expired: %s", err.Error())
   938  	}
   939  
   940  	if expired {
   941  		t.Errorf("Certificate expired")
   942  	}
   943  }
   944  
   945  func Test_CalculateCertExpiry1(t *testing.T) {
   946  	now := time.Now()
   947  	before := now.Add(time.Duration(time.Second * -20))
   948  	after := now.Add(time.Duration(time.Second * -10))
   949  	cert, err := createTestCertificate(before, after)
   950  	if err != nil {
   951  		t.Errorf("Failed creating certificate: %s", err.Error())
   952  	}
   953  	expired, err := IsCertificateExpired(cert)
   954  	if err != nil {
   955  		t.Errorf("Failed finding certificate expired: %s", err.Error())
   956  	}
   957  
   958  	if !expired {
   959  		t.Errorf("Certificate not expired")
   960  	}
   961  }
   962  
   963  func Test_Revocation_IsRevoked(t *testing.T) {
   964  	ctrl := gomock.NewController(t)
   965  	defer ctrl.Finish()
   966  
   967  	ca, _, _ := GenerateClientCertificate("test CA")
   968  	m := mock.NewMockRevocation(ctrl)
   969  	m.EXPECT().IsRevoked(ca)
   970  	m.IsRevoked(ca)
   971  }