go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/breaker/breaker.go (about)

     1  package breaker
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"go.charczuk.com/sdk/errutil"
    10  )
    11  
    12  // New creates a new breaker with the given options.
    13  func New[A, B any](action Action[A, B], options ...Option[A, B]) *Breaker[A, B] {
    14  	b := Breaker[A, B]{
    15  		Action: action,
    16  	}
    17  	for _, opt := range options {
    18  		opt(&b)
    19  	}
    20  	return &b
    21  }
    22  
    23  type (
    24  	// OnStateChangeFunc is called when the state changes.
    25  	OnStateChangeFunc func(ctx context.Context, from, to State, generation int64)
    26  	// ShouldOpenFunc returns if the breaker should open.
    27  	ShouldOpenFunc[A any] func(ctx context.Context, counts Counts, args A) bool
    28  	// NowFunc returns the current time.
    29  	NowFunc func() time.Time
    30  )
    31  
    32  type Option[A, B any] func(b *Breaker[A, B])
    33  
    34  func OptConfig[A, B any](cfg Config) Option[A, B] {
    35  	return func(b *Breaker[A, B]) {
    36  		b.Config = cfg
    37  	}
    38  }
    39  
    40  func OptOpenAction[A, B any](action Action[A, B]) Option[A, B] {
    41  	return func(b *Breaker[A, B]) {
    42  		b.OpenAction = action
    43  	}
    44  }
    45  
    46  func OptShouldOpen[A, B any](shouldOpenFunc ShouldOpenFunc[A]) Option[A, B] {
    47  	return func(b *Breaker[A, B]) {
    48  		b.ShouldOpen = shouldOpenFunc
    49  	}
    50  }
    51  
    52  func OptNow[A, B any](nowFn NowFunc) Option[A, B] {
    53  	return func(b *Breaker[A, B]) {
    54  		b.Now = nowFn
    55  	}
    56  }
    57  
    58  var _ Action[any, any] = (*Breaker[any, any])(nil)
    59  
    60  // Breaker is a state machine to prevent performing actions that are likely to fail.
    61  type Breaker[A, B any] struct {
    62  	Action        Action[A, B]
    63  	OpenAction    Action[A, B]
    64  	OnStateChange OnStateChangeFunc
    65  	ShouldOpen    ShouldOpenFunc[A]
    66  	Now           NowFunc
    67  	Config        Config
    68  
    69  	mu                     sync.Mutex
    70  	state                  State
    71  	generation             int64
    72  	openStateExpiresAt     time.Time
    73  	closedFailuresExpireAt time.Time
    74  	counts                 Counts
    75  }
    76  
    77  func (b *Breaker[A, B]) Counts() Counts { return b.counts }
    78  
    79  // Call invokes the action and returns the result.
    80  func (b *Breaker[A, B]) Call(ctx context.Context, args A) (res B, err error) {
    81  	var generation int64
    82  	generation, err = b.beforeAction(ctx)
    83  	if err != nil {
    84  		if b.OpenAction != nil {
    85  			res, err = b.OpenAction.Call(ctx, args)
    86  			return
    87  		}
    88  		return
    89  	}
    90  	defer func() {
    91  		if r := recover(); r != nil {
    92  			b.afterAction(ctx, generation, false /*success*/, args)
    93  			err = errutil.New(r)
    94  		} else {
    95  			b.afterAction(ctx, generation, err == nil /*success*/, args)
    96  		}
    97  	}()
    98  	res, err = b.Action.Call(ctx, args)
    99  	return
   100  }
   101  
   102  // EvaluateState returns the current state of the CircuitBreaker.
   103  //
   104  // This method is a kind of idirect because can't know for sure
   105  // what the state is at a given time without evaluating expiration
   106  // times and potentially calling handlers if the state changes
   107  // after an expiry.
   108  //
   109  // As a result this method takes a context, and may call the `OnStateChange`
   110  // delegate value if it's set.
   111  func (b *Breaker[A, B]) EvaluateState(ctx context.Context) State {
   112  	b.mu.Lock()
   113  	defer b.mu.Unlock()
   114  
   115  	now := b.now()
   116  	state, _ := b.evaluateStateUnsafe(ctx, now)
   117  	return state
   118  }
   119  
   120  //
   121  // internal methods
   122  //
   123  
   124  func (b *Breaker[A, B]) beforeAction(ctx context.Context) (int64, error) {
   125  	b.mu.Lock()
   126  	defer b.mu.Unlock()
   127  
   128  	now := b.now()
   129  	state, generation := b.evaluateStateUnsafe(ctx, now)
   130  	if state == StateOpen {
   131  		return generation, ErrOpenState
   132  	} else if state == StateHalfOpen && b.counts.Requests >= b.Config.HalfOpenMaxActionsOrDefault() {
   133  		return generation, ErrTooManyRequests
   134  	}
   135  	atomic.AddUint64(&b.counts.Requests, 1)
   136  	return generation, nil
   137  }
   138  
   139  func (b *Breaker[A, B]) afterAction(ctx context.Context, currentGeneration int64, success bool, args A) {
   140  	b.mu.Lock()
   141  	defer b.mu.Unlock()
   142  
   143  	now := b.now()
   144  	state, generation := b.evaluateStateUnsafe(ctx, now)
   145  	if generation != currentGeneration {
   146  		return
   147  	}
   148  	if success {
   149  		b.success(ctx, state, now)
   150  		return
   151  	}
   152  	b.failure(ctx, state, now, args)
   153  }
   154  
   155  func (b *Breaker[A, B]) success(ctx context.Context, state State, now time.Time) {
   156  	switch state {
   157  	case StateClosed:
   158  		atomic.AddUint64(&b.counts.TotalSuccesses, 1)
   159  		atomic.AddUint64(&b.counts.ConsecutiveSuccesses, 1)
   160  		atomic.StoreUint64(&b.counts.ConsecutiveFailures, 0)
   161  	case StateHalfOpen:
   162  		atomic.AddUint64(&b.counts.TotalSuccesses, 1)
   163  		atomic.AddUint64(&b.counts.ConsecutiveSuccesses, 1)
   164  		atomic.StoreUint64(&b.counts.ConsecutiveFailures, 0)
   165  		if b.counts.ConsecutiveSuccesses >= b.Config.HalfOpenMaxActionsOrDefault() {
   166  			b.setStateUnsafe(ctx, StateClosed, now)
   167  		}
   168  	}
   169  }
   170  
   171  func (b *Breaker[A, B]) failure(ctx context.Context, state State, now time.Time, args A) {
   172  	switch state {
   173  	case StateClosed:
   174  		atomic.AddUint64(&b.counts.TotalFailures, 1)
   175  		atomic.AddUint64(&b.counts.ConsecutiveFailures, 1)
   176  		atomic.StoreUint64(&b.counts.ConsecutiveSuccesses, 0)
   177  		if b.shouldOpen(ctx, args) {
   178  			b.setStateUnsafe(ctx, StateOpen, now)
   179  		}
   180  	case StateHalfOpen:
   181  		b.setStateUnsafe(ctx, StateOpen, now)
   182  	}
   183  }
   184  
   185  func (b *Breaker[A, B]) evaluateStateUnsafe(ctx context.Context, now time.Time) (state State, generation int64) {
   186  	switch b.state {
   187  	case StateClosed:
   188  		if b.closedFailuresExpireAt.IsZero() {
   189  			b.closedFailuresExpireAt = now.Add(b.Config.ClosedFailureExpiryIntervalOrDefault())
   190  		} else if b.closedFailuresExpireAt.Before(now) {
   191  			b.counts = Counts{}
   192  			b.closedFailuresExpireAt = now.Add(b.Config.ClosedFailureExpiryIntervalOrDefault())
   193  		}
   194  	case StateOpen:
   195  		if !b.openStateExpiresAt.IsZero() && b.openStateExpiresAt.Before(now) {
   196  			b.setStateUnsafe(ctx, StateHalfOpen, now)
   197  		}
   198  	}
   199  	return b.state, b.generation
   200  }
   201  
   202  func (b *Breaker[A, B]) setStateUnsafe(ctx context.Context, state State, now time.Time) {
   203  	if b.state == state {
   204  		return
   205  	}
   206  	previousState := b.state
   207  	b.state = state
   208  	b.incrementGenerationAfterStateChange(now)
   209  	if b.OnStateChange != nil {
   210  		b.OnStateChange(ctx, previousState, b.state, b.generation)
   211  	}
   212  }
   213  
   214  func (b *Breaker[A, B]) incrementGenerationAfterStateChange(now time.Time) {
   215  	atomic.AddInt64(&b.generation, 1)
   216  	b.counts = Counts{}
   217  
   218  	var zero time.Time
   219  	switch b.state {
   220  	case StateClosed:
   221  		b.openStateExpiresAt = zero
   222  		b.closedFailuresExpireAt = now.Add(b.Config.ClosedFailureExpiryIntervalOrDefault())
   223  	case StateOpen:
   224  		b.openStateExpiresAt = now.Add(b.Config.OpenExpiryIntervalOrDefault())
   225  		b.closedFailuresExpireAt = zero
   226  	case StateHalfOpen:
   227  		b.openStateExpiresAt = zero
   228  		b.closedFailuresExpireAt = zero
   229  	}
   230  }
   231  
   232  func (b *Breaker[A, B]) shouldOpen(ctx context.Context, args A) bool {
   233  	if b.ShouldOpen != nil {
   234  		return b.ShouldOpen(ctx, b.counts, args)
   235  	}
   236  	return b.counts.ConsecutiveFailures >= b.Config.FailureThresholdOrDefault()
   237  }
   238  
   239  func (b *Breaker[A, B]) now() time.Time {
   240  	if b.Now != nil {
   241  		return b.Now()
   242  	}
   243  	return time.Now()
   244  }