github.com/snowflakedb/gosnowflake@v1.9.0/retry.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "fmt" 9 "io" 10 "math" 11 "math/rand" 12 "net/http" 13 "net/url" 14 "strconv" 15 "strings" 16 "sync" 17 "time" 18 ) 19 20 type waitAlgo struct { 21 mutex *sync.Mutex // required for *rand.Rand usage 22 random *rand.Rand 23 base time.Duration // base wait time 24 cap time.Duration // maximum wait time 25 } 26 27 var random *rand.Rand 28 var defaultWaitAlgo *waitAlgo 29 30 var authEndpoints = []string{ 31 loginRequestPath, 32 tokenRequestPath, 33 authenticatorRequestPath, 34 } 35 36 var clientErrorsStatusCodesEligibleForRetry = []int{ 37 http.StatusTooManyRequests, 38 http.StatusRequestTimeout, 39 } 40 41 func init() { 42 random = rand.New(rand.NewSource(time.Now().UnixNano())) 43 // sleep time before retrying starts from 1s and the max sleep time is 16s 44 defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random, base: 1 * time.Second, cap: 16 * time.Second} 45 } 46 47 const ( 48 // requestGUIDKey is attached to every request against Snowflake 49 requestGUIDKey string = "request_guid" 50 // retryCountKey is attached to query-request from the second time 51 retryCountKey string = "retryCount" 52 // retryReasonKey contains last HTTP status or 0 if timeout 53 retryReasonKey string = "retryReason" 54 // clientStartTime contains a time when client started request (first request, not retries) 55 clientStartTimeKey string = "clientStartTime" 56 // requestIDKey is attached to all requests to Snowflake 57 requestIDKey string = "requestId" 58 ) 59 60 // This class takes in an url during construction and replaces the value of 61 // request_guid every time replace() is called. If the url does not contain 62 // request_guid, just return the original url 63 type requestGUIDReplacer interface { 64 // replace the url with new ID 65 replace() *url.URL 66 } 67 68 // Make requestGUIDReplacer given a url string 69 func newRequestGUIDReplace(urlPtr *url.URL) requestGUIDReplacer { 70 values, err := url.ParseQuery(urlPtr.RawQuery) 71 if err != nil { 72 // nop if invalid query parameters 73 return &transientReplace{urlPtr} 74 } 75 if len(values.Get(requestGUIDKey)) == 0 { 76 // nop if no request_guid is included. 77 return &transientReplace{urlPtr} 78 } 79 80 return &requestGUIDReplace{urlPtr, values} 81 } 82 83 // this replacer does nothing but replace the url 84 type transientReplace struct { 85 urlPtr *url.URL 86 } 87 88 func (replacer *transientReplace) replace() *url.URL { 89 return replacer.urlPtr 90 } 91 92 /* 93 requestGUIDReplacer is a one-shot object that is created out of the retry loop and 94 called with replace to change the retry_guid's value upon every retry 95 */ 96 type requestGUIDReplace struct { 97 urlPtr *url.URL 98 urlValues url.Values 99 } 100 101 /* 102 * 103 This function would replace they value of the requestGUIDKey in a url with a newly 104 generated UUID 105 */ 106 func (replacer *requestGUIDReplace) replace() *url.URL { 107 replacer.urlValues.Del(requestGUIDKey) 108 replacer.urlValues.Add(requestGUIDKey, NewUUID().String()) 109 replacer.urlPtr.RawQuery = replacer.urlValues.Encode() 110 return replacer.urlPtr 111 } 112 113 type retryCountUpdater interface { 114 replaceOrAdd(retry int) *url.URL 115 } 116 117 type retryCountUpdate struct { 118 urlPtr *url.URL 119 urlValues url.Values 120 } 121 122 // this replacer does nothing but replace the url 123 type transientRetryCountUpdater struct { 124 urlPtr *url.URL 125 } 126 127 func (replaceOrAdder *transientRetryCountUpdater) replaceOrAdd(retry int) *url.URL { 128 return replaceOrAdder.urlPtr 129 } 130 131 func (replacer *retryCountUpdate) replaceOrAdd(retry int) *url.URL { 132 replacer.urlValues.Del(retryCountKey) 133 replacer.urlValues.Add(retryCountKey, strconv.Itoa(retry)) 134 replacer.urlPtr.RawQuery = replacer.urlValues.Encode() 135 return replacer.urlPtr 136 } 137 138 func newRetryCountUpdater(urlPtr *url.URL) retryCountUpdater { 139 if !isQueryRequest(urlPtr) { 140 // nop if not query-request 141 return &transientRetryCountUpdater{urlPtr} 142 } 143 values, err := url.ParseQuery(urlPtr.RawQuery) 144 if err != nil { 145 // nop if the URL is not valid 146 return &transientRetryCountUpdater{urlPtr} 147 } 148 return &retryCountUpdate{urlPtr, values} 149 } 150 151 type retryReasonUpdater interface { 152 replaceOrAdd(reason int) *url.URL 153 } 154 155 type retryReasonUpdate struct { 156 url *url.URL 157 } 158 159 func (retryReasonUpdater *retryReasonUpdate) replaceOrAdd(reason int) *url.URL { 160 query := retryReasonUpdater.url.Query() 161 query.Del(retryReasonKey) 162 query.Add(retryReasonKey, strconv.Itoa(reason)) 163 retryReasonUpdater.url.RawQuery = query.Encode() 164 return retryReasonUpdater.url 165 } 166 167 type transientRetryReasonUpdater struct { 168 url *url.URL 169 } 170 171 func (retryReasonUpdater *transientRetryReasonUpdater) replaceOrAdd(_ int) *url.URL { 172 return retryReasonUpdater.url 173 } 174 175 func newRetryReasonUpdater(url *url.URL, cfg *Config) retryReasonUpdater { 176 // not a query request 177 if !isQueryRequest(url) { 178 return &transientRetryReasonUpdater{url} 179 } 180 // implicitly disabled retry reason 181 if cfg != nil && cfg.IncludeRetryReason == ConfigBoolFalse { 182 return &transientRetryReasonUpdater{url} 183 } 184 return &retryReasonUpdate{url} 185 } 186 187 func ensureClientStartTimeIsSet(url *url.URL, clientStartTime string) *url.URL { 188 if !isQueryRequest(url) { 189 // nop if not query-request 190 return url 191 } 192 query := url.Query() 193 if query.Has(clientStartTimeKey) { 194 return url 195 } 196 query.Add(clientStartTimeKey, clientStartTime) 197 url.RawQuery = query.Encode() 198 return url 199 } 200 201 func isQueryRequest(url *url.URL) bool { 202 return strings.HasPrefix(url.Path, queryRequestPath) 203 } 204 205 // jitter backoff in seconds 206 func (w *waitAlgo) calculateWaitBeforeRetryForAuthRequest(attempt int, currWaitTimeDuration time.Duration) time.Duration { 207 w.mutex.Lock() 208 defer w.mutex.Unlock() 209 currWaitTimeInSeconds := currWaitTimeDuration.Seconds() 210 jitterAmount := w.getJitter(currWaitTimeInSeconds) 211 jitteredSleepTime := chooseRandomFromRange(currWaitTimeInSeconds+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount) 212 return time.Duration(jitteredSleepTime * float64(time.Second)) 213 } 214 215 func (w *waitAlgo) calculateWaitBeforeRetry(sleep time.Duration) time.Duration { 216 w.mutex.Lock() 217 defer w.mutex.Unlock() 218 // use decorrelated jitter in retry time 219 randDuration := randMilliSecondDuration(w.base, sleep*3) 220 return durationMin(w.cap, randDuration) 221 } 222 223 func randMilliSecondDuration(base time.Duration, bound time.Duration) time.Duration { 224 baseNumber := int64(base / time.Millisecond) 225 boundNumber := int64(bound / time.Millisecond) 226 randomDuration := random.Int63n(boundNumber-baseNumber) + baseNumber 227 return time.Duration(randomDuration) * time.Millisecond 228 } 229 230 func (w *waitAlgo) getJitter(currWaitTime float64) float64 { 231 multiplicationFactor := chooseRandomFromRange(-1, 1) 232 jitterAmount := 0.5 * currWaitTime * multiplicationFactor 233 return jitterAmount 234 } 235 236 type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) 237 238 type clientInterface interface { 239 Do(req *http.Request) (*http.Response, error) 240 } 241 242 type retryHTTP struct { 243 ctx context.Context 244 client clientInterface 245 req requestFunc 246 method string 247 fullURL *url.URL 248 headers map[string]string 249 bodyCreator bodyCreatorType 250 timeout time.Duration 251 maxRetryCount int 252 currentTimeProvider currentTimeProvider 253 cfg *Config 254 } 255 256 func newRetryHTTP(ctx context.Context, 257 client clientInterface, 258 req requestFunc, 259 fullURL *url.URL, 260 headers map[string]string, 261 timeout time.Duration, 262 maxRetryCount int, 263 currentTimeProvider currentTimeProvider, 264 cfg *Config) *retryHTTP { 265 instance := retryHTTP{} 266 instance.ctx = ctx 267 instance.client = client 268 instance.req = req 269 instance.method = "GET" 270 instance.fullURL = fullURL 271 instance.headers = headers 272 instance.timeout = timeout 273 instance.maxRetryCount = maxRetryCount 274 instance.bodyCreator = emptyBodyCreator 275 instance.currentTimeProvider = currentTimeProvider 276 instance.cfg = cfg 277 return &instance 278 } 279 280 func (r *retryHTTP) doPost() *retryHTTP { 281 r.method = "POST" 282 return r 283 } 284 285 func (r *retryHTTP) setBody(body []byte) *retryHTTP { 286 r.bodyCreator = func() ([]byte, error) { 287 return body, nil 288 } 289 return r 290 } 291 292 func (r *retryHTTP) setBodyCreator(bodyCreator bodyCreatorType) *retryHTTP { 293 r.bodyCreator = bodyCreator 294 return r 295 } 296 297 func (r *retryHTTP) execute() (res *http.Response, err error) { 298 totalTimeout := r.timeout 299 logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout) 300 retryCounter := 0 301 sleepTime := time.Duration(time.Second) 302 clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10) 303 304 var requestGUIDReplacer requestGUIDReplacer 305 var retryCountUpdater retryCountUpdater 306 var retryReasonUpdater retryReasonUpdater 307 308 for { 309 logger.Debugf("retry count: %v", retryCounter) 310 body, err := r.bodyCreator() 311 if err != nil { 312 return nil, err 313 } 314 req, err := r.req(r.method, r.fullURL.String(), bytes.NewReader(body)) 315 if err != nil { 316 return nil, err 317 } 318 if req != nil { 319 // req can be nil in tests 320 req = req.WithContext(r.ctx) 321 } 322 for k, v := range r.headers { 323 req.Header.Set(k, v) 324 } 325 res, err = r.client.Do(req) 326 // check if it can retry. 327 retryable, err := isRetryableError(req, res, err) 328 if !retryable { 329 return res, err 330 } 331 if err != nil { 332 logger.WithContext(r.ctx).Warningf( 333 "failed http connection. err: %v. retrying...\n", err) 334 } else { 335 logger.WithContext(r.ctx).Warningf( 336 "failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode) 337 res.Body.Close() 338 } 339 // uses exponential jitter backoff 340 retryCounter++ 341 if isLoginRequest(req) { 342 sleepTime = defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(retryCounter, sleepTime) 343 } else { 344 sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(sleepTime) 345 } 346 347 if totalTimeout > 0 { 348 logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout) 349 // if any timeout is set 350 totalTimeout -= sleepTime 351 if totalTimeout <= 0 || retryCounter > r.maxRetryCount { 352 if err != nil { 353 return nil, err 354 } 355 if res != nil { 356 return nil, fmt.Errorf("timeout after %s and %v retries. HTTP Status: %v. Hanging?", r.timeout, retryCounter, res.StatusCode) 357 } 358 return nil, fmt.Errorf("timeout after %s and %v retries. Hanging?", r.timeout, retryCounter) 359 } 360 } 361 if requestGUIDReplacer == nil { 362 requestGUIDReplacer = newRequestGUIDReplace(r.fullURL) 363 } 364 r.fullURL = requestGUIDReplacer.replace() 365 if retryCountUpdater == nil { 366 retryCountUpdater = newRetryCountUpdater(r.fullURL) 367 } 368 r.fullURL = retryCountUpdater.replaceOrAdd(retryCounter) 369 if retryReasonUpdater == nil { 370 retryReasonUpdater = newRetryReasonUpdater(r.fullURL, r.cfg) 371 } 372 retryReason := 0 373 if res != nil { 374 retryReason = res.StatusCode 375 } 376 r.fullURL = retryReasonUpdater.replaceOrAdd(retryReason) 377 r.fullURL = ensureClientStartTimeIsSet(r.fullURL, clientStartTime) 378 logger.WithContext(r.ctx).Infof("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout) 379 logger.WithContext(r.ctx).Infof("retry count: %v, retry reason: %v", retryCounter, retryReason) 380 381 await := time.NewTimer(sleepTime) 382 select { 383 case <-await.C: 384 // retry the request 385 case <-r.ctx.Done(): 386 await.Stop() 387 return res, r.ctx.Err() 388 } 389 } 390 } 391 392 func isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) { 393 if err != nil && res == nil { // Failed http connection. Most probably client timeout. 394 return true, err 395 } 396 if res == nil || req == nil { 397 return false, err 398 } 399 return isRetryableStatus(res.StatusCode), err 400 } 401 402 func isRetryableStatus(statusCode int) bool { 403 return (statusCode >= 500 && statusCode < 600) || contains(clientErrorsStatusCodesEligibleForRetry, statusCode) 404 } 405 406 func isLoginRequest(req *http.Request) bool { 407 return contains(authEndpoints, req.URL.Path) 408 }