github.com/snowflakedb/gosnowflake@v1.9.0/restful_test.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "net/http" 11 "net/url" 12 "sync" 13 "sync/atomic" 14 "testing" 15 "time" 16 ) 17 18 func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { 19 return &http.Response{ 20 StatusCode: http.StatusOK, 21 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 22 }, errors.New("failed to run post method") 23 } 24 25 func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { 26 return &http.Response{ 27 StatusCode: http.StatusOK, 28 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 29 }, errors.New("failed to run post method") 30 } 31 32 func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { 33 return &http.Response{ 34 StatusCode: http.StatusOK, 35 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 36 }, nil 37 } 38 39 func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { 40 return &http.Response{ 41 StatusCode: http.StatusBadGateway, 42 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 43 }, nil 44 } 45 46 func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { 47 return &http.Response{ 48 StatusCode: http.StatusBadGateway, 49 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 50 }, nil 51 } 52 53 func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { 54 return &http.Response{ 55 StatusCode: http.StatusForbidden, 56 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 57 }, nil 58 } 59 60 func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { 61 return &http.Response{ 62 StatusCode: http.StatusForbidden, 63 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 64 }, nil 65 } 66 67 func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { 68 return &http.Response{ 69 StatusCode: http.StatusInsufficientStorage, 70 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 71 }, nil 72 } 73 74 func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { 75 dd := &execResponseData{} 76 er := &execResponse{ 77 Data: *dd, 78 Message: "", 79 Code: queryNotExecuting, 80 Success: false, 81 } 82 ba, err := json.Marshal(er) 83 if err != nil { 84 panic(err) 85 } 86 87 return &http.Response{ 88 StatusCode: http.StatusOK, 89 Body: &fakeResponseBody{body: ba}, 90 }, nil 91 } 92 93 func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { 94 dd := &execResponseData{} 95 er := &execResponse{ 96 Data: *dd, 97 Message: "", 98 Code: sessionExpiredCode, 99 Success: true, 100 } 101 102 ba, err := json.Marshal(er) 103 logger.Infof("encoded JSON: %v", ba) 104 if err != nil { 105 panic(err) 106 } 107 return &http.Response{ 108 StatusCode: http.StatusOK, 109 Body: &fakeResponseBody{body: ba}, 110 }, nil 111 } 112 113 func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { 114 dd := &execResponseData{} 115 er := &execResponse{ 116 Data: *dd, 117 Message: "", 118 Code: "", 119 Success: true, 120 } 121 122 ba, err := json.Marshal(er) 123 logger.Infof("encoded JSON: %v", ba) 124 if err != nil { 125 panic(err) 126 } 127 return &http.Response{ 128 StatusCode: http.StatusOK, 129 Body: &fakeResponseBody{body: ba}, 130 }, nil 131 } 132 133 func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { 134 dd := &execResponseData{} 135 er := &execResponse{ 136 Data: *dd, 137 Message: "", 138 Code: "", 139 Success: true, 140 } 141 142 ba, err := json.Marshal(er) 143 logger.Infof("encoded JSON: %v", ba) 144 if err != nil { 145 panic(err) 146 } 147 return &http.Response{ 148 StatusCode: http.StatusOK, 149 Body: &fakeResponseBody{body: ba}, 150 }, nil 151 } 152 153 func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID, timeout time.Duration) error { 154 ctxRetry := getCancelRetry(ctx) 155 u := url.URL{} 156 reqByte, err := json.Marshal(make(map[string]string)) 157 if err != nil { 158 return err 159 } 160 resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, defaultTimeProvider, nil) 161 if err != nil { 162 return err 163 } 164 if resp.StatusCode == http.StatusOK { 165 var respd cancelQueryResponse 166 err = json.NewDecoder(resp.Body).Decode(&respd) 167 if err != nil { 168 return err 169 } 170 if !respd.Success && respd.Code == queryNotExecuting && ctxRetry != 0 { 171 return sr.FuncCancelQuery(context.WithValue(ctx, cancelRetry, ctxRetry-1), sr, requestID, timeout) 172 } 173 if ctxRetry == 0 { 174 return nil 175 } 176 } 177 return fmt.Errorf("cancel retry failed") 178 } 179 180 func TestUnitPostQueryHelperError(t *testing.T) { 181 sr := &snowflakeRestful{ 182 FuncPost: postTestError, 183 TokenAccessor: getSimpleTokenAccessor(), 184 } 185 var err error 186 requestID := NewUUID() 187 _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{}) 188 if err == nil { 189 t.Fatalf("should have failed to post") 190 } 191 sr.FuncPost = postTestAppBadGatewayError 192 requestID = NewUUID() 193 _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{}) 194 if err == nil { 195 t.Fatalf("should have failed to post") 196 } 197 sr.FuncPost = postTestSuccessButInvalidJSON 198 requestID = NewUUID() 199 _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{}) 200 if err == nil { 201 t.Fatalf("should have failed to post") 202 } 203 } 204 205 func renewSessionTest(_ context.Context, _ *snowflakeRestful, _ time.Duration) error { 206 return nil 207 } 208 209 func renewSessionTestError(_ context.Context, _ *snowflakeRestful, _ time.Duration) error { 210 return errors.New("failed to renew session in tests") 211 } 212 213 func TestUnitTokenAccessorDoesNotRenewStaleToken(t *testing.T) { 214 accessor := getSimpleTokenAccessor() 215 oldToken := "test" 216 accessor.SetTokens(oldToken, "master", 123) 217 218 renewSessionCalled := false 219 renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error { 220 // should not have gotten to actual renewal 221 renewSessionCalled = true 222 return nil 223 } 224 225 sr := &snowflakeRestful{ 226 FuncRenewSession: renewSessionDummy, 227 TokenAccessor: accessor, 228 } 229 230 // try to intentionally renew with stale token 231 sr.renewExpiredSessionToken(context.Background(), time.Hour, "stale-token") 232 233 if renewSessionCalled { 234 t.Fatal("FuncRenewSession should not have been called") 235 } 236 237 // set the current token to empty, should still call renew even if stale token is passed in 238 accessor.SetTokens("", "master", 123) 239 sr.renewExpiredSessionToken(context.Background(), time.Hour, "stale-token") 240 241 if !renewSessionCalled { 242 t.Fatal("FuncRenewSession should have been called because current token is empty") 243 } 244 } 245 246 type wrappedAccessor struct { 247 ta TokenAccessor 248 lockCallCount int32 249 unlockCallCount int32 250 } 251 252 func (wa *wrappedAccessor) Lock() error { 253 atomic.AddInt32(&wa.lockCallCount, 1) 254 err := wa.ta.Lock() 255 return err 256 } 257 258 func (wa *wrappedAccessor) Unlock() { 259 atomic.AddInt32(&wa.unlockCallCount, 1) 260 wa.ta.Unlock() 261 } 262 263 func (wa *wrappedAccessor) GetTokens() (token string, masterToken string, sessionID int64) { 264 return wa.ta.GetTokens() 265 } 266 267 func (wa *wrappedAccessor) SetTokens(token string, masterToken string, sessionID int64) { 268 wa.ta.SetTokens(token, masterToken, sessionID) 269 } 270 271 func TestUnitTokenAccessorRenewBlocked(t *testing.T) { 272 accessor := wrappedAccessor{ 273 ta: getSimpleTokenAccessor(), 274 } 275 oldToken := "test" 276 accessor.SetTokens(oldToken, "master", 123) 277 278 renewSessionCalled := false 279 renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error { 280 renewSessionCalled = true 281 return nil 282 } 283 284 sr := &snowflakeRestful{ 285 FuncRenewSession: renewSessionDummy, 286 TokenAccessor: &accessor, 287 } 288 289 // intentionally lock the accessor first 290 accessor.Lock() 291 292 // try to intentionally renew with stale token 293 var renewalStart sync.WaitGroup 294 var renewalDone sync.WaitGroup 295 renewalStart.Add(1) 296 renewalDone.Add(1) 297 go func() { 298 renewalStart.Done() 299 sr.renewExpiredSessionToken(context.Background(), time.Hour, oldToken) 300 renewalDone.Done() 301 }() 302 303 // wait for renewal to start and get blocked on lock 304 renewalStart.Wait() 305 // should be blocked and not be able to call renew session 306 if renewSessionCalled { 307 t.Fail() 308 } 309 310 // rotate the token again so that the session token is considered stale 311 accessor.SetTokens("new-token", "m", 321) 312 313 // unlock so that renew can happen 314 accessor.Unlock() 315 renewalDone.Wait() 316 317 // renewal should be done but token should still not 318 // have been renewed since we intentionally swapped token while locked 319 if renewSessionCalled { 320 t.Fail() 321 } 322 323 // wait for accessor defer unlock 324 accessor.Lock() 325 if accessor.lockCallCount != 3 { 326 t.Fatalf("Expected Lock() to be called thrice, but got %v", accessor.lockCallCount) 327 } 328 if accessor.unlockCallCount != 2 { 329 t.Fatalf("Expected Unlock() to be called twice, but got %v", accessor.unlockCallCount) 330 } 331 } 332 333 func TestUnitTokenAccessorRenewSessionContention(t *testing.T) { 334 accessor := getSimpleTokenAccessor() 335 oldToken := "test" 336 accessor.SetTokens(oldToken, "master", 123) 337 var counter int32 = 0 338 339 expectedToken := "new token" 340 expectedMaster := "new master" 341 expectedSession := int64(321) 342 343 renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error { 344 accessor.SetTokens(expectedToken, expectedMaster, expectedSession) 345 atomic.AddInt32(&counter, 1) 346 return nil 347 } 348 349 sr := &snowflakeRestful{ 350 FuncRenewSession: renewSessionDummy, 351 TokenAccessor: accessor, 352 } 353 354 var renewalsStart sync.WaitGroup 355 var renewalsDone sync.WaitGroup 356 var renewalError error 357 numRoutines := 50 358 for i := 0; i < numRoutines; i++ { 359 renewalsDone.Add(1) 360 renewalsStart.Add(1) 361 go func() { 362 // wait for all goroutines to have been created before proceeding to race against each other 363 renewalsStart.Wait() 364 err := sr.renewExpiredSessionToken(context.Background(), time.Hour, oldToken) 365 if err != nil { 366 renewalError = err 367 } 368 renewalsDone.Done() 369 }() 370 } 371 372 // unlock all of the waiting goroutines simultaneously 373 renewalsStart.Add(-numRoutines) 374 375 // wait for all competing goroutines to finish calling renew expired session token 376 renewalsDone.Wait() 377 378 if renewalError != nil { 379 t.Fatalf("failed to renew session, error %v", renewalError) 380 } 381 newToken, newMaster, newSession := accessor.GetTokens() 382 if newToken != expectedToken { 383 t.Fatalf("token %v does not match expected %v", newToken, expectedToken) 384 } 385 if newMaster != expectedMaster { 386 t.Fatalf("master token %v does not match expected %v", newMaster, expectedMaster) 387 } 388 if newSession != expectedSession { 389 t.Fatalf("session %v does not match expected %v", newSession, expectedSession) 390 } 391 // only the first renewal will go through and FuncRenewSession should be called exactly once 392 if counter != 1 { 393 t.Fatalf("renew expired session was called more than once: %v", counter) 394 } 395 } 396 397 func TestUnitPostQueryHelperUsesToken(t *testing.T) { 398 accessor := getSimpleTokenAccessor() 399 token := "token123" 400 accessor.SetTokens(token, "", 0) 401 402 var err error 403 postQueryTest := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { 404 if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, token) { 405 t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, token)) 406 } 407 dd := &execResponseData{} 408 return &execResponse{ 409 Data: *dd, 410 Message: "", 411 Code: "0", 412 Success: true, 413 }, nil 414 } 415 sr := &snowflakeRestful{ 416 FuncPost: postTestRenew, 417 FuncPostQuery: postQueryTest, 418 FuncRenewSession: renewSessionTest, 419 TokenAccessor: accessor, 420 } 421 _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, NewUUID(), &Config{}) 422 if err != nil { 423 t.Fatalf("err: %v", err) 424 } 425 } 426 427 func TestUnitPostQueryHelperRenewSession(t *testing.T) { 428 var err error 429 origRequestID := NewUUID() 430 postQueryTest := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, requestID UUID, _ *Config) (*execResponse, error) { 431 // ensure the same requestID is used after the session token is renewed. 432 if requestID != origRequestID { 433 t.Fatal("requestID doesn't match") 434 } 435 dd := &execResponseData{} 436 return &execResponse{ 437 Data: *dd, 438 Message: "", 439 Code: "0", 440 Success: true, 441 }, nil 442 } 443 sr := &snowflakeRestful{ 444 FuncPost: postTestRenew, 445 FuncPostQuery: postQueryTest, 446 FuncRenewSession: renewSessionTest, 447 TokenAccessor: getSimpleTokenAccessor(), 448 } 449 450 _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, origRequestID, &Config{}) 451 if err != nil { 452 t.Fatalf("err: %v", err) 453 } 454 sr.FuncRenewSession = renewSessionTestError 455 _, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, origRequestID, &Config{}) 456 if err == nil { 457 t.Fatal("should have failed to renew session") 458 } 459 } 460 461 func TestUnitRenewRestfulSession(t *testing.T) { 462 accessor := getSimpleTokenAccessor() 463 oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100) 464 newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200) 465 postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) { 466 if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) { 467 t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken)) 468 } 469 tr := &renewSessionResponse{ 470 Data: renewSessionResponseMain{ 471 SessionToken: newToken, 472 MasterToken: newMasterToken, 473 SessionID: newSessionID, 474 }, 475 Message: "", 476 Success: true, 477 } 478 ba, err := json.Marshal(tr) 479 if err != nil { 480 t.Fatalf("failed to serialize token response %v", err) 481 } 482 return &http.Response{ 483 StatusCode: http.StatusOK, 484 Body: &fakeResponseBody{body: ba}, 485 }, nil 486 } 487 488 sr := &snowflakeRestful{ 489 FuncPost: postTestAfterRenew, 490 TokenAccessor: accessor, 491 } 492 err := renewRestfulSession(context.Background(), sr, time.Second) 493 if err != nil { 494 t.Fatalf("err: %v", err) 495 } 496 sr.FuncPost = postTestError 497 err = renewRestfulSession(context.Background(), sr, time.Second) 498 if err == nil { 499 t.Fatal("should have failed to run post request after the renewal") 500 } 501 sr.FuncPost = postTestAppBadGatewayError 502 err = renewRestfulSession(context.Background(), sr, time.Second) 503 if err == nil { 504 t.Fatal("should have failed to run post request after the renewal") 505 } 506 sr.FuncPost = postTestSuccessButInvalidJSON 507 err = renewRestfulSession(context.Background(), sr, time.Second) 508 if err == nil { 509 t.Fatal("should have failed to run post request after the renewal") 510 } 511 accessor.SetTokens(oldToken, oldMasterToken, oldSessionID) 512 sr.FuncPost = postTestSuccessWithNewTokens 513 err = renewRestfulSession(context.Background(), sr, time.Second) 514 if err != nil { 515 t.Fatal("should not have failed to run post request after the renewal") 516 } 517 token, masterToken, sessionID := accessor.GetTokens() 518 if token != newToken { 519 t.Fatalf("unexpected new token %v", token) 520 } 521 if masterToken != newMasterToken { 522 t.Fatalf("unexpected new master token %v", masterToken) 523 } 524 if sessionID != newSessionID { 525 t.Fatalf("unexpected new session id %v", sessionID) 526 } 527 } 528 529 func TestUnitCloseSession(t *testing.T) { 530 sr := &snowflakeRestful{ 531 FuncPost: postTestAfterRenew, 532 TokenAccessor: getSimpleTokenAccessor(), 533 } 534 err := closeSession(context.Background(), sr, time.Second) 535 if err != nil { 536 t.Fatalf("err: %v", err) 537 } 538 sr.FuncPost = postTestError 539 err = closeSession(context.Background(), sr, time.Second) 540 if err == nil { 541 t.Fatal("should have failed to close session") 542 } 543 sr.FuncPost = postTestAppBadGatewayError 544 err = closeSession(context.Background(), sr, time.Second) 545 if err == nil { 546 t.Fatal("should have failed to close session") 547 } 548 sr.FuncPost = postTestSuccessButInvalidJSON 549 err = closeSession(context.Background(), sr, time.Second) 550 if err == nil { 551 t.Fatal("should have failed to close session") 552 } 553 } 554 555 func TestUnitCancelQuery(t *testing.T) { 556 sr := &snowflakeRestful{ 557 FuncPost: postTestAfterRenew, 558 TokenAccessor: getSimpleTokenAccessor(), 559 } 560 ctx := context.Background() 561 err := cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second) 562 if err != nil { 563 t.Fatalf("err: %v", err) 564 } 565 sr.FuncPost = postTestError 566 err = cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second) 567 if err == nil { 568 t.Fatal("should have failed to close session") 569 } 570 sr.FuncPost = postTestAppBadGatewayError 571 err = cancelQuery(context.Background(), sr, getOrGenerateRequestIDFromContext(ctx), time.Second) 572 if err == nil { 573 t.Fatal("should have failed to close session") 574 } 575 sr.FuncPost = postTestSuccessButInvalidJSON 576 err = cancelQuery(context.Background(), sr, getOrGenerateRequestIDFromContext(ctx), time.Second) 577 if err == nil { 578 t.Fatal("should have failed to close session") 579 } 580 } 581 582 func TestCancelRetry(t *testing.T) { 583 sr := &snowflakeRestful{ 584 TokenAccessor: getSimpleTokenAccessor(), 585 FuncPost: postTestQueryNotExecuting, 586 FuncCancelQuery: cancelTestRetry, 587 } 588 ctx := context.Background() 589 err := cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second) 590 if err != nil { 591 t.Fatal(err) 592 } 593 }