github.com/gofiber/fiber/v2@v2.47.0/middleware/adaptor/adaptor_test.go (about)

     1  //nolint:bodyclose, contextcheck, revive // Much easier to just ignore memory leaks in tests
     2  package adaptor
     3  
     4  import (
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"net/url"
    12  	"reflect"
    13  	"testing"
    14  
    15  	"github.com/gofiber/fiber/v2"
    16  	"github.com/gofiber/fiber/v2/utils"
    17  	"github.com/valyala/fasthttp"
    18  )
    19  
    20  func Test_HTTPHandler(t *testing.T) {
    21  	expectedMethod := fiber.MethodPost
    22  	expectedProto := "HTTP/1.1"
    23  	expectedProtoMajor := 1
    24  	expectedProtoMinor := 1
    25  	expectedRequestURI := "/foo/bar?baz=123"
    26  	expectedBody := "body 123 foo bar baz"
    27  	expectedContentLength := len(expectedBody)
    28  	expectedHost := "foobar.com"
    29  	expectedRemoteAddr := "1.2.3.4:6789"
    30  	expectedHeader := map[string]string{
    31  		"Foo-Bar":         "baz",
    32  		"Abc":             "defg",
    33  		"XXX-Remote-Addr": "123.43.4543.345",
    34  	}
    35  	expectedURL, err := url.ParseRequestURI(expectedRequestURI)
    36  	if err != nil {
    37  		t.Fatalf("unexpected error: %s", err)
    38  	}
    39  	expectedContextKey := "contextKey"
    40  	expectedContextValue := "contextValue"
    41  
    42  	callsCount := 0
    43  	nethttpH := func(w http.ResponseWriter, r *http.Request) {
    44  		callsCount++
    45  		if r.Method != expectedMethod {
    46  			t.Fatalf("unexpected method %q. Expecting %q", r.Method, expectedMethod)
    47  		}
    48  		if r.Proto != expectedProto {
    49  			t.Fatalf("unexpected proto %q. Expecting %q", r.Proto, expectedProto)
    50  		}
    51  		if r.ProtoMajor != expectedProtoMajor {
    52  			t.Fatalf("unexpected protoMajor %d. Expecting %d", r.ProtoMajor, expectedProtoMajor)
    53  		}
    54  		if r.ProtoMinor != expectedProtoMinor {
    55  			t.Fatalf("unexpected protoMinor %d. Expecting %d", r.ProtoMinor, expectedProtoMinor)
    56  		}
    57  		if r.RequestURI != expectedRequestURI {
    58  			t.Fatalf("unexpected requestURI %q. Expecting %q", r.RequestURI, expectedRequestURI)
    59  		}
    60  		if r.ContentLength != int64(expectedContentLength) {
    61  			t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength)
    62  		}
    63  		if len(r.TransferEncoding) != 0 {
    64  			t.Fatalf("unexpected transferEncoding %q. Expecting []", r.TransferEncoding)
    65  		}
    66  		if r.Host != expectedHost {
    67  			t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost)
    68  		}
    69  		if r.RemoteAddr != expectedRemoteAddr {
    70  			t.Fatalf("unexpected remoteAddr %q. Expecting %q", r.RemoteAddr, expectedRemoteAddr)
    71  		}
    72  		body, err := io.ReadAll(r.Body)
    73  		if err != nil {
    74  			t.Fatalf("unexpected error when reading request body: %s", err)
    75  		}
    76  		if string(body) != expectedBody {
    77  			t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
    78  		}
    79  		if !reflect.DeepEqual(r.URL, expectedURL) {
    80  			t.Fatalf("unexpected URL: %#v. Expecting %#v", r.URL, expectedURL)
    81  		}
    82  		if r.Context().Value(expectedContextKey) != expectedContextValue {
    83  			t.Fatalf("unexpected context value for key %q. Expecting %q", expectedContextKey, expectedContextValue)
    84  		}
    85  
    86  		for k, expectedV := range expectedHeader {
    87  			v := r.Header.Get(k)
    88  			if v != expectedV {
    89  				t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV)
    90  			}
    91  		}
    92  
    93  		w.Header().Set("Header1", "value1")
    94  		w.Header().Set("Header2", "value2")
    95  		w.WriteHeader(http.StatusBadRequest)
    96  		fmt.Fprintf(w, "request body is %q", body)
    97  	}
    98  	fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH))
    99  	fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue)
   100  
   101  	var fctx fasthttp.RequestCtx
   102  	var req fasthttp.Request
   103  
   104  	req.Header.SetMethod(expectedMethod)
   105  	req.SetRequestURI(expectedRequestURI)
   106  	req.Header.SetHost(expectedHost)
   107  	req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck, gosec // not needed
   108  	for k, v := range expectedHeader {
   109  		req.Header.Set(k, v)
   110  	}
   111  
   112  	remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
   113  	if err != nil {
   114  		t.Fatalf("unexpected error: %s", err)
   115  	}
   116  	fctx.Init(&req, remoteAddr, nil)
   117  	app := fiber.New()
   118  	ctx := app.AcquireCtx(&fctx)
   119  	defer app.ReleaseCtx(ctx)
   120  
   121  	err = fiberH(ctx)
   122  	if err != nil {
   123  		t.Fatalf("unexpected error: %s", err)
   124  	}
   125  
   126  	if callsCount != 1 {
   127  		t.Fatalf("unexpected callsCount: %d. Expecting 1", callsCount)
   128  	}
   129  
   130  	resp := &fctx.Response
   131  	if resp.StatusCode() != fiber.StatusBadRequest {
   132  		t.Fatalf("unexpected statusCode: %d. Expecting %d", resp.StatusCode(), fiber.StatusBadRequest)
   133  	}
   134  	if string(resp.Header.Peek("Header1")) != "value1" {
   135  		t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header1"), "value1")
   136  	}
   137  	if string(resp.Header.Peek("Header2")) != "value2" {
   138  		t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header2"), "value2")
   139  	}
   140  	expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
   141  	if string(resp.Body()) != expectedResponseBody {
   142  		t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedResponseBody)
   143  	}
   144  }
   145  
   146  type contextKey string
   147  
   148  func (c contextKey) String() string {
   149  	return "test-" + string(c)
   150  }
   151  
   152  var (
   153  	TestContextKey       = contextKey("TestContextKey")
   154  	TestContextSecondKey = contextKey("TestContextSecondKey")
   155  )
   156  
   157  func Test_HTTPMiddleware(t *testing.T) {
   158  	tests := []struct {
   159  		name       string
   160  		url        string
   161  		method     string
   162  		statusCode int
   163  	}{
   164  		{
   165  			name:       "Should return 200",
   166  			url:        "/",
   167  			method:     "POST",
   168  			statusCode: 200,
   169  		},
   170  		{
   171  			name:       "Should return 405",
   172  			url:        "/",
   173  			method:     "GET",
   174  			statusCode: 405,
   175  		},
   176  		{
   177  			name:       "Should return 400",
   178  			url:        "/unknown",
   179  			method:     "POST",
   180  			statusCode: 404,
   181  		},
   182  	}
   183  
   184  	nethttpMW := func(next http.Handler) http.Handler {
   185  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   186  			if r.Method != http.MethodPost {
   187  				w.WriteHeader(http.StatusMethodNotAllowed)
   188  				return
   189  			}
   190  			r = r.WithContext(context.WithValue(r.Context(), TestContextKey, "okay"))
   191  			r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "not_okay"))
   192  			r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "okay"))
   193  
   194  			next.ServeHTTP(w, r)
   195  		})
   196  	}
   197  
   198  	app := fiber.New()
   199  	app.Use(HTTPMiddleware(nethttpMW))
   200  	app.Post("/", func(c *fiber.Ctx) error {
   201  		value := c.Context().Value(TestContextKey)
   202  		val, ok := value.(string)
   203  		if !ok {
   204  			t.Error("unexpected error on type-assertion")
   205  		}
   206  		if value != nil {
   207  			c.Set("context_okay", val)
   208  		}
   209  		value = c.Context().Value(TestContextSecondKey)
   210  		if value != nil {
   211  			val, ok := value.(string)
   212  			if !ok {
   213  				t.Error("unexpected error on type-assertion")
   214  			}
   215  			c.Set("context_second_okay", val)
   216  		}
   217  		return c.SendStatus(fiber.StatusOK)
   218  	})
   219  
   220  	for _, tt := range tests {
   221  		req, err := http.NewRequestWithContext(context.Background(), tt.method, tt.url, nil)
   222  		if err != nil {
   223  			t.Fatalf(`%s: %s`, t.Name(), err)
   224  		}
   225  		resp, err := app.Test(req)
   226  		if err != nil {
   227  			t.Fatalf(`%s: %s`, t.Name(), err)
   228  		}
   229  		if resp.StatusCode != tt.statusCode {
   230  			t.Fatalf(`%s: StatusCode: got %v - expected %v`, t.Name(), resp.StatusCode, tt.statusCode)
   231  		}
   232  	}
   233  
   234  	req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", nil)
   235  	if err != nil {
   236  		t.Fatalf(`%s: %s`, t.Name(), err)
   237  	}
   238  	resp, err := app.Test(req)
   239  	if err != nil {
   240  		t.Fatalf(`%s: %s`, t.Name(), err)
   241  	}
   242  	if resp.Header.Get("context_okay") != "okay" {
   243  		t.Fatalf(`%s: Header context_okay: got %v - expected %v`, t.Name(), resp.Header.Get("context_okay"), "okay")
   244  	}
   245  	if resp.Header.Get("context_second_okay") != "okay" {
   246  		t.Fatalf(`%s: Header context_second_okay: got %v - expected %v`, t.Name(), resp.Header.Get("context_second_okay"), "okay")
   247  	}
   248  }
   249  
   250  func Test_FiberHandler(t *testing.T) {
   251  	testFiberToHandlerFunc(t, false)
   252  }
   253  
   254  func Test_FiberApp(t *testing.T) {
   255  	testFiberToHandlerFunc(t, false, fiber.New())
   256  }
   257  
   258  func Test_FiberHandlerDefaultPort(t *testing.T) {
   259  	testFiberToHandlerFunc(t, true)
   260  }
   261  
   262  func Test_FiberAppDefaultPort(t *testing.T) {
   263  	testFiberToHandlerFunc(t, true, fiber.New())
   264  }
   265  
   266  func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.App) {
   267  	t.Helper()
   268  
   269  	expectedMethod := fiber.MethodPost
   270  	expectedRequestURI := "/foo/bar?baz=123"
   271  	expectedBody := "body 123 foo bar baz"
   272  	expectedContentLength := len(expectedBody)
   273  	expectedHost := "foobar.com"
   274  	expectedRemoteAddr := "1.2.3.4:6789"
   275  	if checkDefaultPort {
   276  		expectedRemoteAddr = "1.2.3.4:80"
   277  	}
   278  	expectedHeader := map[string]string{
   279  		"Foo-Bar":         "baz",
   280  		"Abc":             "defg",
   281  		"XXX-Remote-Addr": "123.43.4543.345",
   282  	}
   283  	expectedURL, err := url.ParseRequestURI(expectedRequestURI)
   284  	if err != nil {
   285  		t.Fatalf("unexpected error: %s", err)
   286  	}
   287  
   288  	callsCount := 0
   289  	fiberH := func(c *fiber.Ctx) error {
   290  		callsCount++
   291  		if c.Method() != expectedMethod {
   292  			t.Fatalf("unexpected method %q. Expecting %q", c.Method(), expectedMethod)
   293  		}
   294  		if string(c.Context().RequestURI()) != expectedRequestURI {
   295  			t.Fatalf("unexpected requestURI %q. Expecting %q", string(c.Context().RequestURI()), expectedRequestURI)
   296  		}
   297  		contentLength := c.Context().Request.Header.ContentLength()
   298  		if contentLength != expectedContentLength {
   299  			t.Fatalf("unexpected contentLength %d. Expecting %d", contentLength, expectedContentLength)
   300  		}
   301  		if c.Hostname() != expectedHost {
   302  			t.Fatalf("unexpected host %q. Expecting %q", c.Hostname(), expectedHost)
   303  		}
   304  		remoteAddr := c.Context().RemoteAddr().String()
   305  		if remoteAddr != expectedRemoteAddr {
   306  			t.Fatalf("unexpected remoteAddr %q. Expecting %q", remoteAddr, expectedRemoteAddr)
   307  		}
   308  		body := string(c.Body())
   309  		if body != expectedBody {
   310  			t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
   311  		}
   312  		if c.OriginalURL() != expectedURL.String() {
   313  			t.Fatalf("unexpected URL: %#v. Expecting %#v", c.OriginalURL(), expectedURL)
   314  		}
   315  
   316  		for k, expectedV := range expectedHeader {
   317  			v := c.Get(k)
   318  			if v != expectedV {
   319  				t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV)
   320  			}
   321  		}
   322  
   323  		c.Set("Header1", "value1")
   324  		c.Set("Header2", "value2")
   325  		c.Status(fiber.StatusBadRequest)
   326  		_, err := c.Write([]byte(fmt.Sprintf("request body is %q", body)))
   327  		return err
   328  	}
   329  
   330  	var handlerFunc http.HandlerFunc
   331  	if len(app) > 0 {
   332  		app[0].Post("/foo/bar", fiberH)
   333  		handlerFunc = FiberApp(app[0])
   334  	} else {
   335  		handlerFunc = FiberHandlerFunc(fiberH)
   336  	}
   337  
   338  	var r http.Request
   339  
   340  	r.Method = expectedMethod
   341  	r.Body = &netHTTPBody{[]byte(expectedBody)}
   342  	r.RequestURI = expectedRequestURI
   343  	r.ContentLength = int64(expectedContentLength)
   344  	r.Host = expectedHost
   345  	r.RemoteAddr = expectedRemoteAddr
   346  	if checkDefaultPort {
   347  		r.RemoteAddr = "1.2.3.4"
   348  	}
   349  
   350  	hdr := make(http.Header)
   351  	for k, v := range expectedHeader {
   352  		hdr.Set(k, v)
   353  	}
   354  	r.Header = hdr
   355  
   356  	var w netHTTPResponseWriter
   357  	handlerFunc.ServeHTTP(&w, &r)
   358  
   359  	if w.StatusCode() != http.StatusBadRequest {
   360  		t.Fatalf("unexpected statusCode: %d. Expecting %d", w.StatusCode(), http.StatusBadRequest)
   361  	}
   362  	if w.Header().Get("Header1") != "value1" {
   363  		t.Fatalf("unexpected header value: %q. Expecting %q", w.Header().Get("Header1"), "value1")
   364  	}
   365  	if w.Header().Get("Header2") != "value2" {
   366  		t.Fatalf("unexpected header value: %q. Expecting %q", w.Header().Get("Header2"), "value2")
   367  	}
   368  	expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
   369  	if string(w.body) != expectedResponseBody {
   370  		t.Fatalf("unexpected response body %q. Expecting %q", string(w.body), expectedResponseBody)
   371  	}
   372  }
   373  
   374  func setFiberContextValueMiddleware(next fiber.Handler, key string, value interface{}) fiber.Handler {
   375  	return func(c *fiber.Ctx) error {
   376  		c.Locals(key, value)
   377  		return next(c)
   378  	}
   379  }
   380  
   381  func Test_FiberHandler_RequestNilBody(t *testing.T) {
   382  	expectedMethod := fiber.MethodGet
   383  	expectedRequestURI := "/foo/bar"
   384  	expectedContentLength := 0
   385  
   386  	callsCount := 0
   387  	fiberH := func(c *fiber.Ctx) error {
   388  		callsCount++
   389  		if c.Method() != expectedMethod {
   390  			t.Fatalf("unexpected method %q. Expecting %q", c.Method(), expectedMethod)
   391  		}
   392  		if string(c.Request().RequestURI()) != expectedRequestURI {
   393  			t.Fatalf("unexpected requestURI %q. Expecting %q", string(c.Request().RequestURI()), expectedRequestURI)
   394  		}
   395  		contentLength := c.Request().Header.ContentLength()
   396  		if contentLength != expectedContentLength {
   397  			t.Fatalf("unexpected contentLength %d. Expecting %d", contentLength, expectedContentLength)
   398  		}
   399  
   400  		_, err := c.Write([]byte("request body is nil"))
   401  		return err
   402  	}
   403  	nethttpH := FiberHandler(fiberH)
   404  
   405  	var r http.Request
   406  
   407  	r.Method = expectedMethod
   408  	r.RequestURI = expectedRequestURI
   409  
   410  	var w netHTTPResponseWriter
   411  	nethttpH.ServeHTTP(&w, &r)
   412  
   413  	expectedResponseBody := "request body is nil"
   414  	if string(w.body) != expectedResponseBody {
   415  		t.Fatalf("unexpected response body %q. Expecting %q", string(w.body), expectedResponseBody)
   416  	}
   417  }
   418  
   419  type netHTTPBody struct {
   420  	b []byte
   421  }
   422  
   423  func (r *netHTTPBody) Read(p []byte) (int, error) {
   424  	if len(r.b) == 0 {
   425  		return 0, io.EOF
   426  	}
   427  	n := copy(p, r.b)
   428  	r.b = r.b[n:]
   429  	return n, nil
   430  }
   431  
   432  func (r *netHTTPBody) Close() error {
   433  	r.b = r.b[:0]
   434  	return nil
   435  }
   436  
   437  type netHTTPResponseWriter struct {
   438  	statusCode int
   439  	h          http.Header
   440  	body       []byte
   441  }
   442  
   443  func (w *netHTTPResponseWriter) StatusCode() int {
   444  	if w.statusCode == 0 {
   445  		return http.StatusOK
   446  	}
   447  	return w.statusCode
   448  }
   449  
   450  func (w *netHTTPResponseWriter) Header() http.Header {
   451  	if w.h == nil {
   452  		w.h = make(http.Header)
   453  	}
   454  	return w.h
   455  }
   456  
   457  func (w *netHTTPResponseWriter) WriteHeader(statusCode int) {
   458  	w.statusCode = statusCode
   459  }
   460  
   461  func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
   462  	w.body = append(w.body, p...)
   463  	return len(p), nil
   464  }
   465  
   466  func Test_ConvertRequest(t *testing.T) {
   467  	t.Parallel()
   468  
   469  	app := fiber.New()
   470  
   471  	app.Get("/test", func(c *fiber.Ctx) error {
   472  		httpReq, err := ConvertRequest(c, false)
   473  		if err != nil {
   474  			return err
   475  		}
   476  
   477  		return c.SendString("Request URL: " + httpReq.URL.String())
   478  	})
   479  
   480  	resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", http.NoBody))
   481  	utils.AssertEqual(t, nil, err, "app.Test(req)")
   482  	utils.AssertEqual(t, http.StatusOK, resp.StatusCode, "Status code")
   483  
   484  	body, err := io.ReadAll(resp.Body)
   485  	utils.AssertEqual(t, nil, err)
   486  	utils.AssertEqual(t, "Request URL: /test?hello=world&another=test", string(body))
   487  }