github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/breaker/round_tripper_test.go (about)

     1  // Copyright 2022 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package breaker
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/jonboulle/clockwork"
    25  	"github.com/stretchr/testify/require"
    26  )
    27  
    28  func TestRoundTripper_RoundTrip(t *testing.T) {
    29  	t.Parallel()
    30  
    31  	clock := clockwork.NewFakeClock()
    32  	cb, err := New(Config{
    33  		Clock:         clock,
    34  		RecoveryLimit: 1,
    35  		Interval:      time.Second,
    36  		TrippedPeriod: time.Second,
    37  		IsSuccessful: func(v interface{}, err error) bool {
    38  			if err != nil {
    39  				return false
    40  			}
    41  
    42  			if v == nil {
    43  				return false
    44  			}
    45  
    46  			switch t := v.(type) {
    47  			case *http.Response:
    48  				return t.StatusCode < http.StatusInternalServerError
    49  			}
    50  
    51  			return true
    52  		},
    53  		Trip:    ConsecutiveFailureTripper(1),
    54  		Recover: ConsecutiveFailureTripper(0),
    55  	})
    56  	require.NoError(t, err)
    57  
    58  	mux := http.NewServeMux()
    59  	mux.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) {
    60  		w.WriteHeader(http.StatusOK)
    61  	})
    62  	mux.HandleFunc("/fail", func(w http.ResponseWriter, r *http.Request) {
    63  		w.WriteHeader(http.StatusBadGateway)
    64  	})
    65  	srv := httptest.NewServer(mux)
    66  
    67  	clt := srv.Client()
    68  	clt.Transport = NewRoundTripper(cb, clt.Transport)
    69  
    70  	ctx := context.Background()
    71  
    72  	cases := []struct {
    73  		desc              string
    74  		url               string
    75  		state             State
    76  		advance           time.Duration
    77  		errorAssertion    require.ErrorAssertionFunc
    78  		responseAssertion require.ValueAssertionFunc
    79  	}{
    80  		{
    81  			desc:           "success in standby",
    82  			url:            "/success",
    83  			state:          StateStandby,
    84  			errorAssertion: require.NoError,
    85  			responseAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
    86  				require.Equal(t, http.StatusOK, i.(*http.Response).StatusCode)
    87  			},
    88  		},
    89  		{
    90  			desc:  "error when tripped",
    91  			url:   "/success",
    92  			state: StateTripped,
    93  			errorAssertion: func(t require.TestingT, err error, i ...interface{}) {
    94  				require.Error(t, err)
    95  				require.ErrorIs(t, err, ErrStateTripped)
    96  			},
    97  			responseAssertion: require.Nil,
    98  		},
    99  		{
   100  			desc:           "allowed request when recovery progresses",
   101  			url:            "/success",
   102  			state:          StateRecovering,
   103  			advance:        time.Minute,
   104  			errorAssertion: require.NoError,
   105  			responseAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
   106  				require.Equal(t, http.StatusOK, i.(*http.Response).StatusCode)
   107  			},
   108  		},
   109  		{
   110  			desc:           "failure in standby",
   111  			url:            "/fail",
   112  			state:          StateStandby,
   113  			errorAssertion: require.NoError,
   114  			responseAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
   115  				require.Equal(t, http.StatusBadGateway, i.(*http.Response).StatusCode)
   116  			},
   117  		},
   118  	}
   119  
   120  	for _, tt := range cases {
   121  		t.Run(tt.desc, func(t *testing.T) {
   122  			cb.setState(tt.state, clock.Now())
   123  			clock.Advance(tt.advance)
   124  
   125  			r, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+tt.url, nil)
   126  			require.NoError(t, err)
   127  			resp, err := clt.Do(r)
   128  			if resp != nil {
   129  				t.Cleanup(func() {
   130  					require.NoError(t, resp.Body.Close())
   131  				})
   132  			}
   133  			tt.errorAssertion(t, err)
   134  			tt.responseAssertion(t, resp)
   135  		})
   136  	}
   137  }