github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/tests/integration/staticheaders/staticheaders_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package staticheaders_test 16 17 import ( 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 22 "github.com/google/go-cmp/cmp" 23 "github.com/google/go-safeweb/safehttp" 24 "github.com/google/go-safeweb/safehttp/plugins/staticheaders" 25 "github.com/google/safehtml" 26 ) 27 28 func TestServeMuxInstallStaticHeaders(t *testing.T) { 29 mb := safehttp.NewServeMuxConfig(nil) 30 mb.Intercept(staticheaders.Interceptor{}) 31 mux := mb.Mux() 32 33 handler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 34 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 35 }) 36 mux.Handle("/asdf", safehttp.MethodGet, handler) 37 38 rw := httptest.NewRecorder() 39 40 req := httptest.NewRequest(http.MethodGet, "https://foo.com/asdf", nil) 41 42 mux.ServeHTTP(rw, req) 43 44 if want := safehttp.StatusOK; rw.Code != int(want) { 45 t.Errorf("rw.Code got: %v want: %v", rw.Code, want) 46 } 47 48 wantHeaders := map[string][]string{ 49 "Content-Type": {"text/html; charset=utf-8"}, 50 "X-Content-Type-Options": {"nosniff"}, 51 "X-Xss-Protection": {"0"}, 52 } 53 if diff := cmp.Diff(wantHeaders, map[string][]string(rw.Header())); diff != "" { 54 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 55 } 56 57 if got, want := rw.Body.String(), "<h1>Hello World!</h1>"; got != want { 58 t.Errorf("response body got: %v want: %v", got, want) 59 } 60 } 61 62 func TestStaticHeadersOnError(t *testing.T) { 63 mb := safehttp.NewServeMuxConfig(nil) 64 mb.Intercept(staticheaders.Interceptor{}) 65 mux := mb.Mux() 66 67 handler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 68 return w.WriteError(safehttp.StatusNotFound) 69 }) 70 mux.Handle("/asdf", safehttp.MethodGet, handler) 71 72 rw := httptest.NewRecorder() 73 74 req := httptest.NewRequest(http.MethodGet, "https://foo.com/asdf", nil) 75 76 mux.ServeHTTP(rw, req) 77 78 if want := safehttp.StatusNotFound; rw.Code != int(want) { 79 t.Errorf("rw.Status() got: %v want: %v", rw.Code, want) 80 } 81 82 wantHeaders := map[string][]string{ 83 "X-Content-Type-Options": {"nosniff"}, 84 "X-Xss-Protection": {"0"}, 85 } 86 for h, want := range wantHeaders { 87 if got := rw.Header()[h]; !cmp.Equal(got, want) { 88 t.Errorf("rw.Header()[%q] got %q, want %q", h, got, want) 89 } 90 } 91 }