github.com/isi-lincoln/grab@v2.0.1-0.20200331080741-9f014744ee41+incompatible/client_test.go (about)

     1  package grab
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/md5"
     7  	"crypto/sha1"
     8  	"crypto/sha256"
     9  	"crypto/sha512"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"math/rand"
    14  	"net/http"
    15  	"os"
    16  	"path/filepath"
    17  	"strings"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/cavaliercoder/grab/grabtest"
    22  )
    23  
    24  // TestFilenameResolutions tests that the destination filename for Requests can
    25  // be determined correctly, using an explicitly requested path,
    26  // Content-Disposition headers or a URL path - with or without an existing
    27  // target directory.
    28  func TestFilenameResolution(t *testing.T) {
    29  	tests := []struct {
    30  		Name               string
    31  		Filename           string
    32  		URL                string
    33  		AttachmentFilename string
    34  		Expect             string
    35  	}{
    36  		{"Using Request.Filename", ".testWithFilename", "/url-filename", "header-filename", ".testWithFilename"},
    37  		{"Using Content-Disposition Header", "", "/url-filename", ".testWithHeaderFilename", ".testWithHeaderFilename"},
    38  		{"Using Content-Disposition Header with target directory", ".test", "/url-filename", "header-filename", ".test/header-filename"},
    39  		{"Using URL Path", "", "/.testWithURLFilename?params-filename", "", ".testWithURLFilename"},
    40  		{"Using URL Path with target directory", ".test", "/url-filename?garbage", "", ".test/url-filename"},
    41  		{"Failure", "", "", "", ""},
    42  	}
    43  
    44  	err := os.Mkdir(".test", 0777)
    45  	if err != nil {
    46  		panic(err)
    47  	}
    48  	defer os.RemoveAll(".test")
    49  
    50  	for _, test := range tests {
    51  		t.Run(test.Name, func(t *testing.T) {
    52  			opts := []grabtest.HandlerOption{}
    53  			if test.AttachmentFilename != "" {
    54  				opts = append(opts, grabtest.AttachmentFilename(test.AttachmentFilename))
    55  			}
    56  			grabtest.WithTestServer(t, func(url string) {
    57  				req := mustNewRequest(test.Filename, url+test.URL)
    58  				resp := DefaultClient.Do(req)
    59  				defer os.Remove(resp.Filename)
    60  				if err := resp.Err(); err != nil {
    61  					if test.Expect != "" || err != ErrNoFilename {
    62  						panic(err)
    63  					}
    64  				} else {
    65  					if test.Expect == "" {
    66  						t.Errorf("expected: %v, got: %v", ErrNoFilename, err)
    67  					}
    68  				}
    69  				if resp.Filename != test.Expect {
    70  					t.Errorf("Filename mismatch. Expected '%s', got '%s'.", test.Expect, resp.Filename)
    71  				}
    72  				testComplete(t, resp)
    73  			}, opts...)
    74  		})
    75  	}
    76  }
    77  
    78  // TestChecksums checks that checksum validation behaves as expected for valid
    79  // and corrupted downloads.
    80  func TestChecksums(t *testing.T) {
    81  	tests := []struct {
    82  		size  int
    83  		hash  hash.Hash
    84  		sum   string
    85  		match bool
    86  	}{
    87  		{128, md5.New(), "37eff01866ba3f538421b30b7cbefcac", true},
    88  		{128, md5.New(), "37eff01866ba3f538421b30b7cbefcad", false},
    89  		{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855b", true},
    90  		{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855c", false},
    91  		{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef372", true},
    92  		{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef373", false},
    93  		{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d535", true},
    94  		{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d536", false},
    95  		{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b77", true},
    96  		{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b78", false},
    97  		{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923be", true},
    98  		{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923bf", false},
    99  		{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be5", true},
   100  		{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be4", false},
   101  		{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c9", true},
   102  		{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c8", false},
   103  		{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83", true},
   104  		{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c82", false},
   105  		{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f7", true},
   106  		{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f8", false},
   107  		{1024, sha512.New(), "37f652be867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566c", true},
   108  		{1024, sha512.New(), "37f652bf867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566d", false},
   109  		{1048576, sha512.New(), "ac1d097b4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", true},
   110  		{1048576, sha512.New(), "ac1d097c4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", false},
   111  	}
   112  
   113  	for _, test := range tests {
   114  		var expect error
   115  		comparison := "Match"
   116  		if !test.match {
   117  			comparison = "Mismatch"
   118  			expect = ErrBadChecksum
   119  		}
   120  
   121  		t.Run(fmt.Sprintf("With%s%s", comparison, test.sum[:8]), func(t *testing.T) {
   122  			filename := fmt.Sprintf(".testChecksum-%s-%s", comparison, test.sum[:8])
   123  			defer os.Remove(filename)
   124  
   125  			grabtest.WithTestServer(t, func(url string) {
   126  				req := mustNewRequest(filename, url)
   127  				req.SetChecksum(test.hash, grabtest.MustHexDecodeString(test.sum), true)
   128  
   129  				resp := DefaultClient.Do(req)
   130  				err := resp.Err()
   131  				if err != expect {
   132  					t.Errorf("expected error: %v, got: %v", expect, err)
   133  				}
   134  
   135  				// ensure mismatch file was deleted
   136  				if !test.match {
   137  					if _, err := os.Stat(filename); err == nil {
   138  						t.Errorf("checksum failure not cleaned up: %s", filename)
   139  					} else if !os.IsNotExist(err) {
   140  						panic(err)
   141  					}
   142  				}
   143  
   144  				testComplete(t, resp)
   145  			}, grabtest.ContentLength(test.size))
   146  		})
   147  	}
   148  }
   149  
   150  // TestContentLength ensures that ErrBadLength is returned if a server response
   151  // does not match the requested length.
   152  func TestContentLength(t *testing.T) {
   153  	size := int64(32768)
   154  	testCases := []struct {
   155  		Name   string
   156  		NoHead bool
   157  		Size   int64
   158  		Expect int64
   159  		Match  bool
   160  	}{
   161  		{"Good size in HEAD request", false, size, size, true},
   162  		{"Good size in GET request", true, size, size, true},
   163  		{"Bad size in HEAD request", false, size - 1, size, false},
   164  		{"Bad size in GET request", true, size - 1, size, false},
   165  	}
   166  
   167  	for _, test := range testCases {
   168  		t.Run(test.Name, func(t *testing.T) {
   169  			opts := []grabtest.HandlerOption{
   170  				grabtest.ContentLength(int(test.Size)),
   171  			}
   172  			if test.NoHead {
   173  				opts = append(opts, grabtest.MethodWhitelist("GET"))
   174  			}
   175  
   176  			grabtest.WithTestServer(t, func(url string) {
   177  				req := mustNewRequest(".testSize-mismatch-head", url)
   178  				req.Size = size
   179  				resp := DefaultClient.Do(req)
   180  				defer os.Remove(resp.Filename)
   181  				err := resp.Err()
   182  				if test.Match {
   183  					if err == ErrBadLength {
   184  						t.Errorf("error: %v", err)
   185  					} else if err != nil {
   186  						panic(err)
   187  					} else if resp.Size() != size {
   188  						t.Errorf("expected %v bytes, got %v bytes", size, resp.Size())
   189  					}
   190  				} else {
   191  					if err == nil {
   192  						t.Errorf("expected: %v, got %v", ErrBadLength, err)
   193  					} else if err != ErrBadLength {
   194  						panic(err)
   195  					}
   196  				}
   197  				testComplete(t, resp)
   198  			}, opts...)
   199  		})
   200  	}
   201  }
   202  
   203  // TestAutoResume tests segmented downloading of a large file.
   204  func TestAutoResume(t *testing.T) {
   205  	segs := 8
   206  	size := 1048576
   207  	sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grabtest.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83")
   208  	filename := ".testAutoResume"
   209  
   210  	defer os.Remove(filename)
   211  
   212  	for i := 0; i < segs; i++ {
   213  		segsize := (i + 1) * (size / segs)
   214  		t.Run(fmt.Sprintf("With%vBytes", segsize), func(t *testing.T) {
   215  			grabtest.WithTestServer(t, func(url string) {
   216  				req := mustNewRequest(filename, url)
   217  				if i == segs-1 {
   218  					req.SetChecksum(sha256.New(), sum, false)
   219  				}
   220  				resp := mustDo(req)
   221  				if i > 0 && !resp.DidResume {
   222  					t.Errorf("expected Response.DidResume to be true")
   223  				}
   224  				testComplete(t, resp)
   225  			},
   226  				grabtest.ContentLength(segsize),
   227  			)
   228  		})
   229  	}
   230  
   231  	t.Run("WithFailure", func(t *testing.T) {
   232  		grabtest.WithTestServer(t, func(url string) {
   233  			// request smaller segment
   234  			req := mustNewRequest(filename, url)
   235  			resp := DefaultClient.Do(req)
   236  			if err := resp.Err(); err != ErrBadLength {
   237  				t.Errorf("expected ErrBadLength for smaller request, got: %v", err)
   238  			}
   239  		},
   240  			grabtest.ContentLength(size-128),
   241  		)
   242  	})
   243  
   244  	t.Run("WithNoResume", func(t *testing.T) {
   245  		grabtest.WithTestServer(t, func(url string) {
   246  			req := mustNewRequest(filename, url)
   247  			req.NoResume = true
   248  			resp := mustDo(req)
   249  			if resp.DidResume {
   250  				t.Errorf("expected Response.DidResume to be false")
   251  			}
   252  			testComplete(t, resp)
   253  		},
   254  			grabtest.ContentLength(size+128),
   255  		)
   256  	})
   257  
   258  	t.Run("WithNoResumeAndTruncate", func(t *testing.T) {
   259  		size := size - 128
   260  		grabtest.WithTestServer(t, func(url string) {
   261  			req := mustNewRequest(filename, url)
   262  			req.NoResume = true
   263  			resp := mustDo(req)
   264  			if resp.DidResume {
   265  				t.Errorf("expected Response.DidResume to be false")
   266  			}
   267  			if v := resp.BytesComplete(); v != int64(size) {
   268  				t.Errorf("expected Response.BytesComplete: %d, got: %d", size, v)
   269  			}
   270  			testComplete(t, resp)
   271  		},
   272  			grabtest.ContentLength(size),
   273  		)
   274  	})
   275  
   276  	t.Run("WithNoContentLengthHeader", func(t *testing.T) {
   277  		grabtest.WithTestServer(t, func(url string) {
   278  			req := mustNewRequest(filename, url)
   279  			req.SetChecksum(sha256.New(), sum, false)
   280  			resp := mustDo(req)
   281  			if !resp.DidResume {
   282  				t.Errorf("expected Response.DidResume to be true")
   283  			}
   284  			if actual := resp.Size(); actual != int64(size) {
   285  				t.Errorf("expected Response.Size: %d, got: %d", size, actual)
   286  			}
   287  			testComplete(t, resp)
   288  		},
   289  			grabtest.ContentLength(size),
   290  			grabtest.HeaderBlacklist("Content-Length"),
   291  		)
   292  	})
   293  
   294  	t.Run("WithNoContentLengthHeaderAndChecksumFailure", func(t *testing.T) {
   295  		// ref: https://github.com/cavaliercoder/grab/pull/27
   296  		size := size * 2
   297  		grabtest.WithTestServer(t, func(url string) {
   298  			req := mustNewRequest(filename, url)
   299  			req.SetChecksum(sha256.New(), sum, false)
   300  			resp := DefaultClient.Do(req)
   301  			if err := resp.Err(); err != ErrBadChecksum {
   302  				t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
   303  			}
   304  			if !resp.DidResume {
   305  				t.Errorf("expected Response.DidResume to be true")
   306  			}
   307  			if actual := resp.BytesComplete(); actual != int64(size) {
   308  				t.Errorf("expected Response.BytesComplete: %d, got: %d", size, actual)
   309  			}
   310  			if actual := resp.Size(); actual != int64(size) {
   311  				t.Errorf("expected Response.Size: %d, got: %d", size, actual)
   312  			}
   313  			testComplete(t, resp)
   314  		},
   315  			grabtest.ContentLength(size),
   316  			grabtest.HeaderBlacklist("Content-Length"),
   317  		)
   318  	})
   319  	// TODO: test when existing file is corrupted
   320  }
   321  
   322  func TestSkipExisting(t *testing.T) {
   323  	filename := ".testSkipExisting"
   324  	defer os.Remove(filename)
   325  
   326  	// download a file
   327  	grabtest.WithTestServer(t, func(url string) {
   328  		resp := mustDo(mustNewRequest(filename, url))
   329  		testComplete(t, resp)
   330  	})
   331  
   332  	// redownload
   333  	grabtest.WithTestServer(t, func(url string) {
   334  		resp := mustDo(mustNewRequest(filename, url))
   335  		testComplete(t, resp)
   336  
   337  		// ensure download was resumed
   338  		if !resp.DidResume {
   339  			t.Fatalf("Expected download to skip existing file, but it did not")
   340  		}
   341  
   342  		// ensure all bytes were resumed
   343  		if resp.Size() == 0 || resp.Size() != resp.bytesResumed {
   344  			t.Fatalf("Expected to skip %d bytes in redownload; got %d", resp.Size(), resp.bytesResumed)
   345  		}
   346  	})
   347  
   348  	// ensure checksum is performed on pre-existing file
   349  	grabtest.WithTestServer(t, func(url string) {
   350  		req := mustNewRequest(filename, url)
   351  		req.SetChecksum(sha256.New(), []byte{0x01, 0x02, 0x03, 0x04}, true)
   352  		resp := DefaultClient.Do(req)
   353  		if err := resp.Err(); err != ErrBadChecksum {
   354  			t.Fatalf("Expected checksum error, got: %v", err)
   355  		}
   356  	})
   357  }
   358  
   359  // TestBatch executes multiple requests simultaneously and validates the
   360  // responses.
   361  func TestBatch(t *testing.T) {
   362  	tests := 32
   363  	size := 32768
   364  	sum := grabtest.MustHexDecodeString("e11360251d1173650cdcd20f111d8f1ca2e412f572e8b36a4dc067121c1799b8")
   365  
   366  	// test with 4 workers and with one per request
   367  	grabtest.WithTestServer(t, func(url string) {
   368  		for _, workerCount := range []int{4, 0} {
   369  			// create requests
   370  			reqs := make([]*Request, tests)
   371  			for i := 0; i < len(reqs); i++ {
   372  				filename := fmt.Sprintf(".testBatch.%d", i+1)
   373  				reqs[i] = mustNewRequest(filename, url+fmt.Sprintf("/request_%d?", i+1))
   374  				reqs[i].Label = fmt.Sprintf("Test %d", i+1)
   375  				reqs[i].SetChecksum(sha256.New(), sum, false)
   376  			}
   377  
   378  			// batch run
   379  			responses := DefaultClient.DoBatch(workerCount, reqs...)
   380  
   381  			// listen for responses
   382  		Loop:
   383  			for i := 0; i < len(reqs); {
   384  				select {
   385  				case resp := <-responses:
   386  					if resp == nil {
   387  						break Loop
   388  					}
   389  					testComplete(t, resp)
   390  					if err := resp.Err(); err != nil {
   391  						t.Errorf("%s: %v", resp.Filename, err)
   392  					}
   393  
   394  					// remove test file
   395  					if resp.IsComplete() {
   396  						os.Remove(resp.Filename) // ignore errors
   397  					}
   398  					i++
   399  				}
   400  			}
   401  		}
   402  	},
   403  		grabtest.ContentLength(size),
   404  	)
   405  }
   406  
   407  // TestCancelContext tests that a batch of requests can be cancel using a
   408  // context.Context cancellation. Requests are cancelled in multiple states:
   409  // in-progress and unstarted.
   410  func TestCancelContext(t *testing.T) {
   411  	fileSize := 134217728
   412  	tests := 256
   413  	client := NewClient()
   414  	ctx, cancel := context.WithCancel(context.Background())
   415  	defer cancel()
   416  
   417  	grabtest.WithTestServer(t, func(url string) {
   418  		reqs := make([]*Request, tests)
   419  		for i := 0; i < tests; i++ {
   420  			req := mustNewRequest("", fmt.Sprintf("%s/.testCancelContext%d", url, i))
   421  			reqs[i] = req.WithContext(ctx)
   422  		}
   423  
   424  		respch := client.DoBatch(8, reqs...)
   425  		time.Sleep(time.Millisecond * 500)
   426  		cancel()
   427  		for resp := range respch {
   428  			defer os.Remove(resp.Filename)
   429  
   430  			// err should be context.Canceled or http.errRequestCanceled
   431  			if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") {
   432  				t.Errorf("expected '%v', got '%v'", context.Canceled, resp.Err())
   433  			}
   434  			if resp.BytesComplete() >= int64(fileSize) {
   435  				t.Errorf("expected Response.BytesComplete: < %d, got: %d", fileSize, resp.BytesComplete())
   436  			}
   437  		}
   438  	},
   439  		grabtest.ContentLength(fileSize),
   440  	)
   441  }
   442  
   443  // TestNestedDirectory tests that missing subdirectories are created.
   444  func TestNestedDirectory(t *testing.T) {
   445  	dir := "./.testNested/one/two/three"
   446  	filename := ".testNestedFile"
   447  	expect := dir + "/" + filename
   448  
   449  	t.Run("Create", func(t *testing.T) {
   450  		grabtest.WithTestServer(t, func(url string) {
   451  			resp := mustDo(mustNewRequest(expect, url+"/"+filename))
   452  			defer os.RemoveAll("./.testNested/")
   453  			if resp.Filename != expect {
   454  				t.Errorf("expected nested Request.Filename to be %v, got %v", expect, resp.Filename)
   455  			}
   456  		})
   457  	})
   458  
   459  	t.Run("No create", func(t *testing.T) {
   460  		grabtest.WithTestServer(t, func(url string) {
   461  			req := mustNewRequest(expect, url+"/"+filename)
   462  			req.NoCreateDirectories = true
   463  			resp := DefaultClient.Do(req)
   464  			err := resp.Err()
   465  			if !os.IsNotExist(err) {
   466  				t.Errorf("expected: %v, got: %v", os.ErrNotExist, err)
   467  			}
   468  		})
   469  	})
   470  }
   471  
   472  // TestRemoteTime tests that the timestamp of the downloaded file can be set
   473  // according to the timestamp of the remote file.
   474  func TestRemoteTime(t *testing.T) {
   475  	filename := "./.testRemoteTime"
   476  	defer os.Remove(filename)
   477  
   478  	// random time between epoch and now
   479  	expect := time.Unix(rand.Int63n(time.Now().Unix()), 0)
   480  	grabtest.WithTestServer(t, func(url string) {
   481  		resp := mustDo(mustNewRequest(filename, url))
   482  		fi, err := os.Stat(resp.Filename)
   483  		if err != nil {
   484  			panic(err)
   485  		}
   486  		actual := fi.ModTime()
   487  		if !actual.Equal(expect) {
   488  			t.Errorf("expected %v, got %v", expect, actual)
   489  		}
   490  	},
   491  		grabtest.LastModified(expect),
   492  	)
   493  }
   494  
   495  func TestResponseCode(t *testing.T) {
   496  	filename := "./.testResponseCode"
   497  
   498  	t.Run("With404", func(t *testing.T) {
   499  		defer os.Remove(filename)
   500  		grabtest.WithTestServer(t, func(url string) {
   501  			req := mustNewRequest(filename, url)
   502  			resp := DefaultClient.Do(req)
   503  			expect := StatusCodeError(http.StatusNotFound)
   504  			err := resp.Err()
   505  			if err != expect {
   506  				t.Errorf("expected %v, got '%v'", expect, err)
   507  			}
   508  			if !IsStatusCodeError(err) {
   509  				t.Errorf("expected IsStatusCodeError to return true for %T: %v", err, err)
   510  			}
   511  		},
   512  			grabtest.StatusCodeStatic(http.StatusNotFound),
   513  		)
   514  	})
   515  
   516  	t.Run("WithIgnoreNon2XX", func(t *testing.T) {
   517  		defer os.Remove(filename)
   518  		grabtest.WithTestServer(t, func(url string) {
   519  			req := mustNewRequest(filename, url)
   520  			req.IgnoreBadStatusCodes = true
   521  			resp := DefaultClient.Do(req)
   522  			if err := resp.Err(); err != nil {
   523  				t.Errorf("expected nil, got '%v'", err)
   524  			}
   525  		},
   526  			grabtest.StatusCodeStatic(http.StatusNotFound),
   527  		)
   528  	})
   529  }
   530  
   531  func TestBeforeCopyHook(t *testing.T) {
   532  	filename := "./.testBeforeCopy"
   533  	t.Run("Noop", func(t *testing.T) {
   534  		defer os.RemoveAll(filename)
   535  		grabtest.WithTestServer(t, func(url string) {
   536  			called := false
   537  			req := mustNewRequest(filename, url)
   538  			req.BeforeCopy = func(resp *Response) error {
   539  				called = true
   540  				if resp.IsComplete() {
   541  					t.Error("Response object passed to BeforeCopy hook has already been closed")
   542  				}
   543  				if resp.Progress() != 0 {
   544  					t.Error("Download progress already > 0 when BeforeCopy hook was called")
   545  				}
   546  				if resp.Duration() == 0 {
   547  					t.Error("Duration was zero when BeforeCopy was called")
   548  				}
   549  				if resp.BytesComplete() != 0 {
   550  					t.Error("BytesComplete already > 0 when BeforeCopy hook was called")
   551  				}
   552  				return nil
   553  			}
   554  			resp := DefaultClient.Do(req)
   555  			if err := resp.Err(); err != nil {
   556  				t.Errorf("unexpected error using BeforeCopy hook: %v", err)
   557  			}
   558  			testComplete(t, resp)
   559  			if !called {
   560  				t.Error("BeforeCopy hook was never called")
   561  			}
   562  		})
   563  	})
   564  
   565  	t.Run("WithError", func(t *testing.T) {
   566  		defer os.RemoveAll(filename)
   567  		grabtest.WithTestServer(t, func(url string) {
   568  			testError := errors.New("test")
   569  			req := mustNewRequest(filename, url)
   570  			req.BeforeCopy = func(resp *Response) error {
   571  				return testError
   572  			}
   573  			resp := DefaultClient.Do(req)
   574  			if err := resp.Err(); err != testError {
   575  				t.Errorf("expected error '%v', got '%v'", testError, err)
   576  			}
   577  			if resp.BytesComplete() != 0 {
   578  				t.Errorf("expected 0 bytes completed for canceled BeforeCopy hook, got %d",
   579  					resp.BytesComplete())
   580  			}
   581  			testComplete(t, resp)
   582  		})
   583  	})
   584  }
   585  
   586  func TestAfterCopyHook(t *testing.T) {
   587  	filename := "./.testAfterCopy"
   588  	t.Run("Noop", func(t *testing.T) {
   589  		defer os.RemoveAll(filename)
   590  		grabtest.WithTestServer(t, func(url string) {
   591  			called := false
   592  			req := mustNewRequest(filename, url)
   593  			req.AfterCopy = func(resp *Response) error {
   594  				called = true
   595  				if resp.IsComplete() {
   596  					t.Error("Response object passed to AfterCopy hook has already been closed")
   597  				}
   598  				if resp.Progress() <= 0 {
   599  					t.Error("Download progress was 0 when AfterCopy hook was called")
   600  				}
   601  				if resp.Duration() == 0 {
   602  					t.Error("Duration was zero when AfterCopy was called")
   603  				}
   604  				if resp.BytesComplete() <= 0 {
   605  					t.Error("BytesComplete was 0 when AfterCopy hook was called")
   606  				}
   607  				return nil
   608  			}
   609  			resp := DefaultClient.Do(req)
   610  			if err := resp.Err(); err != nil {
   611  				t.Errorf("unexpected error using AfterCopy hook: %v", err)
   612  			}
   613  			testComplete(t, resp)
   614  			if !called {
   615  				t.Error("AfterCopy hook was never called")
   616  			}
   617  		})
   618  	})
   619  
   620  	t.Run("WithError", func(t *testing.T) {
   621  		defer os.RemoveAll(filename)
   622  		grabtest.WithTestServer(t, func(url string) {
   623  			testError := errors.New("test")
   624  			req := mustNewRequest(filename, url)
   625  			req.AfterCopy = func(resp *Response) error {
   626  				return testError
   627  			}
   628  			resp := DefaultClient.Do(req)
   629  			if err := resp.Err(); err != testError {
   630  				t.Errorf("expected error '%v', got '%v'", testError, err)
   631  			}
   632  			if resp.BytesComplete() <= 0 {
   633  				t.Errorf("ByteCompleted was %d after AfterCopy hook was called",
   634  					resp.BytesComplete())
   635  			}
   636  			testComplete(t, resp)
   637  		})
   638  	})
   639  }
   640  
   641  func TestIssue37(t *testing.T) {
   642  	// ref: https://github.com/cavaliercoder/grab/issues/37
   643  	filename := "./.testIssue37"
   644  	largeSize := int64(2097152)
   645  	smallSize := int64(1048576)
   646  	defer os.RemoveAll(filename)
   647  
   648  	// download large file
   649  	grabtest.WithTestServer(t, func(url string) {
   650  		resp := mustDo(mustNewRequest(filename, url))
   651  		if resp.Size() != largeSize {
   652  			t.Errorf("expected response size: %d, got: %d", largeSize, resp.Size())
   653  		}
   654  	}, grabtest.ContentLength(int(largeSize)))
   655  
   656  	// download new, smaller version of same file
   657  	grabtest.WithTestServer(t, func(url string) {
   658  		req := mustNewRequest(filename, url)
   659  		req.NoResume = true
   660  		resp := mustDo(req)
   661  		if resp.Size() != smallSize {
   662  			t.Errorf("expected response size: %d, got: %d", smallSize, resp.Size())
   663  		}
   664  
   665  		// local file should have truncated and not resumed
   666  		if resp.DidResume {
   667  			t.Errorf("expected download to truncate, resumed instead")
   668  		}
   669  	}, grabtest.ContentLength(int(smallSize)))
   670  
   671  	fi, err := os.Stat(filename)
   672  	if err != nil {
   673  		t.Fatal(err)
   674  	}
   675  	if fi.Size() != int64(smallSize) {
   676  		t.Errorf("expected file size %d, got %d", smallSize, fi.Size())
   677  	}
   678  }
   679  
   680  // TestHeadBadStatus validates that HEAD requests that return non-200 can be
   681  // ignored and succeed if the GET requests succeeeds.
   682  //
   683  // Fixes: https://github.com/cavaliercoder/grab/issues/43
   684  func TestHeadBadStatus(t *testing.T) {
   685  	expect := http.StatusOK
   686  	filename := ".testIssue43"
   687  
   688  	statusFunc := func(r *http.Request) int {
   689  		if r.Method == "HEAD" {
   690  			return http.StatusForbidden
   691  		}
   692  		return http.StatusOK
   693  	}
   694  
   695  	grabtest.WithTestServer(t, func(url string) {
   696  		testURL := fmt.Sprintf("%s/%s", url, filename)
   697  		resp := mustDo(mustNewRequest("", testURL))
   698  		if resp.HTTPResponse.StatusCode != expect {
   699  			t.Errorf(
   700  				"expected status code: %d, got:% d",
   701  				expect,
   702  				resp.HTTPResponse.StatusCode)
   703  		}
   704  	},
   705  		grabtest.StatusCode(statusFunc),
   706  	)
   707  }
   708  
   709  // TestMissingContentLength ensures that the Response.Size is correct for
   710  // transfers where the remote server does not send a Content-Length header.
   711  //
   712  // TestAutoResume also covers cases with checksum validation.
   713  //
   714  // Kudos to Setnička Jiří <Jiri.Setnicka@ysoft.com> for identifying and raising
   715  // a solution to this issue. Ref: https://github.com/cavaliercoder/grab/pull/27
   716  func TestMissingContentLength(t *testing.T) {
   717  	// expectSize must be sufficiently large that DefaultClient.Do won't prefetch
   718  	// the entire body and compute ContentLength before returning a Response.
   719  	expectSize := 1048576
   720  	opts := []grabtest.HandlerOption{
   721  		grabtest.ContentLength(expectSize),
   722  		grabtest.HeaderBlacklist("Content-Length"),
   723  		grabtest.TimeToFirstByte(time.Millisecond * 100), // delay for initial read
   724  	}
   725  	grabtest.WithTestServer(t, func(url string) {
   726  		req := mustNewRequest(".testMissingContentLength", url)
   727  		req.SetChecksum(
   728  			md5.New(),
   729  			grabtest.DefaultHandlerMD5ChecksumBytes,
   730  			false)
   731  		resp := DefaultClient.Do(req)
   732  
   733  		// ensure remote server is not sending content-length header
   734  		if v := resp.HTTPResponse.Header.Get("Content-Length"); v != "" {
   735  			panic(fmt.Sprintf("http header content length must be empty, got: %s", v))
   736  		}
   737  		if v := resp.HTTPResponse.ContentLength; v != -1 {
   738  			panic(fmt.Sprintf("http response content length must be -1, got: %d", v))
   739  		}
   740  
   741  		// before completion, response size should be -1
   742  		if resp.Size() != -1 {
   743  			t.Errorf("expected response size: -1, got: %d", resp.Size())
   744  		}
   745  
   746  		// block for completion
   747  		if err := resp.Err(); err != nil {
   748  			panic(err)
   749  		}
   750  
   751  		// on completion, response size should be actual transfer size
   752  		if resp.Size() != int64(expectSize) {
   753  			t.Errorf("expected response size: %d, got: %d", expectSize, resp.Size())
   754  		}
   755  	}, opts...)
   756  }
   757  
   758  func TestNoStore(t *testing.T) {
   759  	filename := ".testSubdir/testNoStore"
   760  	t.Run("DefaultCase", func(t *testing.T) {
   761  		grabtest.WithTestServer(t, func(url string) {
   762  			req := mustNewRequest(filename, url)
   763  			req.NoStore = true
   764  			req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true)
   765  			resp := mustDo(req)
   766  
   767  			// ensure Response.Bytes is correct and can be reread
   768  			b, err := resp.Bytes()
   769  			if err != nil {
   770  				panic(err)
   771  			}
   772  			grabtest.AssertSHA256Sum(
   773  				t,
   774  				grabtest.DefaultHandlerSHA256ChecksumBytes,
   775  				bytes.NewReader(b),
   776  			)
   777  
   778  			// ensure Response.Open stream is correct and can be reread
   779  			r, err := resp.Open()
   780  			if err != nil {
   781  				panic(err)
   782  			}
   783  			defer r.Close()
   784  			grabtest.AssertSHA256Sum(
   785  				t,
   786  				grabtest.DefaultHandlerSHA256ChecksumBytes,
   787  				r,
   788  			)
   789  
   790  			// Response.Filename should still be set
   791  			if resp.Filename != filename {
   792  				t.Errorf("expected Response.Filename: %s, got: %s", filename, resp.Filename)
   793  			}
   794  
   795  			// ensure no files were written
   796  			paths := []string{
   797  				filename,
   798  				filepath.Base(filename),
   799  				filepath.Dir(filename),
   800  				resp.Filename,
   801  				filepath.Base(resp.Filename),
   802  				filepath.Dir(resp.Filename),
   803  			}
   804  			for _, path := range paths {
   805  				_, err := os.Stat(path)
   806  				if !os.IsNotExist(err) {
   807  					t.Errorf(
   808  						"expect error: %v, got: %v, for path: %s",
   809  						os.ErrNotExist,
   810  						err,
   811  						path)
   812  				}
   813  			}
   814  		})
   815  	})
   816  
   817  	t.Run("ChecksumValidation", func(t *testing.T) {
   818  		grabtest.WithTestServer(t, func(url string) {
   819  			req := mustNewRequest("", url)
   820  			req.NoStore = true
   821  			req.SetChecksum(
   822  				md5.New(),
   823  				grabtest.MustHexDecodeString("deadbeefcafebabe"),
   824  				true)
   825  			resp := DefaultClient.Do(req)
   826  			if err := resp.Err(); err != ErrBadChecksum {
   827  				t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
   828  			}
   829  		})
   830  	})
   831  }