github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/fileserver_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp_test
    16  
    17  import (
    18  	"io/ioutil"
    19  	"net/http/httptest"
    20  	"os"
    21  	"testing"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  	"github.com/google/go-safeweb/safehttp"
    25  )
    26  
    27  func TestFileServer(t *testing.T) {
    28  	tmpDir, err := ioutil.TempDir("", "go-safehttp-test")
    29  	if err != nil {
    30  		t.Fatalf("ioutil.TempDir(): %v", err)
    31  	}
    32  	defer os.RemoveAll(tmpDir)
    33  
    34  	if err := ioutil.WriteFile(tmpDir+"/foo.html", []byte("<h1>Hello world</h1>"), 0644); err != nil {
    35  		t.Fatalf("ioutil.WriteFile: %v", err)
    36  	}
    37  
    38  	tests := []struct {
    39  		name     string
    40  		path     string
    41  		wantCode safehttp.StatusCode
    42  		wantCT   string
    43  		wantBody string
    44  	}{
    45  		{
    46  			name:     "missing file",
    47  			path:     "failure",
    48  			wantCode: safehttp.StatusNotFound,
    49  			wantCT:   "text/plain; charset=utf-8",
    50  			wantBody: "Not Found\n",
    51  		},
    52  		{
    53  			name:     "file",
    54  			path:     "foo.html",
    55  			wantCode: safehttp.StatusOK,
    56  			wantCT:   "text/html; charset=utf-8",
    57  			wantBody: "<h1>Hello world</h1>",
    58  		},
    59  	}
    60  
    61  	mb := safehttp.NewServeMuxConfig(nil)
    62  	m := mb.Mux()
    63  	m.Handle("/", safehttp.MethodGet, safehttp.FileServer(tmpDir))
    64  
    65  	for _, tt := range tests {
    66  		t.Run(tt.name, func(t *testing.T) {
    67  			rr := httptest.NewRecorder()
    68  
    69  			req := httptest.NewRequest(safehttp.MethodGet, "https://test.science/"+tt.path, nil)
    70  			m.ServeHTTP(rr, req)
    71  
    72  			if got, want := rr.Code, tt.wantCode; got != int(tt.wantCode) {
    73  				t.Errorf("status code got: %v want: %v", got, want)
    74  			}
    75  			if got := rr.Header().Get("Content-Type"); tt.wantCT != got {
    76  				t.Errorf("Content-Type: got %q want %q", got, tt.wantCT)
    77  			}
    78  			if diff := cmp.Diff(tt.wantBody, rr.Body.String()); diff != "" {
    79  				t.Errorf("Response body diff (-want,+got): \n%s", diff)
    80  			}
    81  		})
    82  	}
    83  }
    84  
    85  // TODO(kele): Add tests including interceptors once we have
    86  // https://github.com/google/go-safeweb/issues/261.