github.com/gofiber/fiber/v2@v2.47.0/middleware/earlydata/earlydata_test.go (about)

     1  //nolint:bodyclose // Much easier to just ignore memory leaks in tests
     2  package earlydata_test
     3  
     4  import (
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"testing"
    10  
    11  	"github.com/gofiber/fiber/v2"
    12  	"github.com/gofiber/fiber/v2/middleware/earlydata"
    13  	"github.com/gofiber/fiber/v2/utils"
    14  )
    15  
    16  const (
    17  	headerName   = "Early-Data"
    18  	headerValOn  = "1"
    19  	headerValOff = "0"
    20  )
    21  
    22  func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App {
    23  	t.Helper()
    24  	t.Parallel()
    25  
    26  	var app *fiber.App
    27  	if c == nil {
    28  		app = fiber.New()
    29  	} else {
    30  		app = fiber.New(*c)
    31  	}
    32  
    33  	app.Use(earlydata.New())
    34  
    35  	// Middleware to test IsEarly func
    36  	const localsKeyTestValid = "earlydata_testvalid"
    37  	app.Use(func(c *fiber.Ctx) error {
    38  		isEarly := earlydata.IsEarly(c)
    39  
    40  		switch h := c.Get(headerName); h {
    41  		case "", headerValOff:
    42  			if isEarly {
    43  				return errors.New("is early-data even though it's not")
    44  			}
    45  
    46  		case headerValOn:
    47  			switch {
    48  			case fiber.IsMethodSafe(c.Method()):
    49  				if !isEarly {
    50  					return errors.New("should be early-data on safe HTTP methods")
    51  				}
    52  			default:
    53  				if isEarly {
    54  					return errors.New("early-data unsuported on unsafe HTTP methods")
    55  				}
    56  			}
    57  
    58  		default:
    59  			return fmt.Errorf("header has unsupported value: %s", h)
    60  		}
    61  
    62  		_ = c.Locals(localsKeyTestValid, true)
    63  
    64  		return c.Next()
    65  	})
    66  
    67  	{
    68  		{
    69  			handler := func(c *fiber.Ctx) error {
    70  				if !c.Locals(localsKeyTestValid).(bool) { //nolint:forcetypeassert // We store nothing else in the pool
    71  					return errors.New("handler called even though validation failed")
    72  				}
    73  
    74  				return nil
    75  			}
    76  
    77  			app.Get("/", handler)
    78  			app.Post("/", handler)
    79  		}
    80  	}
    81  
    82  	return app
    83  }
    84  
    85  // go test -run Test_EarlyData
    86  func Test_EarlyData(t *testing.T) {
    87  	t.Parallel()
    88  
    89  	trustedRun := func(t *testing.T, app *fiber.App) {
    90  		t.Helper()
    91  
    92  		{
    93  			req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
    94  
    95  			resp, err := app.Test(req)
    96  			utils.AssertEqual(t, nil, err)
    97  			utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
    98  
    99  			req.Header.Set(headerName, headerValOff)
   100  			resp, err = app.Test(req)
   101  			utils.AssertEqual(t, nil, err)
   102  			utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
   103  
   104  			req.Header.Set(headerName, headerValOn)
   105  			resp, err = app.Test(req)
   106  			utils.AssertEqual(t, nil, err)
   107  			utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
   108  		}
   109  
   110  		{
   111  			req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)
   112  
   113  			resp, err := app.Test(req)
   114  			utils.AssertEqual(t, nil, err)
   115  			utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
   116  
   117  			req.Header.Set(headerName, headerValOff)
   118  			resp, err = app.Test(req)
   119  			utils.AssertEqual(t, nil, err)
   120  			utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
   121  
   122  			req.Header.Set(headerName, headerValOn)
   123  			resp, err = app.Test(req)
   124  			utils.AssertEqual(t, nil, err)
   125  			utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
   126  		}
   127  	}
   128  
   129  	untrustedRun := func(t *testing.T, app *fiber.App) {
   130  		t.Helper()
   131  
   132  		{
   133  			req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
   134  
   135  			resp, err := app.Test(req)
   136  			utils.AssertEqual(t, nil, err)
   137  			utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
   138  
   139  			req.Header.Set(headerName, headerValOff)
   140  			resp, err = app.Test(req)
   141  			utils.AssertEqual(t, nil, err)
   142  			utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
   143  
   144  			req.Header.Set(headerName, headerValOn)
   145  			resp, err = app.Test(req)
   146  			utils.AssertEqual(t, nil, err)
   147  			utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
   148  		}
   149  
   150  		{
   151  			req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)
   152  
   153  			resp, err := app.Test(req)
   154  			utils.AssertEqual(t, nil, err)
   155  			utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
   156  
   157  			req.Header.Set(headerName, headerValOff)
   158  			resp, err = app.Test(req)
   159  			utils.AssertEqual(t, nil, err)
   160  			utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
   161  
   162  			req.Header.Set(headerName, headerValOn)
   163  			resp, err = app.Test(req)
   164  			utils.AssertEqual(t, nil, err)
   165  			utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
   166  		}
   167  	}
   168  
   169  	t.Run("empty config", func(t *testing.T) {
   170  		app := appWithConfig(t, nil)
   171  		trustedRun(t, app)
   172  	})
   173  	t.Run("default config", func(t *testing.T) {
   174  		app := appWithConfig(t, &fiber.Config{})
   175  		trustedRun(t, app)
   176  	})
   177  
   178  	t.Run("config with EnableTrustedProxyCheck", func(t *testing.T) {
   179  		app := appWithConfig(t, &fiber.Config{
   180  			EnableTrustedProxyCheck: true,
   181  		})
   182  		untrustedRun(t, app)
   183  	})
   184  	t.Run("config with EnableTrustedProxyCheck and trusted TrustedProxies", func(t *testing.T) {
   185  		app := appWithConfig(t, &fiber.Config{
   186  			EnableTrustedProxyCheck: true,
   187  			TrustedProxies: []string{
   188  				"0.0.0.0",
   189  			},
   190  		})
   191  		trustedRun(t, app)
   192  	})
   193  }