
     1  // Copyright 2013 The Gorilla Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     5  package handlers
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"compress/gzip"
    11  	http ""
    12  	""
    13  	"io"
    14  	"io/ioutil"
    15  	"net"
    16  	"net/url"
    17  	"os"
    18  	"path/filepath"
    19  	"strconv"
    20  	"testing"
    21  )
    23  var contentType = "text/plain; charset=utf-8"
    25  func compressedRequest(w *httptest.ResponseRecorder, compression string) {
    26  	CompressHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    27  		w.Header().Set("Content-Length", strconv.Itoa(9*1024))
    28  		w.Header().Set("Content-Type", contentType)
    29  		for i := 0; i < 1024; i++ {
    30  			io.WriteString(w, "Gorilla!\n")
    31  		}
    32  	})).ServeHTTP(w, &http.Request{
    33  		Method: "GET",
    34  		Header: http.Header{
    35  			acceptEncoding: []string{compression},
    36  		},
    37  	})
    38  }
    40  func TestCompressHandlerNoCompression(t *testing.T) {
    41  	w := httptest.NewRecorder()
    42  	compressedRequest(w, "")
    43  	if enc := w.HeaderMap.Get("Content-Encoding"); enc != "" {
    44  		t.Errorf("wrong content encoding, got %q want %q", enc, "")
    45  	}
    46  	if ct := w.HeaderMap.Get("Content-Type"); ct != contentType {
    47  		t.Errorf("wrong content type, got %q want %q", ct, contentType)
    48  	}
    49  	if w.Body.Len() != 1024*9 {
    50  		t.Errorf("wrong len, got %d want %d", w.Body.Len(), 1024*9)
    51  	}
    52  	if l := w.HeaderMap.Get("Content-Length"); l != "9216" {
    53  		t.Errorf("wrong content-length. got %q expected %d", l, 1024*9)
    54  	}
    55  	if v := w.HeaderMap.Get("Vary"); v != acceptEncoding {
    56  		t.Errorf("wrong vary. got %s expected %s", v, acceptEncoding)
    57  	}
    58  }
    60  func TestAcceptEncodingIsDropped(t *testing.T) {
    61  	tCases := []struct {
    62  		name,
    63  		compression,
    64  		expect string
    65  		isPresent bool
    66  	}{
    67  		{
    68  			"accept-encoding-gzip",
    69  			"gzip",
    70  			"",
    71  			false,
    72  		},
    73  		{
    74  			"accept-encoding-deflate",
    75  			"deflate",
    76  			"",
    77  			false,
    78  		},
    79  		{
    80  			"accept-encoding-gzip,deflate",
    81  			"gzip,deflate",
    82  			"",
    83  			false,
    84  		},
    85  		{
    86  			"accept-encoding-gzip,deflate,something",
    87  			"gzip,deflate,something",
    88  			"",
    89  			false,
    90  		},
    91  		{
    92  			"accept-encoding-unknown",
    93  			"unknown",
    94  			"unknown",
    95  			true,
    96  		},
    97  	}
    99  	for _, tCase := range tCases {
   100  		ch := CompressHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   101  			acceptEnc := r.Header.Get(acceptEncoding)
   102  			if acceptEnc == "" && tCase.isPresent {
   103  				t.Fatalf("%s: expected 'Accept-Encoding' header to be present but was not",
   104  			}
   105  			if acceptEnc != "" {
   106  				if !tCase.isPresent {
   107  					t.Fatalf("%s: expected 'Accept-Encoding' header to be dropped but was still present having value %q",, acceptEnc)
   108  				}
   109  				if acceptEnc != tCase.expect {
   110  					t.Fatalf("%s: expected 'Accept-Encoding' to be %q but was %q",, tCase.expect, acceptEnc)
   111  				}
   112  			}
   113  		}))
   115  		w := httptest.NewRecorder()
   116  		ch.ServeHTTP(w, &http.Request{
   117  			Method: "GET",
   118  			Header: http.Header{
   119  				acceptEncoding: []string{tCase.compression},
   120  			},
   121  		})
   122  	}
   123  }
   125  func TestCompressHandlerGzip(t *testing.T) {
   126  	w := httptest.NewRecorder()
   127  	compressedRequest(w, "gzip")
   128  	if w.HeaderMap.Get("Content-Encoding") != "gzip" {
   129  		t.Errorf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "gzip")
   130  	}
   131  	if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" {
   132  		t.Errorf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8")
   133  	}
   134  	if w.Body.Len() != 72 {
   135  		t.Errorf("wrong len, got %d want %d", w.Body.Len(), 72)
   136  	}
   137  	if l := w.HeaderMap.Get("Content-Length"); l != "" {
   138  		t.Errorf("wrong content-length. got %q expected %q", l, "")
   139  	}
   140  }
   142  func TestCompressHandlerDeflate(t *testing.T) {
   143  	w := httptest.NewRecorder()
   144  	compressedRequest(w, "deflate")
   145  	if w.HeaderMap.Get("Content-Encoding") != "deflate" {
   146  		t.Fatalf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "deflate")
   147  	}
   148  	if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" {
   149  		t.Fatalf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8")
   150  	}
   151  	if w.Body.Len() != 54 {
   152  		t.Fatalf("wrong len, got %d want %d", w.Body.Len(), 54)
   153  	}
   154  }
   156  func TestCompressHandlerGzipDeflate(t *testing.T) {
   157  	w := httptest.NewRecorder()
   158  	compressedRequest(w, "gzip, deflate ")
   159  	if w.HeaderMap.Get("Content-Encoding") != "gzip" {
   160  		t.Fatalf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "gzip")
   161  	}
   162  	if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" {
   163  		t.Fatalf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8")
   164  	}
   165  }
   167  // Make sure we can compress and serve an *os.File properly. We need
   168  // to use a real http server to trigger the net/http sendfile special
   169  // case.
   170  func TestCompressFile(t *testing.T) {
   171  	dir, err := ioutil.TempDir("", "gorilla_compress")
   172  	if err != nil {
   173  		t.Fatal(err)
   174  	}
   175  	defer os.RemoveAll(dir)
   177  	err = ioutil.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello"), 0644)
   178  	if err != nil {
   179  		t.Fatal(err)
   180  	}
   182  	s := httptest.NewServer(CompressHandler(http.FileServer(http.Dir(dir))))
   183  	defer s.Close()
   185  	url := &url.URL{Scheme: "http", Host: s.Listener.Addr().String(), Path: "/hello.txt"}
   186  	req, err := http.NewRequest("GET", url.String(), nil)
   187  	if err != nil {
   188  		t.Fatal(err)
   189  	}
   190  	req.Header.Set(acceptEncoding, "gzip")
   191  	res, err := http.DefaultClient.Do(req)
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   196  	if res.StatusCode != http.StatusOK {
   197  		t.Fatalf("expected OK, got %q", res.Status)
   198  	}
   200  	var got bytes.Buffer
   201  	gr, err := gzip.NewReader(res.Body)
   202  	if err != nil {
   203  		t.Fatal(err)
   204  	}
   205  	_, err = io.Copy(&got, gr)
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   210  	if got.String() != "hello" {
   211  		t.Errorf("expected hello, got %q", got.String())
   212  	}
   213  }
   215  type fullyFeaturedResponseWriter struct{}
   217  // Header/Write/WriteHeader implement the http.ResponseWriter interface.
   218  func (fullyFeaturedResponseWriter) Header() http.Header {
   219  	return http.Header{}
   220  }
   222  func (fullyFeaturedResponseWriter) Write([]byte) (int, error) {
   223  	return 0, nil
   224  }
   225  func (fullyFeaturedResponseWriter) WriteHeader(int) {}
   227  // Flush implements the http.Flusher interface.
   228  func (fullyFeaturedResponseWriter) Flush() {}
   230  // Hijack implements the http.Hijacker interface.
   231  func (fullyFeaturedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   232  	return nil, nil, nil
   233  }
   235  // CloseNotify implements the http.CloseNotifier interface.
   236  func (fullyFeaturedResponseWriter) CloseNotify() <-chan bool {
   237  	return nil
   238  }
   240  func TestCompressHandlerPreserveInterfaces(t *testing.T) {
   241  	// Compile time validation fullyFeaturedResponseWriter implements all the
   242  	// interfaces we're asserting in the test case below.
   243  	var (
   244  		_ http.Flusher       = fullyFeaturedResponseWriter{}
   245  		_ http.CloseNotifier = fullyFeaturedResponseWriter{}
   246  		_ http.Hijacker      = fullyFeaturedResponseWriter{}
   247  	)
   248  	var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   249  		comp := r.Header.Get(acceptEncoding)
   250  		if _, ok := rw.(http.Flusher); !ok {
   251  			t.Errorf("ResponseWriter lost http.Flusher interface for %q", comp)
   252  		}
   253  		if _, ok := rw.(http.CloseNotifier); !ok {
   254  			t.Errorf("ResponseWriter lost http.CloseNotifier interface for %q", comp)
   255  		}
   256  		if _, ok := rw.(http.Hijacker); !ok {
   257  			t.Errorf("ResponseWriter lost http.Hijacker interface for %q", comp)
   258  		}
   259  	})
   260  	h = CompressHandler(h)
   261  	var rw fullyFeaturedResponseWriter
   262  	r, err := http.NewRequest("GET", "/", nil)
   263  	if err != nil {
   264  		t.Fatalf("Failed to create test request: %v", err)
   265  	}
   266  	r.Header.Set(acceptEncoding, "gzip")
   267  	h.ServeHTTP(rw, r)
   269  	r.Header.Set(acceptEncoding, "deflate")
   270  	h.ServeHTTP(rw, r)
   271  }
   273  type paltryResponseWriter struct{}
   275  func (paltryResponseWriter) Header() http.Header {
   276  	return http.Header{}
   277  }
   279  func (paltryResponseWriter) Write([]byte) (int, error) {
   280  	return 0, nil
   281  }
   282  func (paltryResponseWriter) WriteHeader(int) {}
   284  func TestCompressHandlerDoesntInventInterfaces(t *testing.T) {
   285  	var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   286  		if _, ok := rw.(http.Hijacker); ok {
   287  			t.Error("ResponseWriter shouldn't implement http.Hijacker")
   288  		}
   289  	})
   291  	h = CompressHandler(h)
   293  	var rw paltryResponseWriter
   294  	r, err := http.NewRequest("GET", "/", nil)
   295  	if err != nil {
   296  		t.Fatalf("Failed to create test request: %v", err)
   297  	}
   298  	r.Header.Set(acceptEncoding, "gzip")
   299  	h.ServeHTTP(rw, r)
   300  }