github.com/snowflakedb/gosnowflake@v1.9.0/ocsp.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bufio" 7 "context" 8 "crypto" 9 "crypto/tls" 10 "crypto/x509" 11 "crypto/x509/pkix" 12 "encoding/asn1" 13 "encoding/base64" 14 "encoding/json" 15 "encoding/pem" 16 "errors" 17 "fmt" 18 "io" 19 "math/big" 20 "net" 21 "net/http" 22 "net/url" 23 "os" 24 "path/filepath" 25 "runtime" 26 "strconv" 27 "strings" 28 "sync" 29 "sync/atomic" 30 "time" 31 32 "golang.org/x/crypto/ocsp" 33 ) 34 35 var ( 36 // caRoot includes the CA certificates. 37 caRoot map[string]*x509.Certificate 38 // certPOol includes the CA certificates. 39 certPool *x509.CertPool 40 // cacheDir is the location of OCSP response cache file 41 cacheDir = "" 42 // cacheFileName is the file name of OCSP response cache file 43 cacheFileName = "" 44 // cacheUpdated is true if the memory cache is updated 45 cacheUpdated = true 46 ) 47 48 // OCSPFailOpenMode is OCSP fail open mode. OCSPFailOpenTrue by default and may 49 // set to ocspModeFailClosed for fail closed mode 50 type OCSPFailOpenMode uint32 51 52 const ( 53 ocspFailOpenNotSet OCSPFailOpenMode = iota 54 // OCSPFailOpenTrue represents OCSP fail open mode. 55 OCSPFailOpenTrue 56 // OCSPFailOpenFalse represents OCSP fail closed mode. 57 OCSPFailOpenFalse 58 ) 59 const ( 60 ocspModeFailOpen = "FAIL_OPEN" 61 ocspModeFailClosed = "FAIL_CLOSED" 62 ocspModeInsecure = "INSECURE" 63 ) 64 65 // OCSP fail open mode 66 var ocspFailOpen = OCSPFailOpenTrue 67 68 const ( 69 // defaultOCSPCacheServerTimeout is the total timeout for OCSP cache server. 70 defaultOCSPCacheServerTimeout = 5 * time.Second 71 72 // defaultOCSPResponderTimeout is the total timeout for OCSP responder. 73 defaultOCSPResponderTimeout = 10 * time.Second 74 ) 75 76 const ( 77 cacheFileBaseName = "ocsp_response_cache.json" 78 // cacheExpire specifies cache data expiration time in seconds. 79 cacheExpire = float64(24 * 60 * 60) 80 cacheServerURL = "http://ocsp.snowflakecomputing.com" 81 cacheServerEnabledEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED" 82 cacheServerURLEnv = "SF_OCSP_RESPONSE_CACHE_SERVER_URL" 83 cacheDirEnv = "SF_OCSP_RESPONSE_CACHE_DIR" 84 ocspRetryURLEnv = "SF_OCSP_RESPONSE_RETRY_URL" 85 ) 86 87 const ( 88 ocspTestInjectValidityErrorEnv = "SF_OCSP_TEST_INJECT_VALIDITY_ERROR" 89 ocspTestInjectUnknownStatusEnv = "SF_OCSP_TEST_INJECT_UNKNOWN_STATUS" 90 ocspTestResponseCacheServerTimeoutEnv = "SF_OCSP_TEST_OCSP_RESPONSE_CACHE_SERVER_TIMEOUT" 91 ocspTestResponderTimeoutEnv = "SF_OCSP_TEST_OCSP_RESPONDER_TIMEOUT" 92 ocspTestResponderURLEnv = "SF_OCSP_TEST_RESPONDER_URL" 93 ocspTestNoOCSPURLEnv = "SF_OCSP_TEST_NO_OCSP_RESPONDER_URL" 94 ) 95 96 const ( 97 tolerableValidityRatio = 100 // buffer for certificate revocation update time 98 maxClockSkew = 900 * time.Second // buffer for clock skew 99 ) 100 101 type ocspStatusCode int 102 103 type ocspStatus struct { 104 code ocspStatusCode 105 err error 106 } 107 108 const ( 109 ocspSuccess ocspStatusCode = 0 110 ocspStatusGood ocspStatusCode = -1 111 ocspStatusRevoked ocspStatusCode = -2 112 ocspStatusUnknown ocspStatusCode = -3 113 ocspStatusOthers ocspStatusCode = -4 114 ocspNoServer ocspStatusCode = -5 115 ocspFailedParseOCSPHost ocspStatusCode = -6 116 ocspFailedComposeRequest ocspStatusCode = -7 117 ocspFailedDecomposeRequest ocspStatusCode = -8 118 ocspFailedSubmit ocspStatusCode = -9 119 ocspFailedResponse ocspStatusCode = -10 120 ocspFailedExtractResponse ocspStatusCode = -11 121 ocspFailedParseResponse ocspStatusCode = -12 122 ocspInvalidValidity ocspStatusCode = -13 123 ocspMissedCache ocspStatusCode = -14 124 ocspCacheExpired ocspStatusCode = -15 125 ocspFailedDecodeResponse ocspStatusCode = -16 126 ) 127 128 // copied from crypto/ocsp.go 129 type certID struct { 130 HashAlgorithm pkix.AlgorithmIdentifier 131 NameHash []byte 132 IssuerKeyHash []byte 133 SerialNumber *big.Int 134 } 135 136 // cache key 137 type certIDKey struct { 138 HashAlgorithm crypto.Hash 139 NameHash string 140 IssuerKeyHash string 141 SerialNumber string 142 } 143 144 type certCacheValue struct { 145 ts float64 146 ocspRespBase64 string 147 } 148 149 type parsedOcspRespKey struct { 150 ocspRespBase64 string 151 certIDBase64 string 152 } 153 154 var ( 155 ocspResponseCache map[certIDKey]*certCacheValue 156 ocspParsedRespCache map[parsedOcspRespKey]*ocspStatus 157 ocspResponseCacheLock *sync.RWMutex 158 ocspParsedRespCacheLock *sync.Mutex 159 ) 160 161 // copied from crypto/ocsp 162 var hashOIDs = map[crypto.Hash]asn1.ObjectIdentifier{ 163 crypto.SHA1: asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26}), 164 crypto.SHA256: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 1}), 165 crypto.SHA384: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 2}), 166 crypto.SHA512: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 3}), 167 } 168 169 // copied from crypto/ocsp 170 func getOIDFromHashAlgorithm(target crypto.Hash) asn1.ObjectIdentifier { 171 for hash, oid := range hashOIDs { 172 if hash == target { 173 return oid 174 } 175 } 176 logger.Errorf("no valid OID is found for the hash algorithm. %#v", target) 177 return nil 178 } 179 180 func getHashAlgorithmFromOID(target pkix.AlgorithmIdentifier) crypto.Hash { 181 for hash, oid := range hashOIDs { 182 if oid.Equal(target.Algorithm) { 183 return hash 184 } 185 } 186 logger.Errorf("no valid hash algorithm is found for the oid. Falling back to SHA1: %#v", target) 187 return crypto.SHA1 188 } 189 190 // calcTolerableValidity returns the maximum validity buffer 191 func calcTolerableValidity(thisUpdate, nextUpdate time.Time) time.Duration { 192 return durationMax(time.Duration(nextUpdate.Sub(thisUpdate)/tolerableValidityRatio), maxClockSkew) 193 } 194 195 // isInValidityRange checks the validity 196 func isInValidityRange(currTime, thisUpdate, nextUpdate time.Time) bool { 197 if currTime.Sub(thisUpdate.Add(-maxClockSkew)) < 0 { 198 return false 199 } 200 if nextUpdate.Add(calcTolerableValidity(thisUpdate, nextUpdate)).Sub(currTime) < 0 { 201 return false 202 } 203 return true 204 } 205 206 func isTestInvalidValidity() bool { 207 return strings.EqualFold(os.Getenv(ocspTestInjectValidityErrorEnv), "true") 208 } 209 210 func extractCertIDKeyFromRequest(ocspReq []byte) (*certIDKey, *ocspStatus) { 211 r, err := ocsp.ParseRequest(ocspReq) 212 if err != nil { 213 return nil, &ocspStatus{ 214 code: ocspFailedDecomposeRequest, 215 err: err, 216 } 217 } 218 219 // encode CertID, used as a key in the cache 220 encodedCertID := &certIDKey{ 221 r.HashAlgorithm, 222 base64.StdEncoding.EncodeToString(r.IssuerNameHash), 223 base64.StdEncoding.EncodeToString(r.IssuerKeyHash), 224 r.SerialNumber.String(), 225 } 226 return encodedCertID, &ocspStatus{ 227 code: ocspSuccess, 228 } 229 } 230 231 func decodeCertIDKey(certIDKeyBase64 string) *certIDKey { 232 r, err := base64.StdEncoding.DecodeString(certIDKeyBase64) 233 if err != nil { 234 return nil 235 } 236 var c certID 237 rest, err := asn1.Unmarshal(r, &c) 238 if err != nil { 239 // error in parsing 240 return nil 241 } 242 if len(rest) > 0 { 243 // extra bytes to the end 244 return nil 245 } 246 return &certIDKey{ 247 getHashAlgorithmFromOID(c.HashAlgorithm), 248 base64.StdEncoding.EncodeToString(c.NameHash), 249 base64.StdEncoding.EncodeToString(c.IssuerKeyHash), 250 c.SerialNumber.String(), 251 } 252 } 253 254 func encodeCertIDKey(k *certIDKey) string { 255 serialNumber := new(big.Int) 256 serialNumber.SetString(k.SerialNumber, 10) 257 nameHash, err := base64.StdEncoding.DecodeString(k.NameHash) 258 if err != nil { 259 return "" 260 } 261 issuerKeyHash, err := base64.StdEncoding.DecodeString(k.IssuerKeyHash) 262 if err != nil { 263 return "" 264 } 265 encodedCertID, err := asn1.Marshal(certID{ 266 pkix.AlgorithmIdentifier{ 267 Algorithm: getOIDFromHashAlgorithm(k.HashAlgorithm), 268 Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, 269 }, 270 nameHash, 271 issuerKeyHash, 272 serialNumber, 273 }) 274 if err != nil { 275 return "" 276 } 277 return base64.StdEncoding.EncodeToString(encodedCertID) 278 } 279 280 func checkOCSPResponseCache(certIDKey *certIDKey, subject, issuer *x509.Certificate) *ocspStatus { 281 if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { 282 return &ocspStatus{code: ocspNoServer} 283 } 284 285 gotValueFromCache, ok := func() (*certCacheValue, bool) { 286 ocspResponseCacheLock.RLock() 287 defer ocspResponseCacheLock.RUnlock() 288 valueFromCache, ok := ocspResponseCache[*certIDKey] 289 return valueFromCache, ok 290 }() 291 if !ok { 292 return &ocspStatus{ 293 code: ocspMissedCache, 294 err: fmt.Errorf("miss cache data. subject: %v", subject), 295 } 296 } 297 298 status := extractOCSPCacheResponseValue(certIDKey, gotValueFromCache, subject, issuer) 299 if !isValidOCSPStatus(status.code) { 300 deleteOCSPCache(certIDKey) 301 } 302 return status 303 } 304 305 func deleteOCSPCache(encodedCertID *certIDKey) { 306 ocspResponseCacheLock.Lock() 307 defer ocspResponseCacheLock.Unlock() 308 delete(ocspResponseCache, *encodedCertID) 309 cacheUpdated = true 310 } 311 312 func validateOCSP(ocspRes *ocsp.Response) *ocspStatus { 313 curTime := time.Now() 314 315 if ocspRes == nil { 316 return &ocspStatus{ 317 code: ocspFailedDecomposeRequest, 318 err: errors.New("OCSP Response is nil"), 319 } 320 } 321 if isTestInvalidValidity() || !isInValidityRange(curTime, ocspRes.ThisUpdate, ocspRes.NextUpdate) { 322 return &ocspStatus{ 323 code: ocspInvalidValidity, 324 err: &SnowflakeError{ 325 Number: ErrOCSPInvalidValidity, 326 Message: errMsgOCSPInvalidValidity, 327 MessageArgs: []interface{}{ocspRes.ProducedAt, ocspRes.ThisUpdate, ocspRes.NextUpdate}, 328 }, 329 } 330 } 331 if isTestUnknownStatus() { 332 ocspRes.Status = ocsp.Unknown 333 } 334 return returnOCSPStatus(ocspRes) 335 } 336 337 func returnOCSPStatus(ocspRes *ocsp.Response) *ocspStatus { 338 switch ocspRes.Status { 339 case ocsp.Good: 340 return &ocspStatus{ 341 code: ocspStatusGood, 342 err: nil, 343 } 344 case ocsp.Revoked: 345 return &ocspStatus{ 346 code: ocspStatusRevoked, 347 err: &SnowflakeError{ 348 Number: ErrOCSPStatusRevoked, 349 Message: errMsgOCSPStatusRevoked, 350 MessageArgs: []interface{}{ocspRes.RevocationReason, ocspRes.RevokedAt}, 351 }, 352 } 353 case ocsp.Unknown: 354 return &ocspStatus{ 355 code: ocspStatusUnknown, 356 err: &SnowflakeError{ 357 Number: ErrOCSPStatusUnknown, 358 Message: errMsgOCSPStatusUnknown, 359 }, 360 } 361 default: 362 return &ocspStatus{ 363 code: ocspStatusOthers, 364 err: fmt.Errorf("OCSP others. %v", ocspRes.Status), 365 } 366 } 367 } 368 369 func isTestUnknownStatus() bool { 370 return strings.EqualFold(os.Getenv(ocspTestInjectUnknownStatusEnv), "true") 371 } 372 373 func checkOCSPCacheServer( 374 ctx context.Context, 375 client clientInterface, 376 req requestFunc, 377 ocspServerHost *url.URL, 378 totalTimeout time.Duration) ( 379 cacheContent *map[string]*certCacheValue, 380 ocspS *ocspStatus) { 381 var respd map[string][]interface{} 382 headers := make(map[string]string) 383 res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultMaxRetryCount, defaultTimeProvider, nil).execute() 384 if err != nil { 385 logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err) 386 return nil, &ocspStatus{ 387 code: ocspFailedSubmit, 388 err: err, 389 } 390 } 391 defer res.Body.Close() 392 logger.Debugf("StatusCode from OCSP Cache Server: %v", res.StatusCode) 393 if res.StatusCode != http.StatusOK { 394 return nil, &ocspStatus{ 395 code: ocspFailedResponse, 396 err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), 397 } 398 } 399 logger.Debugf("reading contents") 400 401 dec := json.NewDecoder(res.Body) 402 for { 403 if err := dec.Decode(&respd); err == io.EOF { 404 break 405 } else if err != nil { 406 logger.Errorf("failed to decode OCSP cache. %v", err) 407 return nil, &ocspStatus{ 408 code: ocspFailedExtractResponse, 409 err: err, 410 } 411 } 412 } 413 buf := make(map[string]*certCacheValue) 414 for key, value := range respd { 415 ok, ts, ocspRespBase64 := extractTsAndOcspRespBase64(value) 416 if !ok { 417 continue 418 } 419 buf[key] = &certCacheValue{ts, ocspRespBase64} 420 } 421 return &buf, &ocspStatus{ 422 code: ocspSuccess, 423 } 424 } 425 426 // retryOCSP is the second level of retry method if the returned contents are corrupted. It often happens with OCSP 427 // serer and retry helps. 428 func retryOCSP( 429 ctx context.Context, 430 client clientInterface, 431 req requestFunc, 432 ocspHost *url.URL, 433 headers map[string]string, 434 reqBody []byte, 435 issuer *x509.Certificate, 436 totalTimeout time.Duration) ( 437 ocspRes *ocsp.Response, 438 ocspResBytes []byte, 439 ocspS *ocspStatus) { 440 multiplier := 1 441 if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { 442 multiplier = 3 // up to 3 times for Fail Close mode 443 } 444 res, err := newRetryHTTP( 445 ctx, client, req, ocspHost, headers, 446 totalTimeout*time.Duration(multiplier), defaultMaxRetryCount, defaultTimeProvider, nil).doPost().setBody(reqBody).execute() 447 if err != nil { 448 return ocspRes, ocspResBytes, &ocspStatus{ 449 code: ocspFailedSubmit, 450 err: err, 451 } 452 } 453 defer res.Body.Close() 454 logger.Debugf("StatusCode from OCSP Server: %v\n", res.StatusCode) 455 if res.StatusCode != http.StatusOK { 456 return ocspRes, ocspResBytes, &ocspStatus{ 457 code: ocspFailedResponse, 458 err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), 459 } 460 } 461 ocspResBytes, err = io.ReadAll(res.Body) 462 if err != nil { 463 return ocspRes, ocspResBytes, &ocspStatus{ 464 code: ocspFailedExtractResponse, 465 err: err, 466 } 467 } 468 ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) 469 if err != nil { 470 logger.Warnf("error when parsing ocsp response: %v", err) 471 logger.Warnf("performing GET fallback request to OCSP") 472 return fallbackRetryOCSPToGETRequest(ctx, client, req, ocspHost, headers, issuer, totalTimeout) 473 } 474 475 logger.Debugf("OCSP Status from server: %v", printStatus(ocspRes)) 476 return ocspRes, ocspResBytes, &ocspStatus{ 477 code: ocspSuccess, 478 } 479 } 480 481 // fallbackRetryOCSPToGETRequest is the third level of retry method. Some OCSP responders do not support POST requests 482 // and will return with a "malformed" request error. In that case we also try to perform a GET request 483 func fallbackRetryOCSPToGETRequest( 484 ctx context.Context, 485 client clientInterface, 486 req requestFunc, 487 ocspHost *url.URL, 488 headers map[string]string, 489 issuer *x509.Certificate, 490 totalTimeout time.Duration) ( 491 ocspRes *ocsp.Response, 492 ocspResBytes []byte, 493 ocspS *ocspStatus) { 494 multiplier := 1 495 if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { 496 multiplier = 3 // up to 3 times for Fail Close mode 497 } 498 res, err := newRetryHTTP(ctx, client, req, ocspHost, headers, 499 totalTimeout*time.Duration(multiplier), defaultMaxRetryCount, defaultTimeProvider, nil).execute() 500 if err != nil { 501 return ocspRes, ocspResBytes, &ocspStatus{ 502 code: ocspFailedSubmit, 503 err: err, 504 } 505 } 506 defer res.Body.Close() 507 logger.Debugf("GET fallback StatusCode from OCSP Server: %v", res.StatusCode) 508 if res.StatusCode != http.StatusOK { 509 return ocspRes, ocspResBytes, &ocspStatus{ 510 code: ocspFailedResponse, 511 err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), 512 } 513 } 514 ocspResBytes, err = io.ReadAll(res.Body) 515 if err != nil { 516 return ocspRes, ocspResBytes, &ocspStatus{ 517 code: ocspFailedExtractResponse, 518 err: err, 519 } 520 } 521 ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) 522 if err != nil { 523 return ocspRes, ocspResBytes, &ocspStatus{ 524 code: ocspFailedParseResponse, 525 err: err, 526 } 527 } 528 529 logger.Debugf("GET fallback OCSP Status from server: %v", printStatus(ocspRes)) 530 return ocspRes, ocspResBytes, &ocspStatus{ 531 code: ocspSuccess, 532 } 533 } 534 535 func printStatus(response *ocsp.Response) string { 536 switch response.Status { 537 case ocsp.Good: 538 return "Good" 539 case ocsp.Revoked: 540 return "Revoked" 541 case ocsp.Unknown: 542 return "Unknown" 543 default: 544 return fmt.Sprintf("%d", response.Status) 545 } 546 } 547 548 func fullOCSPURL(url *url.URL) string { 549 fullURL := url.Hostname() 550 if url.Path != "" { 551 if !strings.HasPrefix(url.Path, "/") { 552 fullURL += "/" 553 } 554 fullURL += url.Path 555 } 556 return fullURL 557 } 558 559 // getRevocationStatus checks the certificate revocation status for subject using issuer certificate. 560 func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) *ocspStatus { 561 logger.Infof("Subject: %v, Issuer: %v", subject.Subject, issuer.Subject) 562 563 status, ocspReq, encodedCertID := validateWithCache(subject, issuer) 564 if isValidOCSPStatus(status.code) { 565 return status 566 } 567 if ocspReq == nil || encodedCertID == nil { 568 return status 569 } 570 logger.Infof("cache missed") 571 logger.Infof("OCSP Server: %v", subject.OCSPServer) 572 if len(subject.OCSPServer) == 0 || isTestNoOCSPURL() { 573 return &ocspStatus{ 574 code: ocspNoServer, 575 err: &SnowflakeError{ 576 Number: ErrOCSPNoOCSPResponderURL, 577 Message: errMsgOCSPNoOCSPResponderURL, 578 MessageArgs: []interface{}{subject.Subject}, 579 }, 580 } 581 } 582 ocspHost := subject.OCSPServer[0] 583 u, err := url.Parse(ocspHost) 584 if err != nil { 585 return &ocspStatus{ 586 code: ocspFailedParseOCSPHost, 587 err: fmt.Errorf("failed to parse OCSP server host. %v", ocspHost), 588 } 589 } 590 hostnameStr := os.Getenv(ocspTestResponderURLEnv) 591 var hostname string 592 if retryURL := os.Getenv(ocspRetryURLEnv); retryURL != "" { 593 hostname = fmt.Sprintf(retryURL, fullOCSPURL(u), base64.StdEncoding.EncodeToString(ocspReq)) 594 u0, err := url.Parse(hostname) 595 if err == nil { 596 hostname = u0.Hostname() 597 u = u0 598 } 599 } else { 600 hostname = fullOCSPURL(u) 601 } 602 if hostnameStr != "" { 603 u0, err := url.Parse(hostnameStr) 604 if err == nil { 605 hostname = u0.Hostname() 606 u = u0 607 } 608 } 609 610 logger.Debugf("Fetching OCSP response from server: %v", u) 611 logger.Debugf("Host in headers: %v", hostname) 612 613 headers := make(map[string]string) 614 headers[httpHeaderContentType] = "application/ocsp-request" 615 headers[httpHeaderAccept] = "application/ocsp-response" 616 headers[httpHeaderContentLength] = strconv.Itoa(len(ocspReq)) 617 headers[httpHeaderHost] = hostname 618 timeoutStr := os.Getenv(ocspTestResponderTimeoutEnv) 619 timeout := defaultOCSPResponderTimeout 620 if timeoutStr != "" { 621 var timeoutInMilliseconds int 622 timeoutInMilliseconds, err = strconv.Atoi(timeoutStr) 623 if err == nil { 624 timeout = time.Duration(timeoutInMilliseconds) * time.Millisecond 625 } 626 } 627 ocspClient := &http.Client{ 628 Timeout: timeout, 629 Transport: snowflakeInsecureTransport, 630 } 631 ocspRes, ocspResBytes, ocspS := retryOCSP( 632 ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer, timeout) 633 if ocspS.code != ocspSuccess { 634 return ocspS 635 } 636 637 ret := validateOCSP(ocspRes) 638 if !isValidOCSPStatus(ret.code) { 639 return ret // return invalid 640 } 641 v := &certCacheValue{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)} 642 ocspResponseCacheLock.Lock() 643 ocspResponseCache[*encodedCertID] = v 644 cacheUpdated = true 645 ocspResponseCacheLock.Unlock() 646 return ret 647 } 648 649 func isTestNoOCSPURL() bool { 650 return strings.EqualFold(os.Getenv(ocspTestNoOCSPURLEnv), "true") 651 } 652 653 func isValidOCSPStatus(status ocspStatusCode) bool { 654 return status == ocspStatusGood || status == ocspStatusRevoked || status == ocspStatusUnknown 655 } 656 657 // verifyPeerCertificate verifies all of certificate revocation status 658 func verifyPeerCertificate(ctx context.Context, verifiedChains [][]*x509.Certificate) (err error) { 659 for i := 0; i < len(verifiedChains); i++ { 660 // Certificate signed by Root CA. This should be one before the last in the Certificate Chain 661 numberOfNoneRootCerts := len(verifiedChains[i]) - 1 662 if !verifiedChains[i][numberOfNoneRootCerts].IsCA || string(verifiedChains[i][numberOfNoneRootCerts].RawIssuer) != string(verifiedChains[i][numberOfNoneRootCerts].RawSubject) { 663 // Check if the last Non Root Cert is also a CA or is self signed. 664 // if the last certificate is not, add it to the list 665 rca := caRoot[string(verifiedChains[i][numberOfNoneRootCerts].RawIssuer)] 666 if rca == nil { 667 return fmt.Errorf("failed to find root CA. pkix.name: %v", verifiedChains[i][numberOfNoneRootCerts].Issuer) 668 } 669 verifiedChains[i] = append(verifiedChains[i], rca) 670 numberOfNoneRootCerts++ 671 } 672 results := getAllRevocationStatus(ctx, verifiedChains[i]) 673 if r := canEarlyExitForOCSP(results, numberOfNoneRootCerts); r != nil { 674 return r.err 675 } 676 } 677 678 ocspResponseCacheLock.Lock() 679 if cacheUpdated { 680 writeOCSPCacheFile() 681 } 682 cacheUpdated = false 683 ocspResponseCacheLock.Unlock() 684 return nil 685 } 686 687 func canEarlyExitForOCSP(results []*ocspStatus, chainSize int) *ocspStatus { 688 msg := "" 689 if atomic.LoadUint32((*uint32)(&ocspFailOpen)) == (uint32)(OCSPFailOpenFalse) { 690 // Fail closed. any error is returned to stop connection 691 for _, r := range results { 692 if r.err != nil { 693 return r 694 } 695 } 696 } else { 697 // Fail open and all results are valid. 698 allValid := len(results) == chainSize 699 for _, r := range results { 700 if !isValidOCSPStatus(r.code) { 701 allValid = false 702 break 703 } 704 } 705 for _, r := range results { 706 if allValid && r.code == ocspStatusRevoked { 707 return r 708 } 709 if r != nil && r.code != ocspStatusGood && r.err != nil { 710 msg += "\n" + r.err.Error() 711 } 712 } 713 } 714 if len(msg) > 0 { 715 logger.Warnf( 716 "WARNING!!! Using fail-open to connect. Driver is connecting to an "+ 717 "HTTPS endpoint without OCSP based Certificate Revocation checking "+ 718 "as it could not obtain a valid OCSP Response to use from the CA OCSP "+ 719 "responder. Detail: %v", msg[1:]) 720 } 721 return nil 722 } 723 724 func validateWithCacheForAllCertificates(verifiedChains []*x509.Certificate) bool { 725 n := len(verifiedChains) - 1 726 for j := 0; j < n; j++ { 727 subject := verifiedChains[j] 728 issuer := verifiedChains[j+1] 729 status, _, _ := validateWithCache(subject, issuer) 730 if !isValidOCSPStatus(status.code) { 731 return false 732 } 733 } 734 return true 735 } 736 737 func validateWithCache(subject, issuer *x509.Certificate) (*ocspStatus, []byte, *certIDKey) { 738 ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) 739 if err != nil { 740 logger.Errorf("failed to create OCSP request from the certificates.\n") 741 return &ocspStatus{ 742 code: ocspFailedComposeRequest, 743 err: errors.New("failed to create a OCSP request"), 744 }, nil, nil 745 } 746 encodedCertID, ocspS := extractCertIDKeyFromRequest(ocspReq) 747 if ocspS.code != ocspSuccess { 748 logger.Errorf("failed to extract CertID from OCSP Request.\n") 749 return &ocspStatus{ 750 code: ocspFailedComposeRequest, 751 err: errors.New("failed to extract cert ID Key"), 752 }, ocspReq, nil 753 } 754 status := checkOCSPResponseCache(encodedCertID, subject, issuer) 755 return status, ocspReq, encodedCertID 756 } 757 758 func downloadOCSPCacheServer() { 759 if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { 760 return 761 } 762 ocspCacheServerURL := os.Getenv(cacheServerURLEnv) 763 if ocspCacheServerURL == "" { 764 ocspCacheServerURL = fmt.Sprintf("%v/%v", cacheServerURL, cacheFileBaseName) 765 } 766 u, err := url.Parse(ocspCacheServerURL) 767 if err != nil { 768 return 769 } 770 logger.Infof("downloading OCSP Cache from server %v", ocspCacheServerURL) 771 timeoutStr := os.Getenv(ocspTestResponseCacheServerTimeoutEnv) 772 timeout := defaultOCSPCacheServerTimeout 773 if timeoutStr != "" { 774 var timeoutInMilliseconds int 775 timeoutInMilliseconds, err = strconv.Atoi(timeoutStr) 776 if err == nil { 777 timeout = time.Duration(timeoutInMilliseconds) * time.Millisecond 778 } 779 } 780 ocspClient := &http.Client{ 781 Timeout: timeout, 782 Transport: snowflakeInsecureTransport, 783 } 784 ret, ocspStatus := checkOCSPCacheServer(context.Background(), ocspClient, http.NewRequest, u, timeout) 785 if ocspStatus.code != ocspSuccess { 786 return 787 } 788 789 ocspResponseCacheLock.Lock() 790 for k, cacheValue := range *ret { 791 cacheKey := decodeCertIDKey(k) 792 status := extractOCSPCacheResponseValueWithoutSubject(cacheKey, cacheValue) 793 if !isValidOCSPStatus(status.code) { 794 continue 795 } 796 ocspResponseCache[*cacheKey] = cacheValue 797 } 798 cacheUpdated = true 799 ocspResponseCacheLock.Unlock() 800 } 801 802 func getAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certificate) []*ocspStatus { 803 cached := validateWithCacheForAllCertificates(verifiedChains) 804 if !cached { 805 downloadOCSPCacheServer() 806 } 807 n := len(verifiedChains) - 1 808 results := make([]*ocspStatus, n) 809 for j := 0; j < n; j++ { 810 results[j] = getRevocationStatus(ctx, verifiedChains[j], verifiedChains[j+1]) 811 if !isValidOCSPStatus(results[j].code) { 812 return results 813 } 814 } 815 return results 816 } 817 818 // verifyPeerCertificateSerial verifies the certificate revocation status in serial. 819 func verifyPeerCertificateSerial(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { 820 overrideCacheDir() 821 return verifyPeerCertificate(context.Background(), verifiedChains) 822 } 823 824 func overrideCacheDir() { 825 if os.Getenv(cacheDirEnv) != "" { 826 ocspResponseCacheLock.Lock() 827 defer ocspResponseCacheLock.Unlock() 828 createOCSPCacheDir() 829 } 830 } 831 832 // initOCSPCache initializes OCSP Response cache file. 833 func initOCSPCache() { 834 if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { 835 return 836 } 837 ocspResponseCache = make(map[certIDKey]*certCacheValue) 838 ocspParsedRespCache = make(map[parsedOcspRespKey]*ocspStatus) 839 ocspResponseCacheLock = &sync.RWMutex{} 840 ocspParsedRespCacheLock = &sync.Mutex{} 841 842 logger.Infof("reading OCSP Response cache file. %v\n", cacheFileName) 843 f, err := os.OpenFile(cacheFileName, os.O_CREATE|os.O_RDONLY, readWriteFileMode) 844 if err != nil { 845 logger.Debugf("failed to open. Ignored. %v\n", err) 846 return 847 } 848 defer f.Close() 849 850 buf := make(map[string][]interface{}) 851 r := bufio.NewReader(f) 852 dec := json.NewDecoder(r) 853 for { 854 if err = dec.Decode(&buf); err == io.EOF { 855 break 856 } else if err != nil { 857 logger.Debugf("failed to read. Ignored. %v\n", err) 858 return 859 } 860 } 861 862 for k, cacheValue := range buf { 863 ok, ts, ocspRespBase64 := extractTsAndOcspRespBase64(cacheValue) 864 if !ok { 865 continue 866 } 867 certValue := &certCacheValue{ts, ocspRespBase64} 868 cacheKey := decodeCertIDKey(k) 869 status := extractOCSPCacheResponseValueWithoutSubject(cacheKey, certValue) 870 if !isValidOCSPStatus(status.code) { 871 continue 872 } 873 ocspResponseCache[*cacheKey] = certValue 874 875 } 876 cacheUpdated = false 877 } 878 879 func extractTsAndOcspRespBase64(value []interface{}) (bool, float64, string) { 880 ts, ok := value[0].(float64) 881 if !ok { 882 logger.Warnf("cannot cast %v as float64", value[0]) 883 return false, -1, "" 884 } 885 ocspRespBase64, ok := value[1].(string) 886 if !ok { 887 logger.Warnf("cannot cast %v as string", value[1]) 888 return false, -1, "" 889 } 890 return true, ts, ocspRespBase64 891 } 892 893 func extractOCSPCacheResponseValueWithoutSubject(cacheKey *certIDKey, cacheValue *certCacheValue) *ocspStatus { 894 return extractOCSPCacheResponseValue(cacheKey, cacheValue, nil, nil) 895 } 896 897 func extractOCSPCacheResponseValue(certIDKey *certIDKey, certCacheValue *certCacheValue, subject, issuer *x509.Certificate) *ocspStatus { 898 subjectName := "Unknown" 899 if subject != nil { 900 subjectName = subject.Subject.CommonName 901 } 902 903 curTime := time.Now() 904 currentTime := float64(curTime.UTC().Unix()) 905 if currentTime-certCacheValue.ts >= cacheExpire { 906 return &ocspStatus{ 907 code: ocspCacheExpired, 908 err: fmt.Errorf("cache expired. current: %v, cache: %v", 909 time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(certCacheValue.ts), 0).UTC()), 910 } 911 } 912 913 ocspParsedRespCacheLock.Lock() 914 defer ocspParsedRespCacheLock.Unlock() 915 916 var cacheKey parsedOcspRespKey 917 if certIDKey != nil { 918 cacheKey = parsedOcspRespKey{certCacheValue.ocspRespBase64, encodeCertIDKey(certIDKey)} 919 } else { 920 cacheKey = parsedOcspRespKey{certCacheValue.ocspRespBase64, ""} 921 } 922 status, ok := ocspParsedRespCache[cacheKey] 923 if !ok { 924 logger.Debugf("OCSP status not found in cache; certIdKey: %v", certIDKey) 925 var err error 926 var b []byte 927 b, err = base64.StdEncoding.DecodeString(certCacheValue.ocspRespBase64) 928 if err != nil { 929 return &ocspStatus{ 930 code: ocspFailedDecodeResponse, 931 err: fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err), 932 } 933 } 934 // check the revocation status here 935 ocspResponse, err := ocsp.ParseResponse(b, issuer) 936 937 if err != nil { 938 logger.Warnf("the second cache element is not a valid OCSP Response. Ignored. subject: %v\n", subjectName) 939 return &ocspStatus{ 940 code: ocspFailedParseResponse, 941 err: fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err), 942 } 943 } 944 status = validateOCSP(ocspResponse) 945 ocspParsedRespCache[cacheKey] = status 946 } 947 logger.Debugf("OCSP status found in cache: %v; certIdKey: %v", status, certIDKey) 948 return status 949 } 950 951 // writeOCSPCacheFile writes a OCSP Response cache file. This is called if all revocation status is success. 952 // lock file is used to mitigate race condition with other process. 953 func writeOCSPCacheFile() { 954 if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { 955 return 956 } 957 logger.Infof("writing OCSP Response cache file. %v\n", cacheFileName) 958 cacheLockFileName := cacheFileName + ".lck" 959 err := os.Mkdir(cacheLockFileName, 0600) 960 switch { 961 case os.IsExist(err): 962 statinfo, err := os.Stat(cacheLockFileName) 963 if err != nil { 964 logger.Debugf("failed to get file info for cache lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) 965 return 966 } 967 if time.Since(statinfo.ModTime()) < 15*time.Minute { 968 logger.Debugf("other process locks the cache file. %v. ignored.\n", cacheLockFileName) 969 return 970 } 971 if err = os.Remove(cacheLockFileName); err != nil { 972 logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) 973 return 974 } 975 if err = os.Mkdir(cacheLockFileName, 0600); err != nil { 976 logger.Debugf("failed to create lock file. file: %v, err: %v. ignored.\n", cacheLockFileName, err) 977 return 978 } 979 } 980 // if mkdir fails for any other reason: permission denied, operation not permitted, I/O error, too many open files, etc. 981 if err != nil { 982 logger.Debugf("failed to create lock file. file %v, err: %v. ignored.\n", cacheLockFileName, err) 983 return 984 } 985 defer os.RemoveAll(cacheLockFileName) 986 987 buf := make(map[string][]interface{}) 988 for k, v := range ocspResponseCache { 989 cacheKeyInBase64 := encodeCertIDKey(&k) 990 buf[cacheKeyInBase64] = []interface{}{v.ts, v.ocspRespBase64} 991 } 992 993 j, err := json.Marshal(buf) 994 if err != nil { 995 logger.Debugf("failed to convert OCSP Response cache to JSON. ignored.") 996 return 997 } 998 if err = os.WriteFile(cacheFileName, j, 0644); err != nil { 999 logger.Debugf("failed to write OCSP Response cache. err: %v. ignored.\n", err) 1000 } 1001 } 1002 1003 // readCACerts read a set of root CAs 1004 func readCACerts() { 1005 raw := []byte(caRootPEM) 1006 certPool = x509.NewCertPool() 1007 caRoot = make(map[string]*x509.Certificate) 1008 var p *pem.Block 1009 for { 1010 p, raw = pem.Decode(raw) 1011 if p == nil { 1012 break 1013 } 1014 if p.Type != "CERTIFICATE" { 1015 continue 1016 } 1017 c, err := x509.ParseCertificate(p.Bytes) 1018 if err != nil { 1019 panic("failed to parse CA certificate.") 1020 } 1021 certPool.AddCert(c) 1022 caRoot[string(c.RawSubject)] = c 1023 } 1024 } 1025 1026 // createOCSPCacheDir creates OCSP response cache directory and set the cache file name. 1027 func createOCSPCacheDir() { 1028 if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { 1029 logger.Info(`OCSP Cache Server disabled. All further access and use of 1030 OCSP Cache will be disabled for this OCSP Status Query`) 1031 return 1032 } 1033 cacheDir = os.Getenv(cacheDirEnv) 1034 if cacheDir == "" { 1035 cacheDir = os.Getenv("SNOWFLAKE_TEST_WORKSPACE") 1036 } 1037 if cacheDir == "" { 1038 switch runtime.GOOS { 1039 case "windows": 1040 cacheDir = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local", "Snowflake", "Caches") 1041 case "darwin": 1042 home := os.Getenv("HOME") 1043 if home == "" { 1044 logger.Info("HOME is blank.") 1045 } 1046 cacheDir = filepath.Join(home, "Library", "Caches", "Snowflake") 1047 default: 1048 home := os.Getenv("HOME") 1049 if home == "" { 1050 logger.Info("HOME is blank") 1051 } 1052 cacheDir = filepath.Join(home, ".cache", "snowflake") 1053 } 1054 } 1055 1056 if _, err := os.Stat(cacheDir); os.IsNotExist(err) { 1057 if err = os.MkdirAll(cacheDir, os.ModePerm); err != nil { 1058 logger.Debugf("failed to create cache directory. %v, err: %v. ignored\n", cacheDir, err) 1059 } 1060 } 1061 cacheFileName = filepath.Join(cacheDir, cacheFileBaseName) 1062 logger.Infof("reset OCSP cache file. %v", cacheFileName) 1063 } 1064 1065 func init() { 1066 readCACerts() 1067 createOCSPCacheDir() 1068 initOCSPCache() 1069 } 1070 1071 // snowflakeInsecureTransport is the transport object that doesn't do certificate revocation check. 1072 var snowflakeInsecureTransport = &http.Transport{ 1073 MaxIdleConns: 10, 1074 IdleConnTimeout: 30 * time.Minute, 1075 Proxy: http.ProxyFromEnvironment, 1076 DialContext: (&net.Dialer{ 1077 Timeout: 30 * time.Second, 1078 KeepAlive: 30 * time.Second, 1079 }).DialContext, 1080 } 1081 1082 // SnowflakeTransport includes the certificate revocation check with OCSP in sequential. By default, the driver uses 1083 // this transport object. 1084 var SnowflakeTransport = &http.Transport{ 1085 TLSClientConfig: &tls.Config{ 1086 RootCAs: certPool, 1087 VerifyPeerCertificate: verifyPeerCertificateSerial, 1088 }, 1089 MaxIdleConns: 10, 1090 IdleConnTimeout: 30 * time.Minute, 1091 Proxy: http.ProxyFromEnvironment, 1092 DialContext: (&net.Dialer{ 1093 Timeout: 30 * time.Second, 1094 KeepAlive: 30 * time.Second, 1095 }).DialContext, 1096 } 1097 1098 // SnowflakeTransportTest includes the certificate revocation check in parallel 1099 var SnowflakeTransportTest = SnowflakeTransport