github.com/projecteru2/core@v0.0.0-20240321043226-06bcc1c23f58/client/interceptor/retry.go (about) 1 package interceptor 2 3 import ( 4 "context" 5 6 "github.com/cockroachdb/errors" 7 "github.com/projecteru2/core/log" 8 9 "github.com/cenkalti/backoff/v4" 10 "google.golang.org/grpc" 11 ) 12 13 // NewUnaryRetry makes unary RPC retry on error 14 func NewUnaryRetry(retryOpts RetryOptions) grpc.UnaryClientInterceptor { 15 return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 16 return backoff.Retry(func() error { 17 return invoker(ctx, method, req, reply, cc, opts...) 18 }, backoff.WithMaxRetries(backoff.WithContext(backoff.NewExponentialBackOff(), ctx), uint64(retryOpts.Max))) 19 } 20 } 21 22 // RPCNeedRetry records rpc stream methods to retry 23 var RPCNeedRetry = map[string]struct{}{ 24 "/pb.CoreRPC/WorkloadStatusStream": {}, 25 "/pb.CoreRPC/WatchServiceStatus": {}, 26 } 27 28 // NewStreamRetry make specific stream retry on error 29 func NewStreamRetry(retryOpts RetryOptions) grpc.StreamClientInterceptor { 30 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 31 stream, err := streamer(ctx, desc, cc, method, opts...) 32 if _, ok := RPCNeedRetry[method]; !ok { 33 return stream, err 34 } 35 logger := log.WithFunc("client.NewStreamRetry") 36 logger.Debugf(ctx, "return retryStream for method %s", method) 37 return &retryStream{ 38 ctx: ctx, 39 ClientStream: stream, 40 newStream: func() (grpc.ClientStream, error) { 41 return streamer(ctx, desc, cc, method, opts...) 42 }, 43 retryOpts: retryOpts, 44 }, err 45 } 46 } 47 48 func (s *retryStream) SendMsg(m any) error { 49 s.mux.Lock() 50 s.sent = m 51 s.mux.Unlock() 52 return s.getStream().SendMsg(m) 53 } 54 55 func (s *retryStream) RecvMsg(m any) (err error) { 56 if err = s.ClientStream.RecvMsg(m); err == nil || errors.Is(err, context.Canceled) { 57 return 58 } 59 logger := log.WithFunc("client.RecvMsg") 60 61 return backoff.Retry(func() error { 62 logger.Debug(s.ctx, "retry on new stream") 63 stream, err := s.newStream() 64 if err != nil { 65 // even io.EOF triggers retry, and it's what we want! 66 return err 67 } 68 s.setStream(stream) 69 s.mux.RLock() 70 err = s.getStream().SendMsg(s.sent) 71 s.mux.RUnlock() 72 if err != nil { 73 return err 74 } 75 return s.getStream().RecvMsg(m) 76 }, backoff.WithMaxRetries(backoff.WithContext(backoff.NewExponentialBackOff(), s.ctx), uint64(s.retryOpts.Max))) 77 }