github.com/snowflakedb/gosnowflake@v1.9.0/ocsp_test.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"crypto"
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"encoding/base64"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"net"
    16  	"net/http"
    17  	"net/url"
    18  	"os"
    19  	"testing"
    20  	"time"
    21  
    22  	"golang.org/x/crypto/ocsp"
    23  )
    24  
    25  func TestOCSP(t *testing.T) {
    26  	cacheServerEnabled := []string{
    27  		"true",
    28  		"false",
    29  	}
    30  	targetURL := []string{
    31  		"https://sfctest0.snowflakecomputing.com/",
    32  		"https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64",
    33  		"https://sfcdev2.blob.core.windows.net/",
    34  	}
    35  
    36  	transports := []*http.Transport{
    37  		snowflakeInsecureTransport,
    38  		SnowflakeTransport,
    39  	}
    40  
    41  	for _, enabled := range cacheServerEnabled {
    42  		for _, tgt := range targetURL {
    43  			_ = os.Setenv(cacheServerEnabledEnv, enabled)
    44  			_ = os.Remove(cacheFileName) // clear cache file
    45  			ocspResponseCache = make(map[certIDKey]*certCacheValue)
    46  			for _, tr := range transports {
    47  				t.Run(fmt.Sprintf("%v_%v", tgt, enabled), func(t *testing.T) {
    48  					c := &http.Client{
    49  						Transport: tr,
    50  						Timeout:   30 * time.Second,
    51  					}
    52  					req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil))
    53  					if err != nil {
    54  						t.Fatalf("fail to create a request. err: %v", err)
    55  					}
    56  					res, err := c.Do(req)
    57  					if err != nil {
    58  						t.Fatalf("failed to GET contents. err: %v", err)
    59  					}
    60  					defer res.Body.Close()
    61  					_, err = io.ReadAll(res.Body)
    62  					if err != nil {
    63  						t.Fatalf("failed to read content body for %v", tgt)
    64  					}
    65  				})
    66  			}
    67  		}
    68  	}
    69  	_ = os.Unsetenv(cacheServerEnabledEnv)
    70  }
    71  
    72  type tcValidityRange struct {
    73  	thisTime time.Time
    74  	nextTime time.Time
    75  	ret      bool
    76  }
    77  
    78  func TestUnitIsInValidityRange(t *testing.T) {
    79  	currentTime := time.Now()
    80  	testcases := []tcValidityRange{
    81  		{
    82  			// basic tests
    83  			thisTime: currentTime.Add(-100 * time.Second),
    84  			nextTime: currentTime.Add(maxClockSkew),
    85  			ret:      true,
    86  		},
    87  		{
    88  			// on the border
    89  			thisTime: currentTime.Add(maxClockSkew),
    90  			nextTime: currentTime.Add(maxClockSkew),
    91  			ret:      true,
    92  		},
    93  		{
    94  			// 1 earlier late
    95  			thisTime: currentTime.Add(maxClockSkew + 1*time.Second),
    96  			nextTime: currentTime.Add(maxClockSkew),
    97  			ret:      false,
    98  		},
    99  		{
   100  			// on the border
   101  			thisTime: currentTime.Add(-maxClockSkew),
   102  			nextTime: currentTime.Add(-maxClockSkew),
   103  			ret:      true,
   104  		},
   105  		{
   106  			// around the border
   107  			thisTime: currentTime.Add(-24*time.Hour - 40*time.Second),
   108  			nextTime: currentTime.Add(-24*time.Hour/time.Duration(100) - 40*time.Second),
   109  			ret:      false,
   110  		},
   111  		{
   112  			// on the border
   113  			thisTime: currentTime.Add(-48*time.Hour - 29*time.Minute),
   114  			nextTime: currentTime.Add(-48 * time.Hour / time.Duration(100)),
   115  			ret:      true,
   116  		},
   117  	}
   118  	for _, tc := range testcases {
   119  		t.Run(fmt.Sprintf("%v_%v", tc.thisTime, tc.nextTime), func(t *testing.T) {
   120  			if tc.ret != isInValidityRange(currentTime, tc.thisTime, tc.nextTime) {
   121  				t.Fatalf("failed to check validity. should be: %v, currentTime: %v, thisTime: %v, nextTime: %v", tc.ret, currentTime, tc.thisTime, tc.nextTime)
   122  			}
   123  		})
   124  	}
   125  }
   126  
   127  func TestUnitEncodeCertIDGood(t *testing.T) {
   128  	targetURLs := []string{
   129  		"faketestaccount.snowflakecomputing.com:443",
   130  		"s3-us-west-2.amazonaws.com:443",
   131  		"sfcdev2.blob.core.windows.net:443",
   132  	}
   133  	for _, tt := range targetURLs {
   134  		t.Run(tt, func(t *testing.T) {
   135  			chainedCerts := getCert(tt)
   136  			for i := 0; i < len(chainedCerts)-1; i++ {
   137  				subject := chainedCerts[i]
   138  				issuer := chainedCerts[i+1]
   139  				ocspServers := subject.OCSPServer
   140  				if len(ocspServers) == 0 {
   141  					t.Fatalf("no OCSP server is found. cert: %v", subject.Subject)
   142  				}
   143  				ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{})
   144  				if err != nil {
   145  					t.Fatalf("failed to create OCSP request. err: %v", err)
   146  				}
   147  				var ost *ocspStatus
   148  				_, ost = extractCertIDKeyFromRequest(ocspReq)
   149  				if ost.err != nil {
   150  					t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err)
   151  				}
   152  				// better hash. Not sure if the actual OCSP server accepts this, though.
   153  				ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512})
   154  				if err != nil {
   155  					t.Fatalf("failed to create OCSP request. err: %v", err)
   156  				}
   157  				_, ost = extractCertIDKeyFromRequest(ocspReq)
   158  				if ost.err != nil {
   159  					t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err)
   160  				}
   161  				// tweaked request binary
   162  				ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512})
   163  				if err != nil {
   164  					t.Fatalf("failed to create OCSP request. err: %v", err)
   165  				}
   166  				ocspReq[10] = 0 // random change
   167  				_, ost = extractCertIDKeyFromRequest(ocspReq)
   168  				if ost.err == nil {
   169  					t.Fatal("should have failed")
   170  				}
   171  			}
   172  		})
   173  	}
   174  }
   175  
   176  func TestUnitCheckOCSPResponseCache(t *testing.T) {
   177  	dummyKey0 := certIDKey{
   178  		HashAlgorithm: crypto.SHA1,
   179  		NameHash:      "dummy0",
   180  		IssuerKeyHash: "dummy0",
   181  		SerialNumber:  "dummy0",
   182  	}
   183  	dummyKey := certIDKey{
   184  		HashAlgorithm: crypto.SHA1,
   185  		NameHash:      "dummy1",
   186  		IssuerKeyHash: "dummy1",
   187  		SerialNumber:  "dummy1",
   188  	}
   189  	b64Key := base64.StdEncoding.EncodeToString([]byte("DUMMY_VALUE"))
   190  	currentTime := float64(time.Now().UTC().Unix())
   191  	ocspResponseCache[dummyKey0] = &certCacheValue{currentTime, b64Key}
   192  	subject := &x509.Certificate{}
   193  	issuer := &x509.Certificate{}
   194  	ost := checkOCSPResponseCache(&dummyKey, subject, issuer)
   195  	if ost.code != ocspMissedCache {
   196  		t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code)
   197  	}
   198  	// old timestamp
   199  	ocspResponseCache[dummyKey] = &certCacheValue{float64(1395054952), b64Key}
   200  	ost = checkOCSPResponseCache(&dummyKey, subject, issuer)
   201  	if ost.code != ocspCacheExpired {
   202  		t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code)
   203  	}
   204  	// future timestamp
   205  	ocspResponseCache[dummyKey] = &certCacheValue{float64(1805054952), b64Key}
   206  	ost = checkOCSPResponseCache(&dummyKey, subject, issuer)
   207  	if ost.code != ocspFailedParseResponse {
   208  		t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code)
   209  	}
   210  	// actual OCSP but it fails to parse, because an invalid issuer certificate is given.
   211  	actualOcspResponse := "MIIB0woBAKCCAcwwggHIBgkrBgEFBQcwAQEEggG5MIIBtTCBnqIWBBSxPsNpA/i/RwHUmCYaCALvY2QrwxgPMjAxNz" + // pragma: allowlist secret
   212  		"A1MTYyMjAwMDBaMHMwcTBJMAkGBSsOAwIaBQAEFN+qEuMosQlBk+KfQoLOR0BClVijBBSxPsNpA/i/RwHUmCYaCALvY2QrwwIQBOHnp" + // pragma: allowlist secret
   213  		"Nxc8vNtwCtCuF0Vn4AAGA8yMDE3MDUxNjIyMDAwMFqgERgPMjAxNzA1MjMyMjAwMDBaMA0GCSqGSIb3DQEBCwUAA4IBAQCuRGwqQsKy" + // pragma: allowlist secret
   214  		"IAAGHgezTfG0PzMYgGD/XRDhU+2i08WTJ4Zs40Lu88cBeRXWF3iiJSpiX3/OLgfI7iXmHX9/sm2SmeNWc0Kb39bk5Lw1jwezf8hcI9+" + // pragma: allowlist secret
   215  		"mZHt60vhUgtgZk21SsRlTZ+S4VXwtDqB1Nhv6cnSnfrL2A9qJDZS2ltPNOwebWJnznDAs2dg+KxmT2yBXpHM1kb0EOolWvNgORbgIgB" + // pragma: allowlist secret
   216  		"koRzw/UU7zKsqiTB0ZN/rgJp+MocTdqQSGKvbZyR8d4u8eNQqi1x4Pk3yO/pftANFaJKGB+JPgKS3PQAqJaXcipNcEfqtl7y4PO6kqA" + // pragma: allowlist secret
   217  		"Jb4xI/OTXIrRA5TsT4cCioE"
   218  	// issuer is not a true issuer certificate
   219  	ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse}
   220  	ost = checkOCSPResponseCache(&dummyKey, subject, issuer)
   221  	if ost.code != ocspFailedParseResponse {
   222  		t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedParseResponse, ost.code)
   223  	}
   224  	// invalid validity
   225  	ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse}
   226  	ost = checkOCSPResponseCache(&dummyKey, subject, nil)
   227  	if ost.code != ocspInvalidValidity {
   228  		t.Fatalf("should have failed. expected: %v, got: %v", ocspInvalidValidity, ost.code)
   229  	}
   230  }
   231  
   232  func TestUnitValidateOCSP(t *testing.T) {
   233  	ocspRes := &ocsp.Response{}
   234  	ost := validateOCSP(ocspRes)
   235  	if ost.code != ocspInvalidValidity {
   236  		t.Fatalf("should have failed. expected: %v, got: %v", ocspInvalidValidity, ost.code)
   237  	}
   238  	currentTime := time.Now()
   239  	ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour)
   240  	ocspRes.NextUpdate = currentTime.Add(2 * time.Hour)
   241  	ocspRes.Status = ocsp.Revoked
   242  	ost = validateOCSP(ocspRes)
   243  	if ost.code != ocspStatusRevoked {
   244  		t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code)
   245  	}
   246  	ocspRes.Status = ocsp.Good
   247  	ost = validateOCSP(ocspRes)
   248  	if ost.code != ocspStatusGood {
   249  		t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code)
   250  	}
   251  	ocspRes.Status = ocsp.Unknown
   252  	ost = validateOCSP(ocspRes)
   253  	if ost.code != ocspStatusUnknown {
   254  		t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code)
   255  	}
   256  	ocspRes.Status = ocsp.ServerFailed
   257  	ost = validateOCSP(ocspRes)
   258  	if ost.code != ocspStatusOthers {
   259  		t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code)
   260  	}
   261  }
   262  
   263  func TestUnitEncodeCertID(t *testing.T) {
   264  	var st *ocspStatus
   265  	_, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2})
   266  	if st.code != ocspFailedDecomposeRequest {
   267  		t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code)
   268  	}
   269  }
   270  
   271  func getCert(addr string) []*x509.Certificate {
   272  	tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second)
   273  	if err != nil {
   274  		panic(err)
   275  	}
   276  	defer tcpConn.Close()
   277  
   278  	err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second))
   279  	if err != nil {
   280  		panic(err)
   281  	}
   282  	config := tls.Config{InsecureSkipVerify: true, ServerName: addr}
   283  
   284  	conn := tls.Client(tcpConn, &config)
   285  	defer conn.Close()
   286  
   287  	err = conn.Handshake()
   288  	if err != nil {
   289  		panic(err)
   290  	}
   291  
   292  	state := conn.ConnectionState()
   293  
   294  	return state.PeerCertificates
   295  }
   296  
   297  func TestOCSPRetry(t *testing.T) {
   298  	certs := getCert("s3-us-west-2.amazonaws.com:443")
   299  	dummyOCSPHost := &url.URL{
   300  		Scheme: "https",
   301  		Host:   "dummyOCSPHost",
   302  	}
   303  	client := &fakeHTTPClient{
   304  		cnt:     3,
   305  		success: true,
   306  		body:    []byte{1, 2, 3},
   307  	}
   308  	res, b, st := retryOCSP(
   309  		context.Background(),
   310  		client, emptyRequest,
   311  		dummyOCSPHost,
   312  		make(map[string]string), []byte{0}, certs[len(certs)-1], 10*time.Second)
   313  	if st.err == nil {
   314  		fmt.Printf("should fail: %v, %v, %v\n", res, b, st)
   315  	}
   316  	client = &fakeHTTPClient{
   317  		cnt:     30,
   318  		success: true,
   319  		body:    []byte{1, 2, 3},
   320  	}
   321  	res, b, st = retryOCSP(
   322  		context.Background(),
   323  		client, fakeRequestFunc,
   324  		dummyOCSPHost,
   325  		make(map[string]string), []byte{0}, certs[len(certs)-1], 5*time.Second)
   326  	if st.err == nil {
   327  		fmt.Printf("should fail: %v, %v, %v\n", res, b, st)
   328  	}
   329  }
   330  
   331  func TestFullOCSPURL(t *testing.T) {
   332  	testcases := []tcFullOCSPURL{
   333  		{
   334  			url:               &url.URL{Host: "some-ocsp-url.com"},
   335  			expectedURLString: "some-ocsp-url.com",
   336  		},
   337  		{
   338  			url: &url.URL{
   339  				Host: "some-ocsp-url.com",
   340  				Path: "/some-path",
   341  			},
   342  			expectedURLString: "some-ocsp-url.com/some-path",
   343  		},
   344  		{
   345  			url: &url.URL{
   346  				Host: "some-ocsp-url.com",
   347  				Path: "some-path",
   348  			},
   349  			expectedURLString: "some-ocsp-url.com/some-path",
   350  		},
   351  	}
   352  
   353  	for _, testcase := range testcases {
   354  		t.Run("", func(t *testing.T) {
   355  			returnedStringURL := fullOCSPURL(testcase.url)
   356  			if returnedStringURL != testcase.expectedURLString {
   357  				t.Fatalf("failed to match returned OCSP url string; expected: %v, got: %v",
   358  					testcase.expectedURLString, returnedStringURL)
   359  			}
   360  		})
   361  	}
   362  }
   363  
   364  type tcFullOCSPURL struct {
   365  	url               *url.URL
   366  	expectedURLString string
   367  }
   368  
   369  func TestOCSPCacheServerRetry(t *testing.T) {
   370  	dummyOCSPHost := &url.URL{
   371  		Scheme: "https",
   372  		Host:   "dummyOCSPHost",
   373  	}
   374  	client := &fakeHTTPClient{
   375  		cnt:     3,
   376  		success: true,
   377  		body:    []byte{1, 2, 3},
   378  	}
   379  	res, st := checkOCSPCacheServer(
   380  		context.Background(), client, fakeRequestFunc, dummyOCSPHost, 20*time.Second)
   381  	if st.err == nil {
   382  		t.Errorf("should fail: %v", res)
   383  	}
   384  	client = &fakeHTTPClient{
   385  		cnt:     30,
   386  		success: true,
   387  		body:    []byte{1, 2, 3},
   388  	}
   389  	res, st = checkOCSPCacheServer(
   390  		context.Background(), client, fakeRequestFunc, dummyOCSPHost, 10*time.Second)
   391  	if st.err == nil {
   392  		t.Errorf("should fail: %v", res)
   393  	}
   394  }
   395  
   396  type tcCanEarlyExit struct {
   397  	results       []*ocspStatus
   398  	resultLen     int
   399  	retFailOpen   *ocspStatus
   400  	retFailClosed *ocspStatus
   401  }
   402  
   403  func TestCanEarlyExitForOCSP(t *testing.T) {
   404  	testcases := []tcCanEarlyExit{
   405  		{ // 0
   406  			results: []*ocspStatus{
   407  				{
   408  					code: ocspStatusGood,
   409  				},
   410  				{
   411  					code: ocspStatusGood,
   412  				},
   413  				{
   414  					code: ocspStatusGood,
   415  				},
   416  			},
   417  			retFailOpen:   nil,
   418  			retFailClosed: nil,
   419  		},
   420  		{ // 1
   421  			results: []*ocspStatus{
   422  				{
   423  					code: ocspStatusRevoked,
   424  					err:  errors.New("revoked"),
   425  				},
   426  				{
   427  					code: ocspStatusGood,
   428  				},
   429  				{
   430  					code: ocspStatusGood,
   431  				},
   432  			},
   433  			retFailOpen:   &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
   434  			retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
   435  		},
   436  		{ // 2
   437  			results: []*ocspStatus{
   438  				{
   439  					code: ocspStatusUnknown,
   440  					err:  errors.New("unknown"),
   441  				},
   442  				{
   443  					code: ocspStatusGood,
   444  				},
   445  				{
   446  					code: ocspStatusGood,
   447  				},
   448  			},
   449  			retFailOpen:   nil,
   450  			retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")},
   451  		},
   452  		{ // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included.
   453  			results: []*ocspStatus{
   454  				{
   455  					code: ocspStatusRevoked,
   456  					err:  errors.New("revoked"),
   457  				},
   458  				{
   459  					code: ocspInvalidValidity,
   460  				},
   461  				{
   462  					code: ocspStatusGood,
   463  				},
   464  			},
   465  			retFailOpen:   nil,
   466  			retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
   467  		},
   468  		{ // 4: not taken as revoked if the number of results don't match the expected results.
   469  			results: []*ocspStatus{
   470  				{
   471  					code: ocspStatusRevoked,
   472  					err:  errors.New("revoked"),
   473  				},
   474  				{
   475  					code: ocspStatusGood,
   476  				},
   477  			},
   478  			resultLen:     3,
   479  			retFailOpen:   nil,
   480  			retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")},
   481  		},
   482  	}
   483  
   484  	for idx, tt := range testcases {
   485  		t.Run("", func(t *testing.T) {
   486  			ocspFailOpen = OCSPFailOpenTrue
   487  			expectedLen := len(tt.results)
   488  			if tt.resultLen > 0 {
   489  				expectedLen = tt.resultLen
   490  			}
   491  			r := canEarlyExitForOCSP(tt.results, expectedLen)
   492  			if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) {
   493  				t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r)
   494  			}
   495  			ocspFailOpen = OCSPFailOpenFalse
   496  			r = canEarlyExitForOCSP(tt.results, expectedLen)
   497  			if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) {
   498  				t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r)
   499  			}
   500  		})
   501  	}
   502  }
   503  
   504  func TestInitOCSPCacheFileCreation(t *testing.T) {
   505  	if runningOnGithubAction() {
   506  		t.Skip("cannot write to github file system")
   507  	}
   508  	dirName, err := os.UserHomeDir()
   509  	if err != nil {
   510  		t.Error(err)
   511  	}
   512  	srcFileName := dirName + "/.cache/snowflake/ocsp_response_cache.json"
   513  	tmpFileName := srcFileName + "_tmp"
   514  	dst, err := os.Create(tmpFileName)
   515  	if err != nil {
   516  		t.Error(err)
   517  	}
   518  	defer dst.Close()
   519  
   520  	var src *os.File
   521  	if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) {
   522  		// file does not exist
   523  		if err = os.MkdirAll(dirName+"/.cache/snowflake/", os.ModePerm); err != nil {
   524  			t.Error(err)
   525  		}
   526  		if _, err = os.Create(srcFileName); err != nil {
   527  			t.Error(err)
   528  		}
   529  	} else if err != nil {
   530  		t.Error(err)
   531  	} else {
   532  		// file exists
   533  		src, err = os.Open(srcFileName)
   534  		if err != nil {
   535  			t.Error(err)
   536  		}
   537  		defer src.Close()
   538  		// copy original contents to temporary file
   539  		if _, err = io.Copy(dst, src); err != nil {
   540  			t.Error(err)
   541  		}
   542  		if err = os.Remove(srcFileName); err != nil {
   543  			t.Error(err)
   544  		}
   545  	}
   546  
   547  	// cleanup
   548  	defer func() {
   549  		src, _ = os.Open(tmpFileName)
   550  		defer src.Close()
   551  		dst, _ = os.OpenFile(srcFileName, os.O_WRONLY, readWriteFileMode)
   552  		defer dst.Close()
   553  		// copy temporary file contents back to original file
   554  		if _, err = io.Copy(dst, src); err != nil {
   555  			t.Fatal(err)
   556  		}
   557  		if err = os.Remove(tmpFileName); err != nil {
   558  			t.Error(err)
   559  		}
   560  	}()
   561  
   562  	initOCSPCache()
   563  	if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) {
   564  		t.Error(err)
   565  	} else if err != nil {
   566  		t.Error(err)
   567  	}
   568  }