github.com/gofiber/fiber/v2@v2.47.0/middleware/adaptor/adaptor_test.go (about) 1 //nolint:bodyclose, contextcheck, revive // Much easier to just ignore memory leaks in tests 2 package adaptor 3 4 import ( 5 "context" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "net/http/httptest" 11 "net/url" 12 "reflect" 13 "testing" 14 15 "github.com/gofiber/fiber/v2" 16 "github.com/gofiber/fiber/v2/utils" 17 "github.com/valyala/fasthttp" 18 ) 19 20 func Test_HTTPHandler(t *testing.T) { 21 expectedMethod := fiber.MethodPost 22 expectedProto := "HTTP/1.1" 23 expectedProtoMajor := 1 24 expectedProtoMinor := 1 25 expectedRequestURI := "/foo/bar?baz=123" 26 expectedBody := "body 123 foo bar baz" 27 expectedContentLength := len(expectedBody) 28 expectedHost := "foobar.com" 29 expectedRemoteAddr := "1.2.3.4:6789" 30 expectedHeader := map[string]string{ 31 "Foo-Bar": "baz", 32 "Abc": "defg", 33 "XXX-Remote-Addr": "123.43.4543.345", 34 } 35 expectedURL, err := url.ParseRequestURI(expectedRequestURI) 36 if err != nil { 37 t.Fatalf("unexpected error: %s", err) 38 } 39 expectedContextKey := "contextKey" 40 expectedContextValue := "contextValue" 41 42 callsCount := 0 43 nethttpH := func(w http.ResponseWriter, r *http.Request) { 44 callsCount++ 45 if r.Method != expectedMethod { 46 t.Fatalf("unexpected method %q. Expecting %q", r.Method, expectedMethod) 47 } 48 if r.Proto != expectedProto { 49 t.Fatalf("unexpected proto %q. Expecting %q", r.Proto, expectedProto) 50 } 51 if r.ProtoMajor != expectedProtoMajor { 52 t.Fatalf("unexpected protoMajor %d. Expecting %d", r.ProtoMajor, expectedProtoMajor) 53 } 54 if r.ProtoMinor != expectedProtoMinor { 55 t.Fatalf("unexpected protoMinor %d. Expecting %d", r.ProtoMinor, expectedProtoMinor) 56 } 57 if r.RequestURI != expectedRequestURI { 58 t.Fatalf("unexpected requestURI %q. Expecting %q", r.RequestURI, expectedRequestURI) 59 } 60 if r.ContentLength != int64(expectedContentLength) { 61 t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength) 62 } 63 if len(r.TransferEncoding) != 0 { 64 t.Fatalf("unexpected transferEncoding %q. Expecting []", r.TransferEncoding) 65 } 66 if r.Host != expectedHost { 67 t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost) 68 } 69 if r.RemoteAddr != expectedRemoteAddr { 70 t.Fatalf("unexpected remoteAddr %q. Expecting %q", r.RemoteAddr, expectedRemoteAddr) 71 } 72 body, err := io.ReadAll(r.Body) 73 if err != nil { 74 t.Fatalf("unexpected error when reading request body: %s", err) 75 } 76 if string(body) != expectedBody { 77 t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) 78 } 79 if !reflect.DeepEqual(r.URL, expectedURL) { 80 t.Fatalf("unexpected URL: %#v. Expecting %#v", r.URL, expectedURL) 81 } 82 if r.Context().Value(expectedContextKey) != expectedContextValue { 83 t.Fatalf("unexpected context value for key %q. Expecting %q", expectedContextKey, expectedContextValue) 84 } 85 86 for k, expectedV := range expectedHeader { 87 v := r.Header.Get(k) 88 if v != expectedV { 89 t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV) 90 } 91 } 92 93 w.Header().Set("Header1", "value1") 94 w.Header().Set("Header2", "value2") 95 w.WriteHeader(http.StatusBadRequest) 96 fmt.Fprintf(w, "request body is %q", body) 97 } 98 fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH)) 99 fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue) 100 101 var fctx fasthttp.RequestCtx 102 var req fasthttp.Request 103 104 req.Header.SetMethod(expectedMethod) 105 req.SetRequestURI(expectedRequestURI) 106 req.Header.SetHost(expectedHost) 107 req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck, gosec // not needed 108 for k, v := range expectedHeader { 109 req.Header.Set(k, v) 110 } 111 112 remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr) 113 if err != nil { 114 t.Fatalf("unexpected error: %s", err) 115 } 116 fctx.Init(&req, remoteAddr, nil) 117 app := fiber.New() 118 ctx := app.AcquireCtx(&fctx) 119 defer app.ReleaseCtx(ctx) 120 121 err = fiberH(ctx) 122 if err != nil { 123 t.Fatalf("unexpected error: %s", err) 124 } 125 126 if callsCount != 1 { 127 t.Fatalf("unexpected callsCount: %d. Expecting 1", callsCount) 128 } 129 130 resp := &fctx.Response 131 if resp.StatusCode() != fiber.StatusBadRequest { 132 t.Fatalf("unexpected statusCode: %d. Expecting %d", resp.StatusCode(), fiber.StatusBadRequest) 133 } 134 if string(resp.Header.Peek("Header1")) != "value1" { 135 t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header1"), "value1") 136 } 137 if string(resp.Header.Peek("Header2")) != "value2" { 138 t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header2"), "value2") 139 } 140 expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody) 141 if string(resp.Body()) != expectedResponseBody { 142 t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedResponseBody) 143 } 144 } 145 146 type contextKey string 147 148 func (c contextKey) String() string { 149 return "test-" + string(c) 150 } 151 152 var ( 153 TestContextKey = contextKey("TestContextKey") 154 TestContextSecondKey = contextKey("TestContextSecondKey") 155 ) 156 157 func Test_HTTPMiddleware(t *testing.T) { 158 tests := []struct { 159 name string 160 url string 161 method string 162 statusCode int 163 }{ 164 { 165 name: "Should return 200", 166 url: "/", 167 method: "POST", 168 statusCode: 200, 169 }, 170 { 171 name: "Should return 405", 172 url: "/", 173 method: "GET", 174 statusCode: 405, 175 }, 176 { 177 name: "Should return 400", 178 url: "/unknown", 179 method: "POST", 180 statusCode: 404, 181 }, 182 } 183 184 nethttpMW := func(next http.Handler) http.Handler { 185 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 186 if r.Method != http.MethodPost { 187 w.WriteHeader(http.StatusMethodNotAllowed) 188 return 189 } 190 r = r.WithContext(context.WithValue(r.Context(), TestContextKey, "okay")) 191 r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "not_okay")) 192 r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "okay")) 193 194 next.ServeHTTP(w, r) 195 }) 196 } 197 198 app := fiber.New() 199 app.Use(HTTPMiddleware(nethttpMW)) 200 app.Post("/", func(c *fiber.Ctx) error { 201 value := c.Context().Value(TestContextKey) 202 val, ok := value.(string) 203 if !ok { 204 t.Error("unexpected error on type-assertion") 205 } 206 if value != nil { 207 c.Set("context_okay", val) 208 } 209 value = c.Context().Value(TestContextSecondKey) 210 if value != nil { 211 val, ok := value.(string) 212 if !ok { 213 t.Error("unexpected error on type-assertion") 214 } 215 c.Set("context_second_okay", val) 216 } 217 return c.SendStatus(fiber.StatusOK) 218 }) 219 220 for _, tt := range tests { 221 req, err := http.NewRequestWithContext(context.Background(), tt.method, tt.url, nil) 222 if err != nil { 223 t.Fatalf(`%s: %s`, t.Name(), err) 224 } 225 resp, err := app.Test(req) 226 if err != nil { 227 t.Fatalf(`%s: %s`, t.Name(), err) 228 } 229 if resp.StatusCode != tt.statusCode { 230 t.Fatalf(`%s: StatusCode: got %v - expected %v`, t.Name(), resp.StatusCode, tt.statusCode) 231 } 232 } 233 234 req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", nil) 235 if err != nil { 236 t.Fatalf(`%s: %s`, t.Name(), err) 237 } 238 resp, err := app.Test(req) 239 if err != nil { 240 t.Fatalf(`%s: %s`, t.Name(), err) 241 } 242 if resp.Header.Get("context_okay") != "okay" { 243 t.Fatalf(`%s: Header context_okay: got %v - expected %v`, t.Name(), resp.Header.Get("context_okay"), "okay") 244 } 245 if resp.Header.Get("context_second_okay") != "okay" { 246 t.Fatalf(`%s: Header context_second_okay: got %v - expected %v`, t.Name(), resp.Header.Get("context_second_okay"), "okay") 247 } 248 } 249 250 func Test_FiberHandler(t *testing.T) { 251 testFiberToHandlerFunc(t, false) 252 } 253 254 func Test_FiberApp(t *testing.T) { 255 testFiberToHandlerFunc(t, false, fiber.New()) 256 } 257 258 func Test_FiberHandlerDefaultPort(t *testing.T) { 259 testFiberToHandlerFunc(t, true) 260 } 261 262 func Test_FiberAppDefaultPort(t *testing.T) { 263 testFiberToHandlerFunc(t, true, fiber.New()) 264 } 265 266 func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.App) { 267 t.Helper() 268 269 expectedMethod := fiber.MethodPost 270 expectedRequestURI := "/foo/bar?baz=123" 271 expectedBody := "body 123 foo bar baz" 272 expectedContentLength := len(expectedBody) 273 expectedHost := "foobar.com" 274 expectedRemoteAddr := "1.2.3.4:6789" 275 if checkDefaultPort { 276 expectedRemoteAddr = "1.2.3.4:80" 277 } 278 expectedHeader := map[string]string{ 279 "Foo-Bar": "baz", 280 "Abc": "defg", 281 "XXX-Remote-Addr": "123.43.4543.345", 282 } 283 expectedURL, err := url.ParseRequestURI(expectedRequestURI) 284 if err != nil { 285 t.Fatalf("unexpected error: %s", err) 286 } 287 288 callsCount := 0 289 fiberH := func(c *fiber.Ctx) error { 290 callsCount++ 291 if c.Method() != expectedMethod { 292 t.Fatalf("unexpected method %q. Expecting %q", c.Method(), expectedMethod) 293 } 294 if string(c.Context().RequestURI()) != expectedRequestURI { 295 t.Fatalf("unexpected requestURI %q. Expecting %q", string(c.Context().RequestURI()), expectedRequestURI) 296 } 297 contentLength := c.Context().Request.Header.ContentLength() 298 if contentLength != expectedContentLength { 299 t.Fatalf("unexpected contentLength %d. Expecting %d", contentLength, expectedContentLength) 300 } 301 if c.Hostname() != expectedHost { 302 t.Fatalf("unexpected host %q. Expecting %q", c.Hostname(), expectedHost) 303 } 304 remoteAddr := c.Context().RemoteAddr().String() 305 if remoteAddr != expectedRemoteAddr { 306 t.Fatalf("unexpected remoteAddr %q. Expecting %q", remoteAddr, expectedRemoteAddr) 307 } 308 body := string(c.Body()) 309 if body != expectedBody { 310 t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) 311 } 312 if c.OriginalURL() != expectedURL.String() { 313 t.Fatalf("unexpected URL: %#v. Expecting %#v", c.OriginalURL(), expectedURL) 314 } 315 316 for k, expectedV := range expectedHeader { 317 v := c.Get(k) 318 if v != expectedV { 319 t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV) 320 } 321 } 322 323 c.Set("Header1", "value1") 324 c.Set("Header2", "value2") 325 c.Status(fiber.StatusBadRequest) 326 _, err := c.Write([]byte(fmt.Sprintf("request body is %q", body))) 327 return err 328 } 329 330 var handlerFunc http.HandlerFunc 331 if len(app) > 0 { 332 app[0].Post("/foo/bar", fiberH) 333 handlerFunc = FiberApp(app[0]) 334 } else { 335 handlerFunc = FiberHandlerFunc(fiberH) 336 } 337 338 var r http.Request 339 340 r.Method = expectedMethod 341 r.Body = &netHTTPBody{[]byte(expectedBody)} 342 r.RequestURI = expectedRequestURI 343 r.ContentLength = int64(expectedContentLength) 344 r.Host = expectedHost 345 r.RemoteAddr = expectedRemoteAddr 346 if checkDefaultPort { 347 r.RemoteAddr = "1.2.3.4" 348 } 349 350 hdr := make(http.Header) 351 for k, v := range expectedHeader { 352 hdr.Set(k, v) 353 } 354 r.Header = hdr 355 356 var w netHTTPResponseWriter 357 handlerFunc.ServeHTTP(&w, &r) 358 359 if w.StatusCode() != http.StatusBadRequest { 360 t.Fatalf("unexpected statusCode: %d. Expecting %d", w.StatusCode(), http.StatusBadRequest) 361 } 362 if w.Header().Get("Header1") != "value1" { 363 t.Fatalf("unexpected header value: %q. Expecting %q", w.Header().Get("Header1"), "value1") 364 } 365 if w.Header().Get("Header2") != "value2" { 366 t.Fatalf("unexpected header value: %q. Expecting %q", w.Header().Get("Header2"), "value2") 367 } 368 expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody) 369 if string(w.body) != expectedResponseBody { 370 t.Fatalf("unexpected response body %q. Expecting %q", string(w.body), expectedResponseBody) 371 } 372 } 373 374 func setFiberContextValueMiddleware(next fiber.Handler, key string, value interface{}) fiber.Handler { 375 return func(c *fiber.Ctx) error { 376 c.Locals(key, value) 377 return next(c) 378 } 379 } 380 381 func Test_FiberHandler_RequestNilBody(t *testing.T) { 382 expectedMethod := fiber.MethodGet 383 expectedRequestURI := "/foo/bar" 384 expectedContentLength := 0 385 386 callsCount := 0 387 fiberH := func(c *fiber.Ctx) error { 388 callsCount++ 389 if c.Method() != expectedMethod { 390 t.Fatalf("unexpected method %q. Expecting %q", c.Method(), expectedMethod) 391 } 392 if string(c.Request().RequestURI()) != expectedRequestURI { 393 t.Fatalf("unexpected requestURI %q. Expecting %q", string(c.Request().RequestURI()), expectedRequestURI) 394 } 395 contentLength := c.Request().Header.ContentLength() 396 if contentLength != expectedContentLength { 397 t.Fatalf("unexpected contentLength %d. Expecting %d", contentLength, expectedContentLength) 398 } 399 400 _, err := c.Write([]byte("request body is nil")) 401 return err 402 } 403 nethttpH := FiberHandler(fiberH) 404 405 var r http.Request 406 407 r.Method = expectedMethod 408 r.RequestURI = expectedRequestURI 409 410 var w netHTTPResponseWriter 411 nethttpH.ServeHTTP(&w, &r) 412 413 expectedResponseBody := "request body is nil" 414 if string(w.body) != expectedResponseBody { 415 t.Fatalf("unexpected response body %q. Expecting %q", string(w.body), expectedResponseBody) 416 } 417 } 418 419 type netHTTPBody struct { 420 b []byte 421 } 422 423 func (r *netHTTPBody) Read(p []byte) (int, error) { 424 if len(r.b) == 0 { 425 return 0, io.EOF 426 } 427 n := copy(p, r.b) 428 r.b = r.b[n:] 429 return n, nil 430 } 431 432 func (r *netHTTPBody) Close() error { 433 r.b = r.b[:0] 434 return nil 435 } 436 437 type netHTTPResponseWriter struct { 438 statusCode int 439 h http.Header 440 body []byte 441 } 442 443 func (w *netHTTPResponseWriter) StatusCode() int { 444 if w.statusCode == 0 { 445 return http.StatusOK 446 } 447 return w.statusCode 448 } 449 450 func (w *netHTTPResponseWriter) Header() http.Header { 451 if w.h == nil { 452 w.h = make(http.Header) 453 } 454 return w.h 455 } 456 457 func (w *netHTTPResponseWriter) WriteHeader(statusCode int) { 458 w.statusCode = statusCode 459 } 460 461 func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { 462 w.body = append(w.body, p...) 463 return len(p), nil 464 } 465 466 func Test_ConvertRequest(t *testing.T) { 467 t.Parallel() 468 469 app := fiber.New() 470 471 app.Get("/test", func(c *fiber.Ctx) error { 472 httpReq, err := ConvertRequest(c, false) 473 if err != nil { 474 return err 475 } 476 477 return c.SendString("Request URL: " + httpReq.URL.String()) 478 }) 479 480 resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", http.NoBody)) 481 utils.AssertEqual(t, nil, err, "app.Test(req)") 482 utils.AssertEqual(t, http.StatusOK, resp.StatusCode, "Status code") 483 484 body, err := io.ReadAll(resp.Body) 485 utils.AssertEqual(t, nil, err) 486 utils.AssertEqual(t, "Request URL: /test?hello=world&another=test", string(body)) 487 }