github.com/milanaleksic/devd@v1.0.4/fileserver/fileserver_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fileserver
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"io"
    11  	"io/ioutil"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"os"
    16  	"path"
    17  	"path/filepath"
    18  	"reflect"
    19  	"runtime"
    20  	"strconv"
    21  	"strings"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	rice "github.com/GeertJohan/go.rice"
    27  	"github.com/milanaleksic/devd/inject"
    28  	"github.com/milanaleksic/devd/ricetemp"
    29  	"github.com/milanaleksic/devd/routespec"
    30  	"github.com/cortesi/termlog"
    31  )
    32  
    33  // ServeFile replies to the request with the contents of the named file or directory.
    34  func ServeFile(w http.ResponseWriter, r *http.Request, name string) {
    35  	dir, file := filepath.Split(name)
    36  	logger := termlog.NewLog()
    37  	logger.Quiet()
    38  
    39  	fs := FileServer{
    40  		"version",
    41  		http.Dir(dir),
    42  		inject.CopyInject{},
    43  		ricetemp.MustMakeTemplates(rice.MustFindBox("../templates")),
    44  		[]routespec.RouteSpec{},
    45  		"",
    46  	}
    47  	fs.serveFile(logger, w, r, file, false)
    48  }
    49  
    50  func ServeContent(w http.ResponseWriter, req *http.Request, name string, modtime time.Time, content io.ReadSeeker) error {
    51  	sizeFunc := func() (int64, error) {
    52  		size, err := content.Seek(0, os.SEEK_END)
    53  		if err != nil {
    54  			return 0, errSeeker
    55  		}
    56  		_, err = content.Seek(0, os.SEEK_SET)
    57  		if err != nil {
    58  			return 0, errSeeker
    59  		}
    60  		return size, nil
    61  	}
    62  	return serveContent(inject.CopyInject{}, w, req, name, modtime, sizeFunc, content)
    63  }
    64  
    65  const (
    66  	testFile    = "testdata/file"
    67  	testFileLen = 11
    68  )
    69  
    70  type wantRange struct {
    71  	start, end int64 // range [start,end)
    72  }
    73  
    74  var itoa = strconv.Itoa
    75  
    76  var notFoundSearchPathsSpecs = []struct {
    77  	path   string
    78  	spec   string
    79  	result []string
    80  }{
    81  	{"/index.html", "/foo.html", []string{"/foo.html"}},
    82  	{"/dir/index.html", "/", []string{"/"}},
    83  	{"/dir/index.html", "foo.html", []string{"/dir/foo.html", "/foo.html"}},
    84  	{"/", "foo.html", []string{"/foo.html"}},
    85  	{"/", "../../foo.html", []string{"/foo.html"}},
    86  	{"/", "/../../foo.html", []string{"/foo.html"}},
    87  }
    88  
    89  func TestNotFoundSearchPaths(t *testing.T) {
    90  	for _, tt := range notFoundSearchPathsSpecs {
    91  		paths := notFoundSearchPaths(tt.path, tt.spec)
    92  		if !reflect.DeepEqual(paths, tt.result) {
    93  			t.Errorf("Wanted %#v, got %#v", tt.result, paths)
    94  		}
    95  	}
    96  }
    97  
    98  var matchTypesSpecs = []struct {
    99  	spec   string
   100  	path   string
   101  	result bool
   102  }{
   103  	{"/index.html", "/foo.png", false},
   104  	{"/index.html", "/foo.html", true},
   105  	{"/index/", "/foo.html", true},
   106  	{"/index", "/foo.html", true},
   107  	{"/index.unknown", "/foo.unknown", true},
   108  	{"/index.html", "/foo/", true},
   109  	{"/index.html", "/foo/bar.htm", true},
   110  	{"/index", "/foo/bar.html", true},
   111  	{"/index", "/foo/bar.htm", true},
   112  	{"/index", "/foo", true},
   113  	{"/usr/bob.foo", "/foo", true},
   114  }
   115  
   116  func TestMatchTypes(t *testing.T) {
   117  	for _, tt := range matchTypesSpecs {
   118  		m := matchTypes(tt.spec, tt.path)
   119  		if m != tt.result {
   120  			t.Errorf("Wanted %#v, got %#v", tt.result, m)
   121  		}
   122  	}
   123  }
   124  
   125  func TestServeFile(t *testing.T) {
   126  	defer afterTest(t)
   127  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   128  		ServeFile(w, r, "testdata/file")
   129  	}))
   130  	defer ts.Close()
   131  
   132  	var err error
   133  
   134  	file, err := ioutil.ReadFile(testFile)
   135  	if err != nil {
   136  		t.Fatal("reading file:", err)
   137  	}
   138  
   139  	// set up the Request (re-used for all tests)
   140  	var req http.Request
   141  	req.Header = make(http.Header)
   142  	if req.URL, err = url.Parse(ts.URL); err != nil {
   143  		t.Fatal("ParseURL:", err)
   144  	}
   145  	req.Method = "GET"
   146  
   147  	// straight GET
   148  	_, body := getBody(t, "straight get", req)
   149  	if !bytes.Equal(body, file) {
   150  		t.Fatalf("body mismatch: got %q, want %q", body, file)
   151  	}
   152  }
   153  
   154  var fsRedirectTestData = []struct {
   155  	original, redirect string
   156  }{
   157  	{"/test/index.html", "/test/"},
   158  	{"/test/testdata", "/test/testdata/"},
   159  	{"/test/testdata/file/", "/test/testdata/file"},
   160  }
   161  
   162  func TestFSRedirect(t *testing.T) {
   163  	defer afterTest(t)
   164  	ts := httptest.NewServer(
   165  		http.StripPrefix(
   166  			"/test",
   167  			&FileServer{
   168  				"version",
   169  				http.Dir("."),
   170  				inject.CopyInject{},
   171  				ricetemp.MustMakeTemplates(rice.MustFindBox("../templates")),
   172  				[]routespec.RouteSpec{},
   173  				"",
   174  			},
   175  		),
   176  	)
   177  	defer ts.Close()
   178  
   179  	for _, data := range fsRedirectTestData {
   180  		res, err := http.Get(ts.URL + data.original)
   181  		if err != nil {
   182  			t.Fatal(err)
   183  		}
   184  		_ = res.Body.Close()
   185  		if g, e := res.Request.URL.Path, data.redirect; g != e {
   186  			t.Errorf("redirect from %s: got %s, want %s", data.original, g, e)
   187  		}
   188  	}
   189  }
   190  
   191  type testFileSystem struct {
   192  	open func(name string) (http.File, error)
   193  }
   194  
   195  func (fs *testFileSystem) Open(name string) (http.File, error) {
   196  	return fs.open(name)
   197  }
   198  
   199  func _TestFileServerCleans(t *testing.T) {
   200  	defer afterTest(t)
   201  	ch := make(chan string, 1)
   202  	fs := &FileServer{
   203  		"version",
   204  		&testFileSystem{
   205  			func(name string) (http.File, error) {
   206  				ch <- name
   207  				return nil, errors.New("file does not exist")
   208  			},
   209  		},
   210  		inject.CopyInject{},
   211  		ricetemp.MustMakeTemplates(rice.MustFindBox("../templates")),
   212  		[]routespec.RouteSpec{},
   213  		"",
   214  	}
   215  	tests := []struct {
   216  		reqPath, openArg string
   217  	}{
   218  		{"/foo.txt", "/foo.txt"},
   219  		{"/../foo.txt", "/foo.txt"},
   220  	}
   221  	req, _ := http.NewRequest("GET", "http://example.com", nil)
   222  	for n, test := range tests {
   223  		rec := httptest.NewRecorder()
   224  		req.URL.Path = test.reqPath
   225  		fs.ServeHTTP(rec, req)
   226  		if got := <-ch; got != test.openArg {
   227  			t.Errorf("test %d: got %q, want %q", n, got, test.openArg)
   228  		}
   229  	}
   230  }
   231  
   232  func mustRemoveAll(dir string) {
   233  	err := os.RemoveAll(dir)
   234  	if err != nil {
   235  		panic(err)
   236  	}
   237  }
   238  
   239  func TestFileServerImplicitLeadingSlash(t *testing.T) {
   240  	defer afterTest(t)
   241  	tempDir, err := ioutil.TempDir("", "")
   242  	if err != nil {
   243  		t.Fatalf("TempDir: %v", err)
   244  	}
   245  	defer mustRemoveAll(tempDir)
   246  	if err := ioutil.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil {
   247  		t.Fatalf("WriteFile: %v", err)
   248  	}
   249  	fs := &FileServer{
   250  		"version",
   251  		http.Dir(tempDir),
   252  		inject.CopyInject{},
   253  		ricetemp.MustMakeTemplates(rice.MustFindBox("../templates")),
   254  		[]routespec.RouteSpec{},
   255  		"",
   256  	}
   257  
   258  	ts := httptest.NewServer(http.StripPrefix("/bar/", fs))
   259  	defer ts.Close()
   260  	get := func(suffix string) string {
   261  		res, err := http.Get(ts.URL + suffix)
   262  		if err != nil {
   263  			t.Fatalf("Get %s: %v", suffix, err)
   264  		}
   265  		b, err := ioutil.ReadAll(res.Body)
   266  		if err != nil {
   267  			t.Fatalf("ReadAll %s: %v", suffix, err)
   268  		}
   269  		_ = res.Body.Close()
   270  		return string(b)
   271  	}
   272  	if s := get("/bar/"); !strings.Contains(s, ">foo.txt<") {
   273  		t.Logf("expected a directory listing with foo.txt, got %q", s)
   274  	}
   275  	if s := get("/bar/foo.txt"); s != "Hello world" {
   276  		t.Logf("expected %q, got %q", "Hello world", s)
   277  	}
   278  }
   279  
   280  func TestDirJoin(t *testing.T) {
   281  	if runtime.GOOS == "windows" {
   282  		t.Skip("skipping test on windows")
   283  	}
   284  	wfi, err := os.Stat("/etc/hosts")
   285  	if err != nil {
   286  		t.Skip("skipping test; no /etc/hosts file")
   287  	}
   288  	test := func(d http.Dir, name string) {
   289  		f, err := d.Open(name)
   290  		if err != nil {
   291  			t.Fatalf("open of %s: %v", name, err)
   292  		}
   293  		defer func() { _ = f.Close() }()
   294  		gfi, err := f.Stat()
   295  		if err != nil {
   296  			t.Fatalf("stat of %s: %v", name, err)
   297  		}
   298  		if !os.SameFile(gfi, wfi) {
   299  			t.Errorf("%s got different file", name)
   300  		}
   301  	}
   302  	test(http.Dir("/etc/"), "/hosts")
   303  	test(http.Dir("/etc/"), "hosts")
   304  	test(http.Dir("/etc/"), "../../../../hosts")
   305  	test(http.Dir("/etc"), "/hosts")
   306  	test(http.Dir("/etc"), "hosts")
   307  	test(http.Dir("/etc"), "../../../../hosts")
   308  
   309  	// Not really directories, but since we use this trick in
   310  	// ServeFile, test it:
   311  	test(http.Dir("/etc/hosts"), "")
   312  	test(http.Dir("/etc/hosts"), "/")
   313  	test(http.Dir("/etc/hosts"), "../")
   314  }
   315  
   316  func TestEmptyDirOpenCWD(t *testing.T) {
   317  	test := func(d http.Dir) {
   318  		name := "fileserver_test.go"
   319  		f, err := d.Open(name)
   320  		if err != nil {
   321  			t.Fatalf("open of %s: %v", name, err)
   322  		}
   323  		defer func() { _ = f.Close() }()
   324  	}
   325  	test(http.Dir(""))
   326  	test(http.Dir("."))
   327  	test(http.Dir("./"))
   328  }
   329  
   330  func TestServeFileContentType(t *testing.T) {
   331  	defer afterTest(t)
   332  	const ctype = "icecream/chocolate"
   333  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   334  		switch r.FormValue("override") {
   335  		case "1":
   336  			w.Header().Set("Content-Type", ctype)
   337  		case "2":
   338  			// Explicitly inhibit sniffing.
   339  			w.Header()["Content-Type"] = []string{}
   340  		}
   341  		ServeFile(w, r, "testdata/file")
   342  	}))
   343  	defer ts.Close()
   344  	get := func(override string, want []string) {
   345  		resp, err := http.Get(ts.URL + "?override=" + override)
   346  		if err != nil {
   347  			t.Fatal(err)
   348  		}
   349  		if h := resp.Header["Content-Type"]; !reflect.DeepEqual(h, want) {
   350  			t.Errorf("Content-Type mismatch: got %v, want %v", h, want)
   351  		}
   352  		_ = resp.Body.Close()
   353  	}
   354  	get("0", []string{"text/plain; charset=utf-8"})
   355  	get("1", []string{ctype})
   356  	get("2", nil)
   357  }
   358  
   359  func TestServeFileMimeType(t *testing.T) {
   360  	defer afterTest(t)
   361  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   362  		ServeFile(w, r, "testdata/style.css")
   363  	}))
   364  	defer ts.Close()
   365  	resp, err := http.Get(ts.URL)
   366  	if err != nil {
   367  		t.Fatal(err)
   368  	}
   369  	_ = resp.Body.Close()
   370  	want := "text/css; charset=utf-8"
   371  	if h := resp.Header.Get("Content-Type"); h != want {
   372  		t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
   373  	}
   374  }
   375  
   376  func TestServeFileFromCWD(t *testing.T) {
   377  	defer afterTest(t)
   378  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   379  		ServeFile(w, r, "fileserver_test.go")
   380  	}))
   381  	defer ts.Close()
   382  	r, err := http.Get(ts.URL)
   383  	if err != nil {
   384  		t.Fatal(err)
   385  	}
   386  	_ = r.Body.Close()
   387  	if r.StatusCode != 200 {
   388  		t.Fatalf("expected 200 OK, got %s", r.Status)
   389  	}
   390  }
   391  
   392  func TestServeFileWithContentEncoding(t *testing.T) {
   393  	defer afterTest(t)
   394  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   395  		w.Header().Set("Content-Encoding", "foo")
   396  		ServeFile(w, r, "testdata/file")
   397  	}))
   398  	defer ts.Close()
   399  	resp, err := http.Get(ts.URL)
   400  	if err != nil {
   401  		t.Fatal(err)
   402  	}
   403  	_ = resp.Body.Close()
   404  	if g, e := resp.ContentLength, int64(-1); g != e {
   405  		t.Errorf("Content-Length mismatch: got %d, want %d", g, e)
   406  	}
   407  }
   408  
   409  func TestServeIndexHtml(t *testing.T) {
   410  	defer afterTest(t)
   411  	const want = "index.html says hello"
   412  
   413  	fs := &FileServer{
   414  		"version",
   415  		http.Dir("."),
   416  		inject.CopyInject{},
   417  		ricetemp.MustMakeTemplates(rice.MustFindBox("../templates")),
   418  		[]routespec.RouteSpec{},
   419  		"",
   420  	}
   421  	ts := httptest.NewServer(fs)
   422  	defer ts.Close()
   423  
   424  	for _, path := range []string{"/testdata/", "/testdata/index.html"} {
   425  		res, err := http.Get(ts.URL + path)
   426  		if err != nil {
   427  			t.Fatal(err)
   428  		}
   429  		b, err := ioutil.ReadAll(res.Body)
   430  		if err != nil {
   431  			t.Fatal("reading Body:", err)
   432  		}
   433  		if s := strings.TrimSpace(string(b)); s != want {
   434  			t.Errorf("for path %q got %q, want %q", path, s, want)
   435  		}
   436  		_ = res.Body.Close()
   437  	}
   438  }
   439  
   440  func TestFileServerZeroByte(t *testing.T) {
   441  	defer afterTest(t)
   442  	fs := &FileServer{
   443  		"version",
   444  		http.Dir("."),
   445  		inject.CopyInject{},
   446  		ricetemp.MustMakeTemplates(rice.MustFindBox("../templates")),
   447  		[]routespec.RouteSpec{},
   448  		"",
   449  	}
   450  	ts := httptest.NewServer(fs)
   451  	defer ts.Close()
   452  
   453  	res, err := http.Get(ts.URL + "/" + url.PathEscape("..\x00"))
   454  	if err != nil {
   455  		t.Fatal(err)
   456  	}
   457  	b, err := ioutil.ReadAll(res.Body)
   458  	if err != nil {
   459  		t.Fatal("reading Body:", err)
   460  	}
   461  	if res.StatusCode == 200 {
   462  		t.Errorf("got status 200; want an error. Body is:\n%s", string(b))
   463  	}
   464  }
   465  
   466  type fakeFileInfo struct {
   467  	dir      bool
   468  	basename string
   469  	modtime  time.Time
   470  	ents     []*fakeFileInfo
   471  	contents string
   472  }
   473  
   474  func (f *fakeFileInfo) Name() string       { return f.basename }
   475  func (f *fakeFileInfo) Sys() interface{}   { return nil }
   476  func (f *fakeFileInfo) ModTime() time.Time { return f.modtime }
   477  func (f *fakeFileInfo) IsDir() bool        { return f.dir }
   478  func (f *fakeFileInfo) Size() int64        { return int64(len(f.contents)) }
   479  func (f *fakeFileInfo) Mode() os.FileMode {
   480  	if f.dir {
   481  		return 0755 | os.ModeDir
   482  	}
   483  	return 0644
   484  }
   485  
   486  type fakeFile struct {
   487  	io.ReadSeeker
   488  	fi   *fakeFileInfo
   489  	path string // as opened
   490  }
   491  
   492  func (f *fakeFile) Close() error               { return nil }
   493  func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil }
   494  func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
   495  	if !f.fi.dir {
   496  		return nil, os.ErrInvalid
   497  	}
   498  	var fis []os.FileInfo
   499  	for _, fi := range f.fi.ents {
   500  		fis = append(fis, fi)
   501  	}
   502  	return fis, nil
   503  }
   504  
   505  type fakeFS map[string]*fakeFileInfo
   506  
   507  func (fs fakeFS) Open(name string) (http.File, error) {
   508  	name = path.Clean(name)
   509  	f, ok := fs[name]
   510  	if !ok {
   511  		return nil, os.ErrNotExist
   512  	}
   513  	return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil
   514  }
   515  
   516  func TestNotFoundOverride(t *testing.T) {
   517  	defer afterTest(t)
   518  	ffile := &fakeFileInfo{
   519  		basename: "foo.html",
   520  		modtime:  time.Unix(1000000000, 0).UTC(),
   521  		contents: "I am a fake file",
   522  	}
   523  	fsys := fakeFS{
   524  		"/": &fakeFileInfo{
   525  			dir:     true,
   526  			modtime: time.Unix(123, 0).UTC(),
   527  			ents:    []*fakeFileInfo{},
   528  		},
   529  		"/one": &fakeFileInfo{
   530  			dir:     true,
   531  			modtime: time.Unix(123, 0).UTC(),
   532  			ents:    []*fakeFileInfo{ffile},
   533  		},
   534  		"/one/foo.html": ffile,
   535  	}
   536  
   537  	fs := &FileServer{
   538  		"version",
   539  		fsys,
   540  		inject.CopyInject{},
   541  		ricetemp.MustMakeTemplates(rice.MustFindBox("../templates")),
   542  		[]routespec.RouteSpec{
   543  			{Host: "", Path: "/", Value: "foo.html"},
   544  		},
   545  		"",
   546  	}
   547  
   548  	ts := httptest.NewServer(fs)
   549  	defer ts.Close()
   550  
   551  	res, err := http.Get(ts.URL + "/one/nonexistent.html")
   552  	if err != nil {
   553  		t.Fatal(err)
   554  	}
   555  	_ = res.Body.Close()
   556  	if res.StatusCode != 200 {
   557  		t.Error("Expected to find over-ride file.")
   558  	}
   559  
   560  	res, err = http.Get(ts.URL + "/one/two/nonexistent.html")
   561  	if err != nil {
   562  		t.Fatal(err)
   563  	}
   564  	_ = res.Body.Close()
   565  	if res.StatusCode != 200 {
   566  		t.Error("Expected to find over-ride file.")
   567  	}
   568  
   569  	res, err = http.Get(ts.URL + "/nonexistent.html")
   570  	if err != nil {
   571  		t.Fatal(err)
   572  	}
   573  	_ = res.Body.Close()
   574  	if res.StatusCode != 404 {
   575  		t.Error("Expected to find over-ride file.")
   576  	}
   577  
   578  	res, err = http.Get(ts.URL + "/two/nonexistent.html")
   579  	if err != nil {
   580  		t.Fatal(err)
   581  	}
   582  	_ = res.Body.Close()
   583  	if res.StatusCode != 404 {
   584  		t.Error("Expected to find over-ride file.")
   585  	}
   586  
   587  }
   588  
   589  func TestDirectoryIfNotModified(t *testing.T) {
   590  	defer afterTest(t)
   591  	const indexContents = "I am a fake index.html file"
   592  	fileMod := time.Unix(1000000000, 0).UTC()
   593  	fileModStr := fileMod.Format(http.TimeFormat)
   594  	dirMod := time.Unix(123, 0).UTC()
   595  	indexFile := &fakeFileInfo{
   596  		basename: "index.html",
   597  		modtime:  fileMod,
   598  		contents: indexContents,
   599  	}
   600  	fsys := fakeFS{
   601  		"/": &fakeFileInfo{
   602  			dir:     true,
   603  			modtime: dirMod,
   604  			ents:    []*fakeFileInfo{indexFile},
   605  		},
   606  		"/index.html": indexFile,
   607  	}
   608  
   609  	fs := &FileServer{
   610  		"version",
   611  		fsys,
   612  		inject.CopyInject{},
   613  		ricetemp.MustMakeTemplates(rice.MustFindBox("../templates")),
   614  		[]routespec.RouteSpec{},
   615  		"",
   616  	}
   617  
   618  	ts := httptest.NewServer(fs)
   619  	defer ts.Close()
   620  
   621  	res, err := http.Get(ts.URL)
   622  	if err != nil {
   623  		t.Fatal(err)
   624  	}
   625  	b, err := ioutil.ReadAll(res.Body)
   626  	if err != nil {
   627  		t.Fatal(err)
   628  	}
   629  	if string(b) != indexContents {
   630  		t.Fatalf("Got body %q; want %q", b, indexContents)
   631  	}
   632  	_ = res.Body.Close()
   633  
   634  	lastMod := res.Header.Get("Last-Modified")
   635  	if lastMod != fileModStr {
   636  		t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr)
   637  	}
   638  
   639  	req, _ := http.NewRequest("GET", ts.URL, nil)
   640  	req.Header.Set("If-Modified-Since", lastMod)
   641  
   642  	res, err = http.DefaultClient.Do(req)
   643  	if err != nil {
   644  		t.Fatal(err)
   645  	}
   646  	if res.StatusCode != 304 {
   647  		t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode)
   648  	}
   649  	_ = res.Body.Close()
   650  
   651  	// Advance the index.html file's modtime, but not the directory's.
   652  	indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
   653  
   654  	res, err = http.DefaultClient.Do(req)
   655  	if err != nil {
   656  		t.Fatal(err)
   657  	}
   658  	if res.StatusCode != 200 {
   659  		t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res)
   660  	}
   661  	_ = res.Body.Close()
   662  }
   663  
   664  func mustStat(t *testing.T, fileName string) os.FileInfo {
   665  	fi, err := os.Stat(fileName)
   666  	if err != nil {
   667  		t.Fatal(err)
   668  	}
   669  	return fi
   670  }
   671  
   672  func TestServeContent(t *testing.T) {
   673  	defer afterTest(t)
   674  	type serveParam struct {
   675  		name        string
   676  		modtime     time.Time
   677  		content     io.ReadSeeker
   678  		contentType string
   679  		etag        string
   680  	}
   681  	servec := make(chan serveParam, 1)
   682  	lock := sync.Mutex{}
   683  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   684  		p := <-servec
   685  		if p.etag != "" {
   686  			w.Header().Set("ETag", p.etag)
   687  		}
   688  		if p.contentType != "" {
   689  			w.Header().Set("Content-Type", p.contentType)
   690  		}
   691  		lock.Lock()
   692  		defer lock.Unlock()
   693  		err := ServeContent(w, r, p.name, p.modtime, p.content)
   694  		if err != nil {
   695  			t.Fail()
   696  		}
   697  	}))
   698  	defer ts.Close()
   699  
   700  	type testCase struct {
   701  		// One of file or content must be set:
   702  		file    string
   703  		content io.ReadSeeker
   704  
   705  		modtime          time.Time
   706  		serveETag        string // optional
   707  		serveContentType string // optional
   708  		reqHeader        map[string]string
   709  		wantLastMod      string
   710  		wantContentType  string
   711  		wantStatus       int
   712  	}
   713  	htmlModTime := mustStat(t, "testdata/index.html").ModTime()
   714  	tests := map[string]testCase{
   715  		"no_last_modified": {
   716  			file:            "testdata/style.css",
   717  			wantContentType: "text/css; charset=utf-8",
   718  			wantStatus:      200,
   719  		},
   720  		"with_last_modified": {
   721  			file:            "testdata/index.html",
   722  			wantContentType: "text/html; charset=utf-8",
   723  			modtime:         htmlModTime,
   724  			wantLastMod:     htmlModTime.UTC().Format(http.TimeFormat),
   725  			wantStatus:      200,
   726  		},
   727  		"not_modified_modtime": {
   728  			file:    "testdata/style.css",
   729  			modtime: htmlModTime,
   730  			reqHeader: map[string]string{
   731  				"If-Modified-Since": htmlModTime.UTC().Format(http.TimeFormat),
   732  			},
   733  			wantStatus: 304,
   734  		},
   735  		"not_modified_modtime_with_contenttype": {
   736  			file:             "testdata/style.css",
   737  			serveContentType: "text/css", // explicit content type
   738  			modtime:          htmlModTime,
   739  			reqHeader: map[string]string{
   740  				"If-Modified-Since": htmlModTime.UTC().Format(http.TimeFormat),
   741  			},
   742  			wantStatus: 304,
   743  		},
   744  		"not_modified_etag": {
   745  			file:      "testdata/style.css",
   746  			serveETag: `"foo"`,
   747  			reqHeader: map[string]string{
   748  				"If-None-Match": `"foo"`,
   749  			},
   750  			wantStatus: 304,
   751  		},
   752  		"not_modified_etag_no_seek": {
   753  			content:   panicOnSeek{nil}, // should never be called
   754  			serveETag: `"foo"`,
   755  			reqHeader: map[string]string{
   756  				"If-None-Match": `"foo"`,
   757  			},
   758  			wantStatus: 304,
   759  		},
   760  		// An If-Range resource for entity "A", but entity "B" is now current.
   761  		// The Range request should be ignored.
   762  		"range_no_match": {
   763  			file:      "testdata/style.css",
   764  			serveETag: `"A"`,
   765  			reqHeader: map[string]string{
   766  				"Range":    "bytes=0-4",
   767  				"If-Range": `"B"`,
   768  			},
   769  			wantStatus:      200,
   770  			wantContentType: "text/css; charset=utf-8",
   771  		},
   772  	}
   773  	for testName, tt := range tests {
   774  		var content io.ReadSeeker
   775  		if tt.file != "" {
   776  			f, err := os.Open(tt.file)
   777  			if err != nil {
   778  				t.Fatalf("test %q: %v", testName, err)
   779  			}
   780  			defer func() {
   781  				lock.Lock()
   782  				defer lock.Unlock()
   783  				_ = f.Close()
   784  			}()
   785  			content = f
   786  		} else {
   787  			content = tt.content
   788  		}
   789  
   790  		servec <- serveParam{
   791  			name:        filepath.Base(tt.file),
   792  			content:     content,
   793  			modtime:     tt.modtime,
   794  			etag:        tt.serveETag,
   795  			contentType: tt.serveContentType,
   796  		}
   797  		req, err := http.NewRequest("GET", ts.URL, nil)
   798  		if err != nil {
   799  			t.Fatal(err)
   800  		}
   801  		for k, v := range tt.reqHeader {
   802  			req.Header.Set(k, v)
   803  		}
   804  		res, err := http.DefaultClient.Do(req)
   805  		if err != nil {
   806  			t.Fatal(err)
   807  		}
   808  		_, err = io.Copy(ioutil.Discard, res.Body)
   809  		if err != nil {
   810  			t.Fatal(err)
   811  		}
   812  		_ = res.Body.Close()
   813  		if res.StatusCode != tt.wantStatus {
   814  			t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus)
   815  		}
   816  		if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e {
   817  			t.Errorf("test %q: content-type = %q, want %q", testName, g, e)
   818  		}
   819  		if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e {
   820  			t.Errorf("test %q: last-modified = %q, want %q", testName, g, e)
   821  		}
   822  	}
   823  }
   824  
   825  func getBody(t *testing.T, testName string, req http.Request) (*http.Response, []byte) {
   826  	r, err := http.DefaultClient.Do(&req)
   827  	if err != nil {
   828  		t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err)
   829  	}
   830  	b, err := ioutil.ReadAll(r.Body)
   831  	if err != nil {
   832  		t.Fatalf("%s: for URL %q, reading body: %v", testName, req.URL.String(), err)
   833  	}
   834  	return r, b
   835  }
   836  
   837  type panicOnSeek struct{ io.ReadSeeker }