github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/fileserver_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package safehttp_test 16 17 import ( 18 "io/ioutil" 19 "net/http/httptest" 20 "os" 21 "testing" 22 23 "github.com/google/go-cmp/cmp" 24 "github.com/google/go-safeweb/safehttp" 25 ) 26 27 func TestFileServer(t *testing.T) { 28 tmpDir, err := ioutil.TempDir("", "go-safehttp-test") 29 if err != nil { 30 t.Fatalf("ioutil.TempDir(): %v", err) 31 } 32 defer os.RemoveAll(tmpDir) 33 34 if err := ioutil.WriteFile(tmpDir+"/foo.html", []byte("<h1>Hello world</h1>"), 0644); err != nil { 35 t.Fatalf("ioutil.WriteFile: %v", err) 36 } 37 38 tests := []struct { 39 name string 40 path string 41 wantCode safehttp.StatusCode 42 wantCT string 43 wantBody string 44 }{ 45 { 46 name: "missing file", 47 path: "failure", 48 wantCode: safehttp.StatusNotFound, 49 wantCT: "text/plain; charset=utf-8", 50 wantBody: "Not Found\n", 51 }, 52 { 53 name: "file", 54 path: "foo.html", 55 wantCode: safehttp.StatusOK, 56 wantCT: "text/html; charset=utf-8", 57 wantBody: "<h1>Hello world</h1>", 58 }, 59 } 60 61 mb := safehttp.NewServeMuxConfig(nil) 62 m := mb.Mux() 63 m.Handle("/", safehttp.MethodGet, safehttp.FileServer(tmpDir)) 64 65 for _, tt := range tests { 66 t.Run(tt.name, func(t *testing.T) { 67 rr := httptest.NewRecorder() 68 69 req := httptest.NewRequest(safehttp.MethodGet, "https://test.science/"+tt.path, nil) 70 m.ServeHTTP(rr, req) 71 72 if got, want := rr.Code, tt.wantCode; got != int(tt.wantCode) { 73 t.Errorf("status code got: %v want: %v", got, want) 74 } 75 if got := rr.Header().Get("Content-Type"); tt.wantCT != got { 76 t.Errorf("Content-Type: got %q want %q", got, tt.wantCT) 77 } 78 if diff := cmp.Diff(tt.wantBody, rr.Body.String()); diff != "" { 79 t.Errorf("Response body diff (-want,+got): \n%s", diff) 80 } 81 }) 82 } 83 } 84 85 // TODO(kele): Add tests including interceptors once we have 86 // https://github.com/google/go-safeweb/issues/261.