github.com/snowflakedb/gosnowflake@v1.9.0/ocsp.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bufio"
     7  	"context"
     8  	"crypto"
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"crypto/x509/pkix"
    12  	"encoding/asn1"
    13  	"encoding/base64"
    14  	"encoding/json"
    15  	"encoding/pem"
    16  	"errors"
    17  	"fmt"
    18  	"io"
    19  	"math/big"
    20  	"net"
    21  	"net/http"
    22  	"net/url"
    23  	"os"
    24  	"path/filepath"
    25  	"runtime"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  	"sync/atomic"
    30  	"time"
    31  
    32  	"golang.org/x/crypto/ocsp"
    33  )
    34  
    35  var (
    36  	// caRoot includes the CA certificates.
    37  	caRoot map[string]*x509.Certificate
    38  	// certPOol includes the CA certificates.
    39  	certPool *x509.CertPool
    40  	// cacheDir is the location of OCSP response cache file
    41  	cacheDir = ""
    42  	// cacheFileName is the file name of OCSP response cache file
    43  	cacheFileName = ""
    44  	// cacheUpdated is true if the memory cache is updated
    45  	cacheUpdated = true
    46  )
    47  
    48  // OCSPFailOpenMode is OCSP fail open mode. OCSPFailOpenTrue by default and may
    49  // set to ocspModeFailClosed for fail closed mode
    50  type OCSPFailOpenMode uint32
    51  
    52  const (
    53  	ocspFailOpenNotSet OCSPFailOpenMode = iota
    54  	// OCSPFailOpenTrue represents OCSP fail open mode.
    55  	OCSPFailOpenTrue
    56  	// OCSPFailOpenFalse represents OCSP fail closed mode.
    57  	OCSPFailOpenFalse
    58  )
    59  const (
    60  	ocspModeFailOpen   = "FAIL_OPEN"
    61  	ocspModeFailClosed = "FAIL_CLOSED"
    62  	ocspModeInsecure   = "INSECURE"
    63  )
    64  
    65  // OCSP fail open mode
    66  var ocspFailOpen = OCSPFailOpenTrue
    67  
    68  const (
    69  	// defaultOCSPCacheServerTimeout is the total timeout for OCSP cache server.
    70  	defaultOCSPCacheServerTimeout = 5 * time.Second
    71  
    72  	// defaultOCSPResponderTimeout is the total timeout for OCSP responder.
    73  	defaultOCSPResponderTimeout = 10 * time.Second
    74  )
    75  
    76  const (
    77  	cacheFileBaseName = "ocsp_response_cache.json"
    78  	// cacheExpire specifies cache data expiration time in seconds.
    79  	cacheExpire           = float64(24 * 60 * 60)
    80  	cacheServerURL        = "http://ocsp.snowflakecomputing.com"
    81  	cacheServerEnabledEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED"
    82  	cacheServerURLEnv     = "SF_OCSP_RESPONSE_CACHE_SERVER_URL"
    83  	cacheDirEnv           = "SF_OCSP_RESPONSE_CACHE_DIR"
    84  	ocspRetryURLEnv       = "SF_OCSP_RESPONSE_RETRY_URL"
    85  )
    86  
    87  const (
    88  	ocspTestInjectValidityErrorEnv        = "SF_OCSP_TEST_INJECT_VALIDITY_ERROR"
    89  	ocspTestInjectUnknownStatusEnv        = "SF_OCSP_TEST_INJECT_UNKNOWN_STATUS"
    90  	ocspTestResponseCacheServerTimeoutEnv = "SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT"
    91  	ocspTestResponderTimeoutEnv           = "SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT"
    92  	ocspTestResponderURLEnv               = "SF_OCSP_TEST_RESPONDER_URL"
    93  	ocspTestNoOCSPURLEnv                  = "SF_OCSP_TEST_NO_OCSP_RESPONDER_URL"
    94  )
    95  
    96  const (
    97  	tolerableValidityRatio = 100               // buffer for certificate revocation update time
    98  	maxClockSkew           = 900 * time.Second // buffer for clock skew
    99  )
   100  
   101  type ocspStatusCode int
   102  
   103  type ocspStatus struct {
   104  	code ocspStatusCode
   105  	err  error
   106  }
   107  
   108  const (
   109  	ocspSuccess                ocspStatusCode = 0
   110  	ocspStatusGood             ocspStatusCode = -1
   111  	ocspStatusRevoked          ocspStatusCode = -2
   112  	ocspStatusUnknown          ocspStatusCode = -3
   113  	ocspStatusOthers           ocspStatusCode = -4
   114  	ocspNoServer               ocspStatusCode = -5
   115  	ocspFailedParseOCSPHost    ocspStatusCode = -6
   116  	ocspFailedComposeRequest   ocspStatusCode = -7
   117  	ocspFailedDecomposeRequest ocspStatusCode = -8
   118  	ocspFailedSubmit           ocspStatusCode = -9
   119  	ocspFailedResponse         ocspStatusCode = -10
   120  	ocspFailedExtractResponse  ocspStatusCode = -11
   121  	ocspFailedParseResponse    ocspStatusCode = -12
   122  	ocspInvalidValidity        ocspStatusCode = -13
   123  	ocspMissedCache            ocspStatusCode = -14
   124  	ocspCacheExpired           ocspStatusCode = -15
   125  	ocspFailedDecodeResponse   ocspStatusCode = -16
   126  )
   127  
   128  // copied from crypto/ocsp.go
   129  type certID struct {
   130  	HashAlgorithm pkix.AlgorithmIdentifier
   131  	NameHash      []byte
   132  	IssuerKeyHash []byte
   133  	SerialNumber  *big.Int
   134  }
   135  
   136  // cache key
   137  type certIDKey struct {
   138  	HashAlgorithm crypto.Hash
   139  	NameHash      string
   140  	IssuerKeyHash string
   141  	SerialNumber  string
   142  }
   143  
   144  type certCacheValue struct {
   145  	ts             float64
   146  	ocspRespBase64 string
   147  }
   148  
   149  type parsedOcspRespKey struct {
   150  	ocspRespBase64 string
   151  	certIDBase64   string
   152  }
   153  
   154  var (
   155  	ocspResponseCache       map[certIDKey]*certCacheValue
   156  	ocspParsedRespCache     map[parsedOcspRespKey]*ocspStatus
   157  	ocspResponseCacheLock   *sync.RWMutex
   158  	ocspParsedRespCacheLock *sync.Mutex
   159  )
   160  
   161  // copied from crypto/ocsp
   162  var hashOIDs = map[crypto.Hash]asn1.ObjectIdentifier{
   163  	crypto.SHA1:   asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26}),
   164  	crypto.SHA256: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 1}),
   165  	crypto.SHA384: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 2}),
   166  	crypto.SHA512: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 3}),
   167  }
   168  
   169  // copied from crypto/ocsp
   170  func getOIDFromHashAlgorithm(target crypto.Hash) asn1.ObjectIdentifier {
   171  	for hash, oid := range hashOIDs {
   172  		if hash == target {
   173  			return oid
   174  		}
   175  	}
   176  	logger.Errorf("no valid OID is found for the hash algorithm. %#v", target)
   177  	return nil
   178  }
   179  
   180  func getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto.Hash {
   181  	for hash, oid := range hashOIDs {
   182  		if oid.Equal(target.Algorithm) {
   183  			return hash
   184  		}
   185  	}
   186  	logger.Errorf("no valid hash algorithm is found for the oid. Falling back to SHA1: %#v", target)
   187  	return crypto.SHA1
   188  }
   189  
   190  // calcTolerableValidity returns the maximum validity buffer
   191  func calcTolerableValidity(thisUpdate, nextUpdate time.Time) time.Duration {
   192  	return durationMax(time.Duration(nextUpdate.Sub(thisUpdate)/tolerableValidityRatio), maxClockSkew)
   193  }
   194  
   195  // isInValidityRange checks the validity
   196  func isInValidityRange(currTime, thisUpdate, nextUpdate time.Time) bool {
   197  	if currTime.Sub(thisUpdate.Add(-maxClockSkew)) < 0 {
   198  		return false
   199  	}
   200  	if nextUpdate.Add(calcTolerableValidity(thisUpdate, nextUpdate)).Sub(currTime) < 0 {
   201  		return false
   202  	}
   203  	return true
   204  }
   205  
   206  func isTestInvalidValidity() bool {
   207  	return strings.EqualFold(os.Getenv(ocspTestInjectValidityErrorEnv), "true")
   208  }
   209  
   210  func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) {
   211  	r, err := ocsp.ParseRequest(ocspReq)
   212  	if err != nil {
   213  		return nil, &ocspStatus{
   214  			code: ocspFailedDecomposeRequest,
   215  			err:  err,
   216  		}
   217  	}
   218  
   219  	// encode CertID, used as a key in the cache
   220  	encodedCertID := &certIDKey{
   221  		r.HashAlgorithm,
   222  		base64.StdEncoding.EncodeToString(r.IssuerNameHash),
   223  		base64.StdEncoding.EncodeToString(r.IssuerKeyHash),
   224  		r.SerialNumber.String(),
   225  	}
   226  	return encodedCertID, &ocspStatus{
   227  		code: ocspSuccess,
   228  	}
   229  }
   230  
   231  func decodeCertIDKey(certIDKeyBase64 string) *certIDKey {
   232  	r, err := base64.StdEncoding.DecodeString(certIDKeyBase64)
   233  	if err != nil {
   234  		return nil
   235  	}
   236  	var c certID
   237  	rest, err := asn1.Unmarshal(r, &c)
   238  	if err != nil {
   239  		// error in parsing
   240  		return nil
   241  	}
   242  	if len(rest) > 0 {
   243  		// extra bytes to the end
   244  		return nil
   245  	}
   246  	return &certIDKey{
   247  		getHashAlgorithmFromOID(c.HashAlgorithm),
   248  		base64.StdEncoding.EncodeToString(c.NameHash),
   249  		base64.StdEncoding.EncodeToString(c.IssuerKeyHash),
   250  		c.SerialNumber.String(),
   251  	}
   252  }
   253  
   254  func encodeCertIDKey(k *certIDKey) string {
   255  	serialNumber := new(big.Int)
   256  	serialNumber.SetString(k.SerialNumber, 10)
   257  	nameHash, err := base64.StdEncoding.DecodeString(k.NameHash)
   258  	if err != nil {
   259  		return ""
   260  	}
   261  	issuerKeyHash, err := base64.StdEncoding.DecodeString(k.IssuerKeyHash)
   262  	if err != nil {
   263  		return ""
   264  	}
   265  	encodedCertID, err := asn1.Marshal(certID{
   266  		pkix.AlgorithmIdentifier{
   267  			Algorithm:  getOIDFromHashAlgorithm(k.HashAlgorithm),
   268  			Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */},
   269  		},
   270  		nameHash,
   271  		issuerKeyHash,
   272  		serialNumber,
   273  	})
   274  	if err != nil {
   275  		return ""
   276  	}
   277  	return base64.StdEncoding.EncodeToString(encodedCertID)
   278  }
   279  
   280  func checkOCSPResponseCache(certIDKey *certIDKey, subject, issuer *x509.Certificate) *ocspStatus {
   281  	if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") {
   282  		return &ocspStatus{code: ocspNoServer}
   283  	}
   284  
   285  	gotValueFromCache, ok := func() (*certCacheValue, bool) {
   286  		ocspResponseCacheLock.RLock()
   287  		defer ocspResponseCacheLock.RUnlock()
   288  		valueFromCache, ok := ocspResponseCache[*certIDKey]
   289  		return valueFromCache, ok
   290  	}()
   291  	if !ok {
   292  		return &ocspStatus{
   293  			code: ocspMissedCache,
   294  			err:  fmt.Errorf("miss cache data. subject: %v", subject),
   295  		}
   296  	}
   297  
   298  	status := extractOCSPCacheResponseValue(certIDKey, gotValueFromCache, subject, issuer)
   299  	if !isValidOCSPStatus(status.code) {
   300  		deleteOCSPCache(certIDKey)
   301  	}
   302  	return status
   303  }
   304  
   305  func deleteOCSPCache(encodedCertID *certIDKey) {
   306  	ocspResponseCacheLock.Lock()
   307  	defer ocspResponseCacheLock.Unlock()
   308  	delete(ocspResponseCache, *encodedCertID)
   309  	cacheUpdated = true
   310  }
   311  
   312  func validateOCSP(ocspRes *ocsp.Response) *ocspStatus {
   313  	curTime := time.Now()
   314  
   315  	if ocspRes == nil {
   316  		return &ocspStatus{
   317  			code: ocspFailedDecomposeRequest,
   318  			err:  errors.New("OCSP Response is nil"),
   319  		}
   320  	}
   321  	if isTestInvalidValidity() || !isInValidityRange(curTime, ocspRes.ThisUpdate, ocspRes.NextUpdate) {
   322  		return &ocspStatus{
   323  			code: ocspInvalidValidity,
   324  			err: &SnowflakeError{
   325  				Number:      ErrOCSPInvalidValidity,
   326  				Message:     errMsgOCSPInvalidValidity,
   327  				MessageArgs: []interface{}{ocspRes.ProducedAt, ocspRes.ThisUpdate, ocspRes.NextUpdate},
   328  			},
   329  		}
   330  	}
   331  	if isTestUnknownStatus() {
   332  		ocspRes.Status = ocsp.Unknown
   333  	}
   334  	return returnOCSPStatus(ocspRes)
   335  }
   336  
   337  func returnOCSPStatus(ocspRes *ocsp.Response) *ocspStatus {
   338  	switch ocspRes.Status {
   339  	case ocsp.Good:
   340  		return &ocspStatus{
   341  			code: ocspStatusGood,
   342  			err:  nil,
   343  		}
   344  	case ocsp.Revoked:
   345  		return &ocspStatus{
   346  			code: ocspStatusRevoked,
   347  			err: &SnowflakeError{
   348  				Number:      ErrOCSPStatusRevoked,
   349  				Message:     errMsgOCSPStatusRevoked,
   350  				MessageArgs: []interface{}{ocspRes.RevocationReason, ocspRes.RevokedAt},
   351  			},
   352  		}
   353  	case ocsp.Unknown:
   354  		return &ocspStatus{
   355  			code: ocspStatusUnknown,
   356  			err: &SnowflakeError{
   357  				Number:  ErrOCSPStatusUnknown,
   358  				Message: errMsgOCSPStatusUnknown,
   359  			},
   360  		}
   361  	default:
   362  		return &ocspStatus{
   363  			code: ocspStatusOthers,
   364  			err:  fmt.Errorf("OCSP others. %v", ocspRes.Status),
   365  		}
   366  	}
   367  }
   368  
   369  func isTestUnknownStatus() bool {
   370  	return strings.EqualFold(os.Getenv(ocspTestInjectUnknownStatusEnv), "true")
   371  }
   372  
   373  func checkOCSPCacheServer(
   374  	ctx context.Context,
   375  	client clientInterface,
   376  	req requestFunc,
   377  	ocspServerHost *url.URL,
   378  	totalTimeout time.Duration) (
   379  	cacheContent *map[string]*certCacheValue,
   380  	ocspS *ocspStatus) {
   381  	var respd map[string][]interface{}
   382  	headers := make(map[string]string)
   383  	res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultMaxRetryCount, defaultTimeProvider, nil).execute()
   384  	if err != nil {
   385  		logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err)
   386  		return nil, &ocspStatus{
   387  			code: ocspFailedSubmit,
   388  			err:  err,
   389  		}
   390  	}
   391  	defer res.Body.Close()
   392  	logger.Debugf("StatusCode from OCSP Cache Server: %v", res.StatusCode)
   393  	if res.StatusCode != http.StatusOK {
   394  		return nil, &ocspStatus{
   395  			code: ocspFailedResponse,
   396  			err:  fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status),
   397  		}
   398  	}
   399  	logger.Debugf("reading contents")
   400  
   401  	dec := json.NewDecoder(res.Body)
   402  	for {
   403  		if err := dec.Decode(&respd); err == io.EOF {
   404  			break
   405  		} else if err != nil {
   406  			logger.Errorf("failed to decode OCSP cache. %v", err)
   407  			return nil, &ocspStatus{
   408  				code: ocspFailedExtractResponse,
   409  				err:  err,
   410  			}
   411  		}
   412  	}
   413  	buf := make(map[string]*certCacheValue)
   414  	for key, value := range respd {
   415  		ok, ts, ocspRespBase64 := extractTsAndOcspRespBase64(value)
   416  		if !ok {
   417  			continue
   418  		}
   419  		buf[key] = &certCacheValue{ts, ocspRespBase64}
   420  	}
   421  	return &buf, &ocspStatus{
   422  		code: ocspSuccess,
   423  	}
   424  }
   425  
   426  // retryOCSP is the second level of retry method if the returned contents are corrupted. It often happens with OCSP
   427  // serer and retry helps.
   428  func retryOCSP(
   429  	ctx context.Context,
   430  	client clientInterface,
   431  	req requestFunc,
   432  	ocspHost *url.URL,
   433  	headers map[string]string,
   434  	reqBody []byte,
   435  	issuer *x509.Certificate,
   436  	totalTimeout time.Duration) (
   437  	ocspRes *ocsp.Response,
   438  	ocspResBytes []byte,
   439  	ocspS *ocspStatus) {
   440  	multiplier := 1
   441  	if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) {
   442  		multiplier = 3 // up to 3 times for Fail Close mode
   443  	}
   444  	res, err := newRetryHTTP(
   445  		ctx, client, req, ocspHost, headers,
   446  		totalTimeout*time.Duration(multiplier), defaultMaxRetryCount, defaultTimeProvider, nil).doPost().setBody(reqBody).execute()
   447  	if err != nil {
   448  		return ocspRes, ocspResBytes, &ocspStatus{
   449  			code: ocspFailedSubmit,
   450  			err:  err,
   451  		}
   452  	}
   453  	defer res.Body.Close()
   454  	logger.Debugf("StatusCode from OCSP Server: %v\n", res.StatusCode)
   455  	if res.StatusCode != http.StatusOK {
   456  		return ocspRes, ocspResBytes, &ocspStatus{
   457  			code: ocspFailedResponse,
   458  			err:  fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status),
   459  		}
   460  	}
   461  	ocspResBytes, err = io.ReadAll(res.Body)
   462  	if err != nil {
   463  		return ocspRes, ocspResBytes, &ocspStatus{
   464  			code: ocspFailedExtractResponse,
   465  			err:  err,
   466  		}
   467  	}
   468  	ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer)
   469  	if err != nil {
   470  		logger.Warnf("error when parsing ocsp response: %v", err)
   471  		logger.Warnf("performing GET fallback request to OCSP")
   472  		return fallbackRetryOCSPToGETRequest(ctx, client, req, ocspHost, headers, issuer, totalTimeout)
   473  	}
   474  
   475  	logger.Debugf("OCSP Status from server: %v", printStatus(ocspRes))
   476  	return ocspRes, ocspResBytes, &ocspStatus{
   477  		code: ocspSuccess,
   478  	}
   479  }
   480  
   481  // fallbackRetryOCSPToGETRequest is the third level of retry method. Some OCSP responders do not support POST requests
   482  // and will return with a "malformed" request error. In that case we also try to perform a GET request
   483  func fallbackRetryOCSPToGETRequest(
   484  	ctx context.Context,
   485  	client clientInterface,
   486  	req requestFunc,
   487  	ocspHost *url.URL,
   488  	headers map[string]string,
   489  	issuer *x509.Certificate,
   490  	totalTimeout time.Duration) (
   491  	ocspRes *ocsp.Response,
   492  	ocspResBytes []byte,
   493  	ocspS *ocspStatus) {
   494  	multiplier := 1
   495  	if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) {
   496  		multiplier = 3 // up to 3 times for Fail Close mode
   497  	}
   498  	res, err := newRetryHTTP(ctx, client, req, ocspHost, headers,
   499  		totalTimeout*time.Duration(multiplier), defaultMaxRetryCount, defaultTimeProvider, nil).execute()
   500  	if err != nil {
   501  		return ocspRes, ocspResBytes, &ocspStatus{
   502  			code: ocspFailedSubmit,
   503  			err:  err,
   504  		}
   505  	}
   506  	defer res.Body.Close()
   507  	logger.Debugf("GET fallback StatusCode from OCSP Server: %v", res.StatusCode)
   508  	if res.StatusCode != http.StatusOK {
   509  		return ocspRes, ocspResBytes, &ocspStatus{
   510  			code: ocspFailedResponse,
   511  			err:  fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status),
   512  		}
   513  	}
   514  	ocspResBytes, err = io.ReadAll(res.Body)
   515  	if err != nil {
   516  		return ocspRes, ocspResBytes, &ocspStatus{
   517  			code: ocspFailedExtractResponse,
   518  			err:  err,
   519  		}
   520  	}
   521  	ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer)
   522  	if err != nil {
   523  		return ocspRes, ocspResBytes, &ocspStatus{
   524  			code: ocspFailedParseResponse,
   525  			err:  err,
   526  		}
   527  	}
   528  
   529  	logger.Debugf("GET fallback OCSP Status from server: %v", printStatus(ocspRes))
   530  	return ocspRes, ocspResBytes, &ocspStatus{
   531  		code: ocspSuccess,
   532  	}
   533  }
   534  
   535  func printStatus(response *ocsp.Response) string {
   536  	switch response.Status {
   537  	case ocsp.Good:
   538  		return "Good"
   539  	case ocsp.Revoked:
   540  		return "Revoked"
   541  	case ocsp.Unknown:
   542  		return "Unknown"
   543  	default:
   544  		return fmt.Sprintf("%d", response.Status)
   545  	}
   546  }
   547  
   548  func fullOCSPURL(url *url.URL) string {
   549  	fullURL := url.Hostname()
   550  	if url.Path != "" {
   551  		if !strings.HasPrefix(url.Path, "/") {
   552  			fullURL += "/"
   553  		}
   554  		fullURL += url.Path
   555  	}
   556  	return fullURL
   557  }
   558  
   559  // getRevocationStatus checks the certificate revocation status for subject using issuer certificate.
   560  func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) *ocspStatus {
   561  	logger.Infof("Subject: %v, Issuer: %v", subject.Subject, issuer.Subject)
   562  
   563  	status, ocspReq, encodedCertID := validateWithCache(subject, issuer)
   564  	if isValidOCSPStatus(status.code) {
   565  		return status
   566  	}
   567  	if ocspReq == nil || encodedCertID == nil {
   568  		return status
   569  	}
   570  	logger.Infof("cache missed")
   571  	logger.Infof("OCSP Server: %v", subject.OCSPServer)
   572  	if len(subject.OCSPServer) == 0 || isTestNoOCSPURL() {
   573  		return &ocspStatus{
   574  			code: ocspNoServer,
   575  			err: &SnowflakeError{
   576  				Number:      ErrOCSPNoOCSPResponderURL,
   577  				Message:     errMsgOCSPNoOCSPResponderURL,
   578  				MessageArgs: []interface{}{subject.Subject},
   579  			},
   580  		}
   581  	}
   582  	ocspHost := subject.OCSPServer[0]
   583  	u, err := url.Parse(ocspHost)
   584  	if err != nil {
   585  		return &ocspStatus{
   586  			code: ocspFailedParseOCSPHost,
   587  			err:  fmt.Errorf("failed to parse OCSP server host. %v", ocspHost),
   588  		}
   589  	}
   590  	hostnameStr := os.Getenv(ocspTestResponderURLEnv)
   591  	var hostname string
   592  	if retryURL := os.Getenv(ocspRetryURLEnv); retryURL != "" {
   593  		hostname = fmt.Sprintf(retryURL, fullOCSPURL(u), base64.StdEncoding.EncodeToString(ocspReq))
   594  		u0, err := url.Parse(hostname)
   595  		if err == nil {
   596  			hostname = u0.Hostname()
   597  			u = u0
   598  		}
   599  	} else {
   600  		hostname = fullOCSPURL(u)
   601  	}
   602  	if hostnameStr != "" {
   603  		u0, err := url.Parse(hostnameStr)
   604  		if err == nil {
   605  			hostname = u0.Hostname()
   606  			u = u0
   607  		}
   608  	}
   609  
   610  	logger.Debugf("Fetching OCSP response from server: %v", u)
   611  	logger.Debugf("Host in headers: %v", hostname)
   612  
   613  	headers := make(map[string]string)
   614  	headers[httpHeaderContentType] = "application/ocsp-request"
   615  	headers[httpHeaderAccept] = "application/ocsp-response"
   616  	headers[httpHeaderContentLength] = strconv.Itoa(len(ocspReq))
   617  	headers[httpHeaderHost] = hostname
   618  	timeoutStr := os.Getenv(ocspTestResponderTimeoutEnv)
   619  	timeout := defaultOCSPResponderTimeout
   620  	if timeoutStr != "" {
   621  		var timeoutInMilliseconds int
   622  		timeoutInMilliseconds, err = strconv.Atoi(timeoutStr)
   623  		if err == nil {
   624  			timeout = time.Duration(timeoutInMilliseconds) * time.Millisecond
   625  		}
   626  	}
   627  	ocspClient := &http.Client{
   628  		Timeout:   timeout,
   629  		Transport: snowflakeInsecureTransport,
   630  	}
   631  	ocspRes, ocspResBytes, ocspS := retryOCSP(
   632  		ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer, timeout)
   633  	if ocspS.code != ocspSuccess {
   634  		return ocspS
   635  	}
   636  
   637  	ret := validateOCSP(ocspRes)
   638  	if !isValidOCSPStatus(ret.code) {
   639  		return ret // return invalid
   640  	}
   641  	v := &certCacheValue{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)}
   642  	ocspResponseCacheLock.Lock()
   643  	ocspResponseCache[*encodedCertID] = v
   644  	cacheUpdated = true
   645  	ocspResponseCacheLock.Unlock()
   646  	return ret
   647  }
   648  
   649  func isTestNoOCSPURL() bool {
   650  	return strings.EqualFold(os.Getenv(ocspTestNoOCSPURLEnv), "true")
   651  }
   652  
   653  func isValidOCSPStatus(status ocspStatusCode) bool {
   654  	return status == ocspStatusGood || status == ocspStatusRevoked || status == ocspStatusUnknown
   655  }
   656  
   657  // verifyPeerCertificate verifies all of certificate revocation status
   658  func verifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate) (err error) {
   659  	for i := 0; i < len(verifiedChains); i++ {
   660  		// Certificate signed by Root CA. This should be one before the last in the Certificate Chain
   661  		numberOfNoneRootCerts := len(verifiedChains[i]) - 1
   662  		if !verifiedChains[i][numberOfNoneRootCerts].IsCA || string(verifiedChains[i][numberOfNoneRootCerts].RawIssuer) != string(verifiedChains[i][numberOfNoneRootCerts].RawSubject) {
   663  			// Check if the last Non Root Cert is also a CA or is self signed.
   664  			// if the last certificate is not, add it to the list
   665  			rca := caRoot[string(verifiedChains[i][numberOfNoneRootCerts].RawIssuer)]
   666  			if rca == nil {
   667  				return fmt.Errorf("failed to find root CA. pkix.name: %v", verifiedChains[i][numberOfNoneRootCerts].Issuer)
   668  			}
   669  			verifiedChains[i] = append(verifiedChains[i], rca)
   670  			numberOfNoneRootCerts++
   671  		}
   672  		results := getAllRevocationStatus(ctx, verifiedChains[i])
   673  		if r := canEarlyExitForOCSP(results, numberOfNoneRootCerts); r != nil {
   674  			return r.err
   675  		}
   676  	}
   677  
   678  	ocspResponseCacheLock.Lock()
   679  	if cacheUpdated {
   680  		writeOCSPCacheFile()
   681  	}
   682  	cacheUpdated = false
   683  	ocspResponseCacheLock.Unlock()
   684  	return nil
   685  }
   686  
   687  func canEarlyExitForOCSP(results []*ocspStatus, chainSize int) *ocspStatus {
   688  	msg := ""
   689  	if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) {
   690  		// Fail closed. any error is returned to stop connection
   691  		for _, r := range results {
   692  			if r.err != nil {
   693  				return r
   694  			}
   695  		}
   696  	} else {
   697  		// Fail open and all results are valid.
   698  		allValid := len(results) == chainSize
   699  		for _, r := range results {
   700  			if !isValidOCSPStatus(r.code) {
   701  				allValid = false
   702  				break
   703  			}
   704  		}
   705  		for _, r := range results {
   706  			if allValid && r.code == ocspStatusRevoked {
   707  				return r
   708  			}
   709  			if r != nil && r.code != ocspStatusGood && r.err != nil {
   710  				msg += "\n" + r.err.Error()
   711  			}
   712  		}
   713  	}
   714  	if len(msg) > 0 {
   715  		logger.Warnf(
   716  			"WARNING!!! Using fail-open to connect. Driver is connecting to an "+
   717  				"HTTPS endpoint without OCSP based Certificate Revocation checking "+
   718  				"as it could not obtain a valid OCSP Response to use from the CA OCSP "+
   719  				"responder. Detail: %v", msg[1:])
   720  	}
   721  	return nil
   722  }
   723  
   724  func validateWithCacheForAllCertificates(verifiedChains []*x509.Certificate) bool {
   725  	n := len(verifiedChains) - 1
   726  	for j := 0; j < n; j++ {
   727  		subject := verifiedChains[j]
   728  		issuer := verifiedChains[j+1]
   729  		status, _, _ := validateWithCache(subject, issuer)
   730  		if !isValidOCSPStatus(status.code) {
   731  			return false
   732  		}
   733  	}
   734  	return true
   735  }
   736  
   737  func validateWithCache(subject, issuer *x509.Certificate) (*ocspStatus, []byte, *certIDKey) {
   738  	ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{})
   739  	if err != nil {
   740  		logger.Errorf("failed to create OCSP request from the certificates.\n")
   741  		return &ocspStatus{
   742  			code: ocspFailedComposeRequest,
   743  			err:  errors.New("failed to create a OCSP request"),
   744  		}, nil, nil
   745  	}
   746  	encodedCertID, ocspS := extractCertIDKeyFromRequest(ocspReq)
   747  	if ocspS.code != ocspSuccess {
   748  		logger.Errorf("failed to extract CertID from OCSP Request.\n")
   749  		return &ocspStatus{
   750  			code: ocspFailedComposeRequest,
   751  			err:  errors.New("failed to extract cert ID Key"),
   752  		}, ocspReq, nil
   753  	}
   754  	status := checkOCSPResponseCache(encodedCertID, subject, issuer)
   755  	return status, ocspReq, encodedCertID
   756  }
   757  
   758  func downloadOCSPCacheServer() {
   759  	if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") {
   760  		return
   761  	}
   762  	ocspCacheServerURL := os.Getenv(cacheServerURLEnv)
   763  	if ocspCacheServerURL == "" {
   764  		ocspCacheServerURL = fmt.Sprintf("%v/%v", cacheServerURL, cacheFileBaseName)
   765  	}
   766  	u, err := url.Parse(ocspCacheServerURL)
   767  	if err != nil {
   768  		return
   769  	}
   770  	logger.Infof("downloading OCSP Cache from server %v", ocspCacheServerURL)
   771  	timeoutStr := os.Getenv(ocspTestResponseCacheServerTimeoutEnv)
   772  	timeout := defaultOCSPCacheServerTimeout
   773  	if timeoutStr != "" {
   774  		var timeoutInMilliseconds int
   775  		timeoutInMilliseconds, err = strconv.Atoi(timeoutStr)
   776  		if err == nil {
   777  			timeout = time.Duration(timeoutInMilliseconds) * time.Millisecond
   778  		}
   779  	}
   780  	ocspClient := &http.Client{
   781  		Timeout:   timeout,
   782  		Transport: snowflakeInsecureTransport,
   783  	}
   784  	ret, ocspStatus := checkOCSPCacheServer(context.Background(), ocspClient, http.NewRequest, u, timeout)
   785  	if ocspStatus.code != ocspSuccess {
   786  		return
   787  	}
   788  
   789  	ocspResponseCacheLock.Lock()
   790  	for k, cacheValue := range *ret {
   791  		cacheKey := decodeCertIDKey(k)
   792  		status := extractOCSPCacheResponseValueWithoutSubject(cacheKey, cacheValue)
   793  		if !isValidOCSPStatus(status.code) {
   794  			continue
   795  		}
   796  		ocspResponseCache[*cacheKey] = cacheValue
   797  	}
   798  	cacheUpdated = true
   799  	ocspResponseCacheLock.Unlock()
   800  }
   801  
   802  func getAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certificate) []*ocspStatus {
   803  	cached := validateWithCacheForAllCertificates(verifiedChains)
   804  	if !cached {
   805  		downloadOCSPCacheServer()
   806  	}
   807  	n := len(verifiedChains) - 1
   808  	results := make([]*ocspStatus, n)
   809  	for j := 0; j < n; j++ {
   810  		results[j] = getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1])
   811  		if !isValidOCSPStatus(results[j].code) {
   812  			return results
   813  		}
   814  	}
   815  	return results
   816  }
   817  
   818  // verifyPeerCertificateSerial verifies the certificate revocation status in serial.
   819  func verifyPeerCertificateSerial(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) {
   820  	overrideCacheDir()
   821  	return verifyPeerCertificate(context.Background(), verifiedChains)
   822  }
   823  
   824  func overrideCacheDir() {
   825  	if os.Getenv(cacheDirEnv) != "" {
   826  		ocspResponseCacheLock.Lock()
   827  		defer ocspResponseCacheLock.Unlock()
   828  		createOCSPCacheDir()
   829  	}
   830  }
   831  
   832  // initOCSPCache initializes OCSP Response cache file.
   833  func initOCSPCache() {
   834  	if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") {
   835  		return
   836  	}
   837  	ocspResponseCache = make(map[certIDKey]*certCacheValue)
   838  	ocspParsedRespCache = make(map[parsedOcspRespKey]*ocspStatus)
   839  	ocspResponseCacheLock = &sync.RWMutex{}
   840  	ocspParsedRespCacheLock = &sync.Mutex{}
   841  
   842  	logger.Infof("reading OCSP Response cache file. %v\n", cacheFileName)
   843  	f, err := os.OpenFile(cacheFileName, os.O_CREATE|os.O_RDONLY, readWriteFileMode)
   844  	if err != nil {
   845  		logger.Debugf("failed to open. Ignored. %v\n", err)
   846  		return
   847  	}
   848  	defer f.Close()
   849  
   850  	buf := make(map[string][]interface{})
   851  	r := bufio.NewReader(f)
   852  	dec := json.NewDecoder(r)
   853  	for {
   854  		if err = dec.Decode(&buf); err == io.EOF {
   855  			break
   856  		} else if err != nil {
   857  			logger.Debugf("failed to read. Ignored. %v\n", err)
   858  			return
   859  		}
   860  	}
   861  
   862  	for k, cacheValue := range buf {
   863  		ok, ts, ocspRespBase64 := extractTsAndOcspRespBase64(cacheValue)
   864  		if !ok {
   865  			continue
   866  		}
   867  		certValue := &certCacheValue{ts, ocspRespBase64}
   868  		cacheKey := decodeCertIDKey(k)
   869  		status := extractOCSPCacheResponseValueWithoutSubject(cacheKey, certValue)
   870  		if !isValidOCSPStatus(status.code) {
   871  			continue
   872  		}
   873  		ocspResponseCache[*cacheKey] = certValue
   874  
   875  	}
   876  	cacheUpdated = false
   877  }
   878  
   879  func extractTsAndOcspRespBase64(value []interface{}) (bool, float64, string) {
   880  	ts, ok := value[0].(float64)
   881  	if !ok {
   882  		logger.Warnf("cannot cast %v as float64", value[0])
   883  		return false, -1, ""
   884  	}
   885  	ocspRespBase64, ok := value[1].(string)
   886  	if !ok {
   887  		logger.Warnf("cannot cast %v as string", value[1])
   888  		return false, -1, ""
   889  	}
   890  	return true, ts, ocspRespBase64
   891  }
   892  
   893  func extractOCSPCacheResponseValueWithoutSubject(cacheKey *certIDKey, cacheValue *certCacheValue) *ocspStatus {
   894  	return extractOCSPCacheResponseValue(cacheKey, cacheValue, nil, nil)
   895  }
   896  
   897  func extractOCSPCacheResponseValue(certIDKey *certIDKey, certCacheValue *certCacheValue, subject, issuer *x509.Certificate) *ocspStatus {
   898  	subjectName := "Unknown"
   899  	if subject != nil {
   900  		subjectName = subject.Subject.CommonName
   901  	}
   902  
   903  	curTime := time.Now()
   904  	currentTime := float64(curTime.UTC().Unix())
   905  	if currentTime-certCacheValue.ts >= cacheExpire {
   906  		return &ocspStatus{
   907  			code: ocspCacheExpired,
   908  			err: fmt.Errorf("cache expired. current: %v, cache: %v",
   909  				time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(certCacheValue.ts), 0).UTC()),
   910  		}
   911  	}
   912  
   913  	ocspParsedRespCacheLock.Lock()
   914  	defer ocspParsedRespCacheLock.Unlock()
   915  
   916  	var cacheKey parsedOcspRespKey
   917  	if certIDKey != nil {
   918  		cacheKey = parsedOcspRespKey{certCacheValue.ocspRespBase64, encodeCertIDKey(certIDKey)}
   919  	} else {
   920  		cacheKey = parsedOcspRespKey{certCacheValue.ocspRespBase64, ""}
   921  	}
   922  	status, ok := ocspParsedRespCache[cacheKey]
   923  	if !ok {
   924  		logger.Debugf("OCSP status not found in cache; certIdKey: %v", certIDKey)
   925  		var err error
   926  		var b []byte
   927  		b, err = base64.StdEncoding.DecodeString(certCacheValue.ocspRespBase64)
   928  		if err != nil {
   929  			return &ocspStatus{
   930  				code: ocspFailedDecodeResponse,
   931  				err:  fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err),
   932  			}
   933  		}
   934  		// check the revocation status here
   935  		ocspResponse, err := ocsp.ParseResponse(b, issuer)
   936  
   937  		if err != nil {
   938  			logger.Warnf("the second cache element is not a valid OCSP Response. Ignored. subject: %v\n", subjectName)
   939  			return &ocspStatus{
   940  				code: ocspFailedParseResponse,
   941  				err:  fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err),
   942  			}
   943  		}
   944  		status = validateOCSP(ocspResponse)
   945  		ocspParsedRespCache[cacheKey] = status
   946  	}
   947  	logger.Debugf("OCSP status found in cache: %v; certIdKey: %v", status, certIDKey)
   948  	return status
   949  }
   950  
   951  // writeOCSPCacheFile writes a OCSP Response cache file. This is called if all revocation status is success.
   952  // lock file is used to mitigate race condition with other process.
   953  func writeOCSPCacheFile() {
   954  	if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") {
   955  		return
   956  	}
   957  	logger.Infof("writing OCSP Response cache file. %v\n", cacheFileName)
   958  	cacheLockFileName := cacheFileName + ".lck"
   959  	err := os.Mkdir(cacheLockFileName, 0600)
   960  	switch {
   961  	case os.IsExist(err):
   962  		statinfo, err := os.Stat(cacheLockFileName)
   963  		if err != nil {
   964  			logger.Debugf("failed to get file info for cache lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err)
   965  			return
   966  		}
   967  		if time.Since(statinfo.ModTime()) < 15*time.Minute {
   968  			logger.Debugf("other process locks the cache file. %v. ignored.\n", cacheLockFileName)
   969  			return
   970  		}
   971  		if err = os.Remove(cacheLockFileName); err != nil {
   972  			logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err)
   973  			return
   974  		}
   975  		if err = os.Mkdir(cacheLockFileName, 0600); err != nil {
   976  			logger.Debugf("failed to create lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err)
   977  			return
   978  		}
   979  	}
   980  	// if mkdir fails for any other reason: permission denied, operation not permitted, I/O error, too many open files, etc.
   981  	if err != nil {
   982  		logger.Debugf("failed to create lock file. file %v, err: %v. ignored.\n", cacheLockFileName, err)
   983  		return
   984  	}
   985  	defer os.RemoveAll(cacheLockFileName)
   986  
   987  	buf := make(map[string][]interface{})
   988  	for k, v := range ocspResponseCache {
   989  		cacheKeyInBase64 := encodeCertIDKey(&k)
   990  		buf[cacheKeyInBase64] = []interface{}{v.ts, v.ocspRespBase64}
   991  	}
   992  
   993  	j, err := json.Marshal(buf)
   994  	if err != nil {
   995  		logger.Debugf("failed to convert OCSP Response cache to JSON. ignored.")
   996  		return
   997  	}
   998  	if err = os.WriteFile(cacheFileName, j, 0644); err != nil {
   999  		logger.Debugf("failed to write OCSP Response cache. err: %v. ignored.\n", err)
  1000  	}
  1001  }
  1002  
  1003  // readCACerts read a set of root CAs
  1004  func readCACerts() {
  1005  	raw := []byte(caRootPEM)
  1006  	certPool = x509.NewCertPool()
  1007  	caRoot = make(map[string]*x509.Certificate)
  1008  	var p *pem.Block
  1009  	for {
  1010  		p, raw = pem.Decode(raw)
  1011  		if p == nil {
  1012  			break
  1013  		}
  1014  		if p.Type != "CERTIFICATE" {
  1015  			continue
  1016  		}
  1017  		c, err := x509.ParseCertificate(p.Bytes)
  1018  		if err != nil {
  1019  			panic("failed to parse CA certificate.")
  1020  		}
  1021  		certPool.AddCert(c)
  1022  		caRoot[string(c.RawSubject)] = c
  1023  	}
  1024  }
  1025  
  1026  // createOCSPCacheDir creates OCSP response cache directory and set the cache file name.
  1027  func createOCSPCacheDir() {
  1028  	if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") {
  1029  		logger.Info(`OCSP Cache Server disabled. All further access and use of
  1030  			OCSP Cache will be disabled for this OCSP Status Query`)
  1031  		return
  1032  	}
  1033  	cacheDir = os.Getenv(cacheDirEnv)
  1034  	if cacheDir == "" {
  1035  		cacheDir = os.Getenv("SNOWFLAKE_TEST_WORKSPACE")
  1036  	}
  1037  	if cacheDir == "" {
  1038  		switch runtime.GOOS {
  1039  		case "windows":
  1040  			cacheDir = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local", "Snowflake", "Caches")
  1041  		case "darwin":
  1042  			home := os.Getenv("HOME")
  1043  			if home == "" {
  1044  				logger.Info("HOME is blank.")
  1045  			}
  1046  			cacheDir = filepath.Join(home, "Library", "Caches", "Snowflake")
  1047  		default:
  1048  			home := os.Getenv("HOME")
  1049  			if home == "" {
  1050  				logger.Info("HOME is blank")
  1051  			}
  1052  			cacheDir = filepath.Join(home, ".cache", "snowflake")
  1053  		}
  1054  	}
  1055  
  1056  	if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
  1057  		if err = os.MkdirAll(cacheDir, os.ModePerm); err != nil {
  1058  			logger.Debugf("failed to create cache directory. %v, err: %v. ignored\n", cacheDir, err)
  1059  		}
  1060  	}
  1061  	cacheFileName = filepath.Join(cacheDir, cacheFileBaseName)
  1062  	logger.Infof("reset OCSP cache file. %v", cacheFileName)
  1063  }
  1064  
  1065  func init() {
  1066  	readCACerts()
  1067  	createOCSPCacheDir()
  1068  	initOCSPCache()
  1069  }
  1070  
  1071  // snowflakeInsecureTransport is the transport object that doesn't do certificate revocation check.
  1072  var snowflakeInsecureTransport = &http.Transport{
  1073  	MaxIdleConns:    10,
  1074  	IdleConnTimeout: 30 * time.Minute,
  1075  	Proxy:           http.ProxyFromEnvironment,
  1076  	DialContext: (&net.Dialer{
  1077  		Timeout:   30 * time.Second,
  1078  		KeepAlive: 30 * time.Second,
  1079  	}).DialContext,
  1080  }
  1081  
  1082  // SnowflakeTransport includes the certificate revocation check with OCSP in sequential. By default, the driver uses
  1083  // this transport object.
  1084  var SnowflakeTransport = &http.Transport{
  1085  	TLSClientConfig: &tls.Config{
  1086  		RootCAs:               certPool,
  1087  		VerifyPeerCertificate: verifyPeerCertificateSerial,
  1088  	},
  1089  	MaxIdleConns:    10,
  1090  	IdleConnTimeout: 30 * time.Minute,
  1091  	Proxy:           http.ProxyFromEnvironment,
  1092  	DialContext: (&net.Dialer{
  1093  		Timeout:   30 * time.Second,
  1094  		KeepAlive: 30 * time.Second,
  1095  	}).DialContext,
  1096  }
  1097  
  1098  // SnowflakeTransportTest includes the certificate revocation check in parallel
  1099  var SnowflakeTransportTest = SnowflakeTransport