github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/mux_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package safehttp_test 16 17 import ( 18 "io" 19 "net/http" 20 "net/http/httptest" 21 "testing" 22 23 "github.com/google/go-cmp/cmp" 24 "github.com/google/go-safeweb/safehttp" 25 "github.com/google/safehtml" 26 ) 27 28 func TestMuxOneHandlerOneRequest(t *testing.T) { 29 var test = []struct { 30 name string 31 req *http.Request 32 wantStatus safehttp.StatusCode 33 wantHeader map[string][]string 34 wantBody string 35 }{ 36 { 37 name: "Valid Request", 38 req: httptest.NewRequest(safehttp.MethodGet, "http://foo.com/", nil), 39 wantStatus: safehttp.StatusOK, 40 wantHeader: map[string][]string{ 41 "Content-Type": {"text/html; charset=utf-8"}, 42 }, 43 wantBody: "<h1>Hello World!</h1>", 44 }, 45 { 46 name: "Invalid Method", 47 req: httptest.NewRequest(safehttp.MethodPost, "http://foo.com/", nil), 48 wantStatus: safehttp.StatusMethodNotAllowed, 49 wantHeader: map[string][]string{ 50 "Content-Type": {"text/plain; charset=utf-8"}, 51 "X-Content-Type-Options": {"nosniff"}, 52 }, 53 wantBody: "Method Not Allowed\n", 54 }, 55 } 56 57 for _, tt := range test { 58 t.Run(tt.name, func(t *testing.T) { 59 mb := safehttp.NewServeMuxConfig(nil) 60 mux := mb.Mux() 61 62 h := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 63 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 64 }) 65 mux.Handle("/", safehttp.MethodGet, h) 66 67 rw := httptest.NewRecorder() 68 69 mux.ServeHTTP(rw, tt.req) 70 71 if rw.Code != int(tt.wantStatus) { 72 t.Errorf("rw.Code: got %v want %v", rw.Code, tt.wantStatus) 73 } 74 75 if diff := cmp.Diff(tt.wantHeader, map[string][]string(rw.Header())); diff != "" { 76 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 77 } 78 79 if got := rw.Body.String(); got != tt.wantBody { 80 t.Errorf("response body: got %q want %q", got, tt.wantBody) 81 } 82 }) 83 } 84 } 85 86 func TestMuxServeTwoHandlers(t *testing.T) { 87 var tests = []struct { 88 name string 89 req *http.Request 90 hf safehttp.Handler 91 wantStatus safehttp.StatusCode 92 wantHeaders map[string][]string 93 wantBody string 94 }{ 95 { 96 name: "GET Handler", 97 req: httptest.NewRequest(safehttp.MethodGet, "http://foo.com/bar", nil), 98 hf: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 99 return w.Write(safehtml.HTMLEscaped("<h1>Hello World! GET</h1>")) 100 }), 101 wantStatus: safehttp.StatusOK, 102 wantHeaders: map[string][]string{ 103 "Content-Type": {"text/html; charset=utf-8"}, 104 }, 105 wantBody: "<h1>Hello World! GET</h1>", 106 }, 107 { 108 name: "POST Handler", 109 req: httptest.NewRequest(safehttp.MethodPost, "http://foo.com/bar", nil), 110 hf: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 111 return w.Write(safehtml.HTMLEscaped("<h1>Hello World! POST</h1>")) 112 }), 113 wantStatus: safehttp.StatusOK, 114 wantHeaders: map[string][]string{ 115 "Content-Type": {"text/html; charset=utf-8"}, 116 }, 117 wantBody: "<h1>Hello World! POST</h1>", 118 }, 119 } 120 121 mb := safehttp.NewServeMuxConfig(nil) 122 mux := mb.Mux() 123 124 mux.Handle("/bar", safehttp.MethodGet, tests[0].hf) 125 mux.Handle("/bar", safehttp.MethodPost, tests[1].hf) 126 127 for _, test := range tests { 128 rw := httptest.NewRecorder() 129 mux.ServeHTTP(rw, test.req) 130 if want := int(test.wantStatus); rw.Code != want { 131 t.Errorf("rw.Code: got %v want %v", rw.Code, want) 132 } 133 134 if diff := cmp.Diff(test.wantHeaders, map[string][]string(rw.Header())); diff != "" { 135 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 136 } 137 138 if got, want := rw.Body.String(), test.wantBody; got != want { 139 t.Errorf("response body: got %q want %q", got, want) 140 } 141 } 142 } 143 144 func TestMuxRegisterCorrectHandlerAllPaths(t *testing.T) { 145 var tests = []struct { 146 name string 147 req *http.Request 148 hf safehttp.Handler 149 wantBody string 150 }{ 151 { 152 name: "GET Handler", 153 req: httptest.NewRequest(safehttp.MethodGet, "http://foo.com/get", nil), 154 hf: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 155 return w.Write(safehtml.HTMLEscaped("GET handler for /get")) 156 }), 157 wantBody: "GET handler for /get", 158 }, 159 { 160 name: "GET Handler #2", 161 req: httptest.NewRequest(safehttp.MethodGet, "http://foo.com/get2", nil), 162 hf: safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 163 return w.Write(safehtml.HTMLEscaped("GET handler for /get2")) 164 }), 165 wantBody: "GET handler for /get2", 166 }, 167 } 168 169 mb := safehttp.NewServeMuxConfig(nil) 170 mux := mb.Mux() 171 mux.Handle("/get", safehttp.MethodGet, tests[0].hf) 172 mux.Handle("/get2", safehttp.MethodGet, tests[1].hf) 173 174 for _, test := range tests { 175 rw := httptest.NewRecorder() 176 mux.ServeHTTP(rw, test.req) 177 178 if got, want := rw.Body.String(), test.wantBody; got != want { 179 t.Errorf("response body: got %q want %q", got, want) 180 } 181 } 182 } 183 184 func TestMuxHandleSameMethodTwice(t *testing.T) { 185 mb := safehttp.NewServeMuxConfig(nil) 186 mux := mb.Mux() 187 188 registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 189 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 190 }) 191 mux.Handle("/bar", safehttp.MethodGet, registeredHandler) 192 193 defer func() { 194 if r := recover(); r != nil { 195 return 196 } 197 t.Errorf(`mux.Handle("/bar", MethodGet, registeredHandler) expected panic`) 198 }() 199 200 mux.Handle("/bar", safehttp.MethodGet, registeredHandler) 201 } 202 203 type setHeaderInterceptor struct { 204 name string 205 value string 206 } 207 208 func (p setHeaderInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 209 w.Header().Set(p.name, p.value) 210 return safehttp.NotWritten() 211 } 212 213 func (p setHeaderInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 214 } 215 216 func (setHeaderInterceptor) Match(safehttp.InterceptorConfig) bool { 217 return false 218 } 219 220 type internalErrorInterceptor struct{} 221 222 func (internalErrorInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 223 return w.WriteError(safehttp.StatusInternalServerError) 224 } 225 226 func (internalErrorInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 227 } 228 229 func (internalErrorInterceptor) Match(safehttp.InterceptorConfig) bool { 230 return false 231 } 232 233 type claimHeaderInterceptor struct { 234 headerToClaim string 235 } 236 237 type claimKey struct{} 238 239 func (p *claimHeaderInterceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 240 f := w.Header().Claim(p.headerToClaim) 241 safehttp.FlightValues(r.Context()).Put(claimKey{}, f) 242 return safehttp.NotWritten() 243 } 244 245 func (p *claimHeaderInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 246 } 247 248 func (claimHeaderInterceptor) Match(safehttp.InterceptorConfig) bool { 249 return false 250 } 251 252 func claimInterceptorSetHeader(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, value string) { 253 f := safehttp.FlightValues(r.Context()).Get(claimKey{}).(func([]string)) 254 f([]string{value}) 255 } 256 257 type committerInterceptor struct{} 258 259 func (committerInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 260 return safehttp.NotWritten() 261 } 262 263 func (committerInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 264 w.Header().Set("foo", "bar") 265 } 266 267 func (committerInterceptor) Match(safehttp.InterceptorConfig) bool { 268 return false 269 } 270 271 type setHeaderErroringInterceptor struct{} 272 273 func (setHeaderErroringInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 274 return w.WriteError(safehttp.StatusForbidden) 275 } 276 277 func (setHeaderErroringInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 278 w.Header().Set("name", "foo") 279 } 280 281 func TestMuxInterceptors(t *testing.T) { 282 tests := []struct { 283 name string 284 mux *safehttp.ServeMux 285 wantStatus safehttp.StatusCode 286 wantHeaders map[string][]string 287 wantBody string 288 }{ 289 { 290 name: "Install ServeMux Interceptor before handler registration", 291 mux: func() *safehttp.ServeMux { 292 mb := safehttp.NewServeMuxConfig(nil) 293 mb.Intercept(setHeaderInterceptor{ 294 name: "Foo", 295 value: "bar", 296 }) 297 mux := mb.Mux() 298 299 registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 300 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 301 }) 302 mux.Handle("/bar", safehttp.MethodGet, registeredHandler) 303 return mux 304 }(), 305 wantStatus: safehttp.StatusOK, 306 wantHeaders: map[string][]string{ 307 "Content-Type": {"text/html; charset=utf-8"}, 308 "Foo": {"bar"}, 309 }, 310 wantBody: "<h1>Hello World!</h1>", 311 }, 312 { 313 name: "Install Interrupting Interceptor", 314 mux: func() *safehttp.ServeMux { 315 mb := safehttp.NewServeMuxConfig(nil) 316 mb.Intercept(internalErrorInterceptor{}) 317 mux := mb.Mux() 318 319 registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 320 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 321 }) 322 mux.Handle("/bar", safehttp.MethodGet, registeredHandler) 323 324 return mux 325 }(), 326 wantStatus: safehttp.StatusInternalServerError, 327 wantHeaders: map[string][]string{ 328 "Content-Type": {"text/plain; charset=utf-8"}, 329 "X-Content-Type-Options": {"nosniff"}, 330 }, 331 wantBody: "Internal Server Error\n", 332 }, 333 { 334 name: "Handler Communication With ServeMux Interceptor", 335 mux: func() *safehttp.ServeMux { 336 mb := safehttp.NewServeMuxConfig(nil) 337 mb.Intercept(&claimHeaderInterceptor{headerToClaim: "Foo"}) 338 mux := mb.Mux() 339 340 registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 341 claimInterceptorSetHeader(w, r, "bar") 342 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 343 }) 344 mux.Handle("/bar", safehttp.MethodGet, registeredHandler) 345 346 return mux 347 }(), 348 wantStatus: safehttp.StatusOK, 349 wantHeaders: map[string][]string{ 350 "Content-Type": {"text/html; charset=utf-8"}, 351 "Foo": {"bar"}, 352 }, 353 wantBody: "<h1>Hello World!</h1>", 354 }, 355 { 356 name: "Commit phase sets header", 357 mux: func() *safehttp.ServeMux { 358 mb := safehttp.NewServeMuxConfig(nil) 359 mb.Intercept(committerInterceptor{}) 360 mux := mb.Mux() 361 362 registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 363 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 364 }) 365 mux.Handle("/bar", safehttp.MethodGet, registeredHandler) 366 367 return mux 368 }(), 369 wantStatus: safehttp.StatusOK, 370 wantHeaders: map[string][]string{ 371 "Foo": {"bar"}, 372 "Content-Type": {"text/html; charset=utf-8"}, 373 }, 374 wantBody: "<h1>Hello World!</h1>", 375 }, 376 } 377 378 for _, tt := range tests { 379 t.Run(tt.name, func(t *testing.T) { 380 rw := httptest.NewRecorder() 381 req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/bar", nil) 382 383 tt.mux.ServeHTTP(rw, req) 384 385 if rw.Code != int(tt.wantStatus) { 386 t.Errorf("rw.Code: got %v want %v", rw.Code, tt.wantStatus) 387 } 388 389 if diff := cmp.Diff(tt.wantHeaders, map[string][]string(rw.Header())); diff != "" { 390 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 391 } 392 393 if got := rw.Body.String(); got != tt.wantBody { 394 t.Errorf("response body: got %q want %q", got, tt.wantBody) 395 } 396 }) 397 } 398 } 399 400 type setHeaderConfig struct { 401 name string 402 value string 403 } 404 405 type setHeaderConfigInterceptor struct{} 406 407 func (p setHeaderConfigInterceptor) Before(w safehttp.ResponseWriter, _ *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 408 name := "Pizza" 409 value := "Hawaii" 410 if c, ok := cfg.(setHeaderConfig); ok { 411 name = c.name 412 value = c.value 413 } 414 w.Header().Set(name, value) 415 return safehttp.NotWritten() 416 } 417 418 func (p setHeaderConfigInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 419 name := "Commit-Pizza" 420 value := "Hawaii" 421 if c, ok := cfg.(setHeaderConfig); ok { 422 name = "Commit-" + c.name 423 value = c.value 424 } 425 w.Header().Set(name, value) 426 } 427 428 func (setHeaderConfigInterceptor) Match(cfg safehttp.InterceptorConfig) bool { 429 _, ok := cfg.(setHeaderConfig) 430 return ok 431 } 432 433 type noInterceptorConfig struct{} 434 435 type wrappedInterceptor struct { 436 w safehttp.Interceptor 437 } 438 439 func (wi wrappedInterceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 440 return wi.w.Before(w, r, cfg) 441 } 442 443 func (wi wrappedInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 444 wi.w.Commit(w, r, resp, cfg) 445 } 446 447 func (wi wrappedInterceptor) Match(cfg safehttp.InterceptorConfig) bool { 448 return wi.w.Match(cfg) 449 } 450 451 func (noInterceptorConfig) Match(i safehttp.Interceptor) bool { 452 return false 453 } 454 455 func TestMuxInterceptorConfigs(t *testing.T) { 456 tests := []struct { 457 name string 458 interceptor safehttp.Interceptor 459 config safehttp.InterceptorConfig 460 wantStatus safehttp.StatusCode 461 wantHeaders map[string][]string 462 wantBody string 463 }{ 464 { 465 name: "SetHeaderInterceptor with config", 466 interceptor: setHeaderConfigInterceptor{}, 467 config: setHeaderConfig{name: "Foo", value: "Bar"}, 468 wantStatus: safehttp.StatusOK, 469 wantHeaders: map[string][]string{ 470 "Content-Type": {"text/html; charset=utf-8"}, 471 "Commit-Foo": {"Bar"}, 472 "Foo": {"Bar"}, 473 }, 474 wantBody: "<h1>Hello World!</h1>", 475 }, 476 { 477 name: "Wrapped SetHeaderInterceptor with config", 478 interceptor: wrappedInterceptor{w: setHeaderConfigInterceptor{}}, 479 config: setHeaderConfig{name: "Foo", value: "Bar"}, 480 wantStatus: safehttp.StatusOK, 481 wantHeaders: map[string][]string{ 482 "Content-Type": {"text/html; charset=utf-8"}, 483 "Commit-Foo": {"Bar"}, 484 "Foo": {"Bar"}, 485 }, 486 wantBody: "<h1>Hello World!</h1>", 487 }, 488 { 489 name: "SetHeaderInterceptor with mismatching config", 490 interceptor: setHeaderConfigInterceptor{}, 491 config: noInterceptorConfig{}, 492 wantStatus: safehttp.StatusOK, 493 wantHeaders: map[string][]string{ 494 "Content-Type": {"text/html; charset=utf-8"}, 495 "Pizza": {"Hawaii"}, 496 "Commit-Pizza": {"Hawaii"}, 497 }, 498 wantBody: "<h1>Hello World!</h1>", 499 }, 500 } 501 502 for _, tt := range tests { 503 t.Run(tt.name, func(t *testing.T) { 504 mb := safehttp.NewServeMuxConfig(nil) 505 mb.Intercept(tt.interceptor) 506 mux := mb.Mux() 507 508 registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 509 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 510 }) 511 mux.Handle("/bar", safehttp.MethodGet, registeredHandler, tt.config) 512 513 rw := httptest.NewRecorder() 514 req := httptest.NewRequest("GET", "http://foo.com/bar", nil) 515 516 mux.ServeHTTP(rw, req) 517 518 if rw.Code != int(tt.wantStatus) { 519 t.Errorf("rw.Code: got %v want %v", rw.Code, tt.wantStatus) 520 } 521 522 if diff := cmp.Diff(tt.wantHeaders, map[string][]string(rw.Header())); diff != "" { 523 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 524 } 525 526 if got := rw.Body.String(); got != tt.wantBody { 527 t.Errorf("response body: got %q want %q", got, tt.wantBody) 528 } 529 }) 530 } 531 } 532 533 type interceptorOne struct{} 534 535 func (interceptorOne) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 536 w.Header().Set("pizza", "diavola") 537 return safehttp.NotWritten() 538 } 539 540 func (interceptorOne) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 541 if w.Header().Get("Commit2") != "b" { 542 panic("server bug") 543 } 544 w.Header().Set("Commit1", "a") 545 } 546 547 func (interceptorOne) Match(safehttp.InterceptorConfig) bool { 548 return false 549 } 550 551 type interceptorTwo struct{} 552 553 func (interceptorTwo) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 554 if w.Header().Get("pizza") != "diavola" { 555 panic("server bug") 556 } 557 w.Header().Set("spaghetti", "bolognese") 558 return safehttp.NotWritten() 559 } 560 561 func (interceptorTwo) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 562 if w.Header().Get("Commit3") != "c" { 563 panic("server bug") 564 } 565 w.Header().Set("Commit2", "b") 566 } 567 568 func (interceptorTwo) Match(safehttp.InterceptorConfig) bool { 569 return false 570 } 571 572 type interceptorThree struct{} 573 574 func (interceptorThree) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, cfg safehttp.InterceptorConfig) safehttp.Result { 575 if w.Header().Get("spaghetti") != "bolognese" { 576 panic("server bug") 577 } 578 w.Header().Set("dessert", "tiramisu") 579 return safehttp.NotWritten() 580 } 581 582 func (interceptorThree) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, cfg safehttp.InterceptorConfig) { 583 if w.Header().Get("Dessert") != "tiramisu" { 584 panic("server bug") 585 } 586 w.Header().Set("Commit3", "c") 587 } 588 589 func (interceptorThree) Match(safehttp.InterceptorConfig) bool { 590 return false 591 } 592 593 func TestMuxDeterministicInterceptorOrder(t *testing.T) { 594 mb := safehttp.NewServeMuxConfig(nil) 595 mb.Intercept(interceptorOne{}) 596 mb.Intercept(interceptorTwo{}) 597 mb.Intercept(interceptorThree{}) 598 mux := mb.Mux() 599 600 registeredHandler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 601 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 602 }) 603 mux.Handle("/bar", safehttp.MethodGet, registeredHandler) 604 605 rw := httptest.NewRecorder() 606 req := httptest.NewRequest("GET", "http://foo.com/bar", nil) 607 608 mux.ServeHTTP(rw, req) 609 610 if want := safehttp.StatusOK; rw.Code != int(want) { 611 t.Errorf("rw.Code: got %v want %v", rw.Code, want) 612 } 613 wantHeaders := map[string][]string{ 614 "Dessert": {"tiramisu"}, 615 "Pizza": {"diavola"}, 616 "Spaghetti": {"bolognese"}, 617 "Commit1": {"a"}, 618 "Commit2": {"b"}, 619 "Commit3": {"c"}, 620 "Content-Type": {"text/html; charset=utf-8"}, 621 } 622 if diff := cmp.Diff(wantHeaders, map[string][]string(rw.Header())); diff != "" { 623 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 624 } 625 if got, want := rw.Body.String(), "<h1>Hello World!</h1>"; got != want { 626 t.Errorf(`response body: got %q want %q`, got, want) 627 } 628 } 629 630 func TestMuxHandlerReturnsNotWritten(t *testing.T) { 631 mb := safehttp.NewServeMuxConfig(nil) 632 h := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 633 return safehttp.NotWritten() 634 }) 635 mux := mb.Mux() 636 mux.Handle("/bar", safehttp.MethodGet, h) 637 638 req := httptest.NewRequest(safehttp.MethodGet, "http://foo.com/bar", nil) 639 rw := httptest.NewRecorder() 640 641 mux.ServeHTTP(rw, req) 642 643 if want := safehttp.StatusNoContent; rw.Code != int(want) { 644 t.Errorf("rw.Code: got %v want %v", rw.Code, want) 645 } 646 if diff := cmp.Diff(map[string][]string{}, map[string][]string(rw.Header())); diff != "" { 647 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 648 } 649 if got := rw.Body.String(); got != "" { 650 t.Errorf(`response body got: %q want: ""`, got) 651 } 652 } 653 654 func TestMuxMethodNotAllowedDefaults(t *testing.T) { 655 mb := safehttp.NewServeMuxConfig(nil) 656 mux := mb.Mux() 657 658 h := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 659 panic("not tested") 660 }) 661 mux.Handle("/", safehttp.MethodGet, h) 662 663 rw := httptest.NewRecorder() 664 665 mux.ServeHTTP(rw, httptest.NewRequest(safehttp.MethodPost, "http://foo.com/", nil)) 666 667 if got, want := rw.Code, int(safehttp.StatusMethodNotAllowed); got != want { 668 t.Errorf("rw.Code: got %v want %v", got, want) 669 } 670 671 wantHeader := map[string][]string{ 672 "Content-Type": {"text/plain; charset=utf-8"}, 673 "X-Content-Type-Options": {"nosniff"}, 674 } 675 if diff := cmp.Diff(wantHeader, map[string][]string(rw.Header())); diff != "" { 676 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 677 } 678 679 wantBody := "Method Not Allowed\n" 680 if got := rw.Body.String(); got != wantBody { 681 t.Errorf("response body: got %q want %q", got, wantBody) 682 } 683 } 684 685 type methodNotAllowedError struct { 686 message string 687 } 688 689 func (err *methodNotAllowedError) Code() safehttp.StatusCode { 690 return safehttp.StatusMethodNotAllowed 691 } 692 693 type methodNotAllowedDispatcher struct { 694 safehttp.DefaultDispatcher 695 } 696 697 func (d *methodNotAllowedDispatcher) Error(rw http.ResponseWriter, resp safehttp.ErrorResponse) error { 698 x := resp.(*methodNotAllowedError) 699 rw.Header().Set("Content-Type", "text/html; charset=utf-8") 700 rw.WriteHeader(int(resp.Code())) 701 _, err := io.WriteString(rw, "<h1>"+http.StatusText(int(resp.Code()))+"</h1>"+"<p>"+x.message+"</p>") 702 return err 703 } 704 705 type methodNotAllowedInterceptor struct{} 706 707 func (ip *methodNotAllowedInterceptor) Before(w safehttp.ResponseWriter, r *safehttp.IncomingRequest, ipcfg safehttp.InterceptorConfig) safehttp.Result { 708 cfg := ipcfg.(methodNotAllowedInterceptorConfig) 709 w.Header().Set("Before-Interceptor", cfg.before) 710 return safehttp.NotWritten() 711 } 712 713 // Commit runs before the response is written by the Dispatcher. If an error 714 // is written to the ResponseWriter, then the Commit phases from the 715 // remaining interceptors won't execute. 716 func (ip *methodNotAllowedInterceptor) Commit(w safehttp.ResponseHeadersWriter, r *safehttp.IncomingRequest, resp safehttp.Response, ipcfg safehttp.InterceptorConfig) { 717 cfg := ipcfg.(methodNotAllowedInterceptorConfig) 718 w.Header().Set("Commit-Interceptor", cfg.commit) 719 } 720 721 func (*methodNotAllowedInterceptor) Match(cfg safehttp.InterceptorConfig) bool { 722 _, ok := cfg.(methodNotAllowedInterceptorConfig) 723 return ok 724 } 725 726 type methodNotAllowedInterceptorConfig struct { 727 before, commit string 728 } 729 730 func TestMuxMethodNotAllowedCustom(t *testing.T) { 731 mb := safehttp.NewServeMuxConfig(&methodNotAllowedDispatcher{}) 732 mb.Intercept(&methodNotAllowedInterceptor{}) 733 mb.HandleMethodNotAllowed(safehttp.HandlerFunc(func(rw safehttp.ResponseWriter, ir *safehttp.IncomingRequest) safehttp.Result { 734 return rw.WriteError(&methodNotAllowedError{"custom message"}) 735 }), methodNotAllowedInterceptorConfig{before: "foo", commit: "bar"}) 736 mux := mb.Mux() 737 738 mux.Handle("/", safehttp.MethodGet, safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 739 panic("not tested") 740 })) 741 742 rw := httptest.NewRecorder() 743 744 mux.ServeHTTP(rw, httptest.NewRequest(safehttp.MethodPost, "http://foo.com/", nil)) 745 746 if got, want := rw.Code, int(safehttp.StatusMethodNotAllowed); got != want { 747 t.Errorf("rw.Code: got %v want %v", got, want) 748 } 749 750 wantHeader := map[string][]string{ 751 "Content-Type": {"text/html; charset=utf-8"}, 752 "Before-Interceptor": {"foo"}, 753 "Commit-Interceptor": {"bar"}, 754 } 755 if diff := cmp.Diff(wantHeader, map[string][]string(rw.Header())); diff != "" { 756 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 757 } 758 759 wantBody := "<h1>Method Not Allowed</h1><p>custom message</p>" 760 if got := rw.Body.String(); got != wantBody { 761 t.Errorf("response body: got %q want %q", got, wantBody) 762 } 763 }