
     1  /*
     3  Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     6  */
     8  package breaker
    10  import (
    11  	"context"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    16  	""
    17  	""
    18  )
    20  var (
    21  	_ async.Interceptor = (*Breaker)(nil)
    22  )
    24  type (
    25  	// OnStateChangeHandler is called when the state changes.
    26  	OnStateChangeHandler func(ctx context.Context, from, to State, generation int64)
    27  	// ShouldOpenProvider returns if the breaker should open.
    28  	ShouldOpenProvider func(ctx context.Context, counts Counts) bool
    29  	// NowProvider returns the current time.
    30  	NowProvider func() time.Time
    31  )
    33  // New creates a new breaker with the given options.
    34  func New(options ...Option) *Breaker {
    35  	b := Breaker{
    36  		ClosedExpiryInterval: DefaultClosedExpiryInterval,
    37  		OpenExpiryInterval:   DefaultOpenExpiryInterval,
    38  		HalfOpenMaxActions:   DefaultHalfOpenMaxActions,
    39  		OpenFailureThreshold: DefaultOpenFailureThreshold,
    40  	}
    41  	for _, opt := range options {
    42  		opt(&b)
    43  	}
    44  	return &b
    45  }
    47  // Breaker is a state machine to prevent performing actions that are likely to fail.
    48  type Breaker struct {
    49  	sync.Mutex
    50  	// OpenAction is an optional actioner to be called when the breaker is open (i.e. preventing calls
    51  	// to intercepted action(er)s)
    52  	OpenAction Actioner
    53  	// OnStateChange is an optional handler called when the breaker transitions state.
    54  	OnStateChange OnStateChangeHandler
    55  	// ShouldOpenProvider is called optionally to determine if we should open the breaker.
    56  	ShouldOpenProvider ShouldOpenProvider
    57  	// NowProvider lets you optionally inject the current time for testing.
    58  	NowProvider NowProvider
    60  	// OpenFailureThreshold is the default failure threshold
    61  	// before the breaker enters the open state. It is how many actions
    62  	// have to fail consecutively.
    63  	OpenFailureThreshold int64
    64  	// HalfOpenMaxActions is the maximum number of requests
    65  	// we can make when the state is HalfOpen.
    66  	HalfOpenMaxActions int64
    67  	// ClosedExpiryInterval is the cyclic period of the closed state for the CircuitBreaker to clear the internal Counts.
    68  	// If Interval is 0, the CircuitBreaker doesn't clear internal Counts during the closed state.
    69  	ClosedExpiryInterval time.Duration
    70  	// OpenExpiryInterval is the period of the open state,
    71  	// after which the state of the CircuitBreaker becomes half-open.
    72  	// If Timeout is 0, the timeout value of the CircuitBreaker is set to 60 seconds.
    73  	OpenExpiryInterval time.Duration
    74  	// Counts are stats for the breaker.
    75  	Counts Counts
    77  	// state is the current Breaker state (Closed, HalfOpen, Open etc.)
    78  	state State
    79  	// generation is the current state generation.
    80  	generation int64
    81  	// stateExpiresAt is the time when the current state will expire.
    82  	// It is set when we change state according to the interval
    83  	// and the current time.
    84  	stateExpiresAt time.Time
    85  }
    87  // Intercept implements the breaker by returning a wrapper for a given action(er).
    88  /*
    89  It returns an error instantly if the Breaker rejects the request, otherwise,
    90  it returns the result of the request.
    92  If a panic occurs in the request, the Breaker handles it as an error.
    93  */
    94  func (b *Breaker) Intercept(action Actioner) Actioner {
    95  	return ActionerFunc(func(ctx context.Context, args interface{}) (res interface{}, err error) {
    96  		var generation int64
    97  		generation, err = b.beforeAction(ctx)
    98  		if err != nil {
    99  			if b.OpenAction != nil {
   100  				res, err = b.OpenAction.Action(ctx, args)
   101  				return
   102  			}
   103  			return
   104  		}
   105  		defer func() {
   106  			if r := recover(); r != nil {
   107  				b.afterAction(ctx, generation, false)
   108  			} else {
   109  				b.afterAction(ctx, generation, err == nil)
   110  			}
   111  		}()
   112  		res, err = action.Action(ctx, args)
   113  		return
   114  	})
   115  }
   117  // EvaluateState returns the current state of the CircuitBreaker.
   118  //
   119  // It takes a context because there is a chance that evaluating
   120  // the state causes the state to change, which would
   121  // result in calling the `OnStateChange` handler.
   122  func (b *Breaker) EvaluateState(ctx context.Context) State {
   123  	b.Lock()
   124  	defer b.Unlock()
   126  	now := time.Now()
   127  	state, _ := b.evaluateStateUnsafe(ctx, now)
   128  	return state
   129  }
   131  //
   132  // internal methods
   133  //
   135  func (b *Breaker) beforeAction(ctx context.Context) (int64, error) {
   136  	b.Lock()
   137  	defer b.Unlock()
   139  	now :=
   140  	state, generation := b.evaluateStateUnsafe(ctx, now)
   142  	if state == StateOpen {
   143  		return generation, ex.New(ErrOpenState)
   144  	} else if state == StateHalfOpen && b.Counts.Requests >= b.HalfOpenMaxActions {
   145  		return generation, ex.New(ErrTooManyRequests)
   146  	}
   148  	atomic.AddInt64(&b.Counts.Requests, 1)
   149  	return generation, nil
   150  }
   152  func (b *Breaker) afterAction(ctx context.Context, currentGeneration int64, success bool) {
   153  	b.Lock()
   154  	defer b.Unlock()
   156  	now :=
   157  	state, generation := b.evaluateStateUnsafe(ctx, now)
   158  	if generation != currentGeneration {
   159  		return
   160  	}
   161  	if success {
   162  		b.success(ctx, state, now)
   163  		return
   164  	}
   165  	b.failure(ctx, state, now)
   166  }
   168  func (b *Breaker) success(ctx context.Context, state State, now time.Time) {
   169  	switch state {
   170  	case StateClosed:
   171  		atomic.AddInt64(&b.Counts.TotalSuccesses, 1)
   172  		atomic.AddInt64(&b.Counts.ConsecutiveSuccesses, 1)
   173  		atomic.StoreInt64(&b.Counts.ConsecutiveFailures, 0)
   174  	case StateHalfOpen:
   175  		atomic.AddInt64(&b.Counts.TotalSuccesses, 1)
   176  		atomic.AddInt64(&b.Counts.ConsecutiveSuccesses, 1)
   177  		atomic.StoreInt64(&b.Counts.ConsecutiveFailures, 0)
   178  		if b.Counts.ConsecutiveSuccesses >= b.HalfOpenMaxActions {
   179  			b.setStateUnsafe(ctx, StateClosed, now)
   180  		}
   181  	}
   182  }
   184  func (b *Breaker) failure(ctx context.Context, state State, now time.Time) {
   185  	switch state {
   186  	case StateClosed:
   187  		atomic.AddInt64(&b.Counts.TotalFailures, 1)
   188  		atomic.AddInt64(&b.Counts.ConsecutiveFailures, 1)
   189  		atomic.StoreInt64(&b.Counts.ConsecutiveSuccesses, 0)
   190  		if b.shouldOpen(ctx) {
   191  			b.setStateUnsafe(ctx, StateOpen, now)
   192  		}
   193  	case StateHalfOpen:
   194  		b.setStateUnsafe(ctx, StateOpen, now)
   195  	}
   196  }
   198  func (b *Breaker) evaluateStateUnsafe(ctx context.Context, t time.Time) (state State, generation int64) {
   199  	switch b.state {
   200  	case StateClosed:
   201  		if !b.stateExpiresAt.IsZero() && b.stateExpiresAt.Before(t) {
   202  			b.incrementGeneration(t)
   203  		}
   204  	case StateOpen:
   205  		if b.stateExpiresAt.Before(t) {
   206  			b.setStateUnsafe(ctx, StateHalfOpen, t)
   207  		}
   208  	}
   209  	return b.state, b.generation
   210  }
   212  func (b *Breaker) setStateUnsafe(ctx context.Context, state State, now time.Time) {
   213  	if b.state == state {
   214  		return
   215  	}
   217  	previousState := b.state
   218  	b.state = state
   219  	b.incrementGeneration(now)
   220  	if b.OnStateChange != nil {
   221  		b.OnStateChange(ctx, previousState, b.state, b.generation)
   222  	}
   223  }
   225  func (b *Breaker) incrementGeneration(now time.Time) {
   226  	atomic.AddInt64(&b.generation, 1)
   227  	b.Counts = Counts{}
   229  	var zero time.Time
   230  	switch b.state {
   231  	case StateClosed:
   232  		if b.ClosedExpiryInterval == 0 {
   233  			b.stateExpiresAt = zero
   234  		} else {
   235  			b.stateExpiresAt = now.Add(b.ClosedExpiryInterval)
   236  		}
   237  	case StateOpen:
   238  		b.stateExpiresAt = now.Add(b.OpenExpiryInterval)
   239  	case StateHalfOpen:
   240  		b.stateExpiresAt = zero
   241  	default:
   242  		b.stateExpiresAt = zero
   243  	}
   244  }
   246  func (b *Breaker) shouldOpen(ctx context.Context) bool {
   247  	if b.ShouldOpenProvider != nil {
   248  		return b.ShouldOpenProvider(ctx, b.Counts)
   249  	}
   250  	return b.Counts.ConsecutiveFailures > b.OpenFailureThreshold
   251  }
   253  func (b *Breaker) now() time.Time {
   254  	if b.NowProvider != nil {
   255  		return b.NowProvider()
   256  	}
   257  	return time.Now()
   258  }