github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/breaker/breaker_test.go (about)

     1  // Copyright 2022 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package breaker
    16  
    17  import (
    18  	"errors"
    19  	"net/http"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/gravitational/trace"
    24  	"github.com/jonboulle/clockwork"
    25  	"github.com/stretchr/testify/require"
    26  	"google.golang.org/grpc/codes"
    27  	"google.golang.org/grpc/status"
    28  )
    29  
    30  func TestCircuitBreaker_generation(t *testing.T) {
    31  	t.Parallel()
    32  	clock := clockwork.NewFakeClock()
    33  
    34  	cb, err := New(Config{
    35  		Clock:    clock,
    36  		Interval: time.Second,
    37  		Trip:     StaticTripper(false),
    38  		Recover:  StaticTripper(false),
    39  	})
    40  	require.NoError(t, err)
    41  
    42  	generation, state := cb.currentState(clock.Now())
    43  	require.Equal(t, uint64(1), generation)
    44  	require.Equal(t, StateStandby, state)
    45  	require.Equal(t, clock.Now().Add(time.Second), cb.expiry)
    46  
    47  	clock.Advance(500 * time.Millisecond)
    48  	generation, state = cb.currentState(clock.Now())
    49  	require.Equal(t, uint64(1), generation)
    50  	require.Equal(t, StateStandby, state)
    51  	clock.Advance(501 * time.Millisecond)
    52  	generation, state = cb.currentState(clock.Now())
    53  	require.Equal(t, uint64(2), generation)
    54  	require.Equal(t, StateStandby, state)
    55  	require.Equal(t, clock.Now().Add(time.Second), cb.expiry)
    56  
    57  	for i := 0; i < 1000; i++ {
    58  		prevGeneration, prevState := cb.currentState(clock.Now())
    59  		cb.nextGeneration(clock.Now())
    60  		generation, state := cb.currentState(clock.Now())
    61  		require.NotEqual(t, prevGeneration, generation)
    62  		require.Equal(t, prevState, state)
    63  	}
    64  
    65  	generation, state = cb.currentState(clock.Now())
    66  	require.Equal(t, uint64(1002), generation)
    67  	require.Equal(t, StateStandby, state)
    68  }
    69  
    70  func TestCircuitBreaker_beforeRequest(t *testing.T) {
    71  	t.Parallel()
    72  	cases := []struct {
    73  		desc       string
    74  		generation uint64
    75  		executions uint32
    76  		advance    time.Duration
    77  		state      State
    78  		errorCheck require.ErrorAssertionFunc
    79  	}{
    80  		{
    81  			desc:       "standby allows execution",
    82  			generation: 1,
    83  			executions: 1,
    84  			state:      StateStandby,
    85  			errorCheck: require.NoError,
    86  		},
    87  		{
    88  			desc:       "tripped prevents executions",
    89  			generation: 1,
    90  			executions: 0,
    91  			state:      StateTripped,
    92  			errorCheck: func(t require.TestingT, err error, i ...interface{}) {
    93  				require.Error(t, err)
    94  				require.ErrorIs(t, ErrStateTripped, err)
    95  			},
    96  		},
    97  		{
    98  			desc:       "recovering after allows executions",
    99  			generation: 1,
   100  			executions: 1,
   101  			state:      StateRecovering,
   102  			advance:    3 * time.Second,
   103  			errorCheck: require.NoError,
   104  		},
   105  	}
   106  
   107  	for _, tt := range cases {
   108  		t.Run(tt.desc, func(t *testing.T) {
   109  			clock := clockwork.NewFakeClock()
   110  
   111  			cb, err := New(Config{
   112  				Clock:         clock,
   113  				Interval:      time.Second,
   114  				Trip:          StaticTripper(false),
   115  				Recover:       StaticTripper(false),
   116  				RecoveryLimit: 1,
   117  			})
   118  			require.NoError(t, err)
   119  			cb.state = tt.state
   120  
   121  			clock.Advance(tt.advance)
   122  
   123  			generation, err := cb.beforeExecution()
   124  			tt.errorCheck(t, err)
   125  			require.Equal(t, tt.generation, generation)
   126  			require.Equal(t, tt.executions, cb.metrics.Executions)
   127  
   128  		})
   129  	}
   130  }
   131  
   132  func TestCircuitBreaker_afterExecution(t *testing.T) {
   133  	t.Parallel()
   134  	cases := []struct {
   135  		desc            string
   136  		err             error
   137  		priorGeneration uint64
   138  		checkMetrics    require.ValueAssertionFunc
   139  		trip            TripFn
   140  		recover         TripFn
   141  		expectedState   State
   142  	}{
   143  		{
   144  			desc:            "successful execution",
   145  			priorGeneration: 1,
   146  			checkMetrics: func(t require.TestingT, i interface{}, i2 ...interface{}) {
   147  				m, ok := i.(Metrics)
   148  				require.True(t, ok)
   149  				require.Equal(t, uint32(1), m.Successes)
   150  				require.Equal(t, uint32(0), m.Failures)
   151  			},
   152  			trip:          StaticTripper(false),
   153  			recover:       StaticTripper(false),
   154  			expectedState: StateStandby,
   155  		},
   156  		{
   157  			desc:            "generation change",
   158  			priorGeneration: 0,
   159  			trip:            StaticTripper(false),
   160  			recover:         StaticTripper(false),
   161  			checkMetrics: func(t require.TestingT, i interface{}, i2 ...interface{}) {
   162  				m, ok := i.(Metrics)
   163  				require.True(t, ok)
   164  				require.Equal(t, uint32(0), m.Successes)
   165  				require.Equal(t, uint32(0), m.Failures)
   166  			},
   167  			expectedState: StateStandby,
   168  		},
   169  		{
   170  			desc:            "failed execution with out tripping",
   171  			priorGeneration: 1,
   172  			err:             errors.New("failure"),
   173  			trip:            StaticTripper(false),
   174  			recover:         StaticTripper(false),
   175  			checkMetrics: func(t require.TestingT, i interface{}, i2 ...interface{}) {
   176  				m, ok := i.(Metrics)
   177  				require.True(t, ok)
   178  				require.Equal(t, uint32(0), m.Successes)
   179  				require.Equal(t, uint32(1), m.Failures)
   180  			},
   181  			expectedState: StateStandby,
   182  		},
   183  		{
   184  			desc:            "failed execution causing a trip",
   185  			priorGeneration: 1,
   186  			err:             errors.New("failure"),
   187  			trip:            StaticTripper(true),
   188  			recover:         StaticTripper(false),
   189  			checkMetrics: func(t require.TestingT, i interface{}, i2 ...interface{}) {
   190  				m, ok := i.(Metrics)
   191  				require.True(t, ok)
   192  				require.Equal(t, uint32(0), m.Successes)
   193  				require.Equal(t, uint32(0), m.Failures)
   194  			},
   195  			expectedState: StateTripped,
   196  		},
   197  	}
   198  
   199  	for _, tt := range cases {
   200  		t.Run(tt.desc, func(t *testing.T) {
   201  			clock := clockwork.NewFakeClock()
   202  			cb, err := New(Config{
   203  				Clock:    clock,
   204  				Interval: time.Second,
   205  				Trip:     tt.trip,
   206  				Recover:  tt.recover,
   207  			})
   208  			require.NoError(t, err)
   209  
   210  			cb.afterExecution(tt.priorGeneration, nil, tt.err)
   211  			tt.checkMetrics(t, cb.metrics)
   212  			require.Equal(t, tt.expectedState, cb.state)
   213  		})
   214  	}
   215  }
   216  
   217  func TestCircuitBreaker_success(t *testing.T) {
   218  	t.Parallel()
   219  	cases := []struct {
   220  		desc          string
   221  		initialState  State
   222  		successState  State
   223  		expectedState State
   224  		recoveryLimit uint32
   225  	}{
   226  		{
   227  			desc:          "success in standby",
   228  			initialState:  StateStandby,
   229  			successState:  StateStandby,
   230  			expectedState: StateStandby,
   231  		},
   232  		{
   233  			desc:          "success in recovery below limit",
   234  			initialState:  StateRecovering,
   235  			successState:  StateRecovering,
   236  			expectedState: StateRecovering,
   237  			recoveryLimit: 2,
   238  		},
   239  		{
   240  			desc:          "success in recovery above limit",
   241  			initialState:  StateRecovering,
   242  			successState:  StateRecovering,
   243  			expectedState: StateStandby,
   244  			recoveryLimit: 1,
   245  		},
   246  	}
   247  
   248  	for _, tt := range cases {
   249  		t.Run(tt.desc, func(t *testing.T) {
   250  			clock := clockwork.NewFakeClock()
   251  			cb, err := New(Config{
   252  				Clock:         clock,
   253  				Interval:      time.Second,
   254  				RecoveryLimit: tt.recoveryLimit,
   255  				Trip:          StaticTripper(false),
   256  				Recover:       StaticTripper(false),
   257  			})
   258  			require.NoError(t, err)
   259  			cb.state = tt.initialState
   260  
   261  			generation, state := cb.currentState(clock.Now())
   262  			cb.successLocked(tt.successState, clock.Now())
   263  			require.Equal(t, tt.expectedState, cb.state)
   264  			if tt.expectedState != state {
   265  				require.NotEqual(t, generation, cb.generation)
   266  			}
   267  		})
   268  	}
   269  }
   270  
   271  func TestCircuitBreaker_failure(t *testing.T) {
   272  	t.Parallel()
   273  	cases := []struct {
   274  		desc           string
   275  		initialState   State
   276  		failureState   State
   277  		expectedState  State
   278  		tripFn         TripFn
   279  		recover        TripFn
   280  		onTrip         func(ch chan bool) func()
   281  		tripped        bool
   282  		requireTripped require.BoolAssertionFunc
   283  	}{
   284  		{
   285  			desc:           "failure in recovering transitions to tripped",
   286  			initialState:   StateRecovering,
   287  			failureState:   StateRecovering,
   288  			expectedState:  StateTripped,
   289  			tripFn:         StaticTripper(false),
   290  			recover:        StaticTripper(true),
   291  			requireTripped: require.False,
   292  		},
   293  		{
   294  			desc:           "failure in standby without tripping",
   295  			initialState:   StateStandby,
   296  			failureState:   StateStandby,
   297  			expectedState:  StateStandby,
   298  			tripFn:         StaticTripper(false),
   299  			recover:        StaticTripper(false),
   300  			requireTripped: require.False,
   301  		},
   302  		{
   303  			desc:           "failure in standby causes tripping",
   304  			initialState:   StateStandby,
   305  			failureState:   StateStandby,
   306  			expectedState:  StateTripped,
   307  			tripFn:         StaticTripper(true),
   308  			recover:        StaticTripper(false),
   309  			requireTripped: require.True,
   310  			onTrip: func(ch chan bool) func() {
   311  				return func() {
   312  					ch <- true
   313  				}
   314  			},
   315  		},
   316  	}
   317  
   318  	for _, tt := range cases {
   319  		tt := tt
   320  		t.Run(tt.desc, func(t *testing.T) {
   321  			t.Parallel()
   322  			clock := clockwork.NewFakeClock()
   323  
   324  			if tt.onTrip == nil {
   325  				tt.onTrip = func(ch chan bool) func() {
   326  					ch <- false
   327  					return func() {}
   328  				}
   329  			}
   330  
   331  			trippedCh := make(chan bool, 1)
   332  
   333  			cb, err := New(Config{
   334  				Clock:     clock,
   335  				Interval:  time.Second,
   336  				Trip:      tt.tripFn,
   337  				OnTripped: tt.onTrip(trippedCh),
   338  				Recover:   tt.recover,
   339  			})
   340  			require.NoError(t, err)
   341  			cb.state = tt.initialState
   342  
   343  			generation, state := cb.currentState(clock.Now())
   344  			cb.failureLocked(tt.failureState, clock.Now())
   345  			require.Equal(t, tt.expectedState, cb.state)
   346  			if tt.expectedState != state {
   347  				require.NotEqual(t, generation, cb.generation)
   348  			}
   349  
   350  			tripped := <-trippedCh
   351  
   352  			tt.requireTripped(t, tripped)
   353  		})
   354  	}
   355  }
   356  
   357  func TestCircuitBreaker_Execute(t *testing.T) {
   358  	t.Parallel()
   359  
   360  	clock := clockwork.NewFakeClock()
   361  
   362  	trippedCh := make(chan struct{})
   363  	onTripped := func(ch chan struct{}) func() {
   364  		return func() {
   365  			ch <- struct{}{}
   366  		}
   367  	}
   368  
   369  	cb, err := New(Config{
   370  		Clock:         clock,
   371  		Interval:      time.Second,
   372  		Trip:          ConsecutiveFailureTripper(3),
   373  		Recover:       ConsecutiveFailureTripper(1),
   374  		OnTripped:     onTripped(trippedCh),
   375  		TrippedPeriod: 2 * time.Second,
   376  		RecoveryLimit: 2,
   377  	})
   378  	require.NoError(t, err)
   379  
   380  	testErr := errors.New("failure")
   381  	errorFn := func() (interface{}, error) { return nil, testErr }
   382  	noErrorFn := func() (interface{}, error) { return nil, nil }
   383  	cases := []struct {
   384  		desc               string
   385  		exec               func() (interface{}, error)
   386  		advance            time.Duration
   387  		errorAssertion     require.ErrorAssertionFunc
   388  		expectedState      State
   389  		expectedGeneration uint64
   390  	}{
   391  		{
   392  			desc:               "no errors remain in standby",
   393  			exec:               noErrorFn,
   394  			errorAssertion:     require.NoError,
   395  			expectedState:      StateStandby,
   396  			expectedGeneration: 1,
   397  		},
   398  		{
   399  			desc:               "error below limit remain in standby",
   400  			exec:               errorFn,
   401  			errorAssertion:     require.Error,
   402  			expectedState:      StateStandby,
   403  			expectedGeneration: 1,
   404  		},
   405  		{
   406  			desc:               "another error below limit remain in standby",
   407  			exec:               errorFn,
   408  			errorAssertion:     require.Error,
   409  			expectedState:      StateStandby,
   410  			expectedGeneration: 1,
   411  		},
   412  		{
   413  			desc:               "last error below limit remain in standby",
   414  			exec:               errorFn,
   415  			errorAssertion:     require.Error,
   416  			expectedState:      StateStandby,
   417  			expectedGeneration: 1,
   418  		},
   419  		{
   420  			desc:               "transition from standby to tripped",
   421  			exec:               errorFn,
   422  			errorAssertion:     require.Error,
   423  			expectedState:      StateTripped,
   424  			expectedGeneration: 2,
   425  		},
   426  		{
   427  			desc:               "error remain tripped",
   428  			exec:               errorFn,
   429  			errorAssertion:     require.Error,
   430  			expectedState:      StateTripped,
   431  			expectedGeneration: 2,
   432  		},
   433  		{
   434  			desc:               "no error remain tripped",
   435  			exec:               noErrorFn,
   436  			errorAssertion:     require.Error,
   437  			expectedState:      StateTripped,
   438  			expectedGeneration: 2,
   439  		},
   440  		{
   441  			desc:               "transition from tripped to recovering",
   442  			exec:               noErrorFn,
   443  			errorAssertion:     require.NoError,
   444  			expectedState:      StateRecovering,
   445  			expectedGeneration: 3,
   446  			advance:            3 * time.Second,
   447  		},
   448  		{
   449  			desc:               "first failed execution recovering remains in recovering",
   450  			exec:               errorFn,
   451  			errorAssertion:     require.Error,
   452  			expectedState:      StateRecovering,
   453  			expectedGeneration: 3,
   454  			advance:            250 * time.Millisecond,
   455  		},
   456  		{
   457  			desc:               "second failed execution recovering transitions to tripped",
   458  			exec:               errorFn,
   459  			errorAssertion:     require.Error,
   460  			expectedState:      StateTripped,
   461  			expectedGeneration: 4,
   462  			advance:            450 * time.Millisecond,
   463  		},
   464  		{
   465  			desc:               "transition from tripped to recovering",
   466  			exec:               noErrorFn,
   467  			errorAssertion:     require.NoError,
   468  			expectedState:      StateRecovering,
   469  			expectedGeneration: 5,
   470  			advance:            3 * time.Second,
   471  		},
   472  		{
   473  			desc:               "transition from recovering to standby",
   474  			exec:               noErrorFn,
   475  			errorAssertion:     require.NoError,
   476  			expectedState:      StateStandby,
   477  			expectedGeneration: 6,
   478  			advance:            450 * time.Millisecond,
   479  		},
   480  		{
   481  			desc:               "remain in standby while in new generation",
   482  			exec:               noErrorFn,
   483  			errorAssertion:     require.NoError,
   484  			expectedState:      StateStandby,
   485  			expectedGeneration: 7,
   486  			advance:            time.Minute,
   487  		},
   488  	}
   489  
   490  	for i, tt := range cases {
   491  		t.Run(tt.desc, func(t *testing.T) {
   492  			clock.Advance(tt.advance)
   493  			_, err := cb.Execute(tt.exec)
   494  			tt.errorAssertion(t, err)
   495  			generation, state := cb.currentState(clock.Now())
   496  			require.Equal(t, tt.expectedGeneration, generation, "incorrect generation")
   497  			require.Equal(t, tt.expectedState, state, "incorrect state")
   498  
   499  			if state != StateTripped && tt.expectedState == StateTripped {
   500  				select {
   501  				case <-trippedCh:
   502  				default:
   503  					t.Fatalf("step %d expected to get tripped, but wasn't", i)
   504  				}
   505  			}
   506  		})
   507  	}
   508  
   509  }
   510  
   511  func TestMetrics(t *testing.T) {
   512  	m := Metrics{}
   513  
   514  	zero := uint32(0)
   515  	one := uint32(1)
   516  	require.Equal(t, zero, m.Executions)
   517  	require.Equal(t, zero, m.Successes)
   518  	require.Equal(t, zero, m.Failures)
   519  	require.Equal(t, zero, m.ConsecutiveSuccesses)
   520  	require.Equal(t, zero, m.ConsecutiveFailures)
   521  
   522  	m.success()
   523  
   524  	require.Equal(t, zero, m.Executions)
   525  	require.Equal(t, one, m.Successes)
   526  	require.Equal(t, zero, m.Failures)
   527  	require.Equal(t, one, m.ConsecutiveSuccesses)
   528  	require.Equal(t, zero, m.ConsecutiveFailures)
   529  
   530  	m.execute()
   531  
   532  	require.Equal(t, one, m.Executions)
   533  	require.Equal(t, one, m.Successes)
   534  	require.Equal(t, zero, m.Failures)
   535  	require.Equal(t, one, m.ConsecutiveSuccesses)
   536  	require.Equal(t, zero, m.ConsecutiveFailures)
   537  
   538  	m.failure()
   539  
   540  	require.Equal(t, one, m.Executions)
   541  	require.Equal(t, one, m.Successes)
   542  	require.Equal(t, one, m.Failures)
   543  	require.Equal(t, zero, m.ConsecutiveSuccesses)
   544  	require.Equal(t, one, m.ConsecutiveFailures)
   545  
   546  	m.reset()
   547  
   548  	require.Equal(t, zero, m.Executions)
   549  	require.Equal(t, zero, m.Successes)
   550  	require.Equal(t, zero, m.Failures)
   551  	require.Equal(t, zero, m.ConsecutiveSuccesses)
   552  	require.Equal(t, zero, m.ConsecutiveFailures)
   553  }
   554  
   555  func TestIsResponseSuccessful(t *testing.T) {
   556  	cases := []struct {
   557  		name      string
   558  		err       error
   559  		response  *http.Response
   560  		assertion require.BoolAssertionFunc
   561  	}{
   562  		{
   563  			name:      "nil error",
   564  			assertion: require.True,
   565  		},
   566  		{
   567  			name:      "codes.Canceled error",
   568  			err:       status.Error(codes.Canceled, ""),
   569  			assertion: require.False,
   570  		},
   571  		{
   572  			name:      "codes.Unknown error",
   573  			err:       status.Error(codes.Unknown, ""),
   574  			assertion: require.False,
   575  		},
   576  		{
   577  			name:      "codes.Unavailable error",
   578  			err:       status.Error(codes.Unavailable, ""),
   579  			assertion: require.False,
   580  		},
   581  		{
   582  			name:      "codes.Unavailable error",
   583  			err:       status.Error(codes.DeadlineExceeded, ""),
   584  			assertion: require.False,
   585  		},
   586  		{
   587  			name:      "other error",
   588  			err:       trace.NotFound("not found"),
   589  			assertion: require.False,
   590  		},
   591  		{
   592  			name:      "error",
   593  			err:       trace.NotFound(""),
   594  			assertion: require.False,
   595  		},
   596  		{
   597  			name:      "200",
   598  			response:  &http.Response{StatusCode: http.StatusOK},
   599  			assertion: require.True,
   600  		},
   601  		{
   602  			name:      "500",
   603  			response:  &http.Response{StatusCode: http.StatusBadGateway},
   604  			assertion: require.False,
   605  		},
   606  		{
   607  			name:      "404",
   608  			response:  &http.Response{StatusCode: http.StatusNotFound},
   609  			assertion: require.True,
   610  		},
   611  	}
   612  
   613  	for _, tt := range cases {
   614  		t.Run(tt.name, func(t *testing.T) {
   615  			tt.assertion(t, IsResponseSuccessful(tt.response, tt.err))
   616  		})
   617  	}
   618  }