github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/hostcheck/hostcheck_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package hostcheck_test 16 17 import ( 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 22 "github.com/google/go-safeweb/safehttp" 23 "github.com/google/go-safeweb/safehttp/plugins/hostcheck" 24 "github.com/google/safehtml" 25 ) 26 27 func TestInterceptor(t *testing.T) { 28 var test = []struct { 29 name string 30 req *http.Request 31 wantStatus safehttp.StatusCode 32 }{ 33 { 34 name: "Valid Host", 35 req: httptest.NewRequest(safehttp.MethodGet, "http://foo.com/", nil), 36 wantStatus: safehttp.StatusOK, 37 }, 38 { 39 name: "Invalid Host", 40 req: httptest.NewRequest(safehttp.MethodGet, "http://bar.com/", nil), 41 wantStatus: safehttp.StatusNotFound, 42 }, 43 } 44 45 for _, tt := range test { 46 t.Run(tt.name, func(t *testing.T) { 47 mb := safehttp.NewServeMuxConfig(nil) 48 mb.Intercept(hostcheck.New("foo.com")) 49 mux := mb.Mux() 50 51 h := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result { 52 return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>")) 53 }) 54 mux.Handle("/", safehttp.MethodGet, h) 55 56 rw := httptest.NewRecorder() 57 mux.ServeHTTP(rw, tt.req) 58 59 if rw.Code != int(tt.wantStatus) { 60 t.Errorf("rw.Code: got %v want %v", rw.Code, tt.wantStatus) 61 } 62 }) 63 } 64 }