github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/fileserver_1_16_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 //go:build go1.16 16 // +build go1.16 17 18 package safehttp_test 19 20 import ( 21 "embed" 22 "io" 23 "net/http/httptest" 24 "testing" 25 26 "github.com/google/go-cmp/cmp" 27 "github.com/google/go-safeweb/safehttp" 28 ) 29 30 //go:embed testdata 31 var testEmbeddedFS embed.FS 32 33 func TestFileServerEmbed(t *testing.T) { 34 wantFile, err := testEmbeddedFS.Open("testdata/embed.html") 35 if err != nil { 36 t.Fatalf("Could not open embedded test files: %v", err) 37 } 38 wantFileContent, err := io.ReadAll(wantFile) 39 if err != nil { 40 t.Fatalf("Could not read embedded test files: %v", err) 41 } 42 43 tests := []struct { 44 name string 45 path string 46 wantCode safehttp.StatusCode 47 wantCT string 48 wantBody string 49 }{ 50 { 51 name: "missing file", 52 path: "failure", 53 wantCode: 404, 54 wantCT: "text/plain; charset=utf-8", 55 wantBody: "Not Found\n", 56 }, 57 { 58 name: "embedded file", 59 path: "testdata/embed.html", 60 wantCode: 200, 61 wantCT: "text/html; charset=utf-8", 62 wantBody: string(wantFileContent), 63 }, 64 } 65 66 mb := safehttp.NewServeMuxConfig(nil) 67 m := mb.Mux() 68 m.Handle("/", safehttp.MethodGet, safehttp.FileServerEmbed(testEmbeddedFS)) 69 70 for _, tt := range tests { 71 t.Run(tt.name, func(t *testing.T) { 72 rr := httptest.NewRecorder() 73 74 req := httptest.NewRequest(safehttp.MethodGet, "https://test.science/"+tt.path, nil) 75 m.ServeHTTP(rr, req) 76 77 if got, want := rr.Result().StatusCode, tt.wantCode; safehttp.StatusCode(got) != tt.wantCode { 78 t.Errorf("status code got: %v want: %v", got, want) 79 } 80 if got := rr.Header().Get("Content-Type"); tt.wantCT != got { 81 t.Errorf("Content-Type: got %q want %q", got, tt.wantCT) 82 } 83 body, err := io.ReadAll(rr.Result().Body) 84 if err != nil { 85 t.Errorf("Can't read response body: %v", err) 86 } 87 if diff := cmp.Diff(tt.wantBody, string(body)); diff != "" { 88 t.Errorf("Response body diff (-want,+got): \n%s", diff) 89 } 90 }) 91 } 92 }