github.com/gofiber/fiber/v2@v2.47.0/middleware/keyauth/keyauth_test.go (about)

     1  //nolint:bodyclose // Much easier to just ignore memory leaks in tests
     2  package keyauth
     3  
     4  import (
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"testing"
    12  
    13  	"github.com/gofiber/fiber/v2"
    14  	"github.com/gofiber/fiber/v2/utils"
    15  )
    16  
    17  const CorrectKey = "specials: !$%,.#\"!?~`<>@$^*(){}[]|/\\123"
    18  
    19  func TestAuthSources(t *testing.T) {
    20  	// define test cases
    21  	testSources := []string{"header", "cookie", "query", "param", "form"}
    22  
    23  	tests := []struct {
    24  		route         string
    25  		authTokenName string
    26  		description   string
    27  		APIKey        string
    28  		expectedCode  int
    29  		expectedBody  string
    30  	}{
    31  		{
    32  			route:         "/",
    33  			authTokenName: "access_token",
    34  			description:   "auth with correct key",
    35  			APIKey:        CorrectKey,
    36  			expectedCode:  200,
    37  			expectedBody:  "Success!",
    38  		},
    39  		{
    40  			route:         "/",
    41  			authTokenName: "access_token",
    42  			description:   "auth with no key",
    43  			APIKey:        "",
    44  			expectedCode:  401, // 404 in case of param authentication
    45  			expectedBody:  "missing or malformed API Key",
    46  		},
    47  		{
    48  			route:         "/",
    49  			authTokenName: "access_token",
    50  			description:   "auth with wrong key",
    51  			APIKey:        "WRONGKEY",
    52  			expectedCode:  401,
    53  			expectedBody:  "missing or malformed API Key",
    54  		},
    55  	}
    56  
    57  	for _, authSource := range testSources {
    58  		t.Run(authSource, func(t *testing.T) {
    59  			for _, test := range tests {
    60  				// setup the fiber endpoint
    61  				// note that if UnescapePath: false (the default)
    62  				// escaped characters (such as `\"`) will not be handled correctly in the tests
    63  				app := fiber.New(fiber.Config{UnescapePath: true})
    64  
    65  				authMiddleware := New(Config{
    66  					KeyLookup: authSource + ":" + test.authTokenName,
    67  					Validator: func(c *fiber.Ctx, key string) (bool, error) {
    68  						if key == CorrectKey {
    69  							return true, nil
    70  						}
    71  						return false, ErrMissingOrMalformedAPIKey
    72  					},
    73  				})
    74  
    75  				var route string
    76  				if authSource == param {
    77  					route = test.route + ":" + test.authTokenName
    78  					app.Use(route, authMiddleware)
    79  				} else {
    80  					route = test.route
    81  					app.Use(authMiddleware)
    82  				}
    83  
    84  				app.Get(route, func(c *fiber.Ctx) error {
    85  					return c.SendString("Success!")
    86  				})
    87  
    88  				// construct the test HTTP request
    89  				var req *http.Request
    90  				req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, nil)
    91  				utils.AssertEqual(t, err, nil)
    92  
    93  				// setup the apikey for the different auth schemes
    94  				if authSource == "header" {
    95  					req.Header.Set(test.authTokenName, test.APIKey)
    96  				} else if authSource == "cookie" {
    97  					req.Header.Set("Cookie", test.authTokenName+"="+test.APIKey)
    98  				} else if authSource == "query" || authSource == "form" {
    99  					q := req.URL.Query()
   100  					q.Add(test.authTokenName, test.APIKey)
   101  					req.URL.RawQuery = q.Encode()
   102  				} else if authSource == "param" {
   103  					r := req.URL.Path
   104  					r += url.PathEscape(test.APIKey)
   105  					req.URL.Path = r
   106  				}
   107  
   108  				res, err := app.Test(req, -1)
   109  
   110  				utils.AssertEqual(t, nil, err, test.description)
   111  
   112  				// test the body of the request
   113  				body, err := io.ReadAll(res.Body)
   114  				// for param authentication, the route would be /:access_token
   115  				// when the access_token is empty, it leads to a 404 (not found)
   116  				// not a 401 (auth error)
   117  				if authSource == "param" && test.APIKey == "" {
   118  					test.expectedCode = 404
   119  					test.expectedBody = "Cannot GET /"
   120  				}
   121  				utils.AssertEqual(t, test.expectedCode, res.StatusCode, test.description)
   122  
   123  				// body
   124  				utils.AssertEqual(t, nil, err, test.description)
   125  				utils.AssertEqual(t, test.expectedBody, string(body), test.description)
   126  
   127  				err = res.Body.Close()
   128  				utils.AssertEqual(t, err, nil)
   129  			}
   130  		})
   131  	}
   132  }
   133  
   134  func TestMultipleKeyAuth(t *testing.T) {
   135  	// setup the fiber endpoint
   136  	app := fiber.New()
   137  
   138  	// setup keyauth for /auth1
   139  	app.Use(New(Config{
   140  		Next: func(c *fiber.Ctx) bool {
   141  			return c.OriginalURL() != "/auth1"
   142  		},
   143  		KeyLookup: "header:key",
   144  		Validator: func(c *fiber.Ctx, key string) (bool, error) {
   145  			if key == "password1" {
   146  				return true, nil
   147  			}
   148  			return false, ErrMissingOrMalformedAPIKey
   149  		},
   150  	}))
   151  
   152  	// setup keyauth for /auth2
   153  	app.Use(New(Config{
   154  		Next: func(c *fiber.Ctx) bool {
   155  			return c.OriginalURL() != "/auth2"
   156  		},
   157  		KeyLookup: "header:key",
   158  		Validator: func(c *fiber.Ctx, key string) (bool, error) {
   159  			if key == "password2" {
   160  				return true, nil
   161  			}
   162  			return false, ErrMissingOrMalformedAPIKey
   163  		},
   164  	}))
   165  
   166  	app.Get("/", func(c *fiber.Ctx) error {
   167  		return c.SendString("No auth needed!")
   168  	})
   169  
   170  	app.Get("/auth1", func(c *fiber.Ctx) error {
   171  		return c.SendString("Successfully authenticated for auth1!")
   172  	})
   173  
   174  	app.Get("/auth2", func(c *fiber.Ctx) error {
   175  		return c.SendString("Successfully authenticated for auth2!")
   176  	})
   177  
   178  	// define test cases
   179  	tests := []struct {
   180  		route        string
   181  		description  string
   182  		APIKey       string
   183  		expectedCode int
   184  		expectedBody string
   185  	}{
   186  		// No auth needed for /
   187  		{
   188  			route:        "/",
   189  			description:  "No password needed",
   190  			APIKey:       "",
   191  			expectedCode: 200,
   192  			expectedBody: "No auth needed!",
   193  		},
   194  
   195  		// auth needed for auth1
   196  		{
   197  			route:        "/auth1",
   198  			description:  "Normal Authentication Case",
   199  			APIKey:       "password1",
   200  			expectedCode: 200,
   201  			expectedBody: "Successfully authenticated for auth1!",
   202  		},
   203  		{
   204  			route:        "/auth1",
   205  			description:  "Wrong API Key",
   206  			APIKey:       "WRONG KEY",
   207  			expectedCode: 401,
   208  			expectedBody: "missing or malformed API Key",
   209  		},
   210  		{
   211  			route:        "/auth1",
   212  			description:  "Wrong API Key",
   213  			APIKey:       "", // NO KEY
   214  			expectedCode: 401,
   215  			expectedBody: "missing or malformed API Key",
   216  		},
   217  
   218  		// Auth 2 has a different password
   219  		{
   220  			route:        "/auth2",
   221  			description:  "Normal Authentication Case for auth2",
   222  			APIKey:       "password2",
   223  			expectedCode: 200,
   224  			expectedBody: "Successfully authenticated for auth2!",
   225  		},
   226  		{
   227  			route:        "/auth2",
   228  			description:  "Wrong API Key",
   229  			APIKey:       "WRONG KEY",
   230  			expectedCode: 401,
   231  			expectedBody: "missing or malformed API Key",
   232  		},
   233  		{
   234  			route:        "/auth2",
   235  			description:  "Wrong API Key",
   236  			APIKey:       "", // NO KEY
   237  			expectedCode: 401,
   238  			expectedBody: "missing or malformed API Key",
   239  		},
   240  	}
   241  
   242  	// run the tests
   243  	for _, test := range tests {
   244  		var req *http.Request
   245  		req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, nil)
   246  		utils.AssertEqual(t, err, nil)
   247  		if test.APIKey != "" {
   248  			req.Header.Set("key", test.APIKey)
   249  		}
   250  
   251  		res, err := app.Test(req, -1)
   252  
   253  		utils.AssertEqual(t, nil, err, test.description)
   254  
   255  		// test the body of the request
   256  		body, err := io.ReadAll(res.Body)
   257  		utils.AssertEqual(t, test.expectedCode, res.StatusCode, test.description)
   258  
   259  		// body
   260  		utils.AssertEqual(t, nil, err, test.description)
   261  		utils.AssertEqual(t, test.expectedBody, string(body), test.description)
   262  	}
   263  }
   264  
   265  func TestCustomSuccessAndFailureHandlers(t *testing.T) {
   266  	app := fiber.New()
   267  
   268  	app.Use(New(Config{
   269  		SuccessHandler: func(c *fiber.Ctx) error {
   270  			return c.Status(fiber.StatusOK).SendString("API key is valid and request was handled by custom success handler")
   271  		},
   272  		ErrorHandler: func(c *fiber.Ctx, err error) error {
   273  			return c.Status(fiber.StatusUnauthorized).SendString("API key is invalid and request was handled by custom error handler")
   274  		},
   275  		Validator: func(c *fiber.Ctx, key string) (bool, error) {
   276  			if key == CorrectKey {
   277  				return true, nil
   278  			}
   279  			return false, ErrMissingOrMalformedAPIKey
   280  		},
   281  	}))
   282  
   283  	// Define a test handler that should not be called
   284  	app.Get("/", func(c *fiber.Ctx) error {
   285  		t.Error("Test handler should not be called")
   286  		return nil
   287  	})
   288  
   289  	// Create a request without an API key and send it to the app
   290  	res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
   291  	utils.AssertEqual(t, err, nil)
   292  
   293  	// Read the response body into a string
   294  	body, err := io.ReadAll(res.Body)
   295  	utils.AssertEqual(t, err, nil)
   296  
   297  	// Check that the response has the expected status code and body
   298  	utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
   299  	utils.AssertEqual(t, string(body), "API key is invalid and request was handled by custom error handler")
   300  
   301  	// Create a request with a valid API key in the Authorization header
   302  	req := httptest.NewRequest(fiber.MethodGet, "/", nil)
   303  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", CorrectKey))
   304  
   305  	// Send the request to the app
   306  	res, err = app.Test(req)
   307  	utils.AssertEqual(t, err, nil)
   308  
   309  	// Read the response body into a string
   310  	body, err = io.ReadAll(res.Body)
   311  	utils.AssertEqual(t, err, nil)
   312  
   313  	// Check that the response has the expected status code and body
   314  	utils.AssertEqual(t, res.StatusCode, http.StatusOK)
   315  	utils.AssertEqual(t, string(body), "API key is valid and request was handled by custom success handler")
   316  }
   317  
   318  func TestCustomNextFunc(t *testing.T) {
   319  	app := fiber.New()
   320  
   321  	app.Use(New(Config{
   322  		Next: func(c *fiber.Ctx) bool {
   323  			return c.Path() == "/allowed"
   324  		},
   325  		Validator: func(c *fiber.Ctx, key string) (bool, error) {
   326  			if key == CorrectKey {
   327  				return true, nil
   328  			}
   329  			return false, ErrMissingOrMalformedAPIKey
   330  		},
   331  	}))
   332  
   333  	// Define a test handler
   334  	app.Get("/allowed", func(c *fiber.Ctx) error {
   335  		return c.SendString("API key is valid and request was allowed by custom filter")
   336  	})
   337  
   338  	// Create a request with the "/allowed" path and send it to the app
   339  	req := httptest.NewRequest(fiber.MethodGet, "/allowed", nil)
   340  	res, err := app.Test(req)
   341  	utils.AssertEqual(t, err, nil)
   342  
   343  	// Read the response body into a string
   344  	body, err := io.ReadAll(res.Body)
   345  	utils.AssertEqual(t, err, nil)
   346  
   347  	// Check that the response has the expected status code and body
   348  	utils.AssertEqual(t, res.StatusCode, http.StatusOK)
   349  	utils.AssertEqual(t, string(body), "API key is valid and request was allowed by custom filter")
   350  
   351  	// Create a request with a different path and send it to the app without correct key
   352  	req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", nil)
   353  	res, err = app.Test(req)
   354  	utils.AssertEqual(t, err, nil)
   355  
   356  	// Read the response body into a string
   357  	body, err = io.ReadAll(res.Body)
   358  	utils.AssertEqual(t, err, nil)
   359  
   360  	// Check that the response has the expected status code and body
   361  	utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
   362  	utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
   363  
   364  	// Create a request with a different path and send it to the app with correct key
   365  	req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", nil)
   366  	req.Header.Add("Authorization", fmt.Sprintf("Basic %s", CorrectKey))
   367  
   368  	res, err = app.Test(req)
   369  	utils.AssertEqual(t, err, nil)
   370  
   371  	// Read the response body into a string
   372  	body, err = io.ReadAll(res.Body)
   373  	utils.AssertEqual(t, err, nil)
   374  
   375  	// Check that the response has the expected status code and body
   376  	utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
   377  	utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
   378  }
   379  
   380  func TestAuthSchemeToken(t *testing.T) {
   381  	app := fiber.New()
   382  
   383  	app.Use(New(Config{
   384  		AuthScheme: "Token",
   385  		Validator: func(c *fiber.Ctx, key string) (bool, error) {
   386  			if key == CorrectKey {
   387  				return true, nil
   388  			}
   389  			return false, ErrMissingOrMalformedAPIKey
   390  		},
   391  	}))
   392  
   393  	// Define a test handler
   394  	app.Get("/", func(c *fiber.Ctx) error {
   395  		return c.SendString("API key is valid")
   396  	})
   397  
   398  	// Create a request with a valid API key in the "Token" Authorization header
   399  	req := httptest.NewRequest(fiber.MethodGet, "/", nil)
   400  	req.Header.Add("Authorization", fmt.Sprintf("Token %s", CorrectKey))
   401  
   402  	// Send the request to the app
   403  	res, err := app.Test(req)
   404  	utils.AssertEqual(t, err, nil)
   405  
   406  	// Read the response body into a string
   407  	body, err := io.ReadAll(res.Body)
   408  	utils.AssertEqual(t, err, nil)
   409  
   410  	// Check that the response has the expected status code and body
   411  	utils.AssertEqual(t, res.StatusCode, http.StatusOK)
   412  	utils.AssertEqual(t, string(body), "API key is valid")
   413  }
   414  
   415  func TestAuthSchemeBasic(t *testing.T) {
   416  	app := fiber.New()
   417  
   418  	app.Use(New(Config{
   419  		KeyLookup:  "header:Authorization",
   420  		AuthScheme: "Basic",
   421  		Validator: func(c *fiber.Ctx, key string) (bool, error) {
   422  			if key == CorrectKey {
   423  				return true, nil
   424  			}
   425  			return false, ErrMissingOrMalformedAPIKey
   426  		},
   427  	}))
   428  
   429  	// Define a test handler
   430  	app.Get("/", func(c *fiber.Ctx) error {
   431  		return c.SendString("API key is valid")
   432  	})
   433  
   434  	// Create a request without an API key and  Send the request to the app
   435  	res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
   436  	utils.AssertEqual(t, err, nil)
   437  
   438  	// Read the response body into a string
   439  	body, err := io.ReadAll(res.Body)
   440  	utils.AssertEqual(t, err, nil)
   441  
   442  	// Check that the response has the expected status code and body
   443  	utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
   444  	utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
   445  
   446  	// Create a request with a valid API key in the "Authorization" header using the "Basic" scheme
   447  	req := httptest.NewRequest(fiber.MethodGet, "/", nil)
   448  	req.Header.Add("Authorization", fmt.Sprintf("Basic %s", CorrectKey))
   449  
   450  	// Send the request to the app
   451  	res, err = app.Test(req)
   452  	utils.AssertEqual(t, err, nil)
   453  
   454  	// Read the response body into a string
   455  	body, err = io.ReadAll(res.Body)
   456  	utils.AssertEqual(t, err, nil)
   457  
   458  	// Check that the response has the expected status code and body
   459  	utils.AssertEqual(t, res.StatusCode, http.StatusOK)
   460  	utils.AssertEqual(t, string(body), "API key is valid")
   461  }