github.com/snowflakedb/gosnowflake@v1.9.0/retry.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"fmt"
     9  	"io"
    10  	"math"
    11  	"math/rand"
    12  	"net/http"
    13  	"net/url"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  )
    19  
    20  type waitAlgo struct {
    21  	mutex  *sync.Mutex // required for *rand.Rand usage
    22  	random *rand.Rand
    23  	base   time.Duration // base wait time
    24  	cap    time.Duration // maximum wait time
    25  }
    26  
    27  var random *rand.Rand
    28  var defaultWaitAlgo *waitAlgo
    29  
    30  var authEndpoints = []string{
    31  	loginRequestPath,
    32  	tokenRequestPath,
    33  	authenticatorRequestPath,
    34  }
    35  
    36  var clientErrorsStatusCodesEligibleForRetry = []int{
    37  	http.StatusTooManyRequests,
    38  	http.StatusRequestTimeout,
    39  }
    40  
    41  func init() {
    42  	random = rand.New(rand.NewSource(time.Now().UnixNano()))
    43  	// sleep time before retrying starts from 1s and the max sleep time is 16s
    44  	defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random, base: 1 * time.Second, cap: 16 * time.Second}
    45  }
    46  
    47  const (
    48  	// requestGUIDKey is attached to every request against Snowflake
    49  	requestGUIDKey string = "request_guid"
    50  	// retryCountKey is attached to query-request from the second time
    51  	retryCountKey string = "retryCount"
    52  	// retryReasonKey contains last HTTP status or 0 if timeout
    53  	retryReasonKey string = "retryReason"
    54  	// clientStartTime contains a time when client started request (first request, not retries)
    55  	clientStartTimeKey string = "clientStartTime"
    56  	// requestIDKey is attached to all requests to Snowflake
    57  	requestIDKey string = "requestId"
    58  )
    59  
    60  // This class takes in an url during construction and replaces the value of
    61  // request_guid every time replace() is called. If the url does not contain
    62  // request_guid, just return the original url
    63  type requestGUIDReplacer interface {
    64  	// replace the url with new ID
    65  	replace() *url.URL
    66  }
    67  
    68  // Make requestGUIDReplacer given a url string
    69  func newRequestGUIDReplace(urlPtr *url.URL) requestGUIDReplacer {
    70  	values, err := url.ParseQuery(urlPtr.RawQuery)
    71  	if err != nil {
    72  		// nop if invalid query parameters
    73  		return &transientReplace{urlPtr}
    74  	}
    75  	if len(values.Get(requestGUIDKey)) == 0 {
    76  		// nop if no request_guid is included.
    77  		return &transientReplace{urlPtr}
    78  	}
    79  
    80  	return &requestGUIDReplace{urlPtr, values}
    81  }
    82  
    83  // this replacer does nothing but replace the url
    84  type transientReplace struct {
    85  	urlPtr *url.URL
    86  }
    87  
    88  func (replacer *transientReplace) replace() *url.URL {
    89  	return replacer.urlPtr
    90  }
    91  
    92  /*
    93  requestGUIDReplacer is a one-shot object that is created out of the retry loop and
    94  called with replace to change the retry_guid's value upon every retry
    95  */
    96  type requestGUIDReplace struct {
    97  	urlPtr    *url.URL
    98  	urlValues url.Values
    99  }
   100  
   101  /*
   102  *
   103  This function would replace they value of the requestGUIDKey in a url with a newly
   104  generated UUID
   105  */
   106  func (replacer *requestGUIDReplace) replace() *url.URL {
   107  	replacer.urlValues.Del(requestGUIDKey)
   108  	replacer.urlValues.Add(requestGUIDKey, NewUUID().String())
   109  	replacer.urlPtr.RawQuery = replacer.urlValues.Encode()
   110  	return replacer.urlPtr
   111  }
   112  
   113  type retryCountUpdater interface {
   114  	replaceOrAdd(retry int) *url.URL
   115  }
   116  
   117  type retryCountUpdate struct {
   118  	urlPtr    *url.URL
   119  	urlValues url.Values
   120  }
   121  
   122  // this replacer does nothing but replace the url
   123  type transientRetryCountUpdater struct {
   124  	urlPtr *url.URL
   125  }
   126  
   127  func (replaceOrAdder *transientRetryCountUpdater) replaceOrAdd(retry int) *url.URL {
   128  	return replaceOrAdder.urlPtr
   129  }
   130  
   131  func (replacer *retryCountUpdate) replaceOrAdd(retry int) *url.URL {
   132  	replacer.urlValues.Del(retryCountKey)
   133  	replacer.urlValues.Add(retryCountKey, strconv.Itoa(retry))
   134  	replacer.urlPtr.RawQuery = replacer.urlValues.Encode()
   135  	return replacer.urlPtr
   136  }
   137  
   138  func newRetryCountUpdater(urlPtr *url.URL) retryCountUpdater {
   139  	if !isQueryRequest(urlPtr) {
   140  		// nop if not query-request
   141  		return &transientRetryCountUpdater{urlPtr}
   142  	}
   143  	values, err := url.ParseQuery(urlPtr.RawQuery)
   144  	if err != nil {
   145  		// nop if the URL is not valid
   146  		return &transientRetryCountUpdater{urlPtr}
   147  	}
   148  	return &retryCountUpdate{urlPtr, values}
   149  }
   150  
   151  type retryReasonUpdater interface {
   152  	replaceOrAdd(reason int) *url.URL
   153  }
   154  
   155  type retryReasonUpdate struct {
   156  	url *url.URL
   157  }
   158  
   159  func (retryReasonUpdater *retryReasonUpdate) replaceOrAdd(reason int) *url.URL {
   160  	query := retryReasonUpdater.url.Query()
   161  	query.Del(retryReasonKey)
   162  	query.Add(retryReasonKey, strconv.Itoa(reason))
   163  	retryReasonUpdater.url.RawQuery = query.Encode()
   164  	return retryReasonUpdater.url
   165  }
   166  
   167  type transientRetryReasonUpdater struct {
   168  	url *url.URL
   169  }
   170  
   171  func (retryReasonUpdater *transientRetryReasonUpdater) replaceOrAdd(_ int) *url.URL {
   172  	return retryReasonUpdater.url
   173  }
   174  
   175  func newRetryReasonUpdater(url *url.URL, cfg *Config) retryReasonUpdater {
   176  	// not a query request
   177  	if !isQueryRequest(url) {
   178  		return &transientRetryReasonUpdater{url}
   179  	}
   180  	// implicitly disabled retry reason
   181  	if cfg != nil && cfg.IncludeRetryReason == ConfigBoolFalse {
   182  		return &transientRetryReasonUpdater{url}
   183  	}
   184  	return &retryReasonUpdate{url}
   185  }
   186  
   187  func ensureClientStartTimeIsSet(url *url.URL, clientStartTime string) *url.URL {
   188  	if !isQueryRequest(url) {
   189  		// nop if not query-request
   190  		return url
   191  	}
   192  	query := url.Query()
   193  	if query.Has(clientStartTimeKey) {
   194  		return url
   195  	}
   196  	query.Add(clientStartTimeKey, clientStartTime)
   197  	url.RawQuery = query.Encode()
   198  	return url
   199  }
   200  
   201  func isQueryRequest(url *url.URL) bool {
   202  	return strings.HasPrefix(url.Path, queryRequestPath)
   203  }
   204  
   205  // jitter backoff in seconds
   206  func (w *waitAlgo) calculateWaitBeforeRetryForAuthRequest(attempt int, currWaitTimeDuration time.Duration) time.Duration {
   207  	w.mutex.Lock()
   208  	defer w.mutex.Unlock()
   209  	currWaitTimeInSeconds := currWaitTimeDuration.Seconds()
   210  	jitterAmount := w.getJitter(currWaitTimeInSeconds)
   211  	jitteredSleepTime := chooseRandomFromRange(currWaitTimeInSeconds+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount)
   212  	return time.Duration(jitteredSleepTime * float64(time.Second))
   213  }
   214  
   215  func (w *waitAlgo) calculateWaitBeforeRetry(sleep time.Duration) time.Duration {
   216  	w.mutex.Lock()
   217  	defer w.mutex.Unlock()
   218  	// use decorrelated jitter in retry time
   219  	randDuration := randMilliSecondDuration(w.base, sleep*3)
   220  	return durationMin(w.cap, randDuration)
   221  }
   222  
   223  func randMilliSecondDuration(base time.Duration, bound time.Duration) time.Duration {
   224  	baseNumber := int64(base / time.Millisecond)
   225  	boundNumber := int64(bound / time.Millisecond)
   226  	randomDuration := random.Int63n(boundNumber-baseNumber) + baseNumber
   227  	return time.Duration(randomDuration) * time.Millisecond
   228  }
   229  
   230  func (w *waitAlgo) getJitter(currWaitTime float64) float64 {
   231  	multiplicationFactor := chooseRandomFromRange(-1, 1)
   232  	jitterAmount := 0.5 * currWaitTime * multiplicationFactor
   233  	return jitterAmount
   234  }
   235  
   236  type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error)
   237  
   238  type clientInterface interface {
   239  	Do(req *http.Request) (*http.Response, error)
   240  }
   241  
   242  type retryHTTP struct {
   243  	ctx                 context.Context
   244  	client              clientInterface
   245  	req                 requestFunc
   246  	method              string
   247  	fullURL             *url.URL
   248  	headers             map[string]string
   249  	bodyCreator         bodyCreatorType
   250  	timeout             time.Duration
   251  	maxRetryCount       int
   252  	currentTimeProvider currentTimeProvider
   253  	cfg                 *Config
   254  }
   255  
   256  func newRetryHTTP(ctx context.Context,
   257  	client clientInterface,
   258  	req requestFunc,
   259  	fullURL *url.URL,
   260  	headers map[string]string,
   261  	timeout time.Duration,
   262  	maxRetryCount int,
   263  	currentTimeProvider currentTimeProvider,
   264  	cfg *Config) *retryHTTP {
   265  	instance := retryHTTP{}
   266  	instance.ctx = ctx
   267  	instance.client = client
   268  	instance.req = req
   269  	instance.method = "GET"
   270  	instance.fullURL = fullURL
   271  	instance.headers = headers
   272  	instance.timeout = timeout
   273  	instance.maxRetryCount = maxRetryCount
   274  	instance.bodyCreator = emptyBodyCreator
   275  	instance.currentTimeProvider = currentTimeProvider
   276  	instance.cfg = cfg
   277  	return &instance
   278  }
   279  
   280  func (r *retryHTTP) doPost() *retryHTTP {
   281  	r.method = "POST"
   282  	return r
   283  }
   284  
   285  func (r *retryHTTP) setBody(body []byte) *retryHTTP {
   286  	r.bodyCreator = func() ([]byte, error) {
   287  		return body, nil
   288  	}
   289  	return r
   290  }
   291  
   292  func (r *retryHTTP) setBodyCreator(bodyCreator bodyCreatorType) *retryHTTP {
   293  	r.bodyCreator = bodyCreator
   294  	return r
   295  }
   296  
   297  func (r *retryHTTP) execute() (res *http.Response, err error) {
   298  	totalTimeout := r.timeout
   299  	logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout)
   300  	retryCounter := 0
   301  	sleepTime := time.Duration(time.Second)
   302  	clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10)
   303  
   304  	var requestGUIDReplacer requestGUIDReplacer
   305  	var retryCountUpdater retryCountUpdater
   306  	var retryReasonUpdater retryReasonUpdater
   307  
   308  	for {
   309  		logger.Debugf("retry count: %v", retryCounter)
   310  		body, err := r.bodyCreator()
   311  		if err != nil {
   312  			return nil, err
   313  		}
   314  		req, err := r.req(r.method, r.fullURL.String(), bytes.NewReader(body))
   315  		if err != nil {
   316  			return nil, err
   317  		}
   318  		if req != nil {
   319  			// req can be nil in tests
   320  			req = req.WithContext(r.ctx)
   321  		}
   322  		for k, v := range r.headers {
   323  			req.Header.Set(k, v)
   324  		}
   325  		res, err = r.client.Do(req)
   326  		// check if it can retry.
   327  		retryable, err := isRetryableError(req, res, err)
   328  		if !retryable {
   329  			return res, err
   330  		}
   331  		if err != nil {
   332  			logger.WithContext(r.ctx).Warningf(
   333  				"failed http connection. err: %v. retrying...\n", err)
   334  		} else {
   335  			logger.WithContext(r.ctx).Warningf(
   336  				"failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode)
   337  			res.Body.Close()
   338  		}
   339  		// uses exponential jitter backoff
   340  		retryCounter++
   341  		if isLoginRequest(req) {
   342  			sleepTime = defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(retryCounter, sleepTime)
   343  		} else {
   344  			sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(sleepTime)
   345  		}
   346  
   347  		if totalTimeout > 0 {
   348  			logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout)
   349  			// if any timeout is set
   350  			totalTimeout -= sleepTime
   351  			if totalTimeout <= 0 || retryCounter > r.maxRetryCount {
   352  				if err != nil {
   353  					return nil, err
   354  				}
   355  				if res != nil {
   356  					return nil, fmt.Errorf("timeout after %s and %v retries. HTTP Status: %v. Hanging?", r.timeout, retryCounter, res.StatusCode)
   357  				}
   358  				return nil, fmt.Errorf("timeout after %s and %v retries. Hanging?", r.timeout, retryCounter)
   359  			}
   360  		}
   361  		if requestGUIDReplacer == nil {
   362  			requestGUIDReplacer = newRequestGUIDReplace(r.fullURL)
   363  		}
   364  		r.fullURL = requestGUIDReplacer.replace()
   365  		if retryCountUpdater == nil {
   366  			retryCountUpdater = newRetryCountUpdater(r.fullURL)
   367  		}
   368  		r.fullURL = retryCountUpdater.replaceOrAdd(retryCounter)
   369  		if retryReasonUpdater == nil {
   370  			retryReasonUpdater = newRetryReasonUpdater(r.fullURL, r.cfg)
   371  		}
   372  		retryReason := 0
   373  		if res != nil {
   374  			retryReason = res.StatusCode
   375  		}
   376  		r.fullURL = retryReasonUpdater.replaceOrAdd(retryReason)
   377  		r.fullURL = ensureClientStartTimeIsSet(r.fullURL, clientStartTime)
   378  		logger.WithContext(r.ctx).Infof("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout)
   379  		logger.WithContext(r.ctx).Infof("retry count: %v, retry reason: %v", retryCounter, retryReason)
   380  
   381  		await := time.NewTimer(sleepTime)
   382  		select {
   383  		case <-await.C:
   384  			// retry the request
   385  		case <-r.ctx.Done():
   386  			await.Stop()
   387  			return res, r.ctx.Err()
   388  		}
   389  	}
   390  }
   391  
   392  func isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) {
   393  	if err != nil && res == nil { // Failed http connection. Most probably client timeout.
   394  		return true, err
   395  	}
   396  	if res == nil || req == nil {
   397  		return false, err
   398  	}
   399  	return isRetryableStatus(res.StatusCode), err
   400  }
   401  
   402  func isRetryableStatus(statusCode int) bool {
   403  	return (statusCode >= 500 && statusCode < 600) || contains(clientErrorsStatusCodesEligibleForRetry, statusCode)
   404  }
   405  
   406  func isLoginRequest(req *http.Request) bool {
   407  	return contains(authEndpoints, req.URL.Path)
   408  }