github.com/tyler-smith/grab@v2.0.1-0.20190224022517-abcee96e61b1+incompatible/grab_test.go (about)

     1  package grab
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"log"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"os"
    11  	"strconv"
    12  	"testing"
    13  	"time"
    14  )
    15  
    16  // ts is a test HTTP server that serves configurable content for all test
    17  // functions.
    18  //
    19  // The following URL query parameters are supported:
    20  //
    21  // * filename=[string]	return a filename in the Content-Disposition header of
    22  //                      the response
    23  //
    24  // * lastmod=[unix]		  set the Last-Modified header
    25  //
    26  // * nohead						  disabled support for HEAD requests
    27  //
    28  // * range=[bool]				allow byte range requests (default: yes)
    29  //
    30  // * rate=[int]					throttle file transfer to the given limit as
    31  // 							        bytes per second
    32  //
    33  // * size=[int]					return a file of the specified size in bytes
    34  //
    35  // * sleep=[int]				delay the response by the given number of
    36  // 								      milliseconds (before sending headers)
    37  //
    38  // * status=[int]       return the given status code
    39  //
    40  var ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    41  	// set status code
    42  	statusCode := http.StatusOK
    43  	if v := r.URL.Query().Get("status"); v != "" {
    44  		if _, err := fmt.Sscanf(v, "%d", &statusCode); err != nil {
    45  			panic(err)
    46  		}
    47  	}
    48  	if r.Method == "HEAD" {
    49  		if v := r.URL.Query().Get("headStatus"); v != "" {
    50  			if _, err := fmt.Sscanf(v, "%d", &statusCode); err != nil {
    51  				panic(err)
    52  			}
    53  		}
    54  	}
    55  
    56  	// allow HEAD requests?
    57  	if _, ok := r.URL.Query()["nohead"]; ok && r.Method == "HEAD" {
    58  		http.Error(w, "HEAD method not allowed", http.StatusMethodNotAllowed)
    59  		return
    60  	}
    61  
    62  	// compute transfer size from 'size' parameter (default 1Mb)
    63  	size := 1048576
    64  	if sizep := r.URL.Query().Get("size"); sizep != "" {
    65  		if _, err := fmt.Sscanf(sizep, "%d", &size); err != nil {
    66  			panic(err)
    67  		}
    68  	}
    69  
    70  	// support ranged requests (default yes)?
    71  	ranged := true
    72  	if rangep := r.URL.Query().Get("ranged"); rangep != "" {
    73  		if _, err := fmt.Sscanf(rangep, "%t", &ranged); err != nil {
    74  			panic(err)
    75  		}
    76  	}
    77  
    78  	// set filename in headers (default no)?
    79  	if filenamep := r.URL.Query().Get("filename"); filenamep != "" {
    80  		w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=\"%s\"", filenamep))
    81  	}
    82  
    83  	// set Last-Modified header
    84  	if lastmodp := r.URL.Query().Get("lastmod"); lastmodp != "" {
    85  		lastmodi, err := strconv.ParseInt(lastmodp, 10, 64)
    86  		if err != nil {
    87  			panic(err)
    88  		}
    89  		lastmodt := time.Unix(lastmodi, 0).UTC()
    90  		lastmod := lastmodt.Format("Mon, 02 Jan 2006 15:04:05") + " GMT"
    91  		w.Header().Set("Last-Modified", lastmod)
    92  	}
    93  
    94  	// sleep before responding?
    95  	sleep := 0
    96  	if sleepp := r.URL.Query().Get("sleep"); sleepp != "" {
    97  		if _, err := fmt.Sscanf(sleepp, "%d", &sleep); err != nil {
    98  			panic(err)
    99  		}
   100  	}
   101  
   102  	// throttle rate to n bps
   103  	rate := 0 // bps
   104  	var throttle *time.Ticker
   105  	defer func() {
   106  		if throttle != nil {
   107  			throttle.Stop()
   108  		}
   109  	}()
   110  
   111  	if ratep := r.URL.Query().Get("rate"); ratep != "" {
   112  		if _, err := fmt.Sscanf(ratep, "%d", &rate); err != nil {
   113  			panic(err)
   114  		}
   115  
   116  		if rate > 0 {
   117  			throttle = time.NewTicker(time.Second / time.Duration(rate))
   118  		}
   119  	}
   120  
   121  	// compute offset
   122  	offset := 0
   123  	if rangeh := r.Header.Get("Range"); rangeh != "" {
   124  		if _, err := fmt.Sscanf(rangeh, "bytes=%d-", &offset); err != nil {
   125  			panic(err)
   126  		}
   127  
   128  		// make sure range is in range
   129  		if offset >= size {
   130  			w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
   131  			return
   132  		}
   133  	}
   134  
   135  	// delay response
   136  	if sleep > 0 {
   137  		time.Sleep(time.Duration(sleep) * time.Millisecond)
   138  	}
   139  
   140  	// set response headers
   141  	w.Header().Set("Content-Length", fmt.Sprintf("%d", size-offset))
   142  	if ranged {
   143  		w.Header().Set("Accept-Ranges", "bytes")
   144  	}
   145  	w.WriteHeader(statusCode)
   146  
   147  	// serve content body if method == "GET"
   148  	if r.Method == "GET" {
   149  		// use buffered io to reduce overhead on the reader
   150  		bw := bufio.NewWriterSize(w, 4096)
   151  		for i := offset; i < size; i++ {
   152  			bw.Write([]byte{byte(i)})
   153  			if throttle != nil {
   154  				<-throttle.C
   155  			}
   156  		}
   157  		bw.Flush()
   158  	}
   159  }))
   160  
   161  func TestMain(m *testing.M) {
   162  	os.Exit(func() int {
   163  		// clean up test web server
   164  		defer ts.Close()
   165  
   166  		// chdir to temp so test files downloaded to pwd are isolated and cleaned up
   167  		cwd, err := os.Getwd()
   168  		if err != nil {
   169  			panic(err)
   170  		}
   171  		tmpDir, err := ioutil.TempDir("", "grab-")
   172  		if err != nil {
   173  			panic(err)
   174  		}
   175  		if err := os.Chdir(tmpDir); err != nil {
   176  			panic(err)
   177  		}
   178  		defer func() {
   179  			os.Chdir(cwd)
   180  			if err := os.RemoveAll(tmpDir); err != nil {
   181  				panic(err)
   182  			}
   183  		}()
   184  		return m.Run()
   185  	}())
   186  }
   187  
   188  // TestTestServer ensures that the test server behaves as expected so that it
   189  // does not pollute other tests.
   190  func TestTestServer(t *testing.T) {
   191  	t.Run("default", func(t *testing.T) {
   192  		req, _ := http.NewRequest("GET", ts.URL+"?nohead", nil)
   193  		resp, _ := http.DefaultClient.Do(req)
   194  		defer resp.Body.Close()
   195  
   196  		expectSize := 1048576
   197  		if h := resp.ContentLength; h != int64(expectSize) {
   198  			t.Fatalf("expected Content-Length: %v, got %v", expectSize, h)
   199  		}
   200  		b, _ := ioutil.ReadAll(resp.Body)
   201  		if len(b) != expectSize {
   202  			t.Fatalf("expected body length: %v, got %v", expectSize, len(b))
   203  		}
   204  
   205  		if h := resp.Header.Get("Accept-Ranges"); h != "bytes" {
   206  			t.Fatalf("expected Accept-Ranges: bytes, got: %v", h)
   207  		}
   208  	})
   209  
   210  	t.Run("nohead", func(t *testing.T) {
   211  		req, _ := http.NewRequest("HEAD", ts.URL+"?nohead", nil)
   212  		resp, _ := http.DefaultClient.Do(req)
   213  		defer resp.Body.Close()
   214  		if resp.StatusCode != http.StatusMethodNotAllowed {
   215  			panic("HEAD request was allowed despite ?nohead being set")
   216  		}
   217  	})
   218  
   219  	t.Run("headStatus", func(t *testing.T) {
   220  		expect := http.StatusTeapot
   221  		req, _ := http.NewRequest(
   222  			"HEAD",
   223  			fmt.Sprintf("%s?headStatus=%d", ts.URL, expect),
   224  			nil)
   225  		resp, _ := http.DefaultClient.Do(req)
   226  		defer resp.Body.Close()
   227  		if resp.StatusCode != expect {
   228  			t.Fatalf("expected status: %v, got: %v", expect, resp.StatusCode)
   229  		}
   230  	})
   231  
   232  	t.Run("size", func(t *testing.T) {
   233  		req, _ := http.NewRequest("GET", ts.URL+"?size=321", nil)
   234  		resp, _ := http.DefaultClient.Do(req)
   235  		defer resp.Body.Close()
   236  		if resp.ContentLength != 321 {
   237  			t.Fatalf("expected Content-Length: %v, got %v", 321, resp.ContentLength)
   238  		}
   239  		b, _ := ioutil.ReadAll(resp.Body)
   240  		if len(b) != 321 {
   241  			t.Fatalf("expected body length: %v, got %v", 321, len(b))
   242  		}
   243  	})
   244  
   245  	t.Run("ranged=false", func(t *testing.T) {
   246  		req, _ := http.NewRequest("GET", ts.URL+"?ranged=false", nil)
   247  		resp, _ := http.DefaultClient.Do(req)
   248  		defer resp.Body.Close()
   249  		if h := resp.Header.Get("Accept-Ranges"); h != "" {
   250  			t.Fatalf("expected empty Accept-Ranges header, got: %v", h)
   251  		}
   252  	})
   253  
   254  	t.Run("filename", func(t *testing.T) {
   255  		req, _ := http.NewRequest("GET", ts.URL+"?filename=test", nil)
   256  		resp, _ := http.DefaultClient.Do(req)
   257  		defer resp.Body.Close()
   258  
   259  		expect := "attachment;filename=\"test\""
   260  		if h := resp.Header.Get("Content-Disposition"); h != expect {
   261  			t.Fatalf("expected Content-Disposition header: %v, got %v", expect, h)
   262  		}
   263  	})
   264  
   265  	t.Run("lastmod", func(t *testing.T) {
   266  		req, _ := http.NewRequest("GET", ts.URL+"?lastmod=123456789", nil)
   267  		resp, _ := http.DefaultClient.Do(req)
   268  		defer resp.Body.Close()
   269  
   270  		expect := "Thu, 29 Nov 1973 21:33:09 GMT"
   271  		if h := resp.Header.Get("Last-Modified"); h != expect {
   272  			t.Fatalf("expected Last-Modified header: %v, got %v", expect, h)
   273  		}
   274  	})
   275  }
   276  
   277  // TestGet tests grab.Get
   278  func TestGet(t *testing.T) {
   279  	filename := ".testGet"
   280  	defer os.Remove(filename)
   281  
   282  	resp, err := Get(filename, ts.URL)
   283  	if err != nil {
   284  		t.Fatalf("error in Get(): %v", err)
   285  	}
   286  
   287  	testComplete(t, resp)
   288  }
   289  
   290  func ExampleGet() {
   291  	// download a file to /tmp
   292  	resp, err := Get("/tmp", "http://example.com/example.zip")
   293  	if err != nil {
   294  		log.Fatal(err)
   295  	}
   296  
   297  	fmt.Println("Download saved to", resp.Filename)
   298  }