github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/fileserver_1_16_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build go1.16
    16  // +build go1.16
    17  
    18  package safehttp_test
    19  
    20  import (
    21  	"embed"
    22  	"io"
    23  	"net/http/httptest"
    24  	"testing"
    25  
    26  	"github.com/google/go-cmp/cmp"
    27  	"github.com/google/go-safeweb/safehttp"
    28  )
    29  
    30  //go:embed testdata
    31  var testEmbeddedFS embed.FS
    32  
    33  func TestFileServerEmbed(t *testing.T) {
    34  	wantFile, err := testEmbeddedFS.Open("testdata/embed.html")
    35  	if err != nil {
    36  		t.Fatalf("Could not open embedded test files: %v", err)
    37  	}
    38  	wantFileContent, err := io.ReadAll(wantFile)
    39  	if err != nil {
    40  		t.Fatalf("Could not read embedded test files: %v", err)
    41  	}
    42  
    43  	tests := []struct {
    44  		name     string
    45  		path     string
    46  		wantCode safehttp.StatusCode
    47  		wantCT   string
    48  		wantBody string
    49  	}{
    50  		{
    51  			name:     "missing file",
    52  			path:     "failure",
    53  			wantCode: 404,
    54  			wantCT:   "text/plain; charset=utf-8",
    55  			wantBody: "Not Found\n",
    56  		},
    57  		{
    58  			name:     "embedded file",
    59  			path:     "testdata/embed.html",
    60  			wantCode: 200,
    61  			wantCT:   "text/html; charset=utf-8",
    62  			wantBody: string(wantFileContent),
    63  		},
    64  	}
    65  
    66  	mb := safehttp.NewServeMuxConfig(nil)
    67  	m := mb.Mux()
    68  	m.Handle("/", safehttp.MethodGet, safehttp.FileServerEmbed(testEmbeddedFS))
    69  
    70  	for _, tt := range tests {
    71  		t.Run(tt.name, func(t *testing.T) {
    72  			rr := httptest.NewRecorder()
    73  
    74  			req := httptest.NewRequest(safehttp.MethodGet, "https://test.science/"+tt.path, nil)
    75  			m.ServeHTTP(rr, req)
    76  
    77  			if got, want := rr.Result().StatusCode, tt.wantCode; safehttp.StatusCode(got) != tt.wantCode {
    78  				t.Errorf("status code got: %v want: %v", got, want)
    79  			}
    80  			if got := rr.Header().Get("Content-Type"); tt.wantCT != got {
    81  				t.Errorf("Content-Type: got %q want %q", got, tt.wantCT)
    82  			}
    83  			body, err := io.ReadAll(rr.Result().Body)
    84  			if err != nil {
    85  				t.Errorf("Can't read response body: %v", err)
    86  			}
    87  			if diff := cmp.Diff(tt.wantBody, string(body)); diff != "" {
    88  				t.Errorf("Response body diff (-want,+got): \n%s", diff)
    89  			}
    90  		})
    91  	}
    92  }