github.com/snowflakedb/gosnowflake@v1.9.0/restful_test.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"net/http"
    11  	"net/url"
    12  	"sync"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
    19  	return &http.Response{
    20  		StatusCode: http.StatusOK,
    21  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    22  	}, errors.New("failed to run post method")
    23  }
    24  
    25  func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
    26  	return &http.Response{
    27  		StatusCode: http.StatusOK,
    28  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    29  	}, errors.New("failed to run post method")
    30  }
    31  
    32  func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
    33  	return &http.Response{
    34  		StatusCode: http.StatusOK,
    35  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    36  	}, nil
    37  }
    38  
    39  func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
    40  	return &http.Response{
    41  		StatusCode: http.StatusBadGateway,
    42  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    43  	}, nil
    44  }
    45  
    46  func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
    47  	return &http.Response{
    48  		StatusCode: http.StatusBadGateway,
    49  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    50  	}, nil
    51  }
    52  
    53  func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
    54  	return &http.Response{
    55  		StatusCode: http.StatusForbidden,
    56  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    57  	}, nil
    58  }
    59  
    60  func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
    61  	return &http.Response{
    62  		StatusCode: http.StatusForbidden,
    63  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    64  	}, nil
    65  }
    66  
    67  func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
    68  	return &http.Response{
    69  		StatusCode: http.StatusInsufficientStorage,
    70  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    71  	}, nil
    72  }
    73  
    74  func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
    75  	dd := &execResponseData{}
    76  	er := &execResponse{
    77  		Data:    *dd,
    78  		Message: "",
    79  		Code:    queryNotExecuting,
    80  		Success: false,
    81  	}
    82  	ba, err := json.Marshal(er)
    83  	if err != nil {
    84  		panic(err)
    85  	}
    86  
    87  	return &http.Response{
    88  		StatusCode: http.StatusOK,
    89  		Body:       &fakeResponseBody{body: ba},
    90  	}, nil
    91  }
    92  
    93  func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
    94  	dd := &execResponseData{}
    95  	er := &execResponse{
    96  		Data:    *dd,
    97  		Message: "",
    98  		Code:    sessionExpiredCode,
    99  		Success: true,
   100  	}
   101  
   102  	ba, err := json.Marshal(er)
   103  	logger.Infof("encoded JSON: %v", ba)
   104  	if err != nil {
   105  		panic(err)
   106  	}
   107  	return &http.Response{
   108  		StatusCode: http.StatusOK,
   109  		Body:       &fakeResponseBody{body: ba},
   110  	}, nil
   111  }
   112  
   113  func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
   114  	dd := &execResponseData{}
   115  	er := &execResponse{
   116  		Data:    *dd,
   117  		Message: "",
   118  		Code:    "",
   119  		Success: true,
   120  	}
   121  
   122  	ba, err := json.Marshal(er)
   123  	logger.Infof("encoded JSON: %v", ba)
   124  	if err != nil {
   125  		panic(err)
   126  	}
   127  	return &http.Response{
   128  		StatusCode: http.StatusOK,
   129  		Body:       &fakeResponseBody{body: ba},
   130  	}, nil
   131  }
   132  
   133  func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
   134  	dd := &execResponseData{}
   135  	er := &execResponse{
   136  		Data:    *dd,
   137  		Message: "",
   138  		Code:    "",
   139  		Success: true,
   140  	}
   141  
   142  	ba, err := json.Marshal(er)
   143  	logger.Infof("encoded JSON: %v", ba)
   144  	if err != nil {
   145  		panic(err)
   146  	}
   147  	return &http.Response{
   148  		StatusCode: http.StatusOK,
   149  		Body:       &fakeResponseBody{body: ba},
   150  	}, nil
   151  }
   152  
   153  func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID, timeout time.Duration) error {
   154  	ctxRetry := getCancelRetry(ctx)
   155  	u := url.URL{}
   156  	reqByte, err := json.Marshal(make(map[string]string))
   157  	if err != nil {
   158  		return err
   159  	}
   160  	resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, defaultTimeProvider, nil)
   161  	if err != nil {
   162  		return err
   163  	}
   164  	if resp.StatusCode == http.StatusOK {
   165  		var respd cancelQueryResponse
   166  		err = json.NewDecoder(resp.Body).Decode(&respd)
   167  		if err != nil {
   168  			return err
   169  		}
   170  		if !respd.Success && respd.Code == queryNotExecuting && ctxRetry != 0 {
   171  			return sr.FuncCancelQuery(context.WithValue(ctx, cancelRetry, ctxRetry-1), sr, requestID, timeout)
   172  		}
   173  		if ctxRetry == 0 {
   174  			return nil
   175  		}
   176  	}
   177  	return fmt.Errorf("cancel retry failed")
   178  }
   179  
   180  func TestUnitPostQueryHelperError(t *testing.T) {
   181  	sr := &snowflakeRestful{
   182  		FuncPost:      postTestError,
   183  		TokenAccessor: getSimpleTokenAccessor(),
   184  	}
   185  	var err error
   186  	requestID := NewUUID()
   187  	_, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{})
   188  	if err == nil {
   189  		t.Fatalf("should have failed to post")
   190  	}
   191  	sr.FuncPost = postTestAppBadGatewayError
   192  	requestID = NewUUID()
   193  	_, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{})
   194  	if err == nil {
   195  		t.Fatalf("should have failed to post")
   196  	}
   197  	sr.FuncPost = postTestSuccessButInvalidJSON
   198  	requestID = NewUUID()
   199  	_, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, requestID, &Config{})
   200  	if err == nil {
   201  		t.Fatalf("should have failed to post")
   202  	}
   203  }
   204  
   205  func renewSessionTest(_ context.Context, _ *snowflakeRestful, _ time.Duration) error {
   206  	return nil
   207  }
   208  
   209  func renewSessionTestError(_ context.Context, _ *snowflakeRestful, _ time.Duration) error {
   210  	return errors.New("failed to renew session in tests")
   211  }
   212  
   213  func TestUnitTokenAccessorDoesNotRenewStaleToken(t *testing.T) {
   214  	accessor := getSimpleTokenAccessor()
   215  	oldToken := "test"
   216  	accessor.SetTokens(oldToken, "master", 123)
   217  
   218  	renewSessionCalled := false
   219  	renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error {
   220  		// should not have gotten to actual renewal
   221  		renewSessionCalled = true
   222  		return nil
   223  	}
   224  
   225  	sr := &snowflakeRestful{
   226  		FuncRenewSession: renewSessionDummy,
   227  		TokenAccessor:    accessor,
   228  	}
   229  
   230  	// try to intentionally renew with stale token
   231  	sr.renewExpiredSessionToken(context.Background(), time.Hour, "stale-token")
   232  
   233  	if renewSessionCalled {
   234  		t.Fatal("FuncRenewSession should not have been called")
   235  	}
   236  
   237  	// set the current token to empty, should still call renew even if stale token is passed in
   238  	accessor.SetTokens("", "master", 123)
   239  	sr.renewExpiredSessionToken(context.Background(), time.Hour, "stale-token")
   240  
   241  	if !renewSessionCalled {
   242  		t.Fatal("FuncRenewSession should have been called because current token is empty")
   243  	}
   244  }
   245  
   246  type wrappedAccessor struct {
   247  	ta              TokenAccessor
   248  	lockCallCount   int32
   249  	unlockCallCount int32
   250  }
   251  
   252  func (wa *wrappedAccessor) Lock() error {
   253  	atomic.AddInt32(&wa.lockCallCount, 1)
   254  	err := wa.ta.Lock()
   255  	return err
   256  }
   257  
   258  func (wa *wrappedAccessor) Unlock() {
   259  	atomic.AddInt32(&wa.unlockCallCount, 1)
   260  	wa.ta.Unlock()
   261  }
   262  
   263  func (wa *wrappedAccessor) GetTokens() (token string, masterToken string, sessionID int64) {
   264  	return wa.ta.GetTokens()
   265  }
   266  
   267  func (wa *wrappedAccessor) SetTokens(token string, masterToken string, sessionID int64) {
   268  	wa.ta.SetTokens(token, masterToken, sessionID)
   269  }
   270  
   271  func TestUnitTokenAccessorRenewBlocked(t *testing.T) {
   272  	accessor := wrappedAccessor{
   273  		ta: getSimpleTokenAccessor(),
   274  	}
   275  	oldToken := "test"
   276  	accessor.SetTokens(oldToken, "master", 123)
   277  
   278  	renewSessionCalled := false
   279  	renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error {
   280  		renewSessionCalled = true
   281  		return nil
   282  	}
   283  
   284  	sr := &snowflakeRestful{
   285  		FuncRenewSession: renewSessionDummy,
   286  		TokenAccessor:    &accessor,
   287  	}
   288  
   289  	// intentionally lock the accessor first
   290  	accessor.Lock()
   291  
   292  	// try to intentionally renew with stale token
   293  	var renewalStart sync.WaitGroup
   294  	var renewalDone sync.WaitGroup
   295  	renewalStart.Add(1)
   296  	renewalDone.Add(1)
   297  	go func() {
   298  		renewalStart.Done()
   299  		sr.renewExpiredSessionToken(context.Background(), time.Hour, oldToken)
   300  		renewalDone.Done()
   301  	}()
   302  
   303  	// wait for renewal to start and get blocked on lock
   304  	renewalStart.Wait()
   305  	// should be blocked and not be able to call renew session
   306  	if renewSessionCalled {
   307  		t.Fail()
   308  	}
   309  
   310  	// rotate the token again so that the session token is considered stale
   311  	accessor.SetTokens("new-token", "m", 321)
   312  
   313  	// unlock so that renew can happen
   314  	accessor.Unlock()
   315  	renewalDone.Wait()
   316  
   317  	// renewal should be done but token should still not
   318  	// have been renewed since we intentionally swapped token while locked
   319  	if renewSessionCalled {
   320  		t.Fail()
   321  	}
   322  
   323  	// wait for accessor defer unlock
   324  	accessor.Lock()
   325  	if accessor.lockCallCount != 3 {
   326  		t.Fatalf("Expected Lock() to be called thrice, but got %v", accessor.lockCallCount)
   327  	}
   328  	if accessor.unlockCallCount != 2 {
   329  		t.Fatalf("Expected Unlock() to be called twice, but got %v", accessor.unlockCallCount)
   330  	}
   331  }
   332  
   333  func TestUnitTokenAccessorRenewSessionContention(t *testing.T) {
   334  	accessor := getSimpleTokenAccessor()
   335  	oldToken := "test"
   336  	accessor.SetTokens(oldToken, "master", 123)
   337  	var counter int32 = 0
   338  
   339  	expectedToken := "new token"
   340  	expectedMaster := "new master"
   341  	expectedSession := int64(321)
   342  
   343  	renewSessionDummy := func(_ context.Context, sr *snowflakeRestful, _ time.Duration) error {
   344  		accessor.SetTokens(expectedToken, expectedMaster, expectedSession)
   345  		atomic.AddInt32(&counter, 1)
   346  		return nil
   347  	}
   348  
   349  	sr := &snowflakeRestful{
   350  		FuncRenewSession: renewSessionDummy,
   351  		TokenAccessor:    accessor,
   352  	}
   353  
   354  	var renewalsStart sync.WaitGroup
   355  	var renewalsDone sync.WaitGroup
   356  	var renewalError error
   357  	numRoutines := 50
   358  	for i := 0; i < numRoutines; i++ {
   359  		renewalsDone.Add(1)
   360  		renewalsStart.Add(1)
   361  		go func() {
   362  			// wait for all goroutines to have been created before proceeding to race against each other
   363  			renewalsStart.Wait()
   364  			err := sr.renewExpiredSessionToken(context.Background(), time.Hour, oldToken)
   365  			if err != nil {
   366  				renewalError = err
   367  			}
   368  			renewalsDone.Done()
   369  		}()
   370  	}
   371  
   372  	// unlock all of the waiting goroutines simultaneously
   373  	renewalsStart.Add(-numRoutines)
   374  
   375  	// wait for all competing goroutines to finish calling renew expired session token
   376  	renewalsDone.Wait()
   377  
   378  	if renewalError != nil {
   379  		t.Fatalf("failed to renew session, error %v", renewalError)
   380  	}
   381  	newToken, newMaster, newSession := accessor.GetTokens()
   382  	if newToken != expectedToken {
   383  		t.Fatalf("token %v does not match expected %v", newToken, expectedToken)
   384  	}
   385  	if newMaster != expectedMaster {
   386  		t.Fatalf("master token %v does not match expected %v", newMaster, expectedMaster)
   387  	}
   388  	if newSession != expectedSession {
   389  		t.Fatalf("session %v does not match expected %v", newSession, expectedSession)
   390  	}
   391  	// only the first renewal will go through and FuncRenewSession should be called exactly once
   392  	if counter != 1 {
   393  		t.Fatalf("renew expired session was called more than once: %v", counter)
   394  	}
   395  }
   396  
   397  func TestUnitPostQueryHelperUsesToken(t *testing.T) {
   398  	accessor := getSimpleTokenAccessor()
   399  	token := "token123"
   400  	accessor.SetTokens(token, "", 0)
   401  
   402  	var err error
   403  	postQueryTest := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) {
   404  		if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, token) {
   405  			t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, token))
   406  		}
   407  		dd := &execResponseData{}
   408  		return &execResponse{
   409  			Data:    *dd,
   410  			Message: "",
   411  			Code:    "0",
   412  			Success: true,
   413  		}, nil
   414  	}
   415  	sr := &snowflakeRestful{
   416  		FuncPost:         postTestRenew,
   417  		FuncPostQuery:    postQueryTest,
   418  		FuncRenewSession: renewSessionTest,
   419  		TokenAccessor:    accessor,
   420  	}
   421  	_, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, NewUUID(), &Config{})
   422  	if err != nil {
   423  		t.Fatalf("err: %v", err)
   424  	}
   425  }
   426  
   427  func TestUnitPostQueryHelperRenewSession(t *testing.T) {
   428  	var err error
   429  	origRequestID := NewUUID()
   430  	postQueryTest := func(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, requestID UUID, _ *Config) (*execResponse, error) {
   431  		// ensure the same requestID is used after the session token is renewed.
   432  		if requestID != origRequestID {
   433  			t.Fatal("requestID doesn't match")
   434  		}
   435  		dd := &execResponseData{}
   436  		return &execResponse{
   437  			Data:    *dd,
   438  			Message: "",
   439  			Code:    "0",
   440  			Success: true,
   441  		}, nil
   442  	}
   443  	sr := &snowflakeRestful{
   444  		FuncPost:         postTestRenew,
   445  		FuncPostQuery:    postQueryTest,
   446  		FuncRenewSession: renewSessionTest,
   447  		TokenAccessor:    getSimpleTokenAccessor(),
   448  	}
   449  
   450  	_, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, origRequestID, &Config{})
   451  	if err != nil {
   452  		t.Fatalf("err: %v", err)
   453  	}
   454  	sr.FuncRenewSession = renewSessionTestError
   455  	_, err = postRestfulQueryHelper(context.Background(), sr, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0, origRequestID, &Config{})
   456  	if err == nil {
   457  		t.Fatal("should have failed to renew session")
   458  	}
   459  }
   460  
   461  func TestUnitRenewRestfulSession(t *testing.T) {
   462  	accessor := getSimpleTokenAccessor()
   463  	oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100)
   464  	newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200)
   465  	postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
   466  		if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) {
   467  			t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken))
   468  		}
   469  		tr := &renewSessionResponse{
   470  			Data: renewSessionResponseMain{
   471  				SessionToken: newToken,
   472  				MasterToken:  newMasterToken,
   473  				SessionID:    newSessionID,
   474  			},
   475  			Message: "",
   476  			Success: true,
   477  		}
   478  		ba, err := json.Marshal(tr)
   479  		if err != nil {
   480  			t.Fatalf("failed to serialize token response %v", err)
   481  		}
   482  		return &http.Response{
   483  			StatusCode: http.StatusOK,
   484  			Body:       &fakeResponseBody{body: ba},
   485  		}, nil
   486  	}
   487  
   488  	sr := &snowflakeRestful{
   489  		FuncPost:      postTestAfterRenew,
   490  		TokenAccessor: accessor,
   491  	}
   492  	err := renewRestfulSession(context.Background(), sr, time.Second)
   493  	if err != nil {
   494  		t.Fatalf("err: %v", err)
   495  	}
   496  	sr.FuncPost = postTestError
   497  	err = renewRestfulSession(context.Background(), sr, time.Second)
   498  	if err == nil {
   499  		t.Fatal("should have failed to run post request after the renewal")
   500  	}
   501  	sr.FuncPost = postTestAppBadGatewayError
   502  	err = renewRestfulSession(context.Background(), sr, time.Second)
   503  	if err == nil {
   504  		t.Fatal("should have failed to run post request after the renewal")
   505  	}
   506  	sr.FuncPost = postTestSuccessButInvalidJSON
   507  	err = renewRestfulSession(context.Background(), sr, time.Second)
   508  	if err == nil {
   509  		t.Fatal("should have failed to run post request after the renewal")
   510  	}
   511  	accessor.SetTokens(oldToken, oldMasterToken, oldSessionID)
   512  	sr.FuncPost = postTestSuccessWithNewTokens
   513  	err = renewRestfulSession(context.Background(), sr, time.Second)
   514  	if err != nil {
   515  		t.Fatal("should not have failed to run post request after the renewal")
   516  	}
   517  	token, masterToken, sessionID := accessor.GetTokens()
   518  	if token != newToken {
   519  		t.Fatalf("unexpected new token %v", token)
   520  	}
   521  	if masterToken != newMasterToken {
   522  		t.Fatalf("unexpected new master token %v", masterToken)
   523  	}
   524  	if sessionID != newSessionID {
   525  		t.Fatalf("unexpected new session id %v", sessionID)
   526  	}
   527  }
   528  
   529  func TestUnitCloseSession(t *testing.T) {
   530  	sr := &snowflakeRestful{
   531  		FuncPost:      postTestAfterRenew,
   532  		TokenAccessor: getSimpleTokenAccessor(),
   533  	}
   534  	err := closeSession(context.Background(), sr, time.Second)
   535  	if err != nil {
   536  		t.Fatalf("err: %v", err)
   537  	}
   538  	sr.FuncPost = postTestError
   539  	err = closeSession(context.Background(), sr, time.Second)
   540  	if err == nil {
   541  		t.Fatal("should have failed to close session")
   542  	}
   543  	sr.FuncPost = postTestAppBadGatewayError
   544  	err = closeSession(context.Background(), sr, time.Second)
   545  	if err == nil {
   546  		t.Fatal("should have failed to close session")
   547  	}
   548  	sr.FuncPost = postTestSuccessButInvalidJSON
   549  	err = closeSession(context.Background(), sr, time.Second)
   550  	if err == nil {
   551  		t.Fatal("should have failed to close session")
   552  	}
   553  }
   554  
   555  func TestUnitCancelQuery(t *testing.T) {
   556  	sr := &snowflakeRestful{
   557  		FuncPost:      postTestAfterRenew,
   558  		TokenAccessor: getSimpleTokenAccessor(),
   559  	}
   560  	ctx := context.Background()
   561  	err := cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second)
   562  	if err != nil {
   563  		t.Fatalf("err: %v", err)
   564  	}
   565  	sr.FuncPost = postTestError
   566  	err = cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second)
   567  	if err == nil {
   568  		t.Fatal("should have failed to close session")
   569  	}
   570  	sr.FuncPost = postTestAppBadGatewayError
   571  	err = cancelQuery(context.Background(), sr, getOrGenerateRequestIDFromContext(ctx), time.Second)
   572  	if err == nil {
   573  		t.Fatal("should have failed to close session")
   574  	}
   575  	sr.FuncPost = postTestSuccessButInvalidJSON
   576  	err = cancelQuery(context.Background(), sr, getOrGenerateRequestIDFromContext(ctx), time.Second)
   577  	if err == nil {
   578  		t.Fatal("should have failed to close session")
   579  	}
   580  }
   581  
   582  func TestCancelRetry(t *testing.T) {
   583  	sr := &snowflakeRestful{
   584  		TokenAccessor:   getSimpleTokenAccessor(),
   585  		FuncPost:        postTestQueryNotExecuting,
   586  		FuncCancelQuery: cancelTestRetry,
   587  	}
   588  	ctx := context.Background()
   589  	err := cancelQuery(ctx, sr, getOrGenerateRequestIDFromContext(ctx), time.Second)
   590  	if err != nil {
   591  		t.Fatal(err)
   592  	}
   593  }