github.com/snowflakedb/gosnowflake@v1.9.0/ocsp_test.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "crypto" 9 "crypto/tls" 10 "crypto/x509" 11 "encoding/base64" 12 "errors" 13 "fmt" 14 "io" 15 "net" 16 "net/http" 17 "net/url" 18 "os" 19 "testing" 20 "time" 21 22 "golang.org/x/crypto/ocsp" 23 ) 24 25 func TestOCSP(t *testing.T) { 26 cacheServerEnabled := []string{ 27 "true", 28 "false", 29 } 30 targetURL := []string{ 31 "https://sfctest0.snowflakecomputing.com/", 32 "https://s3-us-west-2.amazonaws.com/sfc-snowsql-updates/?prefix=1.1/windows_x86_64", 33 "https://sfcdev2.blob.core.windows.net/", 34 } 35 36 transports := []*http.Transport{ 37 snowflakeInsecureTransport, 38 SnowflakeTransport, 39 } 40 41 for _, enabled := range cacheServerEnabled { 42 for _, tgt := range targetURL { 43 _ = os.Setenv(cacheServerEnabledEnv, enabled) 44 _ = os.Remove(cacheFileName) // clear cache file 45 ocspResponseCache = make(map[certIDKey]*certCacheValue) 46 for _, tr := range transports { 47 t.Run(fmt.Sprintf("%v_%v", tgt, enabled), func(t *testing.T) { 48 c := &http.Client{ 49 Transport: tr, 50 Timeout: 30 * time.Second, 51 } 52 req, err := http.NewRequest("GET", tgt, bytes.NewReader(nil)) 53 if err != nil { 54 t.Fatalf("fail to create a request. err: %v", err) 55 } 56 res, err := c.Do(req) 57 if err != nil { 58 t.Fatalf("failed to GET contents. err: %v", err) 59 } 60 defer res.Body.Close() 61 _, err = io.ReadAll(res.Body) 62 if err != nil { 63 t.Fatalf("failed to read content body for %v", tgt) 64 } 65 }) 66 } 67 } 68 } 69 _ = os.Unsetenv(cacheServerEnabledEnv) 70 } 71 72 type tcValidityRange struct { 73 thisTime time.Time 74 nextTime time.Time 75 ret bool 76 } 77 78 func TestUnitIsInValidityRange(t *testing.T) { 79 currentTime := time.Now() 80 testcases := []tcValidityRange{ 81 { 82 // basic tests 83 thisTime: currentTime.Add(-100 * time.Second), 84 nextTime: currentTime.Add(maxClockSkew), 85 ret: true, 86 }, 87 { 88 // on the border 89 thisTime: currentTime.Add(maxClockSkew), 90 nextTime: currentTime.Add(maxClockSkew), 91 ret: true, 92 }, 93 { 94 // 1 earlier late 95 thisTime: currentTime.Add(maxClockSkew + 1*time.Second), 96 nextTime: currentTime.Add(maxClockSkew), 97 ret: false, 98 }, 99 { 100 // on the border 101 thisTime: currentTime.Add(-maxClockSkew), 102 nextTime: currentTime.Add(-maxClockSkew), 103 ret: true, 104 }, 105 { 106 // around the border 107 thisTime: currentTime.Add(-24*time.Hour - 40*time.Second), 108 nextTime: currentTime.Add(-24*time.Hour/time.Duration(100) - 40*time.Second), 109 ret: false, 110 }, 111 { 112 // on the border 113 thisTime: currentTime.Add(-48*time.Hour - 29*time.Minute), 114 nextTime: currentTime.Add(-48 * time.Hour / time.Duration(100)), 115 ret: true, 116 }, 117 } 118 for _, tc := range testcases { 119 t.Run(fmt.Sprintf("%v_%v", tc.thisTime, tc.nextTime), func(t *testing.T) { 120 if tc.ret != isInValidityRange(currentTime, tc.thisTime, tc.nextTime) { 121 t.Fatalf("failed to check validity. should be: %v, currentTime: %v, thisTime: %v, nextTime: %v", tc.ret, currentTime, tc.thisTime, tc.nextTime) 122 } 123 }) 124 } 125 } 126 127 func TestUnitEncodeCertIDGood(t *testing.T) { 128 targetURLs := []string{ 129 "faketestaccount.snowflakecomputing.com:443", 130 "s3-us-west-2.amazonaws.com:443", 131 "sfcdev2.blob.core.windows.net:443", 132 } 133 for _, tt := range targetURLs { 134 t.Run(tt, func(t *testing.T) { 135 chainedCerts := getCert(tt) 136 for i := 0; i < len(chainedCerts)-1; i++ { 137 subject := chainedCerts[i] 138 issuer := chainedCerts[i+1] 139 ocspServers := subject.OCSPServer 140 if len(ocspServers) == 0 { 141 t.Fatalf("no OCSP server is found. cert: %v", subject.Subject) 142 } 143 ocspReq, err := ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{}) 144 if err != nil { 145 t.Fatalf("failed to create OCSP request. err: %v", err) 146 } 147 var ost *ocspStatus 148 _, ost = extractCertIDKeyFromRequest(ocspReq) 149 if ost.err != nil { 150 t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) 151 } 152 // better hash. Not sure if the actual OCSP server accepts this, though. 153 ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) 154 if err != nil { 155 t.Fatalf("failed to create OCSP request. err: %v", err) 156 } 157 _, ost = extractCertIDKeyFromRequest(ocspReq) 158 if ost.err != nil { 159 t.Fatalf("failed to extract cert ID from the OCSP request. err: %v", ost.err) 160 } 161 // tweaked request binary 162 ocspReq, err = ocsp.CreateRequest(subject, issuer, &ocsp.RequestOptions{Hash: crypto.SHA512}) 163 if err != nil { 164 t.Fatalf("failed to create OCSP request. err: %v", err) 165 } 166 ocspReq[10] = 0 // random change 167 _, ost = extractCertIDKeyFromRequest(ocspReq) 168 if ost.err == nil { 169 t.Fatal("should have failed") 170 } 171 } 172 }) 173 } 174 } 175 176 func TestUnitCheckOCSPResponseCache(t *testing.T) { 177 dummyKey0 := certIDKey{ 178 HashAlgorithm: crypto.SHA1, 179 NameHash: "dummy0", 180 IssuerKeyHash: "dummy0", 181 SerialNumber: "dummy0", 182 } 183 dummyKey := certIDKey{ 184 HashAlgorithm: crypto.SHA1, 185 NameHash: "dummy1", 186 IssuerKeyHash: "dummy1", 187 SerialNumber: "dummy1", 188 } 189 b64Key := base64.StdEncoding.EncodeToString([]byte("DUMMY_VALUE")) 190 currentTime := float64(time.Now().UTC().Unix()) 191 ocspResponseCache[dummyKey0] = &certCacheValue{currentTime, b64Key} 192 subject := &x509.Certificate{} 193 issuer := &x509.Certificate{} 194 ost := checkOCSPResponseCache(&dummyKey, subject, issuer) 195 if ost.code != ocspMissedCache { 196 t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code) 197 } 198 // old timestamp 199 ocspResponseCache[dummyKey] = &certCacheValue{float64(1395054952), b64Key} 200 ost = checkOCSPResponseCache(&dummyKey, subject, issuer) 201 if ost.code != ocspCacheExpired { 202 t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code) 203 } 204 // future timestamp 205 ocspResponseCache[dummyKey] = &certCacheValue{float64(1805054952), b64Key} 206 ost = checkOCSPResponseCache(&dummyKey, subject, issuer) 207 if ost.code != ocspFailedParseResponse { 208 t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code) 209 } 210 // actual OCSP but it fails to parse, because an invalid issuer certificate is given. 211 actualOcspResponse := "MIIB0woBAKCCAcwwggHIBgkrBgEFBQcwAQEEggG5MIIBtTCBnqIWBBSxPsNpA/i/RwHUmCYaCALvY2QrwxgPMjAxNz" + // pragma: allowlist secret 212 "A1MTYyMjAwMDBaMHMwcTBJMAkGBSsOAwIaBQAEFN+qEuMosQlBk+KfQoLOR0BClVijBBSxPsNpA/i/RwHUmCYaCALvY2QrwwIQBOHnp" + // pragma: allowlist secret 213 "Nxc8vNtwCtCuF0Vn4AAGA8yMDE3MDUxNjIyMDAwMFqgERgPMjAxNzA1MjMyMjAwMDBaMA0GCSqGSIb3DQEBCwUAA4IBAQCuRGwqQsKy" + // pragma: allowlist secret 214 "IAAGHgezTfG0PzMYgGD/XRDhU+2i08WTJ4Zs40Lu88cBeRXWF3iiJSpiX3/OLgfI7iXmHX9/sm2SmeNWc0Kb39bk5Lw1jwezf8hcI9+" + // pragma: allowlist secret 215 "mZHt60vhUgtgZk21SsRlTZ+S4VXwtDqB1Nhv6cnSnfrL2A9qJDZS2ltPNOwebWJnznDAs2dg+KxmT2yBXpHM1kb0EOolWvNgORbgIgB" + // pragma: allowlist secret 216 "koRzw/UU7zKsqiTB0ZN/rgJp+MocTdqQSGKvbZyR8d4u8eNQqi1x4Pk3yO/pftANFaJKGB+JPgKS3PQAqJaXcipNcEfqtl7y4PO6kqA" + // pragma: allowlist secret 217 "Jb4xI/OTXIrRA5TsT4cCioE" 218 // issuer is not a true issuer certificate 219 ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse} 220 ost = checkOCSPResponseCache(&dummyKey, subject, issuer) 221 if ost.code != ocspFailedParseResponse { 222 t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedParseResponse, ost.code) 223 } 224 // invalid validity 225 ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse} 226 ost = checkOCSPResponseCache(&dummyKey, subject, nil) 227 if ost.code != ocspInvalidValidity { 228 t.Fatalf("should have failed. expected: %v, got: %v", ocspInvalidValidity, ost.code) 229 } 230 } 231 232 func TestUnitValidateOCSP(t *testing.T) { 233 ocspRes := &ocsp.Response{} 234 ost := validateOCSP(ocspRes) 235 if ost.code != ocspInvalidValidity { 236 t.Fatalf("should have failed. expected: %v, got: %v", ocspInvalidValidity, ost.code) 237 } 238 currentTime := time.Now() 239 ocspRes.ThisUpdate = currentTime.Add(-2 * time.Hour) 240 ocspRes.NextUpdate = currentTime.Add(2 * time.Hour) 241 ocspRes.Status = ocsp.Revoked 242 ost = validateOCSP(ocspRes) 243 if ost.code != ocspStatusRevoked { 244 t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusRevoked, ost.code) 245 } 246 ocspRes.Status = ocsp.Good 247 ost = validateOCSP(ocspRes) 248 if ost.code != ocspStatusGood { 249 t.Fatalf("should have success. expected: %v, got: %v", ocspStatusGood, ost.code) 250 } 251 ocspRes.Status = ocsp.Unknown 252 ost = validateOCSP(ocspRes) 253 if ost.code != ocspStatusUnknown { 254 t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusUnknown, ost.code) 255 } 256 ocspRes.Status = ocsp.ServerFailed 257 ost = validateOCSP(ocspRes) 258 if ost.code != ocspStatusOthers { 259 t.Fatalf("should have failed. expected: %v, got: %v", ocspStatusOthers, ost.code) 260 } 261 } 262 263 func TestUnitEncodeCertID(t *testing.T) { 264 var st *ocspStatus 265 _, st = extractCertIDKeyFromRequest([]byte{0x1, 0x2}) 266 if st.code != ocspFailedDecomposeRequest { 267 t.Fatalf("failed to get OCSP status. expected: %v, got: %v", ocspFailedDecomposeRequest, st.code) 268 } 269 } 270 271 func getCert(addr string) []*x509.Certificate { 272 tcpConn, err := net.DialTimeout("tcp", addr, 40*time.Second) 273 if err != nil { 274 panic(err) 275 } 276 defer tcpConn.Close() 277 278 err = tcpConn.SetDeadline(time.Now().Add(10 * time.Second)) 279 if err != nil { 280 panic(err) 281 } 282 config := tls.Config{InsecureSkipVerify: true, ServerName: addr} 283 284 conn := tls.Client(tcpConn, &config) 285 defer conn.Close() 286 287 err = conn.Handshake() 288 if err != nil { 289 panic(err) 290 } 291 292 state := conn.ConnectionState() 293 294 return state.PeerCertificates 295 } 296 297 func TestOCSPRetry(t *testing.T) { 298 certs := getCert("s3-us-west-2.amazonaws.com:443") 299 dummyOCSPHost := &url.URL{ 300 Scheme: "https", 301 Host: "dummyOCSPHost", 302 } 303 client := &fakeHTTPClient{ 304 cnt: 3, 305 success: true, 306 body: []byte{1, 2, 3}, 307 } 308 res, b, st := retryOCSP( 309 context.Background(), 310 client, emptyRequest, 311 dummyOCSPHost, 312 make(map[string]string), []byte{0}, certs[len(certs)-1], 10*time.Second) 313 if st.err == nil { 314 fmt.Printf("should fail: %v, %v, %v\n", res, b, st) 315 } 316 client = &fakeHTTPClient{ 317 cnt: 30, 318 success: true, 319 body: []byte{1, 2, 3}, 320 } 321 res, b, st = retryOCSP( 322 context.Background(), 323 client, fakeRequestFunc, 324 dummyOCSPHost, 325 make(map[string]string), []byte{0}, certs[len(certs)-1], 5*time.Second) 326 if st.err == nil { 327 fmt.Printf("should fail: %v, %v, %v\n", res, b, st) 328 } 329 } 330 331 func TestFullOCSPURL(t *testing.T) { 332 testcases := []tcFullOCSPURL{ 333 { 334 url: &url.URL{Host: "some-ocsp-url.com"}, 335 expectedURLString: "some-ocsp-url.com", 336 }, 337 { 338 url: &url.URL{ 339 Host: "some-ocsp-url.com", 340 Path: "/some-path", 341 }, 342 expectedURLString: "some-ocsp-url.com/some-path", 343 }, 344 { 345 url: &url.URL{ 346 Host: "some-ocsp-url.com", 347 Path: "some-path", 348 }, 349 expectedURLString: "some-ocsp-url.com/some-path", 350 }, 351 } 352 353 for _, testcase := range testcases { 354 t.Run("", func(t *testing.T) { 355 returnedStringURL := fullOCSPURL(testcase.url) 356 if returnedStringURL != testcase.expectedURLString { 357 t.Fatalf("failed to match returned OCSP url string; expected: %v, got: %v", 358 testcase.expectedURLString, returnedStringURL) 359 } 360 }) 361 } 362 } 363 364 type tcFullOCSPURL struct { 365 url *url.URL 366 expectedURLString string 367 } 368 369 func TestOCSPCacheServerRetry(t *testing.T) { 370 dummyOCSPHost := &url.URL{ 371 Scheme: "https", 372 Host: "dummyOCSPHost", 373 } 374 client := &fakeHTTPClient{ 375 cnt: 3, 376 success: true, 377 body: []byte{1, 2, 3}, 378 } 379 res, st := checkOCSPCacheServer( 380 context.Background(), client, fakeRequestFunc, dummyOCSPHost, 20*time.Second) 381 if st.err == nil { 382 t.Errorf("should fail: %v", res) 383 } 384 client = &fakeHTTPClient{ 385 cnt: 30, 386 success: true, 387 body: []byte{1, 2, 3}, 388 } 389 res, st = checkOCSPCacheServer( 390 context.Background(), client, fakeRequestFunc, dummyOCSPHost, 10*time.Second) 391 if st.err == nil { 392 t.Errorf("should fail: %v", res) 393 } 394 } 395 396 type tcCanEarlyExit struct { 397 results []*ocspStatus 398 resultLen int 399 retFailOpen *ocspStatus 400 retFailClosed *ocspStatus 401 } 402 403 func TestCanEarlyExitForOCSP(t *testing.T) { 404 testcases := []tcCanEarlyExit{ 405 { // 0 406 results: []*ocspStatus{ 407 { 408 code: ocspStatusGood, 409 }, 410 { 411 code: ocspStatusGood, 412 }, 413 { 414 code: ocspStatusGood, 415 }, 416 }, 417 retFailOpen: nil, 418 retFailClosed: nil, 419 }, 420 { // 1 421 results: []*ocspStatus{ 422 { 423 code: ocspStatusRevoked, 424 err: errors.New("revoked"), 425 }, 426 { 427 code: ocspStatusGood, 428 }, 429 { 430 code: ocspStatusGood, 431 }, 432 }, 433 retFailOpen: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, 434 retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, 435 }, 436 { // 2 437 results: []*ocspStatus{ 438 { 439 code: ocspStatusUnknown, 440 err: errors.New("unknown"), 441 }, 442 { 443 code: ocspStatusGood, 444 }, 445 { 446 code: ocspStatusGood, 447 }, 448 }, 449 retFailOpen: nil, 450 retFailClosed: &ocspStatus{ocspStatusUnknown, errors.New("unknown")}, 451 }, 452 { // 3: not taken as revoked if any invalid OCSP response (ocspInvalidValidity) is included. 453 results: []*ocspStatus{ 454 { 455 code: ocspStatusRevoked, 456 err: errors.New("revoked"), 457 }, 458 { 459 code: ocspInvalidValidity, 460 }, 461 { 462 code: ocspStatusGood, 463 }, 464 }, 465 retFailOpen: nil, 466 retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, 467 }, 468 { // 4: not taken as revoked if the number of results don't match the expected results. 469 results: []*ocspStatus{ 470 { 471 code: ocspStatusRevoked, 472 err: errors.New("revoked"), 473 }, 474 { 475 code: ocspStatusGood, 476 }, 477 }, 478 resultLen: 3, 479 retFailOpen: nil, 480 retFailClosed: &ocspStatus{ocspStatusRevoked, errors.New("revoked")}, 481 }, 482 } 483 484 for idx, tt := range testcases { 485 t.Run("", func(t *testing.T) { 486 ocspFailOpen = OCSPFailOpenTrue 487 expectedLen := len(tt.results) 488 if tt.resultLen > 0 { 489 expectedLen = tt.resultLen 490 } 491 r := canEarlyExitForOCSP(tt.results, expectedLen) 492 if !(tt.retFailOpen == nil && r == nil) && !(tt.retFailOpen != nil && r != nil && tt.retFailOpen.code == r.code) { 493 t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailOpen, r) 494 } 495 ocspFailOpen = OCSPFailOpenFalse 496 r = canEarlyExitForOCSP(tt.results, expectedLen) 497 if !(tt.retFailClosed == nil && r == nil) && !(tt.retFailClosed != nil && r != nil && tt.retFailClosed.code == r.code) { 498 t.Fatalf("%d: failed to match return. expected: %v, got: %v", idx, tt.retFailClosed, r) 499 } 500 }) 501 } 502 } 503 504 func TestInitOCSPCacheFileCreation(t *testing.T) { 505 if runningOnGithubAction() { 506 t.Skip("cannot write to github file system") 507 } 508 dirName, err := os.UserHomeDir() 509 if err != nil { 510 t.Error(err) 511 } 512 srcFileName := dirName + "/.cache/snowflake/ocsp_response_cache.json" 513 tmpFileName := srcFileName + "_tmp" 514 dst, err := os.Create(tmpFileName) 515 if err != nil { 516 t.Error(err) 517 } 518 defer dst.Close() 519 520 var src *os.File 521 if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) { 522 // file does not exist 523 if err = os.MkdirAll(dirName+"/.cache/snowflake/", os.ModePerm); err != nil { 524 t.Error(err) 525 } 526 if _, err = os.Create(srcFileName); err != nil { 527 t.Error(err) 528 } 529 } else if err != nil { 530 t.Error(err) 531 } else { 532 // file exists 533 src, err = os.Open(srcFileName) 534 if err != nil { 535 t.Error(err) 536 } 537 defer src.Close() 538 // copy original contents to temporary file 539 if _, err = io.Copy(dst, src); err != nil { 540 t.Error(err) 541 } 542 if err = os.Remove(srcFileName); err != nil { 543 t.Error(err) 544 } 545 } 546 547 // cleanup 548 defer func() { 549 src, _ = os.Open(tmpFileName) 550 defer src.Close() 551 dst, _ = os.OpenFile(srcFileName, os.O_WRONLY, readWriteFileMode) 552 defer dst.Close() 553 // copy temporary file contents back to original file 554 if _, err = io.Copy(dst, src); err != nil { 555 t.Fatal(err) 556 } 557 if err = os.Remove(tmpFileName); err != nil { 558 t.Error(err) 559 } 560 }() 561 562 initOCSPCache() 563 if _, err = os.Stat(srcFileName); errors.Is(err, os.ErrNotExist) { 564 t.Error(err) 565 } else if err != nil { 566 t.Error(err) 567 } 568 }