github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/retry/middleware_test.go (about)

     1  package retry
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/hellofresh/janus/pkg/test"
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  func TestMiddleware(t *testing.T) {
    14  	t.Parallel()
    15  
    16  	tests := []struct {
    17  		scenario string
    18  		function func(*testing.T, *http.Request, *httptest.ResponseRecorder)
    19  	}{
    20  		{
    21  			scenario: "with wrong predicate given",
    22  			function: testWrongPredicate,
    23  		},
    24  		{
    25  			scenario: "when the upstream respond successfully",
    26  			function: testSuccessfulUpstreamRetry,
    27  		},
    28  		{
    29  			scenario: "when the upstream fails to respond",
    30  			function: testFailedUpstreamRetry,
    31  		},
    32  	}
    33  
    34  	for _, test := range tests {
    35  		t.Run(test.scenario, func(t *testing.T) {
    36  			r := httptest.NewRequest(http.MethodGet, "/", nil)
    37  			w := httptest.NewRecorder()
    38  			test.function(t, r, w)
    39  		})
    40  	}
    41  }
    42  
    43  func testWrongPredicate(t *testing.T, r *http.Request, w *httptest.ResponseRecorder) {
    44  	cfg := Config{
    45  		Predicate: "this is wrong",
    46  	}
    47  	mw := NewRetryMiddleware(cfg)
    48  
    49  	mw(http.HandlerFunc(test.Ping)).ServeHTTP(w, r)
    50  
    51  	assert.Equal(t, http.StatusOK, w.Code)
    52  }
    53  
    54  func testSuccessfulUpstreamRetry(t *testing.T, r *http.Request, w *httptest.ResponseRecorder) {
    55  	mw := NewRetryMiddleware(Config{})
    56  
    57  	mw(http.HandlerFunc(test.Ping)).ServeHTTP(w, r)
    58  
    59  	assert.Equal(t, http.StatusOK, w.Code)
    60  }
    61  
    62  func testFailedUpstreamRetry(t *testing.T, r *http.Request, w *httptest.ResponseRecorder) {
    63  	mw := NewRetryMiddleware(Config{Attempts: 2, Backoff: Duration(time.Second)})
    64  
    65  	mw(test.FailWith(http.StatusBadGateway)).ServeHTTP(w, r)
    66  
    67  	assert.Equal(t, http.StatusBadGateway, w.Code)
    68  }