github.com/msales/pkg/v3@v3.24.0/grpcx/middleware/breaker_test.go (about) 1 package middleware_test 2 3 import ( 4 "context" 5 "errors" 6 "testing" 7 "time" 8 9 "github.com/msales/pkg/v3/breaker" 10 "github.com/msales/pkg/v3/stats" 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/mock" 13 "google.golang.org/grpc" 14 15 . "github.com/msales/pkg/v3/grpcx/middleware" 16 ) 17 18 var breakerErr = errors.New("breaker: circuit breaker is open") 19 20 func TestWithBreaker(t *testing.T) { 21 s := new(mockStats) 22 s.AssertNotCalled(t, "Inc") 23 ctx := context.Background() 24 ctx = stats.WithStats(ctx, s) 25 26 br := breaker.NewBreaker( 27 breaker.RateFuse(1), 28 breaker.WithSleep(1*time.Second), 29 breaker.WithTestRequests(1), 30 ) 31 interceptor := WithClientBreaker(br, "test") 32 err := interceptor(ctx, "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 33 return nil 34 }) 35 36 assert.Nil(t, err) 37 s.AssertExpectations(t) 38 } 39 40 func TestWithBreaker_Errored(t *testing.T) { 41 s := new(mockStats) 42 s.On("Inc", "breaker.error", int64(1), float32(1.0), []interface{}{"state", "open", "name", "test"}).Return(nil).Once() 43 ctx := context.Background() 44 ctx = stats.WithStats(ctx, s) 45 46 br := breaker.NewBreaker( 47 breaker.RateFuse(10), 48 breaker.WithSleep(1*time.Second), 49 breaker.WithTestRequests(1), 50 ) 51 52 interceptor := WithClientBreaker(br, "test") 53 err := interceptor(ctx, "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 54 return testErr 55 }) 56 57 assert.Equal(t, testErr, err) 58 59 err = interceptor(ctx, "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 60 return testErr 61 }) 62 63 assert.Equal(t, breakerErr, err) 64 s.AssertExpectations(t) 65 } 66 67 type mockStats struct { 68 mock.Mock 69 } 70 71 func (m *mockStats) Inc(name string, value int64, rate float32, tags ...interface{}) error { 72 args := m.Called(name, value, rate, tags) 73 return args.Error(0) 74 } 75 76 func (m *mockStats) Dec(name string, value int64, rate float32, tags ...interface{}) error { 77 args := m.Called(name, value, rate, tags) 78 return args.Error(0) 79 } 80 81 func (m *mockStats) Gauge(name string, value float64, rate float32, tags ...interface{}) error { 82 args := m.Called(name, value, rate, tags) 83 return args.Error(0) 84 } 85 86 func (m *mockStats) Timing(name string, value time.Duration, rate float32, tags ...interface{}) error { 87 args := m.Called(name, value, rate, tags) 88 return args.Error(0) 89 } 90 91 func (m *mockStats) Close() error { 92 args := m.Called() 93 return args.Error(0) 94 }