github.com/gofiber/fiber/v2@v2.47.0/middleware/rewrite/rewrite_test.go (about)

     1  //nolint:bodyclose // Much easier to just ignore memory leaks in tests
     2  package rewrite
     3  
     4  import (
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"testing"
    10  
    11  	"github.com/gofiber/fiber/v2"
    12  	"github.com/gofiber/fiber/v2/utils"
    13  )
    14  
    15  func Test_New(t *testing.T) {
    16  	// Test with no config
    17  	m := New()
    18  
    19  	if m == nil {
    20  		t.Error("Expected middleware to be returned, got nil")
    21  	}
    22  
    23  	// Test with config
    24  	m = New(Config{
    25  		Rules: map[string]string{
    26  			"/old": "/new",
    27  		},
    28  	})
    29  
    30  	if m == nil {
    31  		t.Error("Expected middleware to be returned, got nil")
    32  	}
    33  
    34  	// Test with full config
    35  	m = New(Config{
    36  		Next: func(*fiber.Ctx) bool {
    37  			return true
    38  		},
    39  		Rules: map[string]string{
    40  			"/old": "/new",
    41  		},
    42  	})
    43  
    44  	if m == nil {
    45  		t.Error("Expected middleware to be returned, got nil")
    46  	}
    47  }
    48  
    49  func Test_Rewrite(t *testing.T) {
    50  	// Case 1: Next function always returns true
    51  	app := fiber.New()
    52  	app.Use(New(Config{
    53  		Next: func(*fiber.Ctx) bool {
    54  			return true
    55  		},
    56  		Rules: map[string]string{
    57  			"/old": "/new",
    58  		},
    59  	}))
    60  
    61  	app.Get("/old", func(c *fiber.Ctx) error {
    62  		return c.SendString("Rewrite Successful")
    63  	})
    64  
    65  	req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/old", nil)
    66  	utils.AssertEqual(t, err, nil)
    67  	resp, err := app.Test(req)
    68  	utils.AssertEqual(t, err, nil)
    69  	body, err := io.ReadAll(resp.Body)
    70  	utils.AssertEqual(t, err, nil)
    71  	bodyString := string(body)
    72  
    73  	utils.AssertEqual(t, err, nil)
    74  	utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
    75  	utils.AssertEqual(t, "Rewrite Successful", bodyString)
    76  
    77  	// Case 2: Next function always returns false
    78  	app = fiber.New()
    79  	app.Use(New(Config{
    80  		Next: func(*fiber.Ctx) bool {
    81  			return false
    82  		},
    83  		Rules: map[string]string{
    84  			"/old": "/new",
    85  		},
    86  	}))
    87  
    88  	app.Get("/new", func(c *fiber.Ctx) error {
    89  		return c.SendString("Rewrite Successful")
    90  	})
    91  
    92  	req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/old", nil)
    93  	utils.AssertEqual(t, err, nil)
    94  	resp, err = app.Test(req)
    95  	utils.AssertEqual(t, err, nil)
    96  	body, err = io.ReadAll(resp.Body)
    97  	utils.AssertEqual(t, err, nil)
    98  	bodyString = string(body)
    99  
   100  	utils.AssertEqual(t, err, nil)
   101  	utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
   102  	utils.AssertEqual(t, "Rewrite Successful", bodyString)
   103  
   104  	// Case 3: check for captured tokens in rewrite rule
   105  	app = fiber.New()
   106  	app.Use(New(Config{
   107  		Rules: map[string]string{
   108  			"/users/*/orders/*": "/user/$1/order/$2",
   109  		},
   110  	}))
   111  
   112  	app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error {
   113  		return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
   114  	})
   115  
   116  	req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/users/123/orders/456", nil)
   117  	utils.AssertEqual(t, err, nil)
   118  	resp, err = app.Test(req)
   119  	utils.AssertEqual(t, err, nil)
   120  	body, err = io.ReadAll(resp.Body)
   121  	utils.AssertEqual(t, err, nil)
   122  	bodyString = string(body)
   123  
   124  	utils.AssertEqual(t, err, nil)
   125  	utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
   126  	utils.AssertEqual(t, "User ID: 123, Order ID: 456", bodyString)
   127  
   128  	// Case 4: Send non-matching request, handled by default route
   129  	app = fiber.New()
   130  	app.Use(New(Config{
   131  		Rules: map[string]string{
   132  			"/users/*/orders/*": "/user/$1/order/$2",
   133  		},
   134  	}))
   135  
   136  	app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error {
   137  		return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
   138  	})
   139  
   140  	app.Use(func(c *fiber.Ctx) error {
   141  		return c.SendStatus(fiber.StatusOK)
   142  	})
   143  
   144  	req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/not-matching-any-rule", nil)
   145  	utils.AssertEqual(t, err, nil)
   146  	resp, err = app.Test(req)
   147  	utils.AssertEqual(t, err, nil)
   148  	body, err = io.ReadAll(resp.Body)
   149  	utils.AssertEqual(t, err, nil)
   150  	bodyString = string(body)
   151  
   152  	utils.AssertEqual(t, err, nil)
   153  	utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
   154  	utils.AssertEqual(t, "OK", bodyString)
   155  
   156  	// Case 4: Send non-matching request, with no default route
   157  	app = fiber.New()
   158  	app.Use(New(Config{
   159  		Rules: map[string]string{
   160  			"/users/*/orders/*": "/user/$1/order/$2",
   161  		},
   162  	}))
   163  
   164  	app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error {
   165  		return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
   166  	})
   167  
   168  	req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/not-matching-any-rule", nil)
   169  	utils.AssertEqual(t, err, nil)
   170  	resp, err = app.Test(req)
   171  	utils.AssertEqual(t, err, nil)
   172  	utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
   173  }