github.com/gofiber/fiber/v2@v2.47.0/middleware/earlydata/earlydata_test.go (about) 1 //nolint:bodyclose // Much easier to just ignore memory leaks in tests 2 package earlydata_test 3 4 import ( 5 "errors" 6 "fmt" 7 "net/http" 8 "net/http/httptest" 9 "testing" 10 11 "github.com/gofiber/fiber/v2" 12 "github.com/gofiber/fiber/v2/middleware/earlydata" 13 "github.com/gofiber/fiber/v2/utils" 14 ) 15 16 const ( 17 headerName = "Early-Data" 18 headerValOn = "1" 19 headerValOff = "0" 20 ) 21 22 func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App { 23 t.Helper() 24 t.Parallel() 25 26 var app *fiber.App 27 if c == nil { 28 app = fiber.New() 29 } else { 30 app = fiber.New(*c) 31 } 32 33 app.Use(earlydata.New()) 34 35 // Middleware to test IsEarly func 36 const localsKeyTestValid = "earlydata_testvalid" 37 app.Use(func(c *fiber.Ctx) error { 38 isEarly := earlydata.IsEarly(c) 39 40 switch h := c.Get(headerName); h { 41 case "", headerValOff: 42 if isEarly { 43 return errors.New("is early-data even though it's not") 44 } 45 46 case headerValOn: 47 switch { 48 case fiber.IsMethodSafe(c.Method()): 49 if !isEarly { 50 return errors.New("should be early-data on safe HTTP methods") 51 } 52 default: 53 if isEarly { 54 return errors.New("early-data unsuported on unsafe HTTP methods") 55 } 56 } 57 58 default: 59 return fmt.Errorf("header has unsupported value: %s", h) 60 } 61 62 _ = c.Locals(localsKeyTestValid, true) 63 64 return c.Next() 65 }) 66 67 { 68 { 69 handler := func(c *fiber.Ctx) error { 70 if !c.Locals(localsKeyTestValid).(bool) { //nolint:forcetypeassert // We store nothing else in the pool 71 return errors.New("handler called even though validation failed") 72 } 73 74 return nil 75 } 76 77 app.Get("/", handler) 78 app.Post("/", handler) 79 } 80 } 81 82 return app 83 } 84 85 // go test -run Test_EarlyData 86 func Test_EarlyData(t *testing.T) { 87 t.Parallel() 88 89 trustedRun := func(t *testing.T, app *fiber.App) { 90 t.Helper() 91 92 { 93 req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) 94 95 resp, err := app.Test(req) 96 utils.AssertEqual(t, nil, err) 97 utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) 98 99 req.Header.Set(headerName, headerValOff) 100 resp, err = app.Test(req) 101 utils.AssertEqual(t, nil, err) 102 utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) 103 104 req.Header.Set(headerName, headerValOn) 105 resp, err = app.Test(req) 106 utils.AssertEqual(t, nil, err) 107 utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) 108 } 109 110 { 111 req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody) 112 113 resp, err := app.Test(req) 114 utils.AssertEqual(t, nil, err) 115 utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) 116 117 req.Header.Set(headerName, headerValOff) 118 resp, err = app.Test(req) 119 utils.AssertEqual(t, nil, err) 120 utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) 121 122 req.Header.Set(headerName, headerValOn) 123 resp, err = app.Test(req) 124 utils.AssertEqual(t, nil, err) 125 utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) 126 } 127 } 128 129 untrustedRun := func(t *testing.T, app *fiber.App) { 130 t.Helper() 131 132 { 133 req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) 134 135 resp, err := app.Test(req) 136 utils.AssertEqual(t, nil, err) 137 utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) 138 139 req.Header.Set(headerName, headerValOff) 140 resp, err = app.Test(req) 141 utils.AssertEqual(t, nil, err) 142 utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) 143 144 req.Header.Set(headerName, headerValOn) 145 resp, err = app.Test(req) 146 utils.AssertEqual(t, nil, err) 147 utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) 148 } 149 150 { 151 req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody) 152 153 resp, err := app.Test(req) 154 utils.AssertEqual(t, nil, err) 155 utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) 156 157 req.Header.Set(headerName, headerValOff) 158 resp, err = app.Test(req) 159 utils.AssertEqual(t, nil, err) 160 utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) 161 162 req.Header.Set(headerName, headerValOn) 163 resp, err = app.Test(req) 164 utils.AssertEqual(t, nil, err) 165 utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) 166 } 167 } 168 169 t.Run("empty config", func(t *testing.T) { 170 app := appWithConfig(t, nil) 171 trustedRun(t, app) 172 }) 173 t.Run("default config", func(t *testing.T) { 174 app := appWithConfig(t, &fiber.Config{}) 175 trustedRun(t, app) 176 }) 177 178 t.Run("config with EnableTrustedProxyCheck", func(t *testing.T) { 179 app := appWithConfig(t, &fiber.Config{ 180 EnableTrustedProxyCheck: true, 181 }) 182 untrustedRun(t, app) 183 }) 184 t.Run("config with EnableTrustedProxyCheck and trusted TrustedProxies", func(t *testing.T) { 185 app := appWithConfig(t, &fiber.Config{ 186 EnableTrustedProxyCheck: true, 187 TrustedProxies: []string{ 188 "0.0.0.0", 189 }, 190 }) 191 trustedRun(t, app) 192 }) 193 }