github.com/hashicorp/vault/sdk@v0.13.0/helper/ocsp/ocsp_test.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package ocsp
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"crypto"
     9  	"crypto/ecdsa"
    10  	"crypto/elliptic"
    11  	"crypto/rand"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"crypto/x509/pkix"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"io/ioutil"
    19  	"math/big"
    20  	"net"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"net/url"
    24  	"sync/atomic"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/hashicorp/go-hclog"
    29  	"github.com/hashicorp/go-retryablehttp"
    30  	lru "github.com/hashicorp/golang-lru"
    31  	"github.com/stretchr/testify/require"
    32  	"golang.org/x/crypto/ocsp"
    33  )
    34  
    35  func TestOCSP(t *testing.T) {
    36  	targetURL := []string{
    37  		"https://sfcdev1.blob.core.windows.net/",
    38  		"https://sfctest0.snowflakecomputing.com/",
    39  		"https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64",
    40  	}
    41  
    42  	conf := VerifyConfig{
    43  		OcspFailureMode: FailOpenFalse,
    44  	}
    45  	c := New(testLogFactory, 10)
    46  	transports := []*http.Transport{
    47  		newInsecureOcspTransport(nil),
    48  		c.NewTransport(&conf),
    49  	}
    50  
    51  	for _, tgt := range targetURL {
    52  		c.ocspResponseCache, _ = lru.New2Q(10)
    53  		for _, tr := range transports {
    54  			c := &http.Client{
    55  				Transport: tr,
    56  				Timeout:   30 * time.Second,
    57  			}
    58  			req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil))
    59  			if err != nil {
    60  				t.Fatalf("fail to create a request. err: %v", err)
    61  			}
    62  			res, err := c.Do(req)
    63  			if err != nil {
    64  				t.Fatalf("failed to GET contents. err: %v", err)
    65  			}
    66  			defer res.Body.Close()
    67  			_, err = ioutil.ReadAll(res.Body)
    68  			if err != nil {
    69  				t.Fatalf("failed to read content body for %v", tgt)
    70  			}
    71  
    72  		}
    73  	}
    74  }
    75  
    76  /**
    77  // Used for development, requires an active Vault with PKI setup
    78  func TestMultiOCSP(t *testing.T) {
    79  
    80  	targetURL := []string{
    81  		"https://localhost:8200/v1/pki/ocsp",
    82  		"https://localhost:8200/v1/pki/ocsp",
    83  		"https://localhost:8200/v1/pki/ocsp",
    84  	}
    85  
    86  	b, _ := pem.Decode([]byte(vaultCert))
    87  	caCert, _ := x509.ParseCertificate(b.Bytes)
    88  	conf := VerifyConfig{
    89  		OcspFailureMode:     FailOpenFalse,
    90  		QueryAllServers:     true,
    91  		OcspServersOverride: targetURL,
    92  		ExtraCas:            []*x509.Certificate{caCert},
    93  	}
    94  	c := New(testLogFactory, 10)
    95  	transports := []*http.Transport{
    96  		newInsecureOcspTransport(conf.ExtraCas),
    97  		c.NewTransport(&conf),
    98  	}
    99  
   100  	tgt := "https://localhost:8200/v1/pki/ca/pem"
   101  	c.ocspResponseCache, _ = lru.New2Q(10)
   102  	for _, tr := range transports {
   103  		c := &http.Client{
   104  			Transport: tr,
   105  			Timeout:   30 * time.Second,
   106  		}
   107  		req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil))
   108  		if err != nil {
   109  			t.Fatalf("fail to create a request. err: %v", err)
   110  		}
   111  		res, err := c.Do(req)
   112  		if err != nil {
   113  			t.Fatalf("failed to GET contents. err: %v", err)
   114  		}
   115  		defer res.Body.Close()
   116  		_, err = ioutil.ReadAll(res.Body)
   117  		if err != nil {
   118  			t.Fatalf("failed to read content body for %v", tgt)
   119  		}
   120  	}
   121  }
   122  */
   123  
   124  func TestUnitEncodeCertIDGood(t *testing.T) {
   125  	targetURLs := []string{
   126  		"faketestaccount.snowflakecomputing.com:443",
   127  		"s3-us-west-2.amazonaws.com:443",
   128  		"sfcdev1.blob.core.windows.net:443",
   129  	}
   130  	for _, tt := range targetURLs {
   131  		chainedCerts := getCert(tt)
   132  		for i := 0; i < len(chainedCerts)-1; i++ {
   133  			subject := chainedCerts[i]
   134  			issuer := chainedCerts[i+1]
   135  			ocspServers := subject.OCSPServer
   136  			if len(ocspServers) == 0 {
   137  				t.Fatalf("no OCSP server is found. cert: %v", subject.Subject)
   138  			}
   139  			ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{})
   140  			if err != nil {
   141  				t.Fatalf("failed to create OCSP request. err: %v", err)
   142  			}
   143  			var ost *ocspStatus
   144  			_, ost = extractCertIDKeyFromRequest(ocspReq)
   145  			if ost.err != nil {
   146  				t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err)
   147  			}
   148  			// better hash. Not sure if the actual OCSP server accepts this, though.
   149  			ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512})
   150  			if err != nil {
   151  				t.Fatalf("failed to create OCSP request. err: %v", err)
   152  			}
   153  			_, ost = extractCertIDKeyFromRequest(ocspReq)
   154  			if ost.err != nil {
   155  				t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err)
   156  			}
   157  			// tweaked request binary
   158  			ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512})
   159  			if err != nil {
   160  				t.Fatalf("failed to create OCSP request. err: %v", err)
   161  			}
   162  			ocspReq[10] = 0 // random change
   163  			_, ost = extractCertIDKeyFromRequest(ocspReq)
   164  			if ost.err == nil {
   165  				t.Fatal("should have failed")
   166  			}
   167  		}
   168  	}
   169  }
   170  
   171  func TestUnitCheckOCSPResponseCache(t *testing.T) {
   172  	conf := &VerifyConfig{OcspEnabled: true}
   173  	c := New(testLogFactory, 10)
   174  	dummyKey0 := certIDKey{
   175  		NameHash:      "dummy0",
   176  		IssuerKeyHash: "dummy0",
   177  		SerialNumber:  "dummy0",
   178  	}
   179  	dummyKey := certIDKey{
   180  		NameHash:      "dummy1",
   181  		IssuerKeyHash: "dummy1",
   182  		SerialNumber:  "dummy1",
   183  	}
   184  	currentTime := float64(time.Now().UTC().Unix())
   185  	c.ocspResponseCache.Add(dummyKey0, &ocspCachedResponse{time: currentTime})
   186  	subject := &x509.Certificate{}
   187  	issuer := &x509.Certificate{}
   188  	ost, err := c.checkOCSPResponseCache(&dummyKey, subject, issuer, conf)
   189  	if err != nil {
   190  		t.Fatal(err)
   191  	}
   192  	if ost.code != ocspMissedCache {
   193  		t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code)
   194  	}
   195  	// old timestamp
   196  	c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(1395054952)})
   197  	ost, err = c.checkOCSPResponseCache(&dummyKey, subject, issuer, conf)
   198  	if err != nil {
   199  		t.Fatal(err)
   200  	}
   201  	if ost.code != ocspCacheExpired {
   202  		t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code)
   203  	}
   204  
   205  	// invalid validity
   206  	c.ocspResponseCache.Add(dummyKey, &ocspCachedResponse{time: float64(currentTime - 1000)})
   207  	ost, err = c.checkOCSPResponseCache(&dummyKey, subject, nil, conf)
   208  	if err == nil && isValidOCSPStatus(ost.code) {
   209  		t.Fatalf("should have failed.")
   210  	}
   211  }
   212  
   213  // TestUnitValidOCSPResponse validates various combinations of acceptable OCSP responses
   214  func TestUnitValidOCSPResponse(t *testing.T) {
   215  	rootCaKey, rootCa, leafCert := createCaLeafCerts(t)
   216  
   217  	type tests struct {
   218  		name           string
   219  		ocspRes        ocsp.Response
   220  		expectedStatus ocspStatusCode
   221  	}
   222  
   223  	now := time.Now()
   224  	ctx := context.Background()
   225  
   226  	tt := []tests{
   227  		{
   228  			name: "normal",
   229  			ocspRes: ocsp.Response{
   230  				SerialNumber: leafCert.SerialNumber,
   231  				ThisUpdate:   now.Add(-1 * time.Hour),
   232  				NextUpdate:   now.Add(30 * time.Minute),
   233  				Status:       ocsp.Good,
   234  			},
   235  			expectedStatus: ocspStatusGood,
   236  		},
   237  		{
   238  			name: "no-next-update",
   239  			ocspRes: ocsp.Response{
   240  				SerialNumber: leafCert.SerialNumber,
   241  				ThisUpdate:   now.Add(-1 * time.Hour),
   242  				Status:       ocsp.Good,
   243  			},
   244  			expectedStatus: ocspStatusGood,
   245  		},
   246  		{
   247  			name: "revoked-update",
   248  			ocspRes: ocsp.Response{
   249  				SerialNumber: leafCert.SerialNumber,
   250  				ThisUpdate:   now.Add(-1 * time.Hour),
   251  				Status:       ocsp.Revoked,
   252  			},
   253  			expectedStatus: ocspStatusRevoked,
   254  		},
   255  		{
   256  			name: "revoked-update-with-next-update",
   257  			ocspRes: ocsp.Response{
   258  				SerialNumber: leafCert.SerialNumber,
   259  				ThisUpdate:   now.Add(-1 * time.Hour),
   260  				NextUpdate:   now.Add(1 * time.Hour),
   261  				Status:       ocsp.Revoked,
   262  			},
   263  			expectedStatus: ocspStatusRevoked,
   264  		},
   265  	}
   266  	for _, tc := range tt {
   267  		for _, maxAge := range []time.Duration{time.Duration(0), time.Duration(2 * time.Hour)} {
   268  			t.Run(tc.name+"-max-age-"+maxAge.String(), func(t *testing.T) {
   269  				ocspHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   270  					response := buildOcspResponse(t, rootCa, rootCaKey, tc.ocspRes)
   271  					_, _ = w.Write(response)
   272  				})
   273  				ts := httptest.NewServer(ocspHandler)
   274  				defer ts.Close()
   275  
   276  				logFactory := func() hclog.Logger {
   277  					return hclog.NewNullLogger()
   278  				}
   279  				client := New(logFactory, 100)
   280  				config := &VerifyConfig{
   281  					OcspEnabled:          true,
   282  					OcspServersOverride:  []string{ts.URL},
   283  					OcspFailureMode:      FailOpenFalse,
   284  					QueryAllServers:      false,
   285  					OcspThisUpdateMaxAge: maxAge,
   286  				}
   287  
   288  				status, err := client.GetRevocationStatus(ctx, leafCert, rootCa, config)
   289  				require.NoError(t, err, "ocsp response should have been considered valid")
   290  				require.NoError(t, status.err, "ocsp status should not contain an error")
   291  				require.Equal(t, &ocspStatus{code: tc.expectedStatus}, status)
   292  			})
   293  		}
   294  	}
   295  }
   296  
   297  // TestUnitBadOCSPResponses verifies that we fail properly on a bunch of different
   298  // OCSP response conditions
   299  func TestUnitBadOCSPResponses(t *testing.T) {
   300  	rootCaKey, rootCa, leafCert := createCaLeafCerts(t)
   301  	rootCaKey2, rootCa2, _ := createCaLeafCerts(t)
   302  
   303  	type tests struct {
   304  		name        string
   305  		ocspRes     ocsp.Response
   306  		maxAge      time.Duration
   307  		ca          *x509.Certificate
   308  		caKey       *ecdsa.PrivateKey
   309  		errContains string
   310  	}
   311  
   312  	now := time.Now()
   313  	ctx := context.Background()
   314  
   315  	tt := []tests{
   316  		{
   317  			name: "bad-signing-issuer",
   318  			ocspRes: ocsp.Response{
   319  				SerialNumber: leafCert.SerialNumber,
   320  				ThisUpdate:   now.Add(-1 * time.Hour),
   321  				NextUpdate:   now.Add(30 * time.Minute),
   322  				Status:       ocsp.Good,
   323  			},
   324  			ca:          rootCa2,
   325  			caKey:       rootCaKey2,
   326  			errContains: "error directly verifying signature",
   327  		},
   328  		{
   329  			name: "incorrect-serial-number",
   330  			ocspRes: ocsp.Response{
   331  				SerialNumber: big.NewInt(1000),
   332  				ThisUpdate:   now.Add(-1 * time.Hour),
   333  				NextUpdate:   now.Add(30 * time.Minute),
   334  				Status:       ocsp.Good,
   335  			},
   336  			ca:          rootCa,
   337  			caKey:       rootCaKey,
   338  			errContains: "did not match the leaf certificate serial number",
   339  		},
   340  		{
   341  			name: "expired-next-update",
   342  			ocspRes: ocsp.Response{
   343  				SerialNumber: leafCert.SerialNumber,
   344  				ThisUpdate:   now.Add(-1 * time.Hour),
   345  				NextUpdate:   now.Add(-30 * time.Minute),
   346  				Status:       ocsp.Good,
   347  			},
   348  			errContains: "invalid validity",
   349  		},
   350  		{
   351  			name: "this-update-in-future",
   352  			ocspRes: ocsp.Response{
   353  				SerialNumber: leafCert.SerialNumber,
   354  				ThisUpdate:   now.Add(1 * time.Hour),
   355  				NextUpdate:   now.Add(2 * time.Hour),
   356  				Status:       ocsp.Good,
   357  			},
   358  			errContains: "invalid validity",
   359  		},
   360  		{
   361  			name: "next-update-before-this-update",
   362  			ocspRes: ocsp.Response{
   363  				SerialNumber: leafCert.SerialNumber,
   364  				ThisUpdate:   now.Add(-1 * time.Hour),
   365  				NextUpdate:   now.Add(-2 * time.Hour),
   366  				Status:       ocsp.Good,
   367  			},
   368  			errContains: "invalid validity",
   369  		},
   370  		{
   371  			name: "missing-this-update",
   372  			ocspRes: ocsp.Response{
   373  				SerialNumber: leafCert.SerialNumber,
   374  				NextUpdate:   now.Add(2 * time.Hour),
   375  				Status:       ocsp.Good,
   376  			},
   377  			errContains: "invalid validity",
   378  		},
   379  		{
   380  			name: "unknown-status",
   381  			ocspRes: ocsp.Response{
   382  				SerialNumber: leafCert.SerialNumber,
   383  				ThisUpdate:   now.Add(-1 * time.Hour),
   384  				NextUpdate:   now.Add(30 * time.Minute),
   385  				Status:       ocsp.Unknown,
   386  			},
   387  			errContains: "OCSP status unknown",
   388  		},
   389  		{
   390  			name: "over-max-age",
   391  			ocspRes: ocsp.Response{
   392  				SerialNumber: leafCert.SerialNumber,
   393  				ThisUpdate:   now.Add(-1 * time.Hour),
   394  				NextUpdate:   now.Add(30 * time.Minute),
   395  				Status:       ocsp.Good,
   396  			},
   397  			maxAge:      10 * time.Minute,
   398  			errContains: "is greater than max age",
   399  		},
   400  	}
   401  	for _, tc := range tt {
   402  		t.Run(tc.name, func(t *testing.T) {
   403  			ocspHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   404  				useCa := rootCa
   405  				useCaKey := rootCaKey
   406  				if tc.ca != nil {
   407  					useCa = tc.ca
   408  				}
   409  				if tc.caKey != nil {
   410  					useCaKey = tc.caKey
   411  				}
   412  				response := buildOcspResponse(t, useCa, useCaKey, tc.ocspRes)
   413  				_, _ = w.Write(response)
   414  			})
   415  			ts := httptest.NewServer(ocspHandler)
   416  			defer ts.Close()
   417  
   418  			logFactory := func() hclog.Logger {
   419  				return hclog.NewNullLogger()
   420  			}
   421  			client := New(logFactory, 100)
   422  
   423  			config := &VerifyConfig{
   424  				OcspEnabled:          true,
   425  				OcspServersOverride:  []string{ts.URL},
   426  				OcspFailureMode:      FailOpenFalse,
   427  				QueryAllServers:      false,
   428  				OcspThisUpdateMaxAge: tc.maxAge,
   429  			}
   430  
   431  			status, err := client.GetRevocationStatus(ctx, leafCert, rootCa, config)
   432  			if err == nil && status == nil || (status != nil && status.err == nil) {
   433  				t.Fatalf("expected an error got none")
   434  			}
   435  			if err != nil {
   436  				require.ErrorContains(t, err, tc.errContains,
   437  					"Expected error got response: %v, %v", status, err)
   438  			}
   439  			if status != nil && status.err != nil {
   440  				require.ErrorContains(t, status.err, tc.errContains,
   441  					"Expected error got response: %v, %v", status, err)
   442  			}
   443  		})
   444  	}
   445  }
   446  
   447  // TestUnitZeroNextUpdateAreNotCached verifies that we are not caching the responses
   448  // with no NextUpdate field set as according to RFC6960 4.2.2.1
   449  // "If nextUpdate is not set, the responder is indicating that newer
   450  // revocation information is available all the time."
   451  func TestUnitZeroNextUpdateAreNotCached(t *testing.T) {
   452  	rootCaKey, rootCa, leafCert := createCaLeafCerts(t)
   453  	numQueries := &atomic.Uint32{}
   454  	ocspHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   455  		numQueries.Add(1)
   456  		now := time.Now()
   457  		ocspRes := ocsp.Response{
   458  			SerialNumber: leafCert.SerialNumber,
   459  			ThisUpdate:   now.Add(-1 * time.Hour),
   460  			Status:       ocsp.Good,
   461  		}
   462  		response := buildOcspResponse(t, rootCa, rootCaKey, ocspRes)
   463  		_, _ = w.Write(response)
   464  	})
   465  	ts := httptest.NewServer(ocspHandler)
   466  	defer ts.Close()
   467  
   468  	logFactory := func() hclog.Logger {
   469  		return hclog.NewNullLogger()
   470  	}
   471  	client := New(logFactory, 100)
   472  
   473  	config := &VerifyConfig{
   474  		OcspEnabled:         true,
   475  		OcspServersOverride: []string{ts.URL},
   476  	}
   477  
   478  	_, err := client.GetRevocationStatus(context.Background(), leafCert, rootCa, config)
   479  	require.NoError(t, err, "Failed fetching revocation status")
   480  
   481  	_, err = client.GetRevocationStatus(context.Background(), leafCert, rootCa, config)
   482  	require.NoError(t, err, "Failed fetching revocation status second time")
   483  
   484  	require.Equal(t, uint32(2), numQueries.Load())
   485  }
   486  
   487  // TestUnitResponsesAreCached verify that the OCSP responses are properly cached when
   488  // querying for the same leaf certificates
   489  func TestUnitResponsesAreCached(t *testing.T) {
   490  	rootCaKey, rootCa, leafCert := createCaLeafCerts(t)
   491  	numQueries := &atomic.Uint32{}
   492  	ocspHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   493  		numQueries.Add(1)
   494  		now := time.Now()
   495  		ocspRes := ocsp.Response{
   496  			SerialNumber: leafCert.SerialNumber,
   497  			ThisUpdate:   now.Add(-1 * time.Hour),
   498  			NextUpdate:   now.Add(1 * time.Hour),
   499  			Status:       ocsp.Good,
   500  		}
   501  		response := buildOcspResponse(t, rootCa, rootCaKey, ocspRes)
   502  		_, _ = w.Write(response)
   503  	})
   504  	ts1 := httptest.NewServer(ocspHandler)
   505  	ts2 := httptest.NewServer(ocspHandler)
   506  	defer ts1.Close()
   507  	defer ts2.Close()
   508  
   509  	logFactory := func() hclog.Logger {
   510  		return hclog.NewNullLogger()
   511  	}
   512  	client := New(logFactory, 100)
   513  
   514  	config := &VerifyConfig{
   515  		OcspEnabled:         true,
   516  		OcspServersOverride: []string{ts1.URL, ts2.URL},
   517  		QueryAllServers:     true,
   518  	}
   519  
   520  	_, err := client.GetRevocationStatus(context.Background(), leafCert, rootCa, config)
   521  	require.NoError(t, err, "Failed fetching revocation status")
   522  	// Make sure that we queried both servers and not the cache
   523  	require.Equal(t, uint32(2), numQueries.Load())
   524  
   525  	// These query should be cached and not influence our counter
   526  	_, err = client.GetRevocationStatus(context.Background(), leafCert, rootCa, config)
   527  	require.NoError(t, err, "Failed fetching revocation status second time")
   528  
   529  	require.Equal(t, uint32(2), numQueries.Load())
   530  }
   531  
   532  func buildOcspResponse(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, ocspRes ocsp.Response) []byte {
   533  	response, err := ocsp.CreateResponse(ca, ca, ocspRes, caKey)
   534  	if err != nil {
   535  		t.Fatalf("failed generating OCSP response: %v", err)
   536  	}
   537  	return response
   538  }
   539  
   540  func createCaLeafCerts(t *testing.T) (*ecdsa.PrivateKey, *x509.Certificate, *x509.Certificate) {
   541  	rootCaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   542  	require.NoError(t, err, "failed generated root key for CA")
   543  
   544  	// Validate we reject CSRs that contain CN that aren't in the original order
   545  	cr := &x509.Certificate{
   546  		Subject:               pkix.Name{CommonName: "Root Cert"},
   547  		SerialNumber:          big.NewInt(1),
   548  		IsCA:                  true,
   549  		BasicConstraintsValid: true,
   550  		SignatureAlgorithm:    x509.ECDSAWithSHA256,
   551  		NotBefore:             time.Now().Add(-1 * time.Second),
   552  		NotAfter:              time.Now().AddDate(1, 0, 0),
   553  		KeyUsage:              x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
   554  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageOCSPSigning},
   555  	}
   556  	rootCaBytes, err := x509.CreateCertificate(rand.Reader, cr, cr, &rootCaKey.PublicKey, rootCaKey)
   557  	require.NoError(t, err, "failed generating root ca")
   558  
   559  	rootCa, err := x509.ParseCertificate(rootCaBytes)
   560  	require.NoError(t, err, "failed parsing root ca")
   561  
   562  	leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   563  	require.NoError(t, err, "failed generated leaf key")
   564  
   565  	cr = &x509.Certificate{
   566  		Subject:            pkix.Name{CommonName: "Leaf Cert"},
   567  		SerialNumber:       big.NewInt(2),
   568  		SignatureAlgorithm: x509.ECDSAWithSHA256,
   569  		NotBefore:          time.Now().Add(-1 * time.Second),
   570  		NotAfter:           time.Now().AddDate(1, 0, 0),
   571  		KeyUsage:           x509.KeyUsageDigitalSignature,
   572  		ExtKeyUsage:        []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   573  	}
   574  	leafCertBytes, err := x509.CreateCertificate(rand.Reader, cr, rootCa, &leafKey.PublicKey, rootCaKey)
   575  	require.NoError(t, err, "failed generating root ca")
   576  
   577  	leafCert, err := x509.ParseCertificate(leafCertBytes)
   578  	require.NoError(t, err, "failed parsing root ca")
   579  	return rootCaKey, rootCa, leafCert
   580  }
   581  
   582  func TestUnitValidateOCSP(t *testing.T) {
   583  	conf := &VerifyConfig{OcspEnabled: true}
   584  	ocspRes := &ocsp.Response{}
   585  	ost, err := validateOCSP(conf, ocspRes)
   586  	if err == nil && isValidOCSPStatus(ost.code) {
   587  		t.Fatalf("should have failed.")
   588  	}
   589  
   590  	currentTime := time.Now()
   591  	ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour)
   592  	ocspRes.NextUpdate = currentTime.Add(2 * time.Hour)
   593  	ocspRes.Status = ocsp.Revoked
   594  	ost, err = validateOCSP(conf, ocspRes)
   595  	if err != nil {
   596  		t.Fatal(err)
   597  	}
   598  
   599  	if ost.code != ocspStatusRevoked {
   600  		t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code)
   601  	}
   602  	ocspRes.Status = ocsp.Good
   603  	ost, err = validateOCSP(conf, ocspRes)
   604  	if err != nil {
   605  		t.Fatal(err)
   606  	}
   607  
   608  	if ost.code != ocspStatusGood {
   609  		t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code)
   610  	}
   611  	ocspRes.Status = ocsp.Unknown
   612  	ost, err = validateOCSP(conf, ocspRes)
   613  	if err != nil {
   614  		t.Fatal(err)
   615  	}
   616  	if ost.code != ocspStatusUnknown {
   617  		t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code)
   618  	}
   619  	ocspRes.Status = ocsp.ServerFailed
   620  	ost, err = validateOCSP(conf, ocspRes)
   621  	if err != nil {
   622  		t.Fatal(err)
   623  	}
   624  	if ost.code != ocspStatusOthers {
   625  		t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code)
   626  	}
   627  }
   628  
   629  func TestUnitEncodeCertID(t *testing.T) {
   630  	var st *ocspStatus
   631  	_, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2})
   632  	if st.code != ocspFailedDecomposeRequest {
   633  		t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code)
   634  	}
   635  }
   636  
   637  func getCert(addr string) []*x509.Certificate {
   638  	tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second)
   639  	if err != nil {
   640  		panic(err)
   641  	}
   642  	defer tcpConn.Close()
   643  
   644  	err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second))
   645  	if err != nil {
   646  		panic(err)
   647  	}
   648  	config := tls.Config{InsecureSkipVerify: true, ServerName: addr}
   649  
   650  	conn := tls.Client(tcpConn, &config)
   651  	defer conn.Close()
   652  
   653  	err = conn.Handshake()
   654  	if err != nil {
   655  		panic(err)
   656  	}
   657  
   658  	state := conn.ConnectionState()
   659  
   660  	return state.PeerCertificates
   661  }
   662  
   663  func TestOCSPRetry(t *testing.T) {
   664  	c := New(testLogFactory, 10)
   665  	certs := getCert("s3-us-west-2.amazonaws.com:443")
   666  	dummyOCSPHost := &url.URL{
   667  		Scheme: "https",
   668  		Host:   "dummyOCSPHost",
   669  	}
   670  	client := &fakeHTTPClient{
   671  		cnt:     3,
   672  		success: true,
   673  		body:    []byte{1, 2, 3},
   674  		logger:  hclog.New(hclog.DefaultOptions),
   675  		t:       t,
   676  	}
   677  	res, b, st, err := c.retryOCSP(
   678  		context.TODO(),
   679  		client, fakeRequestFunc,
   680  		dummyOCSPHost,
   681  		make(map[string]string), []byte{0}, certs[0], certs[len(certs)-1])
   682  	if err == nil {
   683  		fmt.Printf("should fail: %v, %v, %v\n", res, b, st)
   684  	}
   685  	client = &fakeHTTPClient{
   686  		cnt:     30,
   687  		success: true,
   688  		body:    []byte{1, 2, 3},
   689  		logger:  hclog.New(hclog.DefaultOptions),
   690  		t:       t,
   691  	}
   692  	res, b, st, err = c.retryOCSP(
   693  		context.TODO(),
   694  		client, fakeRequestFunc,
   695  		dummyOCSPHost,
   696  		make(map[string]string), []byte{0}, certs[0], certs[len(certs)-1])
   697  	if err == nil {
   698  		fmt.Printf("should fail: %v, %v, %v\n", res, b, st)
   699  	}
   700  }
   701  
   702  type tcCanEarlyExit struct {
   703  	results       []*ocspStatus
   704  	resultLen     int
   705  	retFailOpen   *ocspStatus
   706  	retFailClosed *ocspStatus
   707  }
   708  
   709  func TestCanEarlyExitForOCSP(t *testing.T) {
   710  	testcases := []tcCanEarlyExit{
   711  		{ // 0
   712  			results: []*ocspStatus{
   713  				{
   714  					code: ocspStatusGood,
   715  				},
   716  				{
   717  					code: ocspStatusGood,
   718  				},
   719  				{
   720  					code: ocspStatusGood,
   721  				},
   722  			},
   723  			retFailOpen:   nil,
   724  			retFailClosed: nil,
   725  		},
   726  		{ // 1
   727  			results: []*ocspStatus{
   728  				{
   729  					code: ocspStatusRevoked,
   730  					err:  errors.New("revoked"),
   731  				},
   732  				{
   733  					code: ocspStatusGood,
   734  				},
   735  				{
   736  					code: ocspStatusGood,
   737  				},
   738  			},
   739  			retFailOpen:   &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
   740  			retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
   741  		},
   742  		{ // 2
   743  			results: []*ocspStatus{
   744  				{
   745  					code: ocspStatusUnknown,
   746  					err:  errors.New("unknown"),
   747  				},
   748  				{
   749  					code: ocspStatusGood,
   750  				},
   751  				{
   752  					code: ocspStatusGood,
   753  				},
   754  			},
   755  			retFailOpen:   nil,
   756  			retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")},
   757  		},
   758  		{ // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included.
   759  			results: []*ocspStatus{
   760  				{
   761  					code: ocspStatusRevoked,
   762  					err:  errors.New("revoked"),
   763  				},
   764  				{
   765  					code: ocspInvalidValidity,
   766  				},
   767  				{
   768  					code: ocspStatusGood,
   769  				},
   770  			},
   771  			retFailOpen:   nil,
   772  			retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
   773  		},
   774  		{ // 4: not taken as revoked if the number of results don't match the expected results.
   775  			results: []*ocspStatus{
   776  				{
   777  					code: ocspStatusRevoked,
   778  					err:  errors.New("revoked"),
   779  				},
   780  				{
   781  					code: ocspStatusGood,
   782  				},
   783  			},
   784  			resultLen:     3,
   785  			retFailOpen:   nil,
   786  			retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
   787  		},
   788  	}
   789  	c := New(testLogFactory, 10)
   790  	for idx, tt := range testcases {
   791  		expectedLen := len(tt.results)
   792  		if tt.resultLen > 0 {
   793  			expectedLen = tt.resultLen
   794  		}
   795  		r := c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenTrue})
   796  		if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) {
   797  			t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r)
   798  		}
   799  		r = c.canEarlyExitForOCSP(tt.results, expectedLen, &VerifyConfig{OcspFailureMode: FailOpenFalse})
   800  		if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) {
   801  			t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r)
   802  		}
   803  	}
   804  }
   805  
   806  var testLogger = hclog.New(hclog.DefaultOptions)
   807  
   808  func testLogFactory() hclog.Logger {
   809  	return testLogger
   810  }
   811  
   812  type fakeHTTPClient struct {
   813  	cnt        int    // number of retry
   814  	success    bool   // return success after retry in cnt times
   815  	timeout    bool   // timeout
   816  	body       []byte // return body
   817  	t          *testing.T
   818  	logger     hclog.Logger
   819  	redirected bool
   820  }
   821  
   822  func (c *fakeHTTPClient) Do(_ *retryablehttp.Request) (*http.Response, error) {
   823  	c.cnt--
   824  	if c.cnt < 0 {
   825  		c.cnt = 0
   826  	}
   827  	c.t.Log("fakeHTTPClient.cnt", c.cnt)
   828  
   829  	var retcode int
   830  	if !c.redirected {
   831  		c.redirected = true
   832  		c.cnt++
   833  		retcode = 405
   834  	} else if c.success && c.cnt == 1 {
   835  		retcode = 200
   836  	} else {
   837  		if c.timeout {
   838  			// simulate timeout
   839  			time.Sleep(time.Second * 1)
   840  			return nil, &fakeHTTPError{
   841  				err:     "Whatever reason (Client.Timeout exceeded while awaiting headers)",
   842  				timeout: true,
   843  			}
   844  		}
   845  		retcode = 0
   846  	}
   847  
   848  	ret := &http.Response{
   849  		StatusCode: retcode,
   850  		Body:       &fakeResponseBody{body: c.body},
   851  	}
   852  	return ret, nil
   853  }
   854  
   855  type fakeHTTPError struct {
   856  	err     string
   857  	timeout bool
   858  }
   859  
   860  func (e *fakeHTTPError) Error() string   { return e.err }
   861  func (e *fakeHTTPError) Timeout() bool   { return e.timeout }
   862  func (e *fakeHTTPError) Temporary() bool { return true }
   863  
   864  type fakeResponseBody struct {
   865  	body []byte
   866  	cnt  int
   867  }
   868  
   869  func (b *fakeResponseBody) Read(p []byte) (n int, err error) {
   870  	if b.cnt == 0 {
   871  		copy(p, b.body)
   872  		b.cnt = 1
   873  		return len(b.body), nil
   874  	}
   875  	b.cnt = 0
   876  	return 0, io.EOF
   877  }
   878  
   879  func (b *fakeResponseBody) Close() error {
   880  	return nil
   881  }
   882  
   883  func fakeRequestFunc(_, _ string, _ interface{}) (*retryablehttp.Request, error) {
   884  	return nil, nil
   885  }
   886  
   887  const vaultCert = `-----BEGIN CERTIFICATE-----
   888  MIIDuTCCAqGgAwIBAgIUA6VeVD1IB5rXcCZRAqPO4zr/GAMwDQYJKoZIhvcNAQEL
   889  BQAwcjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAlZBMREwDwYDVQQHDAhTb21lQ2l0
   890  eTESMBAGA1UECgwJTXlDb21wYW55MRMwEQYDVQQLDApNeURpdmlzaW9uMRowGAYD
   891  VQQDDBF3d3cuY29uaHVnZWNvLmNvbTAeFw0yMjA5MDcxOTA1MzdaFw0yNDA5MDYx
   892  OTA1MzdaMHIxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJWQTERMA8GA1UEBwwIU29t
   893  ZUNpdHkxEjAQBgNVBAoMCU15Q29tcGFueTETMBEGA1UECwwKTXlEaXZpc2lvbjEa
   894  MBgGA1UEAwwRd3d3LmNvbmh1Z2Vjby5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB
   895  DwAwggEKAoIBAQDL9qzEXi4PIafSAqfcwcmjujFvbG1QZbI8swxnD+w8i4ufAQU5
   896  LDmvMrGo3ZbhJ0mCihYmFxpjhRdP2raJQ9TysHlPXHtDRpr9ckWTKBz2oIfqVtJ2
   897  qzteQkWCkDAO7kPqzgCFsMeoMZeONRkeGib0lEzQAbW/Rqnphg8zVVkyQ71DZ7Pc
   898  d5WkC2E28kKcSramhWfVFpxG3hSIrLOX2esEXteLRzKxFPf+gi413JZFKYIWrebP
   899  u5t0++MLNpuX322geoki4BWMjQsd47XILmxZ4aj33ScZvdrZESCnwP76hKIxg9mO
   900  lMxrqSWKVV5jHZrElSEj9LYJgDO1Y6eItn7hAgMBAAGjRzBFMAsGA1UdDwQEAwIE
   901  MDATBgNVHSUEDDAKBggrBgEFBQcDATAhBgNVHREEGjAYggtleGFtcGxlLmNvbYIJ
   902  bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQA5dPdf5SdtMwe2uSspO/EuWqbM
   903  497vMQBW1Ey8KRKasJjhvOVYMbe7De5YsnW4bn8u5pl0zQGF4hEtpmifAtVvziH/
   904  K+ritQj9VVNbLLCbFcg+b0kfjt4yrDZ64vWvIeCgPjG1Kme8gdUUWgu9dOud5gdx
   905  qg/tIFv4TRS/eIIymMlfd9owOD3Ig6S5fy4NaAJFAwXf8+3Rzuc+e7JSAPgAufjh
   906  tOTWinxvoiOLuYwo9CyGgq4qKBFsrY0aE0gdA7oTQkpbEbo2EbqiWUl/PTCl1Y4Z
   907  nSZ0n+4q9QC9RLrWwYTwh838d5RVLUst2mBKSA+vn7YkqmBJbdBC6nkd7n7H
   908  -----END CERTIFICATE-----
   909  `