github.com/gofiber/fiber/v2@v2.47.0/middleware/cors/cors_test.go (about)

     1  package cors
     2  
     3  import (
     4  	"net/http/httptest"
     5  	"strings"
     6  	"testing"
     7  
     8  	"github.com/gofiber/fiber/v2"
     9  	"github.com/gofiber/fiber/v2/utils"
    10  
    11  	"github.com/valyala/fasthttp"
    12  )
    13  
    14  func Test_CORS_Defaults(t *testing.T) {
    15  	t.Parallel()
    16  	app := fiber.New()
    17  	app.Use(New())
    18  
    19  	testDefaultOrEmptyConfig(t, app)
    20  }
    21  
    22  func Test_CORS_Empty_Config(t *testing.T) {
    23  	t.Parallel()
    24  	app := fiber.New()
    25  	app.Use(New(Config{}))
    26  
    27  	testDefaultOrEmptyConfig(t, app)
    28  }
    29  
    30  func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
    31  	t.Helper()
    32  
    33  	h := app.Handler()
    34  
    35  	// Test default GET response headers
    36  	ctx := &fasthttp.RequestCtx{}
    37  	ctx.Request.Header.SetMethod(fiber.MethodGet)
    38  	h(ctx)
    39  
    40  	utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
    41  	utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
    42  	utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
    43  
    44  	// Test default OPTIONS (preflight) response headers
    45  	ctx = &fasthttp.RequestCtx{}
    46  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
    47  	h(ctx)
    48  
    49  	utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)))
    50  	utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
    51  	utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
    52  }
    53  
    54  // go test -run -v Test_CORS_Wildcard
    55  func Test_CORS_Wildcard(t *testing.T) {
    56  	t.Parallel()
    57  	// New fiber instance
    58  	app := fiber.New()
    59  	// OPTIONS (preflight) response headers when AllowOrigins is *
    60  	app.Use(New(Config{
    61  		AllowOrigins:     "*",
    62  		AllowCredentials: true,
    63  		MaxAge:           3600,
    64  		ExposeHeaders:    "X-Request-ID",
    65  		AllowHeaders:     "Authentication",
    66  	}))
    67  	// Get handler pointer
    68  	handler := app.Handler()
    69  
    70  	// Make request
    71  	ctx := &fasthttp.RequestCtx{}
    72  	ctx.Request.SetRequestURI("/")
    73  	ctx.Request.Header.Set(fiber.HeaderOrigin, "localhost")
    74  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
    75  
    76  	// Perform request
    77  	handler(ctx)
    78  
    79  	// Check result
    80  	utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
    81  	utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
    82  	utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
    83  	utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
    84  
    85  	// Test non OPTIONS (preflight) response headers
    86  	ctx = &fasthttp.RequestCtx{}
    87  	ctx.Request.Header.SetMethod(fiber.MethodGet)
    88  	handler(ctx)
    89  
    90  	utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
    91  	utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
    92  }
    93  
    94  // go test -run -v Test_CORS_Subdomain
    95  func Test_CORS_Subdomain(t *testing.T) {
    96  	t.Parallel()
    97  	// New fiber instance
    98  	app := fiber.New()
    99  	// OPTIONS (preflight) response headers when AllowOrigins is set to a subdomain
   100  	app.Use("/", New(Config{AllowOrigins: "http://*.example.com"}))
   101  
   102  	// Get handler pointer
   103  	handler := app.Handler()
   104  
   105  	// Make request with disallowed origin
   106  	ctx := &fasthttp.RequestCtx{}
   107  	ctx.Request.SetRequestURI("/")
   108  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
   109  	ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
   110  
   111  	// Perform request
   112  	handler(ctx)
   113  
   114  	// Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com
   115  	utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   116  
   117  	ctx.Request.Reset()
   118  	ctx.Response.Reset()
   119  
   120  	// Make request with allowed origin
   121  	ctx.Request.SetRequestURI("/")
   122  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
   123  	ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")
   124  
   125  	handler(ctx)
   126  
   127  	utils.AssertEqual(t, "http://test.example.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   128  }
   129  
   130  func Test_CORS_AllowOriginScheme(t *testing.T) {
   131  	t.Parallel()
   132  	tests := []struct {
   133  		reqOrigin, pattern string
   134  		shouldAllowOrigin  bool
   135  	}{
   136  		{
   137  			pattern:           "http://example.com",
   138  			reqOrigin:         "http://example.com",
   139  			shouldAllowOrigin: true,
   140  		},
   141  		{
   142  			pattern:           "https://example.com",
   143  			reqOrigin:         "https://example.com",
   144  			shouldAllowOrigin: true,
   145  		},
   146  		{
   147  			pattern:           "http://example.com",
   148  			reqOrigin:         "https://example.com",
   149  			shouldAllowOrigin: false,
   150  		},
   151  		{
   152  			pattern:           "http://*.example.com",
   153  			reqOrigin:         "http://aaa.example.com",
   154  			shouldAllowOrigin: true,
   155  		},
   156  		{
   157  			pattern:           "http://*.example.com",
   158  			reqOrigin:         "http://bbb.aaa.example.com",
   159  			shouldAllowOrigin: true,
   160  		},
   161  		{
   162  			pattern:           "http://*.aaa.example.com",
   163  			reqOrigin:         "http://bbb.aaa.example.com",
   164  			shouldAllowOrigin: true,
   165  		},
   166  		{
   167  			pattern:           "http://*.example.com:8080",
   168  			reqOrigin:         "http://aaa.example.com:8080",
   169  			shouldAllowOrigin: true,
   170  		},
   171  		{
   172  			pattern:           "http://example.com",
   173  			reqOrigin:         "http://gofiber.com",
   174  			shouldAllowOrigin: false,
   175  		},
   176  		{
   177  			pattern:           "http://*.aaa.example.com",
   178  			reqOrigin:         "http://ccc.bbb.example.com",
   179  			shouldAllowOrigin: false,
   180  		},
   181  		{
   182  			pattern: "http://*.example.com",
   183  			reqOrigin: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
   184  		  .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
   185  		  .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
   186  			.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
   187  			shouldAllowOrigin: false,
   188  		},
   189  		{
   190  			pattern:           "http://example.com",
   191  			reqOrigin:         "http://ccc.bbb.example.com",
   192  			shouldAllowOrigin: false,
   193  		},
   194  		{
   195  			pattern:           "https://*--aaa.bbb.com",
   196  			reqOrigin:         "https://prod-preview--aaa.bbb.com",
   197  			shouldAllowOrigin: false,
   198  		},
   199  		{
   200  			pattern:           "http://*.example.com",
   201  			reqOrigin:         "http://ccc.bbb.example.com",
   202  			shouldAllowOrigin: true,
   203  		},
   204  		{
   205  			pattern:           "http://foo.[a-z]*.example.com",
   206  			reqOrigin:         "http://ccc.bbb.example.com",
   207  			shouldAllowOrigin: false,
   208  		},
   209  	}
   210  
   211  	for _, tt := range tests {
   212  		app := fiber.New()
   213  		app.Use("/", New(Config{AllowOrigins: tt.pattern}))
   214  
   215  		handler := app.Handler()
   216  
   217  		ctx := &fasthttp.RequestCtx{}
   218  		ctx.Request.SetRequestURI("/")
   219  		ctx.Request.Header.SetMethod(fiber.MethodOptions)
   220  		ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)
   221  
   222  		handler(ctx)
   223  
   224  		if tt.shouldAllowOrigin {
   225  			utils.AssertEqual(t, tt.reqOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   226  		} else {
   227  			utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   228  		}
   229  	}
   230  }
   231  
   232  // go test -run Test_CORS_Next
   233  func Test_CORS_Next(t *testing.T) {
   234  	t.Parallel()
   235  	app := fiber.New()
   236  	app.Use(New(Config{
   237  		Next: func(_ *fiber.Ctx) bool {
   238  			return true
   239  		},
   240  	}))
   241  
   242  	resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
   243  	utils.AssertEqual(t, nil, err)
   244  	utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
   245  }
   246  
   247  func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
   248  	t.Parallel()
   249  	// New fiber instance
   250  	app := fiber.New()
   251  	app.Use("/", New(Config{
   252  		AllowOrigins: "http://example-1.com",
   253  		AllowOriginsFunc: func(origin string) bool {
   254  			return strings.Contains(origin, "example-2")
   255  		},
   256  	}))
   257  
   258  	// Get handler pointer
   259  	handler := app.Handler()
   260  
   261  	// Make request with disallowed origin
   262  	ctx := &fasthttp.RequestCtx{}
   263  	ctx.Request.SetRequestURI("/")
   264  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
   265  	ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
   266  
   267  	// Perform request
   268  	handler(ctx)
   269  
   270  	// Allow-Origin header should be "" because http://google.com does not satisfy http://example-1.com or 'strings.Contains(origin, "example-2")'
   271  	utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   272  
   273  	ctx.Request.Reset()
   274  	ctx.Response.Reset()
   275  
   276  	// Make request with allowed origin
   277  	ctx.Request.SetRequestURI("/")
   278  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
   279  	ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")
   280  
   281  	handler(ctx)
   282  
   283  	utils.AssertEqual(t, "http://example-1.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   284  
   285  	ctx.Request.Reset()
   286  	ctx.Response.Reset()
   287  
   288  	// Make request with allowed origin
   289  	ctx.Request.SetRequestURI("/")
   290  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
   291  	ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
   292  
   293  	handler(ctx)
   294  
   295  	utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   296  }
   297  
   298  func Test_CORS_AllowOriginsFunc(t *testing.T) {
   299  	t.Parallel()
   300  	// New fiber instance
   301  	app := fiber.New()
   302  	app.Use("/", New(Config{
   303  		AllowOriginsFunc: func(origin string) bool {
   304  			return strings.Contains(origin, "example-2")
   305  		},
   306  	}))
   307  
   308  	// Get handler pointer
   309  	handler := app.Handler()
   310  
   311  	// Make request with disallowed origin
   312  	ctx := &fasthttp.RequestCtx{}
   313  	ctx.Request.SetRequestURI("/")
   314  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
   315  	ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
   316  
   317  	// Perform request
   318  	handler(ctx)
   319  
   320  	// Allow-Origin header should be "*" because http://google.com does not satisfy 'strings.Contains(origin, "example-2")'
   321  	// and AllowOrigins has not been set so the default "*" is used
   322  	utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   323  
   324  	ctx.Request.Reset()
   325  	ctx.Response.Reset()
   326  
   327  	// Make request with allowed origin
   328  	ctx.Request.SetRequestURI("/")
   329  	ctx.Request.Header.SetMethod(fiber.MethodOptions)
   330  	ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
   331  
   332  	handler(ctx)
   333  
   334  	// Allow-Origin header should be "http://example-2.com"
   335  	utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
   336  }