github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/tests/integration/hsts/hsts_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package hsts_test 16 17 import ( 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 22 "github.com/google/go-cmp/cmp" 23 "github.com/google/go-safeweb/safehttp" 24 "github.com/google/go-safeweb/safehttp/plugins/hsts" 25 "github.com/google/safehtml" 26 ) 27 28 func TestHSTSServeMuxInstall(t *testing.T) { 29 mb := safehttp.NewServeMuxConfig(nil) 30 mb.Intercept(hsts.Default()) 31 mux := mb.Mux() 32 33 handler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 34 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 35 }) 36 mux.Handle("/asdf", safehttp.MethodGet, handler) 37 38 rw := httptest.NewRecorder() 39 req := httptest.NewRequest(http.MethodGet, "https://foo.com/asdf", nil) 40 41 mux.ServeHTTP(rw, req) 42 43 if want := safehttp.StatusOK; rw.Code != int(want) { 44 t.Errorf("rw.Code got: %v want: %v", rw.Code, want) 45 } 46 47 wantHeaders := map[string][]string{ 48 "Content-Type": {"text/html; charset=utf-8"}, 49 "Strict-Transport-Security": {"max-age=63072000; includeSubDomains"}, 50 } 51 if diff := cmp.Diff(wantHeaders, map[string][]string(rw.Header())); diff != "" { 52 t.Errorf("rw.Header() mismatch (-want +got):\n%s", diff) 53 } 54 55 if got, want := rw.Body.String(), "<h1>Hello World!</h1>"; got != want { 56 t.Errorf("Body got: %v want: %v", got, want) 57 } 58 } 59 60 func TestHSTSOnErrors(t *testing.T) { 61 mb := safehttp.NewServeMuxConfig(nil) 62 mb.Intercept(hsts.Default()) 63 mux := mb.Mux() 64 65 handler := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 66 return w.WriteError(safehttp.StatusNotFound) 67 }) 68 mux.Handle("/asdf", safehttp.MethodGet, handler) 69 70 rw := httptest.NewRecorder() 71 req := httptest.NewRequest(http.MethodGet, "https://foo.com/asdf", nil) 72 73 mux.ServeHTTP(rw, req) 74 75 if want := safehttp.StatusNotFound; rw.Code != int(want) { 76 t.Errorf("rw.Code got: %v want: %v", rw.Code, want) 77 } 78 if got, want := rw.Header()["Strict-Transport-Security"], []string{"max-age=63072000; includeSubDomains"}; !cmp.Equal(got, want) { 79 t.Errorf("rw.Header()[\"Strict-Transport-Security\"] = %q, want %q", got, want) 80 } 81 }