github.com/gofiber/fiber/v2@v2.47.0/middleware/timeout/timeout_test.go (about)

     1  package timeout
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/http/httptest"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/gofiber/fiber/v2"
    12  	"github.com/gofiber/fiber/v2/utils"
    13  )
    14  
    15  // go test -run Test_WithContextTimeout
    16  func Test_WithContextTimeout(t *testing.T) {
    17  	t.Parallel()
    18  	// fiber instance
    19  	app := fiber.New()
    20  	h := NewWithContext(func(c *fiber.Ctx) error {
    21  		sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
    22  		utils.AssertEqual(t, nil, err)
    23  		if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
    24  			return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
    25  		}
    26  		return nil
    27  	}, 100*time.Millisecond)
    28  	app.Get("/test/:sleepTime", h)
    29  	testTimeout := func(timeoutStr string) {
    30  		resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
    31  		utils.AssertEqual(t, nil, err, "app.Test(req)")
    32  		utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
    33  	}
    34  	testSucces := func(timeoutStr string) {
    35  		resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
    36  		utils.AssertEqual(t, nil, err, "app.Test(req)")
    37  		utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
    38  	}
    39  	testTimeout("300")
    40  	testTimeout("500")
    41  	testSucces("50")
    42  	testSucces("30")
    43  }
    44  
    45  var ErrFooTimeOut = errors.New("foo context canceled")
    46  
    47  // go test -run Test_WithContextTimeoutWithCustomError
    48  func Test_WithContextTimeoutWithCustomError(t *testing.T) {
    49  	t.Parallel()
    50  	// fiber instance
    51  	app := fiber.New()
    52  	h := NewWithContext(func(c *fiber.Ctx) error {
    53  		sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
    54  		utils.AssertEqual(t, nil, err)
    55  		if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
    56  			return fmt.Errorf("%w: execution error", err)
    57  		}
    58  		return nil
    59  	}, 100*time.Millisecond, ErrFooTimeOut)
    60  	app.Get("/test/:sleepTime", h)
    61  	testTimeout := func(timeoutStr string) {
    62  		resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
    63  		utils.AssertEqual(t, nil, err, "app.Test(req)")
    64  		utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
    65  	}
    66  	testSucces := func(timeoutStr string) {
    67  		resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
    68  		utils.AssertEqual(t, nil, err, "app.Test(req)")
    69  		utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
    70  	}
    71  	testTimeout("300")
    72  	testTimeout("500")
    73  	testSucces("50")
    74  	testSucces("30")
    75  }
    76  
    77  func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
    78  	timer := time.NewTimer(d)
    79  	select {
    80  	case <-ctx.Done():
    81  		if !timer.Stop() {
    82  			<-timer.C
    83  		}
    84  		return te
    85  	case <-timer.C:
    86  	}
    87  	return nil
    88  }