github.com/blend/go-sdk@v1.20220411.3/grpcutil/client_retry.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package grpcutil 9 10 import ( 11 "context" 12 "encoding/base64" 13 "fmt" 14 "io" 15 "strings" 16 "sync" 17 "time" 18 19 "google.golang.org/grpc" 20 "google.golang.org/grpc/codes" 21 "google.golang.org/grpc/metadata" 22 "google.golang.org/grpc/status" 23 ) 24 25 var ( 26 // DefaultRetriableCodes is a set of well known types gRPC codes that should be retri-able. 27 // 28 // `ResourceExhausted` means that the user quota, e.g. per-RPC limits, have been reached. 29 // `Unavailable` means that system is currently unavailable and the client should retry again. 30 DefaultRetriableCodes = []codes.Code{codes.ResourceExhausted, codes.Unavailable} 31 32 defaultRetryOptions = &retryOptions{ 33 max: 0, // disabled 34 perCallTimeout: 0, // disabled 35 includeHeader: true, 36 codes: DefaultRetriableCodes, 37 backoffFunc: BackoffFuncContext(func(ctx context.Context, attempt uint) time.Duration { 38 return BackoffLinearWithJitter(50*time.Millisecond, 0.10)(attempt) 39 }), 40 } 41 ) 42 43 // Metadata Keys 44 const ( 45 MetadataKeyAttempt = "x-retry-attempty" 46 ) 47 48 // WithRetriesDisabled disables the retry behavior on this call, or this interceptor. 49 // 50 // Its semantically the same to `WithMax` 51 func WithRetriesDisabled() CallOption { 52 return WithClientRetries(0) 53 } 54 55 // WithClientRetries sets the maximum number of retries on this call, or this interceptor. 56 func WithClientRetries(maxRetries uint) CallOption { 57 return CallOption{applyFunc: func(o *retryOptions) { 58 o.max = maxRetries 59 }} 60 } 61 62 // WithClientRetryBackoffLinear sets the retry backoff to a fixed duration. 63 func WithClientRetryBackoffLinear(d time.Duration) CallOption { 64 return WithClientRetryBackoffFunc(BackoffLinear(d)) 65 } 66 67 // WithClientRetryBackoffFunc sets the `ClientRetryBackoffFunc` used to control time between retries. 68 func WithClientRetryBackoffFunc(bf BackoffFunc) CallOption { 69 return CallOption{applyFunc: func(o *retryOptions) { 70 o.backoffFunc = BackoffFuncContext(func(ctx context.Context, attempt uint) time.Duration { 71 return bf(attempt) 72 }) 73 }} 74 } 75 76 // WithClientRetryBackoffContext sets the `BackoffFuncContext` used to control time between retries. 77 func WithClientRetryBackoffContext(bf BackoffFuncContext) CallOption { 78 return CallOption{applyFunc: func(o *retryOptions) { 79 o.backoffFunc = bf 80 }} 81 } 82 83 // WithClientRetryCodes sets which codes should be retried. 84 // 85 // Please *use with care*, as you may be retrying non-idempotent calls. 86 // 87 // You cannot automatically retry on Canceled and Deadline, please use `WithPerRetryTimeout` for these. 88 func WithClientRetryCodes(retryCodes ...codes.Code) CallOption { 89 return CallOption{applyFunc: func(o *retryOptions) { 90 o.codes = retryCodes 91 }} 92 } 93 94 // WithClientRetryPerRetryTimeout sets the RPC timeout per call (including initial call) on this call, or this interceptor. 95 // 96 // The context.Deadline of the call takes precedence and sets the maximum time the whole invocation 97 // will take, but WithPerRetryTimeout can be used to limit the RPC time per each call. 98 // 99 // For example, with context.Deadline = now + 10s, and WithPerRetryTimeout(3 * time.Seconds), each 100 // of the retry calls (including the initial one) will have a deadline of now + 3s. 101 // 102 // A value of 0 disables the timeout overrides completely and returns to each retry call using the 103 // parent `context.Deadline`. 104 // 105 // Note that when this is enabled, any DeadlineExceeded errors that are propagated up will be retried. 106 func WithClientRetryPerRetryTimeout(timeout time.Duration) CallOption { 107 return CallOption{applyFunc: func(o *retryOptions) { 108 o.perCallTimeout = timeout 109 }} 110 } 111 112 type retryOptions struct { 113 max uint 114 perCallTimeout time.Duration 115 includeHeader bool 116 codes []codes.Code 117 backoffFunc BackoffFuncContext 118 abortOnFailure bool 119 } 120 121 // CallOption is a grpc.CallOption that is local to grpc_retry. 122 type CallOption struct { 123 grpc.EmptyCallOption // make sure we implement private after() and before() fields so we don't panic. 124 applyFunc func(opt *retryOptions) 125 } 126 127 func reuseOrNewWithCallOptions(opt *retryOptions, callOptions []CallOption) *retryOptions { 128 if len(callOptions) == 0 { 129 return opt 130 } 131 optCopy := new(retryOptions) 132 *optCopy = *opt 133 for _, f := range callOptions { 134 f.applyFunc(optCopy) 135 } 136 return optCopy 137 } 138 139 func filterCallOptions(callOptions []grpc.CallOption) (grpcOptions []grpc.CallOption, retryOptions []CallOption) { 140 for _, opt := range callOptions { 141 if co, ok := opt.(CallOption); ok { 142 retryOptions = append(retryOptions, co) 143 } else { 144 grpcOptions = append(grpcOptions, opt) 145 } 146 } 147 return grpcOptions, retryOptions 148 } 149 150 // RetryUnaryClientInterceptor returns a new retrying unary client interceptor. 151 // 152 // The default configuration of the interceptor is to not retry *at all*. This behavior can be 153 // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions). 154 func RetryUnaryClientInterceptor(optFuncs ...CallOption) grpc.UnaryClientInterceptor { 155 intOpts := reuseOrNewWithCallOptions(defaultRetryOptions, optFuncs) 156 return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 157 grpcOpts, retryOpts := filterCallOptions(opts) 158 callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts) 159 if callOpts.max == 0 { 160 return invoker(parentCtx, method, req, reply, cc, grpcOpts...) 161 } 162 var lastErr error 163 for attempt := uint(0); attempt < callOpts.max; attempt++ { 164 callCtx, cancel := perCallContext(parentCtx, callOpts, attempt) 165 func() { 166 defer cancel() 167 lastErr = invoker(callCtx, method, req, reply, cc, grpcOpts...) 168 }() 169 if lastErr == nil { 170 return nil 171 } 172 if isContextError(lastErr) { 173 if parentCtx.Err() != nil { 174 // its the parent context deadline or cancellation. 175 return lastErr 176 } else if callOpts.perCallTimeout != 0 { 177 // We have set a perCallTimeout in the retry middleware, which would result in a context error if 178 // the deadline was exceeded, in which case try again. 179 continue 180 } 181 } 182 if !isRetriable(lastErr, callOpts) { 183 return lastErr 184 } 185 if err := waitRetryBackoff(parentCtx, attempt, callOpts); err != nil { 186 return err 187 } 188 } 189 return lastErr 190 } 191 } 192 193 // RetryStreamClientInterceptor returns a new retrying stream client interceptor for server side streaming calls. 194 // 195 // The default configuration of the interceptor is to not retry *at all*. This behavior can be 196 // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions). 197 // 198 // Retry logic is available *only for ServerStreams*, i.e. 1:n streams, as the internal logic needs 199 // to buffer the messages sent by the client. If retry is enabled on any other streams (ClientStreams, 200 // BidiStreams), the retry interceptor will fail the call. 201 func RetryStreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientInterceptor { 202 intOpts := reuseOrNewWithCallOptions(defaultRetryOptions, optFuncs) 203 return func(parentCtx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 204 grpcOpts, retryOpts := filterCallOptions(opts) 205 callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts) 206 // short circuit for simplicity, and avoiding allocations. 207 if callOpts.max == 0 { 208 return streamer(parentCtx, desc, cc, method, grpcOpts...) 209 } 210 if desc.ClientStreams { 211 return nil, status.Errorf(codes.Unimplemented, "grpc_retry: cannot retry on ClientStreams, set grpc_retry.Disable()") 212 } 213 214 var lastErr error 215 for attempt := uint(0); attempt < callOpts.max; attempt++ { 216 if err := waitRetryBackoff(parentCtx, attempt, callOpts); err != nil { 217 return nil, err 218 } 219 callCtx, cancel := perCallContext(parentCtx, callOpts, 0) 220 221 var newStreamer grpc.ClientStream 222 func() { 223 defer cancel() 224 newStreamer, lastErr = streamer(callCtx, desc, cc, method, grpcOpts...) 225 }() 226 if lastErr == nil { 227 retryingStreamer := &serverStreamingRetryingStream{ 228 ClientStream: newStreamer, 229 callOpts: callOpts, 230 parentCtx: parentCtx, 231 streamerCall: func(ctx context.Context) (grpc.ClientStream, error) { 232 return streamer(ctx, desc, cc, method, grpcOpts...) 233 }, 234 } 235 return retryingStreamer, nil 236 } 237 238 if isContextError(lastErr) { 239 if parentCtx.Err() != nil { 240 // its the parent context deadline or cancellation. 241 return nil, lastErr 242 } else if callOpts.perCallTimeout != 0 { 243 // We have set a perCallTimeout in the retry middleware, which would result in a context error if 244 // the deadline was exceeded, in which case try again. 245 continue 246 } 247 } 248 if !isRetriable(lastErr, callOpts) { 249 return nil, lastErr 250 } 251 } 252 return nil, lastErr 253 } 254 } 255 256 // type serverStreamingRetryingStream is the implementation of grpc.ClientStream that acts as a 257 // proxy to the underlying call. If any of the RecvMsg() calls fail, it will try to reestablish 258 // a new ClientStream according to the retry policy. 259 type serverStreamingRetryingStream struct { 260 grpc.ClientStream 261 bufferedSends []interface{} // single message that the client can sen 262 receivedGood bool // indicates whether any prior receives were successful 263 wasClosedSend bool // indicates that CloseSend was closed 264 parentCtx context.Context 265 callOpts *retryOptions 266 streamerCall func(ctx context.Context) (grpc.ClientStream, error) 267 mu sync.RWMutex 268 } 269 270 func (s *serverStreamingRetryingStream) setStream(clientStream grpc.ClientStream) { 271 s.mu.Lock() 272 s.ClientStream = clientStream 273 s.mu.Unlock() 274 } 275 276 func (s *serverStreamingRetryingStream) getStream() grpc.ClientStream { 277 s.mu.RLock() 278 defer s.mu.RUnlock() 279 return s.ClientStream 280 } 281 282 func (s *serverStreamingRetryingStream) SendMsg(m interface{}) error { 283 s.mu.Lock() 284 s.bufferedSends = append(s.bufferedSends, m) 285 s.mu.Unlock() 286 return s.getStream().SendMsg(m) 287 } 288 289 func (s *serverStreamingRetryingStream) CloseSend() error { 290 s.mu.Lock() 291 s.wasClosedSend = true 292 s.mu.Unlock() 293 return s.getStream().CloseSend() 294 } 295 296 func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) { 297 return s.getStream().Header() 298 } 299 300 func (s *serverStreamingRetryingStream) Trailer() metadata.MD { 301 return s.getStream().Trailer() 302 } 303 304 func (s *serverStreamingRetryingStream) RecvMsg(m interface{}) error { 305 attemptRetry, lastErr := s.receiveMsgAndIndicateRetry(m) 306 if !attemptRetry { 307 return lastErr // success or hard failure 308 } 309 // We start off from attempt 1, because zeroth was already made on normal SendMsg(). 310 for attempt := uint(1); attempt < s.callOpts.max; attempt++ { 311 if err := waitRetryBackoff(s.parentCtx, attempt, s.callOpts); err != nil { 312 return err 313 } 314 callCtx, cancel := perCallContext(s.parentCtx, s.callOpts, attempt) 315 316 var newStream grpc.ClientStream 317 var err error 318 func() { 319 defer cancel() 320 newStream, err = s.reestablishStreamAndResendBuffer(callCtx) 321 }() 322 if err != nil { 323 // TODO(mwitkow): Maybe dial and transport errors should be retriable? 324 return err 325 } 326 s.setStream(newStream) 327 attemptRetry, lastErr = s.receiveMsgAndIndicateRetry(m) 328 //fmt.Printf("Received message and indicate: %v %v\n", attemptRetry, lastErr) 329 if !attemptRetry { 330 return lastErr 331 } 332 } 333 return lastErr 334 } 335 336 func (s *serverStreamingRetryingStream) receiveMsgAndIndicateRetry(m interface{}) (bool, error) { 337 s.mu.RLock() 338 wasGood := s.receivedGood 339 s.mu.RUnlock() 340 err := s.getStream().RecvMsg(m) 341 if err == nil || err == io.EOF { 342 s.mu.Lock() 343 s.receivedGood = true 344 s.mu.Unlock() 345 return false, err 346 } else if wasGood { 347 // previous RecvMsg in the stream succeeded, no retry logic should interfere 348 return false, err 349 } 350 if isContextError(err) { 351 if s.parentCtx.Err() != nil { 352 return false, err 353 } else if s.callOpts.perCallTimeout != 0 { 354 // We have set a perCallTimeout in the retry middleware, which would result in a context error if 355 // the deadline was exceeded, in which case try again. 356 return true, err 357 } 358 } 359 return isRetriable(err, s.callOpts), err 360 } 361 362 func (s *serverStreamingRetryingStream) reestablishStreamAndResendBuffer(callCtx context.Context) (grpc.ClientStream, error) { 363 s.mu.RLock() 364 bufferedSends := s.bufferedSends 365 s.mu.RUnlock() 366 newStream, err := s.streamerCall(callCtx) 367 if err != nil { 368 return nil, err 369 } 370 for _, msg := range bufferedSends { 371 if err := newStream.SendMsg(msg); err != nil { 372 return nil, err 373 } 374 } 375 if err := newStream.CloseSend(); err != nil { 376 return nil, err 377 } 378 return newStream, nil 379 } 380 381 func waitRetryBackoff(parentCtx context.Context, attempt uint, callOpts *retryOptions) error { 382 var waitTime time.Duration = 0 383 if attempt > 0 { 384 waitTime = callOpts.backoffFunc(parentCtx, attempt) 385 } 386 if waitTime > 0 { 387 timer := time.NewTimer(waitTime) 388 select { 389 case <-parentCtx.Done(): 390 timer.Stop() 391 return contextErrToGrpcErr(parentCtx.Err()) 392 case <-timer.C: 393 } 394 } 395 return nil 396 } 397 398 func isRetriable(err error, callOpts *retryOptions) bool { 399 if isContextError(err) { 400 return false 401 } 402 403 errCode := status.Code(err) 404 for _, code := range callOpts.codes { 405 if code == errCode { 406 return true 407 } 408 } 409 return !callOpts.abortOnFailure 410 } 411 412 func isContextError(err error) bool { 413 code := status.Code(err) 414 return code == codes.DeadlineExceeded || code == codes.Canceled 415 } 416 417 func perCallContext(parentCtx context.Context, callOpts *retryOptions, attempt uint) (ctx context.Context, cancel func()) { 418 ctx = parentCtx 419 cancel = func() {} 420 if callOpts.perCallTimeout != 0 { 421 ctx, cancel = context.WithTimeout(ctx, callOpts.perCallTimeout) 422 } 423 if attempt > 0 && callOpts.includeHeader { 424 mdClone := cloneMetadata(extractOutgoingMetadata(ctx)) 425 mdClone = setMetadata(mdClone, MetadataKeyAttempt, fmt.Sprintf("%d", attempt)) 426 ctx = toOutgoing(ctx, mdClone) 427 } 428 return 429 } 430 431 func contextErrToGrpcErr(err error) error { 432 switch err { 433 case context.DeadlineExceeded: 434 return status.Errorf(codes.DeadlineExceeded, err.Error()) 435 case context.Canceled: 436 return status.Errorf(codes.Canceled, err.Error()) 437 default: 438 return status.Errorf(codes.Unknown, err.Error()) 439 } 440 } 441 442 // extractOutgoingMetadata extracts an outbound metadata from the client-side context. 443 // 444 // This function always returns a NiceMD wrapper of the metadata.MD, in case the context doesn't have metadata it returns 445 // a new empty NiceMD. 446 func extractOutgoingMetadata(ctx context.Context) metadata.MD { 447 md, ok := metadata.FromOutgoingContext(ctx) 448 if !ok { 449 return metadata.Pairs() // empty md set 450 } 451 return md 452 } 453 454 // cloneMetadata clones a given md set. 455 func cloneMetadata(md metadata.MD, copiedKeys ...string) metadata.MD { 456 newMd := make(metadata.MD) 457 for k, vv := range md { 458 var found bool 459 if len(copiedKeys) == 0 { 460 found = true 461 } else { 462 for _, allowedKey := range copiedKeys { 463 if strings.EqualFold(allowedKey, k) { 464 found = true 465 break 466 } 467 } 468 } 469 if !found { 470 continue 471 } 472 newMd[k] = make([]string, len(vv)) 473 copy(newMd[k], vv) 474 } 475 return newMd 476 } 477 478 func setMetadata(md metadata.MD, key string, value string) metadata.MD { 479 k, v := encodeMetadataKeyValue(key, value) 480 md[k] = []string{v} 481 return md 482 } 483 484 // toOutgoing sets the given NiceMD as a client-side context for dispatching. 485 func toOutgoing(ctx context.Context, md metadata.MD) context.Context { 486 return metadata.NewOutgoingContext(ctx, md) 487 } 488 489 const ( 490 binHdrSuffix = "-bin" 491 ) 492 493 func encodeMetadataKeyValue(k, v string) (string, string) { 494 k = strings.ToLower(k) 495 if strings.HasSuffix(k, binHdrSuffix) { 496 val := base64.StdEncoding.EncodeToString([]byte(v)) 497 v = string(val) 498 } 499 return k, v 500 }