github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/go-grpc-middleware/retry/retry.go (about)

     1  // Copyright 2016 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_retry
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/hxx258456/ccgo/go-grpc-middleware/util/metautils"
    13  	"github.com/hxx258456/ccgo/grpc"
    14  	"github.com/hxx258456/ccgo/grpc/codes"
    15  	"github.com/hxx258456/ccgo/grpc/metadata"
    16  	"github.com/hxx258456/ccgo/net/context"
    17  	"github.com/hxx258456/ccgo/net/trace"
    18  )
    19  
    20  const (
    21  	AttemptMetadataKey = "x-retry-attempty"
    22  )
    23  
    24  // UnaryClientInterceptor returns a new retrying unary client interceptor.
    25  //
    26  // The default configuration of the interceptor is to not retry *at all*. This behaviour can be
    27  // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
    28  func UnaryClientInterceptor(optFuncs ...CallOption) grpc.UnaryClientInterceptor {
    29  	intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
    30  	return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    31  		grpcOpts, retryOpts := filterCallOptions(opts)
    32  		callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
    33  		// short circuit for simplicity, and avoiding allocations.
    34  		if callOpts.max == 0 {
    35  			return invoker(parentCtx, method, req, reply, cc, grpcOpts...)
    36  		}
    37  		var lastErr error
    38  		for attempt := uint(0); attempt < callOpts.max; attempt++ {
    39  			if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil {
    40  				return err
    41  			}
    42  			callCtx := perCallContext(parentCtx, callOpts, attempt)
    43  			lastErr = invoker(callCtx, method, req, reply, cc, grpcOpts...)
    44  			// TODO(mwitkow): Maybe dial and transport errors should be retriable?
    45  			if lastErr == nil {
    46  				return nil
    47  			}
    48  			logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr)
    49  			if isContextError(lastErr) {
    50  				if parentCtx.Err() != nil {
    51  					logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err())
    52  					// its the parent context deadline or cancellation.
    53  					return lastErr
    54  				} else {
    55  					logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt)
    56  					// its the callCtx deadline or cancellation, in which case try again.
    57  					continue
    58  				}
    59  			}
    60  			if !isRetriable(lastErr, callOpts) {
    61  				return lastErr
    62  			}
    63  		}
    64  		return lastErr
    65  	}
    66  }
    67  
    68  // StreamClientInterceptor returns a new retrying stream client interceptor for server side streaming calls.
    69  //
    70  // The default configuration of the interceptor is to not retry *at all*. This behaviour can be
    71  // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
    72  //
    73  // Retry logic is available *only for ServerStreams*, i.e. 1:n streams, as the internal logic needs
    74  // to buffer the messages sent by the client. If retry is enabled on any other streams (ClientStreams,
    75  // BidiStreams), the retry interceptor will fail the call.
    76  func StreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientInterceptor {
    77  	intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
    78  	return func(parentCtx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    79  		grpcOpts, retryOpts := filterCallOptions(opts)
    80  		callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
    81  		// short circuit for simplicity, and avoiding allocations.
    82  		if callOpts.max == 0 {
    83  			return streamer(parentCtx, desc, cc, method, grpcOpts...)
    84  		}
    85  		if desc.ClientStreams {
    86  			return nil, grpc.Errorf(codes.Unimplemented, "grpc_retry: cannot retry on ClientStreams, set grpc_retry.Disable()")
    87  		}
    88  
    89  		var lastErr error
    90  		for attempt := uint(0); attempt < callOpts.max; attempt++ {
    91  			if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil {
    92  				return nil, err
    93  			}
    94  			callCtx := perCallContext(parentCtx, callOpts, 0)
    95  
    96  			var newStreamer grpc.ClientStream
    97  			newStreamer, lastErr = streamer(callCtx, desc, cc, method, grpcOpts...)
    98  			if lastErr == nil {
    99  				retryingStreamer := &serverStreamingRetryingStream{
   100  					ClientStream: newStreamer,
   101  					callOpts:     callOpts,
   102  					parentCtx:    parentCtx,
   103  					streamerCall: func(ctx context.Context) (grpc.ClientStream, error) {
   104  						return streamer(ctx, desc, cc, method, grpcOpts...)
   105  					},
   106  				}
   107  				return retryingStreamer, nil
   108  			}
   109  
   110  			logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr)
   111  			if isContextError(lastErr) {
   112  				if parentCtx.Err() != nil {
   113  					logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err())
   114  					// its the parent context deadline or cancellation.
   115  					return nil, lastErr
   116  				} else {
   117  					logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt)
   118  					// its the callCtx deadline or cancellation, in which case try again.
   119  					continue
   120  				}
   121  			}
   122  			if !isRetriable(lastErr, callOpts) {
   123  				return nil, lastErr
   124  			}
   125  		}
   126  		return nil, lastErr
   127  	}
   128  }
   129  
   130  // type serverStreamingRetryingStream is the implementation of grpc.ClientStream that acts as a
   131  // proxy to the underlying call. If any of the RecvMsg() calls fail, it will try to reestablish
   132  // a new ClientStream according to the retry policy.
   133  type serverStreamingRetryingStream struct {
   134  	grpc.ClientStream
   135  	bufferedSends []interface{} // single messsage that the client can sen
   136  	receivedGood  bool          // indicates whether any prior receives were successful
   137  	wasClosedSend bool          // indicates that CloseSend was closed
   138  	parentCtx     context.Context
   139  	callOpts      *options
   140  	streamerCall  func(ctx context.Context) (grpc.ClientStream, error)
   141  	mu            sync.RWMutex
   142  }
   143  
   144  func (s *serverStreamingRetryingStream) setStream(clientStream grpc.ClientStream) {
   145  	s.mu.Lock()
   146  	s.ClientStream = clientStream
   147  	s.mu.Unlock()
   148  }
   149  
   150  func (s *serverStreamingRetryingStream) getStream() grpc.ClientStream {
   151  	s.mu.RLock()
   152  	defer s.mu.RUnlock()
   153  	return s.ClientStream
   154  }
   155  
   156  func (s *serverStreamingRetryingStream) SendMsg(m interface{}) error {
   157  	s.mu.Lock()
   158  	s.bufferedSends = append(s.bufferedSends, m)
   159  	s.mu.Unlock()
   160  	return s.getStream().SendMsg(m)
   161  }
   162  
   163  func (s *serverStreamingRetryingStream) CloseSend() error {
   164  	s.mu.Lock()
   165  	s.wasClosedSend = true
   166  	s.mu.Unlock()
   167  	return s.getStream().CloseSend()
   168  }
   169  
   170  func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) {
   171  	return s.getStream().Header()
   172  }
   173  
   174  func (s *serverStreamingRetryingStream) Trailer() metadata.MD {
   175  	return s.getStream().Trailer()
   176  }
   177  
   178  func (s *serverStreamingRetryingStream) RecvMsg(m interface{}) error {
   179  	attemptRetry, lastErr := s.receiveMsgAndIndicateRetry(m)
   180  	if !attemptRetry {
   181  		return lastErr // success or hard failure
   182  	}
   183  	// We start off from attempt 1, because zeroth was already made on normal SendMsg().
   184  	for attempt := uint(1); attempt < s.callOpts.max; attempt++ {
   185  		if err := waitRetryBackoff(attempt, s.parentCtx, s.callOpts); err != nil {
   186  			return err
   187  		}
   188  		callCtx := perCallContext(s.parentCtx, s.callOpts, attempt)
   189  		newStream, err := s.reestablishStreamAndResendBuffer(callCtx)
   190  		if err != nil {
   191  			// TODO(mwitkow): Maybe dial and transport errors should be retriable?
   192  			return err
   193  		}
   194  		s.setStream(newStream)
   195  		attemptRetry, lastErr = s.receiveMsgAndIndicateRetry(m)
   196  		//fmt.Printf("Received message and indicate: %v  %v\n", attemptRetry, lastErr)
   197  		if !attemptRetry {
   198  			return lastErr
   199  		}
   200  	}
   201  	return lastErr
   202  }
   203  
   204  func (s *serverStreamingRetryingStream) receiveMsgAndIndicateRetry(m interface{}) (bool, error) {
   205  	s.mu.RLock()
   206  	wasGood := s.receivedGood
   207  	s.mu.RUnlock()
   208  	err := s.getStream().RecvMsg(m)
   209  	if err == nil || err == io.EOF {
   210  		s.mu.Lock()
   211  		s.receivedGood = true
   212  		s.mu.Unlock()
   213  		return false, err
   214  	} else if wasGood {
   215  		// previous RecvMsg in the stream succeeded, no retry logic should interfere
   216  		return false, err
   217  	}
   218  	if isContextError(err) {
   219  		if s.parentCtx.Err() != nil {
   220  			logTrace(s.parentCtx, "grpc_retry parent context error: %v", s.parentCtx.Err())
   221  			return false, err
   222  		} else {
   223  			logTrace(s.parentCtx, "grpc_retry context error from retry call")
   224  			// its the callCtx deadline or cancellation, in which case try again.
   225  			return true, err
   226  		}
   227  	}
   228  	return isRetriable(err, s.callOpts), err
   229  
   230  }
   231  
   232  func (s *serverStreamingRetryingStream) reestablishStreamAndResendBuffer(callCtx context.Context) (grpc.ClientStream, error) {
   233  	s.mu.RLock()
   234  	bufferedSends := s.bufferedSends
   235  	s.mu.RUnlock()
   236  	newStream, err := s.streamerCall(callCtx)
   237  	if err != nil {
   238  		logTrace(callCtx, "grpc_retry failed redialing new stream: %v", err)
   239  		return nil, err
   240  	}
   241  	for _, msg := range bufferedSends {
   242  		if err := newStream.SendMsg(msg); err != nil {
   243  			logTrace(callCtx, "grpc_retry failed resending message: %v", err)
   244  			return nil, err
   245  		}
   246  	}
   247  	if err := newStream.CloseSend(); err != nil {
   248  		logTrace(callCtx, "grpc_retry failed CloseSend on new stream %v", err)
   249  		return nil, err
   250  	}
   251  	return newStream, nil
   252  }
   253  
   254  func waitRetryBackoff(attempt uint, parentCtx context.Context, callOpts *options) error {
   255  	var waitTime time.Duration = 0
   256  	if attempt > 0 {
   257  		waitTime = callOpts.backoffFunc(parentCtx, attempt)
   258  	}
   259  	if waitTime > 0 {
   260  		logTrace(parentCtx, "grpc_retry attempt: %d, backoff for %v", attempt, waitTime)
   261  		timer := time.NewTimer(waitTime)
   262  		select {
   263  		case <-parentCtx.Done():
   264  			timer.Stop()
   265  			return contextErrToGrpcErr(parentCtx.Err())
   266  		case <-timer.C:
   267  		}
   268  	}
   269  	return nil
   270  }
   271  
   272  func isRetriable(err error, callOpts *options) bool {
   273  	errCode := grpc.Code(err)
   274  	if isContextError(err) {
   275  		// context errors are not retriable based on user settings.
   276  		return false
   277  	}
   278  	for _, code := range callOpts.codes {
   279  		if code == errCode {
   280  			return true
   281  		}
   282  	}
   283  	return false
   284  }
   285  
   286  func isContextError(err error) bool {
   287  	return grpc.Code(err) == codes.DeadlineExceeded || grpc.Code(err) == codes.Canceled
   288  }
   289  
   290  func perCallContext(parentCtx context.Context, callOpts *options, attempt uint) context.Context {
   291  	ctx := parentCtx
   292  	if callOpts.perCallTimeout != 0 {
   293  		ctx, _ = context.WithTimeout(ctx, callOpts.perCallTimeout)
   294  	}
   295  	if attempt > 0 && callOpts.includeHeader {
   296  		mdClone := metautils.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, fmt.Sprintf("%d", attempt))
   297  		ctx = mdClone.ToOutgoing(ctx)
   298  	}
   299  	return ctx
   300  }
   301  
   302  func contextErrToGrpcErr(err error) error {
   303  	switch err {
   304  	case context.DeadlineExceeded:
   305  		return grpc.Errorf(codes.DeadlineExceeded, err.Error())
   306  	case context.Canceled:
   307  		return grpc.Errorf(codes.Canceled, err.Error())
   308  	default:
   309  		return grpc.Errorf(codes.Unknown, err.Error())
   310  	}
   311  }
   312  
   313  func logTrace(ctx context.Context, format string, a ...interface{}) {
   314  	tr, ok := trace.FromContext(ctx)
   315  	if !ok {
   316  		return
   317  	}
   318  	tr.LazyPrintf(format, a...)
   319  }