github.com/rudderlabs/rudder-go-kit@v0.30.0/testhelper/assert/assert.go (about)

     1  package assert
     2  
     3  import (
     4  	"io"
     5  	"net/http"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/require"
    10  )
    11  
    12  type (
    13  	ResponseBody            = string
    14  	RequireStatusCodeOption func(*requireStatusCodeConfig)
    15  )
    16  
    17  type requireStatusCodeConfig struct {
    18  	waitFor      time.Duration
    19  	pollInterval time.Duration
    20  	httpClient   *http.Client
    21  }
    22  
    23  func (c *requireStatusCodeConfig) reset() {
    24  	c.waitFor = 10 * time.Second
    25  	c.pollInterval = 100 * time.Millisecond
    26  	c.httpClient = http.DefaultClient
    27  }
    28  
    29  func WithRequireStatusCodeWaitFor(waitFor time.Duration) RequireStatusCodeOption {
    30  	return func(c *requireStatusCodeConfig) {
    31  		c.waitFor = waitFor
    32  	}
    33  }
    34  
    35  func WithRequireStatusCodePollInterval(pollInterval time.Duration) RequireStatusCodeOption {
    36  	return func(c *requireStatusCodeConfig) {
    37  		c.pollInterval = pollInterval
    38  	}
    39  }
    40  
    41  func WithRequireStatusCodeHTTPClient(httpClient *http.Client) RequireStatusCodeOption {
    42  	return func(c *requireStatusCodeConfig) {
    43  		c.httpClient = httpClient
    44  	}
    45  }
    46  
    47  // RequireEventuallyStatusCode is a helper function that retries a request until the expected status code is returned.
    48  func RequireEventuallyStatusCode(
    49  	t *testing.T, expectedStatusCode int, r *http.Request, opts ...RequireStatusCodeOption,
    50  ) ResponseBody {
    51  	t.Helper()
    52  
    53  	var config requireStatusCodeConfig
    54  	config.reset()
    55  	for _, opt := range opts {
    56  		opt(&config)
    57  	}
    58  
    59  	var (
    60  		body             []byte
    61  		actualStatusCode int
    62  	)
    63  	require.Eventuallyf(t,
    64  		func() bool {
    65  			resp, err := config.httpClient.Do(r)
    66  			if err != nil {
    67  				return false
    68  			}
    69  
    70  			defer func() { _ = resp.Body.Close() }()
    71  
    72  			body, err = io.ReadAll(resp.Body)
    73  			if err != nil {
    74  				return false
    75  			}
    76  
    77  			actualStatusCode = resp.StatusCode
    78  			return expectedStatusCode == actualStatusCode
    79  		},
    80  		config.waitFor, config.pollInterval,
    81  		"Expected status code %d, got %d. Body: %s",
    82  		expectedStatusCode, actualStatusCode, string(body),
    83  	)
    84  
    85  	return string(body)
    86  }