github.com/blend/go-sdk@v1.20220411.3/grpcutil/client_retry.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package grpcutil
     9  
    10  import (
    11  	"context"
    12  	"encoding/base64"
    13  	"fmt"
    14  	"io"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"google.golang.org/grpc"
    20  	"google.golang.org/grpc/codes"
    21  	"google.golang.org/grpc/metadata"
    22  	"google.golang.org/grpc/status"
    23  )
    24  
    25  var (
    26  	// DefaultRetriableCodes is a set of well known types gRPC codes that should be retri-able.
    27  	//
    28  	// `ResourceExhausted` means that the user quota, e.g. per-RPC limits, have been reached.
    29  	// `Unavailable` means that system is currently unavailable and the client should retry again.
    30  	DefaultRetriableCodes = []codes.Code{codes.ResourceExhausted, codes.Unavailable}
    31  
    32  	defaultRetryOptions = &retryOptions{
    33  		max:            0, // disabled
    34  		perCallTimeout: 0, // disabled
    35  		includeHeader:  true,
    36  		codes:          DefaultRetriableCodes,
    37  		backoffFunc: BackoffFuncContext(func(ctx context.Context, attempt uint) time.Duration {
    38  			return BackoffLinearWithJitter(50*time.Millisecond, 0.10)(attempt)
    39  		}),
    40  	}
    41  )
    42  
    43  // Metadata Keys
    44  const (
    45  	MetadataKeyAttempt = "x-retry-attempty"
    46  )
    47  
    48  // WithRetriesDisabled disables the retry behavior on this call, or this interceptor.
    49  //
    50  // Its semantically the same to `WithMax`
    51  func WithRetriesDisabled() CallOption {
    52  	return WithClientRetries(0)
    53  }
    54  
    55  // WithClientRetries sets the maximum number of retries on this call, or this interceptor.
    56  func WithClientRetries(maxRetries uint) CallOption {
    57  	return CallOption{applyFunc: func(o *retryOptions) {
    58  		o.max = maxRetries
    59  	}}
    60  }
    61  
    62  // WithClientRetryBackoffLinear sets the retry backoff to a fixed duration.
    63  func WithClientRetryBackoffLinear(d time.Duration) CallOption {
    64  	return WithClientRetryBackoffFunc(BackoffLinear(d))
    65  }
    66  
    67  // WithClientRetryBackoffFunc sets the `ClientRetryBackoffFunc` used to control time between retries.
    68  func WithClientRetryBackoffFunc(bf BackoffFunc) CallOption {
    69  	return CallOption{applyFunc: func(o *retryOptions) {
    70  		o.backoffFunc = BackoffFuncContext(func(ctx context.Context, attempt uint) time.Duration {
    71  			return bf(attempt)
    72  		})
    73  	}}
    74  }
    75  
    76  // WithClientRetryBackoffContext sets the `BackoffFuncContext` used to control time between retries.
    77  func WithClientRetryBackoffContext(bf BackoffFuncContext) CallOption {
    78  	return CallOption{applyFunc: func(o *retryOptions) {
    79  		o.backoffFunc = bf
    80  	}}
    81  }
    82  
    83  // WithClientRetryCodes sets which codes should be retried.
    84  //
    85  // Please *use with care*, as you may be retrying non-idempotent calls.
    86  //
    87  // You cannot automatically retry on Canceled and Deadline, please use `WithPerRetryTimeout` for these.
    88  func WithClientRetryCodes(retryCodes ...codes.Code) CallOption {
    89  	return CallOption{applyFunc: func(o *retryOptions) {
    90  		o.codes = retryCodes
    91  	}}
    92  }
    93  
    94  // WithClientRetryPerRetryTimeout sets the RPC timeout per call (including initial call) on this call, or this interceptor.
    95  //
    96  // The context.Deadline of the call takes precedence and sets the maximum time the whole invocation
    97  // will take, but WithPerRetryTimeout can be used to limit the RPC time per each call.
    98  //
    99  // For example, with context.Deadline = now + 10s, and WithPerRetryTimeout(3 * time.Seconds), each
   100  // of the retry calls (including the initial one) will have a deadline of now + 3s.
   101  //
   102  // A value of 0 disables the timeout overrides completely and returns to each retry call using the
   103  // parent `context.Deadline`.
   104  //
   105  // Note that when this is enabled, any DeadlineExceeded errors that are propagated up will be retried.
   106  func WithClientRetryPerRetryTimeout(timeout time.Duration) CallOption {
   107  	return CallOption{applyFunc: func(o *retryOptions) {
   108  		o.perCallTimeout = timeout
   109  	}}
   110  }
   111  
   112  type retryOptions struct {
   113  	max            uint
   114  	perCallTimeout time.Duration
   115  	includeHeader  bool
   116  	codes          []codes.Code
   117  	backoffFunc    BackoffFuncContext
   118  	abortOnFailure bool
   119  }
   120  
   121  // CallOption is a grpc.CallOption that is local to grpc_retry.
   122  type CallOption struct {
   123  	grpc.EmptyCallOption // make sure we implement private after() and before() fields so we don't panic.
   124  	applyFunc            func(opt *retryOptions)
   125  }
   126  
   127  func reuseOrNewWithCallOptions(opt *retryOptions, callOptions []CallOption) *retryOptions {
   128  	if len(callOptions) == 0 {
   129  		return opt
   130  	}
   131  	optCopy := new(retryOptions)
   132  	*optCopy = *opt
   133  	for _, f := range callOptions {
   134  		f.applyFunc(optCopy)
   135  	}
   136  	return optCopy
   137  }
   138  
   139  func filterCallOptions(callOptions []grpc.CallOption) (grpcOptions []grpc.CallOption, retryOptions []CallOption) {
   140  	for _, opt := range callOptions {
   141  		if co, ok := opt.(CallOption); ok {
   142  			retryOptions = append(retryOptions, co)
   143  		} else {
   144  			grpcOptions = append(grpcOptions, opt)
   145  		}
   146  	}
   147  	return grpcOptions, retryOptions
   148  }
   149  
   150  // RetryUnaryClientInterceptor returns a new retrying unary client interceptor.
   151  //
   152  // The default configuration of the interceptor is to not retry *at all*. This behavior can be
   153  // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
   154  func RetryUnaryClientInterceptor(optFuncs ...CallOption) grpc.UnaryClientInterceptor {
   155  	intOpts := reuseOrNewWithCallOptions(defaultRetryOptions, optFuncs)
   156  	return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   157  		grpcOpts, retryOpts := filterCallOptions(opts)
   158  		callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
   159  		if callOpts.max == 0 {
   160  			return invoker(parentCtx, method, req, reply, cc, grpcOpts...)
   161  		}
   162  		var lastErr error
   163  		for attempt := uint(0); attempt < callOpts.max; attempt++ {
   164  			callCtx, cancel := perCallContext(parentCtx, callOpts, attempt)
   165  			func() {
   166  				defer cancel()
   167  				lastErr = invoker(callCtx, method, req, reply, cc, grpcOpts...)
   168  			}()
   169  			if lastErr == nil {
   170  				return nil
   171  			}
   172  			if isContextError(lastErr) {
   173  				if parentCtx.Err() != nil {
   174  					// its the parent context deadline or cancellation.
   175  					return lastErr
   176  				} else if callOpts.perCallTimeout != 0 {
   177  					// We have set a perCallTimeout in the retry middleware, which would result in a context error if
   178  					// the deadline was exceeded, in which case try again.
   179  					continue
   180  				}
   181  			}
   182  			if !isRetriable(lastErr, callOpts) {
   183  				return lastErr
   184  			}
   185  			if err := waitRetryBackoff(parentCtx, attempt, callOpts); err != nil {
   186  				return err
   187  			}
   188  		}
   189  		return lastErr
   190  	}
   191  }
   192  
   193  // RetryStreamClientInterceptor returns a new retrying stream client interceptor for server side streaming calls.
   194  //
   195  // The default configuration of the interceptor is to not retry *at all*. This behavior can be
   196  // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
   197  //
   198  // Retry logic is available *only for ServerStreams*, i.e. 1:n streams, as the internal logic needs
   199  // to buffer the messages sent by the client. If retry is enabled on any other streams (ClientStreams,
   200  // BidiStreams), the retry interceptor will fail the call.
   201  func RetryStreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientInterceptor {
   202  	intOpts := reuseOrNewWithCallOptions(defaultRetryOptions, optFuncs)
   203  	return func(parentCtx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   204  		grpcOpts, retryOpts := filterCallOptions(opts)
   205  		callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
   206  		// short circuit for simplicity, and avoiding allocations.
   207  		if callOpts.max == 0 {
   208  			return streamer(parentCtx, desc, cc, method, grpcOpts...)
   209  		}
   210  		if desc.ClientStreams {
   211  			return nil, status.Errorf(codes.Unimplemented, "grpc_retry: cannot retry on ClientStreams, set grpc_retry.Disable()")
   212  		}
   213  
   214  		var lastErr error
   215  		for attempt := uint(0); attempt < callOpts.max; attempt++ {
   216  			if err := waitRetryBackoff(parentCtx, attempt, callOpts); err != nil {
   217  				return nil, err
   218  			}
   219  			callCtx, cancel := perCallContext(parentCtx, callOpts, 0)
   220  
   221  			var newStreamer grpc.ClientStream
   222  			func() {
   223  				defer cancel()
   224  				newStreamer, lastErr = streamer(callCtx, desc, cc, method, grpcOpts...)
   225  			}()
   226  			if lastErr == nil {
   227  				retryingStreamer := &serverStreamingRetryingStream{
   228  					ClientStream: newStreamer,
   229  					callOpts:     callOpts,
   230  					parentCtx:    parentCtx,
   231  					streamerCall: func(ctx context.Context) (grpc.ClientStream, error) {
   232  						return streamer(ctx, desc, cc, method, grpcOpts...)
   233  					},
   234  				}
   235  				return retryingStreamer, nil
   236  			}
   237  
   238  			if isContextError(lastErr) {
   239  				if parentCtx.Err() != nil {
   240  					// its the parent context deadline or cancellation.
   241  					return nil, lastErr
   242  				} else if callOpts.perCallTimeout != 0 {
   243  					// We have set a perCallTimeout in the retry middleware, which would result in a context error if
   244  					// the deadline was exceeded, in which case try again.
   245  					continue
   246  				}
   247  			}
   248  			if !isRetriable(lastErr, callOpts) {
   249  				return nil, lastErr
   250  			}
   251  		}
   252  		return nil, lastErr
   253  	}
   254  }
   255  
   256  // type serverStreamingRetryingStream is the implementation of grpc.ClientStream that acts as a
   257  // proxy to the underlying call. If any of the RecvMsg() calls fail, it will try to reestablish
   258  // a new ClientStream according to the retry policy.
   259  type serverStreamingRetryingStream struct {
   260  	grpc.ClientStream
   261  	bufferedSends []interface{} // single message that the client can sen
   262  	receivedGood  bool          // indicates whether any prior receives were successful
   263  	wasClosedSend bool          // indicates that CloseSend was closed
   264  	parentCtx     context.Context
   265  	callOpts      *retryOptions
   266  	streamerCall  func(ctx context.Context) (grpc.ClientStream, error)
   267  	mu            sync.RWMutex
   268  }
   269  
   270  func (s *serverStreamingRetryingStream) setStream(clientStream grpc.ClientStream) {
   271  	s.mu.Lock()
   272  	s.ClientStream = clientStream
   273  	s.mu.Unlock()
   274  }
   275  
   276  func (s *serverStreamingRetryingStream) getStream() grpc.ClientStream {
   277  	s.mu.RLock()
   278  	defer s.mu.RUnlock()
   279  	return s.ClientStream
   280  }
   281  
   282  func (s *serverStreamingRetryingStream) SendMsg(m interface{}) error {
   283  	s.mu.Lock()
   284  	s.bufferedSends = append(s.bufferedSends, m)
   285  	s.mu.Unlock()
   286  	return s.getStream().SendMsg(m)
   287  }
   288  
   289  func (s *serverStreamingRetryingStream) CloseSend() error {
   290  	s.mu.Lock()
   291  	s.wasClosedSend = true
   292  	s.mu.Unlock()
   293  	return s.getStream().CloseSend()
   294  }
   295  
   296  func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) {
   297  	return s.getStream().Header()
   298  }
   299  
   300  func (s *serverStreamingRetryingStream) Trailer() metadata.MD {
   301  	return s.getStream().Trailer()
   302  }
   303  
   304  func (s *serverStreamingRetryingStream) RecvMsg(m interface{}) error {
   305  	attemptRetry, lastErr := s.receiveMsgAndIndicateRetry(m)
   306  	if !attemptRetry {
   307  		return lastErr // success or hard failure
   308  	}
   309  	// We start off from attempt 1, because zeroth was already made on normal SendMsg().
   310  	for attempt := uint(1); attempt < s.callOpts.max; attempt++ {
   311  		if err := waitRetryBackoff(s.parentCtx, attempt, s.callOpts); err != nil {
   312  			return err
   313  		}
   314  		callCtx, cancel := perCallContext(s.parentCtx, s.callOpts, attempt)
   315  
   316  		var newStream grpc.ClientStream
   317  		var err error
   318  		func() {
   319  			defer cancel()
   320  			newStream, err = s.reestablishStreamAndResendBuffer(callCtx)
   321  		}()
   322  		if err != nil {
   323  			// TODO(mwitkow): Maybe dial and transport errors should be retriable?
   324  			return err
   325  		}
   326  		s.setStream(newStream)
   327  		attemptRetry, lastErr = s.receiveMsgAndIndicateRetry(m)
   328  		//fmt.Printf("Received message and indicate: %v  %v\n", attemptRetry, lastErr)
   329  		if !attemptRetry {
   330  			return lastErr
   331  		}
   332  	}
   333  	return lastErr
   334  }
   335  
   336  func (s *serverStreamingRetryingStream) receiveMsgAndIndicateRetry(m interface{}) (bool, error) {
   337  	s.mu.RLock()
   338  	wasGood := s.receivedGood
   339  	s.mu.RUnlock()
   340  	err := s.getStream().RecvMsg(m)
   341  	if err == nil || err == io.EOF {
   342  		s.mu.Lock()
   343  		s.receivedGood = true
   344  		s.mu.Unlock()
   345  		return false, err
   346  	} else if wasGood {
   347  		// previous RecvMsg in the stream succeeded, no retry logic should interfere
   348  		return false, err
   349  	}
   350  	if isContextError(err) {
   351  		if s.parentCtx.Err() != nil {
   352  			return false, err
   353  		} else if s.callOpts.perCallTimeout != 0 {
   354  			// We have set a perCallTimeout in the retry middleware, which would result in a context error if
   355  			// the deadline was exceeded, in which case try again.
   356  			return true, err
   357  		}
   358  	}
   359  	return isRetriable(err, s.callOpts), err
   360  }
   361  
   362  func (s *serverStreamingRetryingStream) reestablishStreamAndResendBuffer(callCtx context.Context) (grpc.ClientStream, error) {
   363  	s.mu.RLock()
   364  	bufferedSends := s.bufferedSends
   365  	s.mu.RUnlock()
   366  	newStream, err := s.streamerCall(callCtx)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  	for _, msg := range bufferedSends {
   371  		if err := newStream.SendMsg(msg); err != nil {
   372  			return nil, err
   373  		}
   374  	}
   375  	if err := newStream.CloseSend(); err != nil {
   376  		return nil, err
   377  	}
   378  	return newStream, nil
   379  }
   380  
   381  func waitRetryBackoff(parentCtx context.Context, attempt uint, callOpts *retryOptions) error {
   382  	var waitTime time.Duration = 0
   383  	if attempt > 0 {
   384  		waitTime = callOpts.backoffFunc(parentCtx, attempt)
   385  	}
   386  	if waitTime > 0 {
   387  		timer := time.NewTimer(waitTime)
   388  		select {
   389  		case <-parentCtx.Done():
   390  			timer.Stop()
   391  			return contextErrToGrpcErr(parentCtx.Err())
   392  		case <-timer.C:
   393  		}
   394  	}
   395  	return nil
   396  }
   397  
   398  func isRetriable(err error, callOpts *retryOptions) bool {
   399  	if isContextError(err) {
   400  		return false
   401  	}
   402  
   403  	errCode := status.Code(err)
   404  	for _, code := range callOpts.codes {
   405  		if code == errCode {
   406  			return true
   407  		}
   408  	}
   409  	return !callOpts.abortOnFailure
   410  }
   411  
   412  func isContextError(err error) bool {
   413  	code := status.Code(err)
   414  	return code == codes.DeadlineExceeded || code == codes.Canceled
   415  }
   416  
   417  func perCallContext(parentCtx context.Context, callOpts *retryOptions, attempt uint) (ctx context.Context, cancel func()) {
   418  	ctx = parentCtx
   419  	cancel = func() {}
   420  	if callOpts.perCallTimeout != 0 {
   421  		ctx, cancel = context.WithTimeout(ctx, callOpts.perCallTimeout)
   422  	}
   423  	if attempt > 0 && callOpts.includeHeader {
   424  		mdClone := cloneMetadata(extractOutgoingMetadata(ctx))
   425  		mdClone = setMetadata(mdClone, MetadataKeyAttempt, fmt.Sprintf("%d", attempt))
   426  		ctx = toOutgoing(ctx, mdClone)
   427  	}
   428  	return
   429  }
   430  
   431  func contextErrToGrpcErr(err error) error {
   432  	switch err {
   433  	case context.DeadlineExceeded:
   434  		return status.Errorf(codes.DeadlineExceeded, err.Error())
   435  	case context.Canceled:
   436  		return status.Errorf(codes.Canceled, err.Error())
   437  	default:
   438  		return status.Errorf(codes.Unknown, err.Error())
   439  	}
   440  }
   441  
   442  // extractOutgoingMetadata extracts an outbound metadata from the client-side context.
   443  //
   444  // This function always returns a NiceMD wrapper of the metadata.MD, in case the context doesn't have metadata it returns
   445  // a new empty NiceMD.
   446  func extractOutgoingMetadata(ctx context.Context) metadata.MD {
   447  	md, ok := metadata.FromOutgoingContext(ctx)
   448  	if !ok {
   449  		return metadata.Pairs() // empty md set
   450  	}
   451  	return md
   452  }
   453  
   454  // cloneMetadata clones a given md set.
   455  func cloneMetadata(md metadata.MD, copiedKeys ...string) metadata.MD {
   456  	newMd := make(metadata.MD)
   457  	for k, vv := range md {
   458  		var found bool
   459  		if len(copiedKeys) == 0 {
   460  			found = true
   461  		} else {
   462  			for _, allowedKey := range copiedKeys {
   463  				if strings.EqualFold(allowedKey, k) {
   464  					found = true
   465  					break
   466  				}
   467  			}
   468  		}
   469  		if !found {
   470  			continue
   471  		}
   472  		newMd[k] = make([]string, len(vv))
   473  		copy(newMd[k], vv)
   474  	}
   475  	return newMd
   476  }
   477  
   478  func setMetadata(md metadata.MD, key string, value string) metadata.MD {
   479  	k, v := encodeMetadataKeyValue(key, value)
   480  	md[k] = []string{v}
   481  	return md
   482  }
   483  
   484  // toOutgoing sets the given NiceMD as a client-side context for dispatching.
   485  func toOutgoing(ctx context.Context, md metadata.MD) context.Context {
   486  	return metadata.NewOutgoingContext(ctx, md)
   487  }
   488  
   489  const (
   490  	binHdrSuffix = "-bin"
   491  )
   492  
   493  func encodeMetadataKeyValue(k, v string) (string, string) {
   494  	k = strings.ToLower(k)
   495  	if strings.HasSuffix(k, binHdrSuffix) {
   496  		val := base64.StdEncoding.EncodeToString([]byte(v))
   497  		v = string(val)
   498  	}
   499  	return k, v
   500  }