github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/cors/cors_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cors_test
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/google/go-cmp/cmp"
    21  	"github.com/google/go-safeweb/safehttp"
    22  	"github.com/google/go-safeweb/safehttp/plugins/cors"
    23  	"github.com/google/go-safeweb/safehttp/safehttptest"
    24  )
    25  
    26  func TestRequest(t *testing.T) {
    27  	tests := []struct {
    28  		name             string
    29  		req              *safehttp.IncomingRequest
    30  		allowCredentials bool
    31  		exposedHeaders   []string
    32  		want             map[string][]string
    33  	}{
    34  		{
    35  			name: "Basic GET",
    36  			req: func() *safehttp.IncomingRequest {
    37  				r := safehttptest.NewRequest(safehttp.MethodGet, "http://bar.com", nil)
    38  				r.Header.Set("Origin", "https://foo.com")
    39  				r.Header.Set("X-Cors", "1")
    40  				r.Header.Set("Content-Type", "application/json")
    41  				return r
    42  			}(),
    43  			want: map[string][]string{
    44  				"Access-Control-Allow-Origin": {"https://foo.com"},
    45  				"Vary":                        {"Origin"},
    46  			},
    47  		},
    48  		{
    49  			name: "Basic PUT",
    50  			req: func() *safehttp.IncomingRequest {
    51  				r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil)
    52  				r.Header.Set("Origin", "https://foo.com")
    53  				r.Header.Set("X-Cors", "1")
    54  				r.Header.Set("Content-Type", "application/json")
    55  				return r
    56  			}(),
    57  			want: map[string][]string{
    58  				"Access-Control-Allow-Origin": {"https://foo.com"},
    59  				"Vary":                        {"Origin"},
    60  			},
    61  		},
    62  		{
    63  			name: "Basic POST",
    64  			req: func() *safehttp.IncomingRequest {
    65  				r := safehttptest.NewRequest(safehttp.MethodPost, "http://bar.com", nil)
    66  				r.Header.Set("Origin", "https://foo.com")
    67  				r.Header.Set("X-Cors", "1")
    68  				r.Header.Set("Content-Type", "application/json")
    69  				return r
    70  			}(),
    71  			want: map[string][]string{
    72  				"Access-Control-Allow-Origin": {"https://foo.com"},
    73  				"Vary":                        {"Origin"},
    74  			},
    75  		},
    76  		{
    77  			name: "No Origin header",
    78  			req: func() *safehttp.IncomingRequest {
    79  				r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil)
    80  				r.Header.Set("X-Cors", "1")
    81  				r.Header.Set("Content-Type", "application/json")
    82  				return r
    83  			}(),
    84  			want: map[string][]string{},
    85  		},
    86  		{
    87  			name: "AllowCredentials but no cookies",
    88  			req: func() *safehttp.IncomingRequest {
    89  				r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil)
    90  				r.Header.Set("Origin", "https://foo.com")
    91  				r.Header.Set("X-Cors", "1")
    92  				r.Header.Set("Content-Type", "application/json")
    93  				return r
    94  			}(),
    95  			allowCredentials: true,
    96  			want: map[string][]string{
    97  				"Access-Control-Allow-Origin": {"https://foo.com"},
    98  				"Vary":                        {"Origin"},
    99  			},
   100  		},
   101  		{
   102  			name: "AllowCredentials with cookies",
   103  			req: func() *safehttp.IncomingRequest {
   104  				r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil)
   105  				r.Header.Set("Origin", "https://foo.com")
   106  				r.Header.Set("X-Cors", "1")
   107  				r.Header.Set("Content-Type", "application/json")
   108  				r.Header.Set("Cookie", "a=b")
   109  				return r
   110  			}(),
   111  			allowCredentials: true,
   112  			want: map[string][]string{
   113  				"Access-Control-Allow-Credentials": {"true"},
   114  				"Access-Control-Allow-Origin":      {"https://foo.com"},
   115  				"Vary":                             {"Origin"},
   116  			},
   117  		},
   118  		{
   119  			name: "Expose one header",
   120  			req: func() *safehttp.IncomingRequest {
   121  				r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil)
   122  				r.Header.Set("Origin", "https://foo.com")
   123  				r.Header.Set("X-Cors", "1")
   124  				r.Header.Set("Content-Type", "application/json")
   125  				return r
   126  			}(),
   127  			exposedHeaders: []string{"Aaaa"},
   128  			want: map[string][]string{
   129  				"Access-Control-Expose-Headers": {"Aaaa"},
   130  				"Access-Control-Allow-Origin":   {"https://foo.com"},
   131  				"Vary":                          {"Origin"},
   132  			},
   133  		},
   134  		{
   135  			name: "Expose multiple headers",
   136  			req: func() *safehttp.IncomingRequest {
   137  				r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil)
   138  				r.Header.Set("Origin", "https://foo.com")
   139  				r.Header.Set("X-Cors", "1")
   140  				r.Header.Set("Content-Type", "application/json")
   141  				return r
   142  			}(),
   143  			exposedHeaders: []string{"Aaaa", "Bbbb", "Cccc"},
   144  			want: map[string][]string{
   145  				"Access-Control-Expose-Headers": {"Aaaa, Bbbb, Cccc"},
   146  				"Access-Control-Allow-Origin":   {"https://foo.com"},
   147  				"Vary":                          {"Origin"},
   148  			},
   149  		},
   150  	}
   151  
   152  	for _, tt := range tests {
   153  		t.Run(tt.name, func(t *testing.T) {
   154  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   155  
   156  			it := cors.Default("https://foo.com")
   157  			it.AllowCredentials = tt.allowCredentials
   158  			it.ExposedHeaders = tt.exposedHeaders
   159  			it.Before(fakeRW, tt.req, nil)
   160  
   161  			if diff := cmp.Diff(tt.want, map[string][]string(rr.Header())); diff != "" {
   162  				t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   163  			}
   164  			if got := rr.Body.String(); got != "" {
   165  				t.Errorf(`rr.Body.String() got: %q want: ""`, got)
   166  			}
   167  		})
   168  	}
   169  }
   170  
   171  func TestVaryHeaderAppending(t *testing.T) {
   172  	req := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil)
   173  	req.Header.Set("Origin", "https://foo.com")
   174  	req.Header.Set("X-Cors", "1")
   175  	req.Header.Set("Content-Type", "application/json")
   176  
   177  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   178  	rr.Header().Set("Vary", "a")
   179  
   180  	it := cors.Default("https://foo.com")
   181  	it.Before(fakeRW, req, nil)
   182  
   183  	wantHeaders := map[string][]string{
   184  		"Access-Control-Allow-Origin": {"https://foo.com"},
   185  		"Vary":                        {"a, Origin"},
   186  	}
   187  	if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   188  		t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   189  	}
   190  	if got := rr.Body.String(); got != "" {
   191  		t.Errorf(`rr.Body.String() got: %q want: ""`, got)
   192  	}
   193  }
   194  
   195  func TestHeadRequest(t *testing.T) {
   196  	req := safehttptest.NewRequest(safehttp.MethodHead, "http://bar.com", nil)
   197  	req.Header.Set("Origin", "https://foo.com")
   198  	req.Header.Set("X-Cors", "1")
   199  	req.Header.Set("Content-Type", "application/json")
   200  
   201  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   202  
   203  	it := cors.Default("https://foo.com")
   204  	it.Before(fakeRW, req, nil)
   205  
   206  	if got, want := rr.Code, int(safehttp.StatusMethodNotAllowed); got != want {
   207  		t.Errorf("rr.Code got: %v want: %v", got, want)
   208  	}
   209  	wantHeaders := map[string][]string{}
   210  	if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   211  		t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   212  	}
   213  }
   214  
   215  func TestInvalidRequest(t *testing.T) {
   216  	tests := []struct {
   217  		name string
   218  		req  *safehttp.IncomingRequest
   219  	}{
   220  		{
   221  			name: "No X-Cors: 1, but Sec-Fetch-Mode: cors",
   222  			req: func() *safehttp.IncomingRequest {
   223  				r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil)
   224  				r.Header.Set("Origin", "https://foo.com")
   225  				r.Header.Set("Sec-Fetch-Mode", "cors")
   226  				return r
   227  			}(),
   228  		},
   229  		{
   230  			name: "No X-Cors: 1",
   231  			req: func() *safehttp.IncomingRequest {
   232  				r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com/asdf", nil)
   233  				r.Header.Set("Origin", "https://foo.com")
   234  				return r
   235  			}(),
   236  		},
   237  	}
   238  
   239  	for _, tt := range tests {
   240  		t.Run(tt.name, func(t *testing.T) {
   241  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   242  
   243  			it := cors.Default("https://foo.com")
   244  			it.Before(fakeRW, tt.req, nil)
   245  
   246  			if want := safehttp.StatusPreconditionFailed; rr.Code != int(want) {
   247  				t.Errorf("rr.Code got: %v want: %v", rr.Code, want)
   248  			}
   249  			wantHeaders := map[string][]string{}
   250  			if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   251  				t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   252  			}
   253  		})
   254  	}
   255  }
   256  
   257  func TestRequestDisallowedContentTypes(t *testing.T) {
   258  	contentTypes := []string{
   259  		"application/x-www-form-urlencoded",
   260  		"multipart/form-data",
   261  		"text/plain",
   262  		"",
   263  	}
   264  
   265  	for _, ct := range contentTypes {
   266  		t.Run(ct, func(t *testing.T) {
   267  			req := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com/asdf", nil)
   268  			req.Header.Set("Origin", "https://foo.com")
   269  			req.Header.Set("X-Cors", "1")
   270  			if ct != "" {
   271  				req.Header.Set("Content-Type", ct)
   272  			}
   273  
   274  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   275  
   276  			it := cors.Default("https://foo.com")
   277  			it.Before(fakeRW, req, nil)
   278  
   279  			if want := safehttp.StatusUnsupportedMediaType; rr.Code != int(want) {
   280  				t.Errorf("rr.Code got: %v want: %v", rr.Code, want)
   281  			}
   282  			wantHeaders := map[string][]string{}
   283  			if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   284  				t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   285  			}
   286  		})
   287  	}
   288  }
   289  
   290  func TestDisallowedOrigin(t *testing.T) {
   291  	req := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com/asdf", nil)
   292  	req.Header.Set("Origin", "https://pizza.com")
   293  
   294  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   295  
   296  	it := cors.Default("https://foo.com")
   297  	it.Before(fakeRW, req, nil)
   298  
   299  	if want := safehttp.StatusForbidden; rr.Code != int(want) {
   300  		t.Errorf("rr.Code got: %v want: %v", rr.Code, want)
   301  	}
   302  	wantHeaders := map[string][]string{}
   303  	if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   304  		t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   305  	}
   306  }
   307  
   308  func TestPreflight(t *testing.T) {
   309  	tests := []struct {
   310  		name           string
   311  		req            *safehttp.IncomingRequest
   312  		allowedHeaders []string
   313  		maxAge         int
   314  		wantHeaders    map[string][]string
   315  	}{
   316  		{
   317  			name: "Basic",
   318  			req: func() *safehttp.IncomingRequest {
   319  				r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   320  				r.Header.Set("Origin", "https://foo.com")
   321  				r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut)
   322  				return r
   323  			}(),
   324  			wantHeaders: map[string][]string{
   325  				"Access-Control-Allow-Methods": {"PUT"},
   326  				"Access-Control-Allow-Origin":  {"https://foo.com"},
   327  				"Access-Control-Max-Age":       {"5"},
   328  				"Vary":                         {"Origin"},
   329  			},
   330  		},
   331  		{
   332  			name: "Request X-Cors header",
   333  			req: func() *safehttp.IncomingRequest {
   334  				r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   335  				r.Header.Set("Origin", "https://foo.com")
   336  				r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut)
   337  				r.Header.Set("Access-Control-Request-Headers", "X-Cors")
   338  				return r
   339  			}(),
   340  			wantHeaders: map[string][]string{
   341  				"Access-Control-Allow-Headers": {"X-Cors"},
   342  				"Access-Control-Allow-Methods": {"PUT"},
   343  				"Access-Control-Allow-Origin":  {"https://foo.com"},
   344  				"Access-Control-Max-Age":       {"5"},
   345  				"Vary":                         {"Origin"},
   346  			},
   347  		},
   348  		{
   349  			name: "Request custom header",
   350  			req: func() *safehttp.IncomingRequest {
   351  				r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   352  				r.Header.Set("Origin", "https://foo.com")
   353  				r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut)
   354  				r.Header.Set("Access-Control-Request-Headers", "Aaaa")
   355  				return r
   356  			}(),
   357  			allowedHeaders: []string{"Aaaa"},
   358  			wantHeaders: map[string][]string{
   359  				"Access-Control-Allow-Headers": {"Aaaa"},
   360  				"Access-Control-Allow-Methods": {"PUT"},
   361  				"Access-Control-Allow-Origin":  {"https://foo.com"},
   362  				"Access-Control-Max-Age":       {"5"},
   363  				"Vary":                         {"Origin"},
   364  			},
   365  		},
   366  		{
   367  			name: "Request multiple headers",
   368  			req: func() *safehttp.IncomingRequest {
   369  				r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   370  				r.Header.Set("Origin", "https://foo.com")
   371  				r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut)
   372  				r.Header.Set("Access-Control-Request-Headers", "X-Cors, Aaaa")
   373  				return r
   374  			}(),
   375  			allowedHeaders: []string{"Aaaa"},
   376  			wantHeaders: map[string][]string{
   377  				"Access-Control-Allow-Headers": {"X-Cors, Aaaa"},
   378  				"Access-Control-Allow-Methods": {"PUT"},
   379  				"Access-Control-Allow-Origin":  {"https://foo.com"},
   380  				"Access-Control-Max-Age":       {"5"},
   381  				"Vary":                         {"Origin"},
   382  			},
   383  		},
   384  		{
   385  			name: "Request headers test canonicalization",
   386  			req: func() *safehttp.IncomingRequest {
   387  				r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   388  				r.Header.Set("Origin", "https://foo.com")
   389  				r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut)
   390  				r.Header.Set("Access-Control-Request-Headers", "x-coRS, aaAA")
   391  				return r
   392  			}(),
   393  			allowedHeaders: []string{"AAaa"},
   394  			wantHeaders: map[string][]string{
   395  				"Access-Control-Allow-Headers": {"x-coRS, aaAA"},
   396  				"Access-Control-Allow-Methods": {"PUT"},
   397  				"Access-Control-Allow-Origin":  {"https://foo.com"},
   398  				"Access-Control-Max-Age":       {"5"},
   399  				"Vary":                         {"Origin"},
   400  			},
   401  		},
   402  		{
   403  			name: "Custom Max age",
   404  			req: func() *safehttp.IncomingRequest {
   405  				r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   406  				r.Header.Set("Origin", "https://foo.com")
   407  				r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut)
   408  				return r
   409  			}(),
   410  			maxAge: 3600,
   411  			wantHeaders: map[string][]string{
   412  				"Access-Control-Allow-Methods": {"PUT"},
   413  				"Access-Control-Allow-Origin":  {"https://foo.com"},
   414  				"Access-Control-Max-Age":       {"3600"},
   415  				"Vary":                         {"Origin"},
   416  			},
   417  		},
   418  	}
   419  
   420  	for _, tt := range tests {
   421  		t.Run(tt.name, func(t *testing.T) {
   422  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   423  
   424  			it := cors.Default("https://foo.com")
   425  			it.MaxAge = tt.maxAge
   426  			it.SetAllowedHeaders(tt.allowedHeaders...)
   427  			it.Before(fakeRW, tt.req, nil)
   428  
   429  			if rr.Code != int(safehttp.StatusNoContent) {
   430  				t.Errorf("rr.Code got: %v want: %v", rr.Code, safehttp.StatusNoContent)
   431  			}
   432  			if diff := cmp.Diff(tt.wantHeaders, map[string][]string(rr.Header())); diff != "" {
   433  				t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   434  			}
   435  			if got := rr.Body.String(); got != "" {
   436  				t.Errorf(`rr.Body.String() got: %q want: ""`, got)
   437  			}
   438  		})
   439  	}
   440  }
   441  
   442  func TestInvalidAccessControlRequestHeaders(t *testing.T) {
   443  	tests := []struct {
   444  		name    string
   445  		headers string
   446  	}{
   447  		{
   448  			name:    "B is not allowed",
   449  			headers: "B",
   450  		},
   451  		{
   452  			name:    "One in list is not allowed",
   453  			headers: "X-Cors, B",
   454  		},
   455  		{
   456  			name:    "Empty at the end",
   457  			headers: "X-Cors, ",
   458  		},
   459  	}
   460  
   461  	for _, tt := range tests {
   462  		t.Run(tt.name, func(t *testing.T) {
   463  			req := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   464  			rh := req.Header
   465  			rh.Set("Origin", "https://foo.com")
   466  			rh.Set("Access-Control-Request-Method", safehttp.MethodPut)
   467  			rh.Set("Access-Control-Request-Headers", tt.headers)
   468  
   469  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   470  
   471  			it := cors.Default("https://foo.com")
   472  			it.Before(fakeRW, req, nil)
   473  
   474  			if want := safehttp.StatusForbidden; rr.Code != int(want) {
   475  				t.Errorf("rr.Code got: %v want: %v", rr.Code, want)
   476  			}
   477  			wantHeaders := map[string][]string{}
   478  			if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   479  				t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   480  			}
   481  		})
   482  	}
   483  }
   484  
   485  func TestEmptyAccessControlRequestMethod(t *testing.T) {
   486  	req := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   487  	rh := req.Header
   488  	rh.Set("Origin", "https://foo.com")
   489  
   490  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   491  
   492  	it := cors.Default("https://foo.com")
   493  	it.Before(fakeRW, req, nil)
   494  
   495  	if want := safehttp.StatusForbidden; rr.Code != int(want) {
   496  		t.Errorf("rr.Code got: %v want: %v", rr.Code, want)
   497  	}
   498  	wantHeaders := map[string][]string{}
   499  	if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   500  		t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   501  	}
   502  }
   503  
   504  func TestAccessControlRequestMethodHead(t *testing.T) {
   505  	req := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   506  	req.Header.Set("Origin", "https://foo.com")
   507  	req.Header.Set("Access-Control-Request-Method", safehttp.MethodHead)
   508  
   509  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   510  
   511  	it := cors.Default("https://foo.com")
   512  	it.Before(fakeRW, req, nil)
   513  
   514  	if want := safehttp.StatusForbidden; rr.Code != int(want) {
   515  		t.Errorf("rr.Code got: %v want: %v", rr.Code, want)
   516  	}
   517  	wantHeaders := map[string][]string{}
   518  	if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   519  		t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   520  	}
   521  }
   522  
   523  func TestPreflightEmptyOrigin(t *testing.T) {
   524  	req := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil)
   525  	req.Header.Set("Access-Control-Request-Method", safehttp.MethodHead)
   526  
   527  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   528  
   529  	it := cors.Default("https://foo.com")
   530  	it.Before(fakeRW, req, nil)
   531  
   532  	if want := safehttp.StatusForbidden; rr.Code != int(want) {
   533  		t.Errorf("rr.Code got: %v want: %v", rr.Code, want)
   534  	}
   535  	wantHeaders := map[string][]string{}
   536  	if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" {
   537  		t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff)
   538  	}
   539  }