github.com/projecteru2/core@v0.0.0-20240321043226-06bcc1c23f58/client/interceptor/retry.go (about)

     1  package interceptor
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/cockroachdb/errors"
     7  	"github.com/projecteru2/core/log"
     8  
     9  	"github.com/cenkalti/backoff/v4"
    10  	"google.golang.org/grpc"
    11  )
    12  
    13  // NewUnaryRetry makes unary RPC retry on error
    14  func NewUnaryRetry(retryOpts RetryOptions) grpc.UnaryClientInterceptor {
    15  	return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    16  		return backoff.Retry(func() error {
    17  			return invoker(ctx, method, req, reply, cc, opts...)
    18  		}, backoff.WithMaxRetries(backoff.WithContext(backoff.NewExponentialBackOff(), ctx), uint64(retryOpts.Max)))
    19  	}
    20  }
    21  
    22  // RPCNeedRetry records rpc stream methods to retry
    23  var RPCNeedRetry = map[string]struct{}{
    24  	"/pb.CoreRPC/WorkloadStatusStream": {},
    25  	"/pb.CoreRPC/WatchServiceStatus":   {},
    26  }
    27  
    28  // NewStreamRetry make specific stream retry on error
    29  func NewStreamRetry(retryOpts RetryOptions) grpc.StreamClientInterceptor {
    30  	return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    31  		stream, err := streamer(ctx, desc, cc, method, opts...)
    32  		if _, ok := RPCNeedRetry[method]; !ok {
    33  			return stream, err
    34  		}
    35  		logger := log.WithFunc("client.NewStreamRetry")
    36  		logger.Debugf(ctx, "return retryStream for method %s", method)
    37  		return &retryStream{
    38  			ctx:          ctx,
    39  			ClientStream: stream,
    40  			newStream: func() (grpc.ClientStream, error) {
    41  				return streamer(ctx, desc, cc, method, opts...)
    42  			},
    43  			retryOpts: retryOpts,
    44  		}, err
    45  	}
    46  }
    47  
    48  func (s *retryStream) SendMsg(m any) error {
    49  	s.mux.Lock()
    50  	s.sent = m
    51  	s.mux.Unlock()
    52  	return s.getStream().SendMsg(m)
    53  }
    54  
    55  func (s *retryStream) RecvMsg(m any) (err error) {
    56  	if err = s.ClientStream.RecvMsg(m); err == nil || errors.Is(err, context.Canceled) {
    57  		return
    58  	}
    59  	logger := log.WithFunc("client.RecvMsg")
    60  
    61  	return backoff.Retry(func() error {
    62  		logger.Debug(s.ctx, "retry on new stream")
    63  		stream, err := s.newStream()
    64  		if err != nil {
    65  			// even io.EOF triggers retry, and it's what we want!
    66  			return err
    67  		}
    68  		s.setStream(stream)
    69  		s.mux.RLock()
    70  		err = s.getStream().SendMsg(s.sent)
    71  		s.mux.RUnlock()
    72  		if err != nil {
    73  			return err
    74  		}
    75  		return s.getStream().RecvMsg(m)
    76  	}, backoff.WithMaxRetries(backoff.WithContext(backoff.NewExponentialBackOff(), s.ctx), uint64(s.retryOpts.Max)))
    77  }