github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/cors/cors_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package cors_test 16 17 import ( 18 "testing" 19 20 "github.com/google/go-cmp/cmp" 21 "github.com/google/go-safeweb/safehttp" 22 "github.com/google/go-safeweb/safehttp/plugins/cors" 23 "github.com/google/go-safeweb/safehttp/safehttptest" 24 ) 25 26 func TestRequest(t *testing.T) { 27 tests := []struct { 28 name string 29 req *safehttp.IncomingRequest 30 allowCredentials bool 31 exposedHeaders []string 32 want map[string][]string 33 }{ 34 { 35 name: "Basic GET", 36 req: func() *safehttp.IncomingRequest { 37 r := safehttptest.NewRequest(safehttp.MethodGet, "http://bar.com", nil) 38 r.Header.Set("Origin", "https://foo.com") 39 r.Header.Set("X-Cors", "1") 40 r.Header.Set("Content-Type", "application/json") 41 return r 42 }(), 43 want: map[string][]string{ 44 "Access-Control-Allow-Origin": {"https://foo.com"}, 45 "Vary": {"Origin"}, 46 }, 47 }, 48 { 49 name: "Basic PUT", 50 req: func() *safehttp.IncomingRequest { 51 r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil) 52 r.Header.Set("Origin", "https://foo.com") 53 r.Header.Set("X-Cors", "1") 54 r.Header.Set("Content-Type", "application/json") 55 return r 56 }(), 57 want: map[string][]string{ 58 "Access-Control-Allow-Origin": {"https://foo.com"}, 59 "Vary": {"Origin"}, 60 }, 61 }, 62 { 63 name: "Basic POST", 64 req: func() *safehttp.IncomingRequest { 65 r := safehttptest.NewRequest(safehttp.MethodPost, "http://bar.com", nil) 66 r.Header.Set("Origin", "https://foo.com") 67 r.Header.Set("X-Cors", "1") 68 r.Header.Set("Content-Type", "application/json") 69 return r 70 }(), 71 want: map[string][]string{ 72 "Access-Control-Allow-Origin": {"https://foo.com"}, 73 "Vary": {"Origin"}, 74 }, 75 }, 76 { 77 name: "No Origin header", 78 req: func() *safehttp.IncomingRequest { 79 r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil) 80 r.Header.Set("X-Cors", "1") 81 r.Header.Set("Content-Type", "application/json") 82 return r 83 }(), 84 want: map[string][]string{}, 85 }, 86 { 87 name: "AllowCredentials but no cookies", 88 req: func() *safehttp.IncomingRequest { 89 r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil) 90 r.Header.Set("Origin", "https://foo.com") 91 r.Header.Set("X-Cors", "1") 92 r.Header.Set("Content-Type", "application/json") 93 return r 94 }(), 95 allowCredentials: true, 96 want: map[string][]string{ 97 "Access-Control-Allow-Origin": {"https://foo.com"}, 98 "Vary": {"Origin"}, 99 }, 100 }, 101 { 102 name: "AllowCredentials with cookies", 103 req: func() *safehttp.IncomingRequest { 104 r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil) 105 r.Header.Set("Origin", "https://foo.com") 106 r.Header.Set("X-Cors", "1") 107 r.Header.Set("Content-Type", "application/json") 108 r.Header.Set("Cookie", "a=b") 109 return r 110 }(), 111 allowCredentials: true, 112 want: map[string][]string{ 113 "Access-Control-Allow-Credentials": {"true"}, 114 "Access-Control-Allow-Origin": {"https://foo.com"}, 115 "Vary": {"Origin"}, 116 }, 117 }, 118 { 119 name: "Expose one header", 120 req: func() *safehttp.IncomingRequest { 121 r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil) 122 r.Header.Set("Origin", "https://foo.com") 123 r.Header.Set("X-Cors", "1") 124 r.Header.Set("Content-Type", "application/json") 125 return r 126 }(), 127 exposedHeaders: []string{"Aaaa"}, 128 want: map[string][]string{ 129 "Access-Control-Expose-Headers": {"Aaaa"}, 130 "Access-Control-Allow-Origin": {"https://foo.com"}, 131 "Vary": {"Origin"}, 132 }, 133 }, 134 { 135 name: "Expose multiple headers", 136 req: func() *safehttp.IncomingRequest { 137 r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil) 138 r.Header.Set("Origin", "https://foo.com") 139 r.Header.Set("X-Cors", "1") 140 r.Header.Set("Content-Type", "application/json") 141 return r 142 }(), 143 exposedHeaders: []string{"Aaaa", "Bbbb", "Cccc"}, 144 want: map[string][]string{ 145 "Access-Control-Expose-Headers": {"Aaaa, Bbbb, Cccc"}, 146 "Access-Control-Allow-Origin": {"https://foo.com"}, 147 "Vary": {"Origin"}, 148 }, 149 }, 150 } 151 152 for _, tt := range tests { 153 t.Run(tt.name, func(t *testing.T) { 154 fakeRW, rr := safehttptest.NewFakeResponseWriter() 155 156 it := cors.Default("https://foo.com") 157 it.AllowCredentials = tt.allowCredentials 158 it.ExposedHeaders = tt.exposedHeaders 159 it.Before(fakeRW, tt.req, nil) 160 161 if diff := cmp.Diff(tt.want, map[string][]string(rr.Header())); diff != "" { 162 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 163 } 164 if got := rr.Body.String(); got != "" { 165 t.Errorf(`rr.Body.String() got: %q want: ""`, got) 166 } 167 }) 168 } 169 } 170 171 func TestVaryHeaderAppending(t *testing.T) { 172 req := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil) 173 req.Header.Set("Origin", "https://foo.com") 174 req.Header.Set("X-Cors", "1") 175 req.Header.Set("Content-Type", "application/json") 176 177 fakeRW, rr := safehttptest.NewFakeResponseWriter() 178 rr.Header().Set("Vary", "a") 179 180 it := cors.Default("https://foo.com") 181 it.Before(fakeRW, req, nil) 182 183 wantHeaders := map[string][]string{ 184 "Access-Control-Allow-Origin": {"https://foo.com"}, 185 "Vary": {"a, Origin"}, 186 } 187 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 188 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 189 } 190 if got := rr.Body.String(); got != "" { 191 t.Errorf(`rr.Body.String() got: %q want: ""`, got) 192 } 193 } 194 195 func TestHeadRequest(t *testing.T) { 196 req := safehttptest.NewRequest(safehttp.MethodHead, "http://bar.com", nil) 197 req.Header.Set("Origin", "https://foo.com") 198 req.Header.Set("X-Cors", "1") 199 req.Header.Set("Content-Type", "application/json") 200 201 fakeRW, rr := safehttptest.NewFakeResponseWriter() 202 203 it := cors.Default("https://foo.com") 204 it.Before(fakeRW, req, nil) 205 206 if got, want := rr.Code, int(safehttp.StatusMethodNotAllowed); got != want { 207 t.Errorf("rr.Code got: %v want: %v", got, want) 208 } 209 wantHeaders := map[string][]string{} 210 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 211 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 212 } 213 } 214 215 func TestInvalidRequest(t *testing.T) { 216 tests := []struct { 217 name string 218 req *safehttp.IncomingRequest 219 }{ 220 { 221 name: "No X-Cors: 1, but Sec-Fetch-Mode: cors", 222 req: func() *safehttp.IncomingRequest { 223 r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com", nil) 224 r.Header.Set("Origin", "https://foo.com") 225 r.Header.Set("Sec-Fetch-Mode", "cors") 226 return r 227 }(), 228 }, 229 { 230 name: "No X-Cors: 1", 231 req: func() *safehttp.IncomingRequest { 232 r := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com/asdf", nil) 233 r.Header.Set("Origin", "https://foo.com") 234 return r 235 }(), 236 }, 237 } 238 239 for _, tt := range tests { 240 t.Run(tt.name, func(t *testing.T) { 241 fakeRW, rr := safehttptest.NewFakeResponseWriter() 242 243 it := cors.Default("https://foo.com") 244 it.Before(fakeRW, tt.req, nil) 245 246 if want := safehttp.StatusPreconditionFailed; rr.Code != int(want) { 247 t.Errorf("rr.Code got: %v want: %v", rr.Code, want) 248 } 249 wantHeaders := map[string][]string{} 250 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 251 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 252 } 253 }) 254 } 255 } 256 257 func TestRequestDisallowedContentTypes(t *testing.T) { 258 contentTypes := []string{ 259 "application/x-www-form-urlencoded", 260 "multipart/form-data", 261 "text/plain", 262 "", 263 } 264 265 for _, ct := range contentTypes { 266 t.Run(ct, func(t *testing.T) { 267 req := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com/asdf", nil) 268 req.Header.Set("Origin", "https://foo.com") 269 req.Header.Set("X-Cors", "1") 270 if ct != "" { 271 req.Header.Set("Content-Type", ct) 272 } 273 274 fakeRW, rr := safehttptest.NewFakeResponseWriter() 275 276 it := cors.Default("https://foo.com") 277 it.Before(fakeRW, req, nil) 278 279 if want := safehttp.StatusUnsupportedMediaType; rr.Code != int(want) { 280 t.Errorf("rr.Code got: %v want: %v", rr.Code, want) 281 } 282 wantHeaders := map[string][]string{} 283 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 284 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 285 } 286 }) 287 } 288 } 289 290 func TestDisallowedOrigin(t *testing.T) { 291 req := safehttptest.NewRequest(safehttp.MethodPut, "http://bar.com/asdf", nil) 292 req.Header.Set("Origin", "https://pizza.com") 293 294 fakeRW, rr := safehttptest.NewFakeResponseWriter() 295 296 it := cors.Default("https://foo.com") 297 it.Before(fakeRW, req, nil) 298 299 if want := safehttp.StatusForbidden; rr.Code != int(want) { 300 t.Errorf("rr.Code got: %v want: %v", rr.Code, want) 301 } 302 wantHeaders := map[string][]string{} 303 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 304 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 305 } 306 } 307 308 func TestPreflight(t *testing.T) { 309 tests := []struct { 310 name string 311 req *safehttp.IncomingRequest 312 allowedHeaders []string 313 maxAge int 314 wantHeaders map[string][]string 315 }{ 316 { 317 name: "Basic", 318 req: func() *safehttp.IncomingRequest { 319 r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 320 r.Header.Set("Origin", "https://foo.com") 321 r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut) 322 return r 323 }(), 324 wantHeaders: map[string][]string{ 325 "Access-Control-Allow-Methods": {"PUT"}, 326 "Access-Control-Allow-Origin": {"https://foo.com"}, 327 "Access-Control-Max-Age": {"5"}, 328 "Vary": {"Origin"}, 329 }, 330 }, 331 { 332 name: "Request X-Cors header", 333 req: func() *safehttp.IncomingRequest { 334 r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 335 r.Header.Set("Origin", "https://foo.com") 336 r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut) 337 r.Header.Set("Access-Control-Request-Headers", "X-Cors") 338 return r 339 }(), 340 wantHeaders: map[string][]string{ 341 "Access-Control-Allow-Headers": {"X-Cors"}, 342 "Access-Control-Allow-Methods": {"PUT"}, 343 "Access-Control-Allow-Origin": {"https://foo.com"}, 344 "Access-Control-Max-Age": {"5"}, 345 "Vary": {"Origin"}, 346 }, 347 }, 348 { 349 name: "Request custom header", 350 req: func() *safehttp.IncomingRequest { 351 r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 352 r.Header.Set("Origin", "https://foo.com") 353 r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut) 354 r.Header.Set("Access-Control-Request-Headers", "Aaaa") 355 return r 356 }(), 357 allowedHeaders: []string{"Aaaa"}, 358 wantHeaders: map[string][]string{ 359 "Access-Control-Allow-Headers": {"Aaaa"}, 360 "Access-Control-Allow-Methods": {"PUT"}, 361 "Access-Control-Allow-Origin": {"https://foo.com"}, 362 "Access-Control-Max-Age": {"5"}, 363 "Vary": {"Origin"}, 364 }, 365 }, 366 { 367 name: "Request multiple headers", 368 req: func() *safehttp.IncomingRequest { 369 r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 370 r.Header.Set("Origin", "https://foo.com") 371 r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut) 372 r.Header.Set("Access-Control-Request-Headers", "X-Cors, Aaaa") 373 return r 374 }(), 375 allowedHeaders: []string{"Aaaa"}, 376 wantHeaders: map[string][]string{ 377 "Access-Control-Allow-Headers": {"X-Cors, Aaaa"}, 378 "Access-Control-Allow-Methods": {"PUT"}, 379 "Access-Control-Allow-Origin": {"https://foo.com"}, 380 "Access-Control-Max-Age": {"5"}, 381 "Vary": {"Origin"}, 382 }, 383 }, 384 { 385 name: "Request headers test canonicalization", 386 req: func() *safehttp.IncomingRequest { 387 r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 388 r.Header.Set("Origin", "https://foo.com") 389 r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut) 390 r.Header.Set("Access-Control-Request-Headers", "x-coRS, aaAA") 391 return r 392 }(), 393 allowedHeaders: []string{"AAaa"}, 394 wantHeaders: map[string][]string{ 395 "Access-Control-Allow-Headers": {"x-coRS, aaAA"}, 396 "Access-Control-Allow-Methods": {"PUT"}, 397 "Access-Control-Allow-Origin": {"https://foo.com"}, 398 "Access-Control-Max-Age": {"5"}, 399 "Vary": {"Origin"}, 400 }, 401 }, 402 { 403 name: "Custom Max age", 404 req: func() *safehttp.IncomingRequest { 405 r := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 406 r.Header.Set("Origin", "https://foo.com") 407 r.Header.Set("Access-Control-Request-Method", safehttp.MethodPut) 408 return r 409 }(), 410 maxAge: 3600, 411 wantHeaders: map[string][]string{ 412 "Access-Control-Allow-Methods": {"PUT"}, 413 "Access-Control-Allow-Origin": {"https://foo.com"}, 414 "Access-Control-Max-Age": {"3600"}, 415 "Vary": {"Origin"}, 416 }, 417 }, 418 } 419 420 for _, tt := range tests { 421 t.Run(tt.name, func(t *testing.T) { 422 fakeRW, rr := safehttptest.NewFakeResponseWriter() 423 424 it := cors.Default("https://foo.com") 425 it.MaxAge = tt.maxAge 426 it.SetAllowedHeaders(tt.allowedHeaders...) 427 it.Before(fakeRW, tt.req, nil) 428 429 if rr.Code != int(safehttp.StatusNoContent) { 430 t.Errorf("rr.Code got: %v want: %v", rr.Code, safehttp.StatusNoContent) 431 } 432 if diff := cmp.Diff(tt.wantHeaders, map[string][]string(rr.Header())); diff != "" { 433 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 434 } 435 if got := rr.Body.String(); got != "" { 436 t.Errorf(`rr.Body.String() got: %q want: ""`, got) 437 } 438 }) 439 } 440 } 441 442 func TestInvalidAccessControlRequestHeaders(t *testing.T) { 443 tests := []struct { 444 name string 445 headers string 446 }{ 447 { 448 name: "B is not allowed", 449 headers: "B", 450 }, 451 { 452 name: "One in list is not allowed", 453 headers: "X-Cors, B", 454 }, 455 { 456 name: "Empty at the end", 457 headers: "X-Cors, ", 458 }, 459 } 460 461 for _, tt := range tests { 462 t.Run(tt.name, func(t *testing.T) { 463 req := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 464 rh := req.Header 465 rh.Set("Origin", "https://foo.com") 466 rh.Set("Access-Control-Request-Method", safehttp.MethodPut) 467 rh.Set("Access-Control-Request-Headers", tt.headers) 468 469 fakeRW, rr := safehttptest.NewFakeResponseWriter() 470 471 it := cors.Default("https://foo.com") 472 it.Before(fakeRW, req, nil) 473 474 if want := safehttp.StatusForbidden; rr.Code != int(want) { 475 t.Errorf("rr.Code got: %v want: %v", rr.Code, want) 476 } 477 wantHeaders := map[string][]string{} 478 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 479 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 480 } 481 }) 482 } 483 } 484 485 func TestEmptyAccessControlRequestMethod(t *testing.T) { 486 req := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 487 rh := req.Header 488 rh.Set("Origin", "https://foo.com") 489 490 fakeRW, rr := safehttptest.NewFakeResponseWriter() 491 492 it := cors.Default("https://foo.com") 493 it.Before(fakeRW, req, nil) 494 495 if want := safehttp.StatusForbidden; rr.Code != int(want) { 496 t.Errorf("rr.Code got: %v want: %v", rr.Code, want) 497 } 498 wantHeaders := map[string][]string{} 499 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 500 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 501 } 502 } 503 504 func TestAccessControlRequestMethodHead(t *testing.T) { 505 req := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 506 req.Header.Set("Origin", "https://foo.com") 507 req.Header.Set("Access-Control-Request-Method", safehttp.MethodHead) 508 509 fakeRW, rr := safehttptest.NewFakeResponseWriter() 510 511 it := cors.Default("https://foo.com") 512 it.Before(fakeRW, req, nil) 513 514 if want := safehttp.StatusForbidden; rr.Code != int(want) { 515 t.Errorf("rr.Code got: %v want: %v", rr.Code, want) 516 } 517 wantHeaders := map[string][]string{} 518 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 519 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 520 } 521 } 522 523 func TestPreflightEmptyOrigin(t *testing.T) { 524 req := safehttptest.NewRequest(safehttp.MethodOptions, "http://bar.com/asdf", nil) 525 req.Header.Set("Access-Control-Request-Method", safehttp.MethodHead) 526 527 fakeRW, rr := safehttptest.NewFakeResponseWriter() 528 529 it := cors.Default("https://foo.com") 530 it.Before(fakeRW, req, nil) 531 532 if want := safehttp.StatusForbidden; rr.Code != int(want) { 533 t.Errorf("rr.Code got: %v want: %v", rr.Code, want) 534 } 535 wantHeaders := map[string][]string{} 536 if diff := cmp.Diff(wantHeaders, map[string][]string(rr.Header())); diff != "" { 537 t.Errorf("rr.Header() mismatch (-want +got):\n%s", diff) 538 } 539 }