github.com/cavaliergopher/grab/v3@v3.0.1/client_test.go (about)

     1  package grab
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/md5"
     7  	"crypto/sha1"
     8  	"crypto/sha256"
     9  	"crypto/sha512"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io/ioutil"
    14  	"math/rand"
    15  	"net/http"
    16  	"os"
    17  	"path/filepath"
    18  	"strings"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/cavaliergopher/grab/v3/pkg/grabtest"
    23  )
    24  
    25  // TestFilenameResolutions tests that the destination filename for Requests can
    26  // be determined correctly, using an explicitly requested path,
    27  // Content-Disposition headers or a URL path - with or without an existing
    28  // target directory.
    29  func TestFilenameResolution(t *testing.T) {
    30  	tests := []struct {
    31  		Name               string
    32  		Filename           string
    33  		URL                string
    34  		AttachmentFilename string
    35  		Expect             string
    36  	}{
    37  		{"Using Request.Filename", ".testWithFilename", "/url-filename", "header-filename", ".testWithFilename"},
    38  		{"Using Content-Disposition Header", "", "/url-filename", ".testWithHeaderFilename", ".testWithHeaderFilename"},
    39  		{"Using Content-Disposition Header with target directory", ".test", "/url-filename", "header-filename", ".test/header-filename"},
    40  		{"Using URL Path", "", "/.testWithURLFilename?params-filename", "", ".testWithURLFilename"},
    41  		{"Using URL Path with target directory", ".test", "/url-filename?garbage", "", ".test/url-filename"},
    42  		{"Failure", "", "", "", ""},
    43  	}
    44  
    45  	err := os.Mkdir(".test", 0777)
    46  	if err != nil {
    47  		panic(err)
    48  	}
    49  	defer os.RemoveAll(".test")
    50  
    51  	for _, test := range tests {
    52  		t.Run(test.Name, func(t *testing.T) {
    53  			opts := []grabtest.HandlerOption{}
    54  			if test.AttachmentFilename != "" {
    55  				opts = append(opts, grabtest.AttachmentFilename(test.AttachmentFilename))
    56  			}
    57  			grabtest.WithTestServer(t, func(url string) {
    58  				req := mustNewRequest(test.Filename, url+test.URL)
    59  				resp := DefaultClient.Do(req)
    60  				defer os.Remove(resp.Filename)
    61  				if err := resp.Err(); err != nil {
    62  					if test.Expect != "" || err != ErrNoFilename {
    63  						panic(err)
    64  					}
    65  				} else {
    66  					if test.Expect == "" {
    67  						t.Errorf("expected: %v, got: %v", ErrNoFilename, err)
    68  					}
    69  				}
    70  				if resp.Filename != test.Expect {
    71  					t.Errorf("Filename mismatch. Expected '%s', got '%s'.", test.Expect, resp.Filename)
    72  				}
    73  				testComplete(t, resp)
    74  			}, opts...)
    75  		})
    76  	}
    77  }
    78  
    79  // TestChecksums checks that checksum validation behaves as expected for valid
    80  // and corrupted downloads.
    81  func TestChecksums(t *testing.T) {
    82  	tests := []struct {
    83  		size  int
    84  		hash  hash.Hash
    85  		sum   string
    86  		match bool
    87  	}{
    88  		{128, md5.New(), "37eff01866ba3f538421b30b7cbefcac", true},
    89  		{128, md5.New(), "37eff01866ba3f538421b30b7cbefcad", false},
    90  		{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855b", true},
    91  		{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855c", false},
    92  		{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef372", true},
    93  		{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef373", false},
    94  		{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d535", true},
    95  		{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d536", false},
    96  		{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b77", true},
    97  		{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b78", false},
    98  		{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923be", true},
    99  		{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923bf", false},
   100  		{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be5", true},
   101  		{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be4", false},
   102  		{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c9", true},
   103  		{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c8", false},
   104  		{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83", true},
   105  		{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c82", false},
   106  		{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f7", true},
   107  		{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f8", false},
   108  		{1024, sha512.New(), "37f652be867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566c", true},
   109  		{1024, sha512.New(), "37f652bf867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566d", false},
   110  		{1048576, sha512.New(), "ac1d097b4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", true},
   111  		{1048576, sha512.New(), "ac1d097c4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", false},
   112  	}
   113  
   114  	for _, test := range tests {
   115  		var expect error
   116  		comparison := "Match"
   117  		if !test.match {
   118  			comparison = "Mismatch"
   119  			expect = ErrBadChecksum
   120  		}
   121  
   122  		t.Run(fmt.Sprintf("With%s%s", comparison, test.sum[:8]), func(t *testing.T) {
   123  			filename := fmt.Sprintf(".testChecksum-%s-%s", comparison, test.sum[:8])
   124  			defer os.Remove(filename)
   125  
   126  			grabtest.WithTestServer(t, func(url string) {
   127  				req := mustNewRequest(filename, url)
   128  				req.SetChecksum(test.hash, grabtest.MustHexDecodeString(test.sum), true)
   129  
   130  				resp := DefaultClient.Do(req)
   131  				err := resp.Err()
   132  				if err != expect {
   133  					t.Errorf("expected error: %v, got: %v", expect, err)
   134  				}
   135  
   136  				// ensure mismatch file was deleted
   137  				if !test.match {
   138  					if _, err := os.Stat(filename); err == nil {
   139  						t.Errorf("checksum failure not cleaned up: %s", filename)
   140  					} else if !os.IsNotExist(err) {
   141  						panic(err)
   142  					}
   143  				}
   144  
   145  				testComplete(t, resp)
   146  			}, grabtest.ContentLength(test.size))
   147  		})
   148  	}
   149  }
   150  
   151  // TestContentLength ensures that ErrBadLength is returned if a server response
   152  // does not match the requested length.
   153  func TestContentLength(t *testing.T) {
   154  	size := int64(32768)
   155  	testCases := []struct {
   156  		Name   string
   157  		NoHead bool
   158  		Size   int64
   159  		Expect int64
   160  		Match  bool
   161  	}{
   162  		{"Good size in HEAD request", false, size, size, true},
   163  		{"Good size in GET request", true, size, size, true},
   164  		{"Bad size in HEAD request", false, size - 1, size, false},
   165  		{"Bad size in GET request", true, size - 1, size, false},
   166  	}
   167  
   168  	for _, test := range testCases {
   169  		t.Run(test.Name, func(t *testing.T) {
   170  			opts := []grabtest.HandlerOption{
   171  				grabtest.ContentLength(int(test.Size)),
   172  			}
   173  			if test.NoHead {
   174  				opts = append(opts, grabtest.MethodWhitelist("GET"))
   175  			}
   176  
   177  			grabtest.WithTestServer(t, func(url string) {
   178  				req := mustNewRequest(".testSize-mismatch-head", url)
   179  				req.Size = size
   180  				resp := DefaultClient.Do(req)
   181  				defer os.Remove(resp.Filename)
   182  				err := resp.Err()
   183  				if test.Match {
   184  					if err == ErrBadLength {
   185  						t.Errorf("error: %v", err)
   186  					} else if err != nil {
   187  						panic(err)
   188  					} else if resp.Size() != size {
   189  						t.Errorf("expected %v bytes, got %v bytes", size, resp.Size())
   190  					}
   191  				} else {
   192  					if err == nil {
   193  						t.Errorf("expected: %v, got %v", ErrBadLength, err)
   194  					} else if err != ErrBadLength {
   195  						panic(err)
   196  					}
   197  				}
   198  				testComplete(t, resp)
   199  			}, opts...)
   200  		})
   201  	}
   202  }
   203  
   204  // TestAutoResume tests segmented downloading of a large file.
   205  func TestAutoResume(t *testing.T) {
   206  	segs := 8
   207  	size := 1048576
   208  	sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83")
   209  	filename := ".testAutoResume"
   210  
   211  	defer os.Remove(filename)
   212  
   213  	for i := 0; i < segs; i++ {
   214  		segsize := (i + 1) * (size / segs)
   215  		t.Run(fmt.Sprintf("With%vBytes", segsize), func(t *testing.T) {
   216  			grabtest.WithTestServer(t, func(url string) {
   217  				req := mustNewRequest(filename, url)
   218  				if i == segs-1 {
   219  					req.SetChecksum(sha256.New(), sum, false)
   220  				}
   221  				resp := mustDo(req)
   222  				if i > 0 && !resp.DidResume {
   223  					t.Errorf("expected Response.DidResume to be true")
   224  				}
   225  				testComplete(t, resp)
   226  			},
   227  				grabtest.ContentLength(segsize),
   228  			)
   229  		})
   230  	}
   231  
   232  	t.Run("WithFailure", func(t *testing.T) {
   233  		grabtest.WithTestServer(t, func(url string) {
   234  			// request smaller segment
   235  			req := mustNewRequest(filename, url)
   236  			resp := DefaultClient.Do(req)
   237  			if err := resp.Err(); err != ErrBadLength {
   238  				t.Errorf("expected ErrBadLength for smaller request, got: %v", err)
   239  			}
   240  		},
   241  			grabtest.ContentLength(size-128),
   242  		)
   243  	})
   244  
   245  	t.Run("WithNoResume", func(t *testing.T) {
   246  		grabtest.WithTestServer(t, func(url string) {
   247  			req := mustNewRequest(filename, url)
   248  			req.NoResume = true
   249  			resp := mustDo(req)
   250  			if resp.DidResume {
   251  				t.Errorf("expected Response.DidResume to be false")
   252  			}
   253  			testComplete(t, resp)
   254  		},
   255  			grabtest.ContentLength(size+128),
   256  		)
   257  	})
   258  
   259  	t.Run("WithNoResumeAndTruncate", func(t *testing.T) {
   260  		size := size - 128
   261  		grabtest.WithTestServer(t, func(url string) {
   262  			req := mustNewRequest(filename, url)
   263  			req.NoResume = true
   264  			resp := mustDo(req)
   265  			if resp.DidResume {
   266  				t.Errorf("expected Response.DidResume to be false")
   267  			}
   268  			if v := resp.BytesComplete(); v != int64(size) {
   269  				t.Errorf("expected Response.BytesComplete: %d, got: %d", size, v)
   270  			}
   271  			testComplete(t, resp)
   272  		},
   273  			grabtest.ContentLength(size),
   274  		)
   275  	})
   276  
   277  	t.Run("WithNoContentLengthHeader", func(t *testing.T) {
   278  		grabtest.WithTestServer(t, func(url string) {
   279  			req := mustNewRequest(filename, url)
   280  			req.SetChecksum(sha256.New(), sum, false)
   281  			resp := mustDo(req)
   282  			if !resp.DidResume {
   283  				t.Errorf("expected Response.DidResume to be true")
   284  			}
   285  			if actual := resp.Size(); actual != int64(size) {
   286  				t.Errorf("expected Response.Size: %d, got: %d", size, actual)
   287  			}
   288  			testComplete(t, resp)
   289  		},
   290  			grabtest.ContentLength(size),
   291  			grabtest.HeaderBlacklist("Content-Length"),
   292  		)
   293  	})
   294  
   295  	t.Run("WithNoContentLengthHeaderAndChecksumFailure", func(t *testing.T) {
   296  		// ref: https://github.com/cavaliergopher/grab/v3/pull/27
   297  		size := size * 2
   298  		grabtest.WithTestServer(t, func(url string) {
   299  			req := mustNewRequest(filename, url)
   300  			req.SetChecksum(sha256.New(), sum, false)
   301  			resp := DefaultClient.Do(req)
   302  			if err := resp.Err(); err != ErrBadChecksum {
   303  				t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
   304  			}
   305  			if !resp.DidResume {
   306  				t.Errorf("expected Response.DidResume to be true")
   307  			}
   308  			if actual := resp.BytesComplete(); actual != int64(size) {
   309  				t.Errorf("expected Response.BytesComplete: %d, got: %d", size, actual)
   310  			}
   311  			if actual := resp.Size(); actual != int64(size) {
   312  				t.Errorf("expected Response.Size: %d, got: %d", size, actual)
   313  			}
   314  			testComplete(t, resp)
   315  		},
   316  			grabtest.ContentLength(size),
   317  			grabtest.HeaderBlacklist("Content-Length"),
   318  		)
   319  	})
   320  	// TODO: test when existing file is corrupted
   321  }
   322  
   323  func TestSkipExisting(t *testing.T) {
   324  	filename := ".testSkipExisting"
   325  	defer os.Remove(filename)
   326  
   327  	// download a file
   328  	grabtest.WithTestServer(t, func(url string) {
   329  		resp := mustDo(mustNewRequest(filename, url))
   330  		testComplete(t, resp)
   331  	})
   332  
   333  	// redownload
   334  	grabtest.WithTestServer(t, func(url string) {
   335  		resp := mustDo(mustNewRequest(filename, url))
   336  		testComplete(t, resp)
   337  
   338  		// ensure download was resumed
   339  		if !resp.DidResume {
   340  			t.Fatalf("Expected download to skip existing file, but it did not")
   341  		}
   342  
   343  		// ensure all bytes were resumed
   344  		if resp.Size() == 0 || resp.Size() != resp.bytesResumed {
   345  			t.Fatalf("Expected to skip %d bytes in redownload; got %d", resp.Size(), resp.bytesResumed)
   346  		}
   347  	})
   348  
   349  	// ensure checksum is performed on pre-existing file
   350  	grabtest.WithTestServer(t, func(url string) {
   351  		req := mustNewRequest(filename, url)
   352  		req.SetChecksum(sha256.New(), []byte{0x01, 0x02, 0x03, 0x04}, true)
   353  		resp := DefaultClient.Do(req)
   354  		if err := resp.Err(); err != ErrBadChecksum {
   355  			t.Fatalf("Expected checksum error, got: %v", err)
   356  		}
   357  	})
   358  }
   359  
   360  // TestBatch executes multiple requests simultaneously and validates the
   361  // responses.
   362  func TestBatch(t *testing.T) {
   363  	tests := 32
   364  	size := 32768
   365  	sum := grabtest.MustHexDecodeString("e11360251d1173650cdcd20f111d8f1ca2e412f572e8b36a4dc067121c1799b8")
   366  
   367  	// test with 4 workers and with one per request
   368  	grabtest.WithTestServer(t, func(url string) {
   369  		for _, workerCount := range []int{4, 0} {
   370  			// create requests
   371  			reqs := make([]*Request, tests)
   372  			for i := 0; i < len(reqs); i++ {
   373  				filename := fmt.Sprintf(".testBatch.%d", i+1)
   374  				reqs[i] = mustNewRequest(filename, url+fmt.Sprintf("/request_%d?", i+1))
   375  				reqs[i].Label = fmt.Sprintf("Test %d", i+1)
   376  				reqs[i].SetChecksum(sha256.New(), sum, false)
   377  			}
   378  
   379  			// batch run
   380  			responses := DefaultClient.DoBatch(workerCount, reqs...)
   381  
   382  			// listen for responses
   383  		Loop:
   384  			for i := 0; i < len(reqs); {
   385  				select {
   386  				case resp := <-responses:
   387  					if resp == nil {
   388  						break Loop
   389  					}
   390  					testComplete(t, resp)
   391  					if err := resp.Err(); err != nil {
   392  						t.Errorf("%s: %v", resp.Filename, err)
   393  					}
   394  
   395  					// remove test file
   396  					if resp.IsComplete() {
   397  						os.Remove(resp.Filename) // ignore errors
   398  					}
   399  					i++
   400  				}
   401  			}
   402  		}
   403  	},
   404  		grabtest.ContentLength(size),
   405  	)
   406  }
   407  
   408  // TestCancelContext tests that a batch of requests can be cancel using a
   409  // context.Context cancellation. Requests are cancelled in multiple states:
   410  // in-progress and unstarted.
   411  func TestCancelContext(t *testing.T) {
   412  	fileSize := 134217728
   413  	tests := 256
   414  	client := NewClient()
   415  	ctx, cancel := context.WithCancel(context.Background())
   416  	defer cancel()
   417  
   418  	grabtest.WithTestServer(t, func(url string) {
   419  		reqs := make([]*Request, tests)
   420  		for i := 0; i < tests; i++ {
   421  			req := mustNewRequest("", fmt.Sprintf("%s/.testCancelContext%d", url, i))
   422  			reqs[i] = req.WithContext(ctx)
   423  		}
   424  
   425  		respch := client.DoBatch(8, reqs...)
   426  		time.Sleep(time.Millisecond * 500)
   427  		cancel()
   428  		for resp := range respch {
   429  			defer os.Remove(resp.Filename)
   430  
   431  			// err should be context.Canceled or http.errRequestCanceled
   432  			if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") {
   433  				t.Errorf("expected '%v', got '%v'", context.Canceled, resp.Err())
   434  			}
   435  			if resp.BytesComplete() >= int64(fileSize) {
   436  				t.Errorf("expected Response.BytesComplete: < %d, got: %d", fileSize, resp.BytesComplete())
   437  			}
   438  		}
   439  	},
   440  		grabtest.ContentLength(fileSize),
   441  	)
   442  }
   443  
   444  // TestCancelHangingResponse tests that a never ending request is terminated
   445  // when the response is cancelled.
   446  func TestCancelHangingResponse(t *testing.T) {
   447  	fileSize := 10
   448  	client := NewClient()
   449  
   450  	grabtest.WithTestServer(t, func(url string) {
   451  		req := mustNewRequest("", fmt.Sprintf("%s/.testCancelHangingResponse", url))
   452  
   453  		resp := client.Do(req)
   454  		defer os.Remove(resp.Filename)
   455  
   456  		// Wait for some bytes to be transferred
   457  		for resp.BytesComplete() == 0 {
   458  			time.Sleep(50 * time.Millisecond)
   459  		}
   460  
   461  		done := make(chan error)
   462  		go func() {
   463  			done <- resp.Cancel()
   464  		}()
   465  
   466  		select {
   467  		case err := <-done:
   468  			if err != context.Canceled {
   469  				t.Errorf("Expected context.Canceled error, go: %v", err)
   470  			}
   471  		case <-time.After(time.Second):
   472  			t.Fatal("response was not cancelled within 1s")
   473  		}
   474  		if resp.BytesComplete() == int64(fileSize) {
   475  			t.Error("download was not supposed to be complete")
   476  		}
   477  	},
   478  		grabtest.RateLimiter(1),
   479  		grabtest.ContentLength(fileSize),
   480  	)
   481  }
   482  
   483  // TestNestedDirectory tests that missing subdirectories are created.
   484  func TestNestedDirectory(t *testing.T) {
   485  	dir := "./.testNested/one/two/three"
   486  	filename := ".testNestedFile"
   487  	expect := dir + "/" + filename
   488  
   489  	t.Run("Create", func(t *testing.T) {
   490  		grabtest.WithTestServer(t, func(url string) {
   491  			resp := mustDo(mustNewRequest(expect, url+"/"+filename))
   492  			defer os.RemoveAll("./.testNested/")
   493  			if resp.Filename != expect {
   494  				t.Errorf("expected nested Request.Filename to be %v, got %v", expect, resp.Filename)
   495  			}
   496  		})
   497  	})
   498  
   499  	t.Run("No create", func(t *testing.T) {
   500  		grabtest.WithTestServer(t, func(url string) {
   501  			req := mustNewRequest(expect, url+"/"+filename)
   502  			req.NoCreateDirectories = true
   503  			resp := DefaultClient.Do(req)
   504  			err := resp.Err()
   505  			if !os.IsNotExist(err) {
   506  				t.Errorf("expected: %v, got: %v", os.ErrNotExist, err)
   507  			}
   508  		})
   509  	})
   510  }
   511  
   512  // TestRemoteTime tests that the timestamp of the downloaded file can be set
   513  // according to the timestamp of the remote file.
   514  func TestRemoteTime(t *testing.T) {
   515  	filename := "./.testRemoteTime"
   516  	defer os.Remove(filename)
   517  
   518  	// random time between epoch and now
   519  	expect := time.Unix(rand.Int63n(time.Now().Unix()), 0)
   520  	grabtest.WithTestServer(t, func(url string) {
   521  		resp := mustDo(mustNewRequest(filename, url))
   522  		fi, err := os.Stat(resp.Filename)
   523  		if err != nil {
   524  			panic(err)
   525  		}
   526  		actual := fi.ModTime()
   527  		if !actual.Equal(expect) {
   528  			t.Errorf("expected %v, got %v", expect, actual)
   529  		}
   530  	},
   531  		grabtest.LastModified(expect),
   532  	)
   533  }
   534  
   535  func TestResponseCode(t *testing.T) {
   536  	filename := "./.testResponseCode"
   537  
   538  	t.Run("With404", func(t *testing.T) {
   539  		defer os.Remove(filename)
   540  		grabtest.WithTestServer(t, func(url string) {
   541  			req := mustNewRequest(filename, url)
   542  			resp := DefaultClient.Do(req)
   543  			expect := StatusCodeError(http.StatusNotFound)
   544  			err := resp.Err()
   545  			if err != expect {
   546  				t.Errorf("expected %v, got '%v'", expect, err)
   547  			}
   548  			if !IsStatusCodeError(err) {
   549  				t.Errorf("expected IsStatusCodeError to return true for %T: %v", err, err)
   550  			}
   551  		},
   552  			grabtest.StatusCodeStatic(http.StatusNotFound),
   553  		)
   554  	})
   555  
   556  	t.Run("WithIgnoreNon2XX", func(t *testing.T) {
   557  		defer os.Remove(filename)
   558  		grabtest.WithTestServer(t, func(url string) {
   559  			req := mustNewRequest(filename, url)
   560  			req.IgnoreBadStatusCodes = true
   561  			resp := DefaultClient.Do(req)
   562  			if err := resp.Err(); err != nil {
   563  				t.Errorf("expected nil, got '%v'", err)
   564  			}
   565  		},
   566  			grabtest.StatusCodeStatic(http.StatusNotFound),
   567  		)
   568  	})
   569  }
   570  
   571  func TestBeforeCopyHook(t *testing.T) {
   572  	filename := "./.testBeforeCopy"
   573  	t.Run("Noop", func(t *testing.T) {
   574  		defer os.RemoveAll(filename)
   575  		grabtest.WithTestServer(t, func(url string) {
   576  			called := false
   577  			req := mustNewRequest(filename, url)
   578  			req.BeforeCopy = func(resp *Response) error {
   579  				called = true
   580  				if resp.IsComplete() {
   581  					t.Error("Response object passed to BeforeCopy hook has already been closed")
   582  				}
   583  				if resp.Progress() != 0 {
   584  					t.Error("Download progress already > 0 when BeforeCopy hook was called")
   585  				}
   586  				if resp.Duration() == 0 {
   587  					t.Error("Duration was zero when BeforeCopy was called")
   588  				}
   589  				if resp.BytesComplete() != 0 {
   590  					t.Error("BytesComplete already > 0 when BeforeCopy hook was called")
   591  				}
   592  				return nil
   593  			}
   594  			resp := DefaultClient.Do(req)
   595  			if err := resp.Err(); err != nil {
   596  				t.Errorf("unexpected error using BeforeCopy hook: %v", err)
   597  			}
   598  			testComplete(t, resp)
   599  			if !called {
   600  				t.Error("BeforeCopy hook was never called")
   601  			}
   602  		})
   603  	})
   604  
   605  	t.Run("WithError", func(t *testing.T) {
   606  		defer os.RemoveAll(filename)
   607  		grabtest.WithTestServer(t, func(url string) {
   608  			testError := errors.New("test")
   609  			req := mustNewRequest(filename, url)
   610  			req.BeforeCopy = func(resp *Response) error {
   611  				return testError
   612  			}
   613  			resp := DefaultClient.Do(req)
   614  			if err := resp.Err(); err != testError {
   615  				t.Errorf("expected error '%v', got '%v'", testError, err)
   616  			}
   617  			if resp.BytesComplete() != 0 {
   618  				t.Errorf("expected 0 bytes completed for canceled BeforeCopy hook, got %d",
   619  					resp.BytesComplete())
   620  			}
   621  			testComplete(t, resp)
   622  		})
   623  	})
   624  
   625  	// Assert that an existing local file will not be truncated prior to the
   626  	// BeforeCopy hook has a chance to cancel the request
   627  	t.Run("NoTruncate", func(t *testing.T) {
   628  		tfile, err := ioutil.TempFile("", "grab_client_test.*.file")
   629  		if err != nil {
   630  			t.Fatal(err)
   631  		}
   632  		defer os.Remove(tfile.Name())
   633  
   634  		const size = 128
   635  		_, err = tfile.Write(bytes.Repeat([]byte("x"), size))
   636  		if err != nil {
   637  			t.Fatal(err)
   638  		}
   639  
   640  		grabtest.WithTestServer(t, func(url string) {
   641  			called := false
   642  			req := mustNewRequest(tfile.Name(), url)
   643  			req.NoResume = true
   644  			req.BeforeCopy = func(resp *Response) error {
   645  				called = true
   646  				fi, err := tfile.Stat()
   647  				if err != nil {
   648  					t.Errorf("failed to stat temp file: %v", err)
   649  					return nil
   650  				}
   651  				if fi.Size() != size {
   652  					t.Errorf("expected existing file size of %d bytes "+
   653  						"prior to BeforeCopy hook, got %d", size, fi.Size())
   654  				}
   655  				return nil
   656  			}
   657  			resp := DefaultClient.Do(req)
   658  			if err := resp.Err(); err != nil {
   659  				t.Errorf("unexpected error using BeforeCopy hook: %v", err)
   660  			}
   661  			testComplete(t, resp)
   662  			if !called {
   663  				t.Error("BeforeCopy hook was never called")
   664  			}
   665  		})
   666  	})
   667  }
   668  
   669  func TestAfterCopyHook(t *testing.T) {
   670  	filename := "./.testAfterCopy"
   671  	t.Run("Noop", func(t *testing.T) {
   672  		defer os.RemoveAll(filename)
   673  		grabtest.WithTestServer(t, func(url string) {
   674  			called := false
   675  			req := mustNewRequest(filename, url)
   676  			req.AfterCopy = func(resp *Response) error {
   677  				called = true
   678  				if resp.IsComplete() {
   679  					t.Error("Response object passed to AfterCopy hook has already been closed")
   680  				}
   681  				if resp.Progress() <= 0 {
   682  					t.Error("Download progress was 0 when AfterCopy hook was called")
   683  				}
   684  				if resp.Duration() == 0 {
   685  					t.Error("Duration was zero when AfterCopy was called")
   686  				}
   687  				if resp.BytesComplete() <= 0 {
   688  					t.Error("BytesComplete was 0 when AfterCopy hook was called")
   689  				}
   690  				return nil
   691  			}
   692  			resp := DefaultClient.Do(req)
   693  			if err := resp.Err(); err != nil {
   694  				t.Errorf("unexpected error using AfterCopy hook: %v", err)
   695  			}
   696  			testComplete(t, resp)
   697  			if !called {
   698  				t.Error("AfterCopy hook was never called")
   699  			}
   700  		})
   701  	})
   702  
   703  	t.Run("WithError", func(t *testing.T) {
   704  		defer os.RemoveAll(filename)
   705  		grabtest.WithTestServer(t, func(url string) {
   706  			testError := errors.New("test")
   707  			req := mustNewRequest(filename, url)
   708  			req.AfterCopy = func(resp *Response) error {
   709  				return testError
   710  			}
   711  			resp := DefaultClient.Do(req)
   712  			if err := resp.Err(); err != testError {
   713  				t.Errorf("expected error '%v', got '%v'", testError, err)
   714  			}
   715  			if resp.BytesComplete() <= 0 {
   716  				t.Errorf("ByteCompleted was %d after AfterCopy hook was called",
   717  					resp.BytesComplete())
   718  			}
   719  			testComplete(t, resp)
   720  		})
   721  	})
   722  }
   723  
   724  func TestIssue37(t *testing.T) {
   725  	// ref: https://github.com/cavaliergopher/grab/v3/issues/37
   726  	filename := "./.testIssue37"
   727  	largeSize := int64(2097152)
   728  	smallSize := int64(1048576)
   729  	defer os.RemoveAll(filename)
   730  
   731  	// download large file
   732  	grabtest.WithTestServer(t, func(url string) {
   733  		resp := mustDo(mustNewRequest(filename, url))
   734  		if resp.Size() != largeSize {
   735  			t.Errorf("expected response size: %d, got: %d", largeSize, resp.Size())
   736  		}
   737  	}, grabtest.ContentLength(int(largeSize)))
   738  
   739  	// download new, smaller version of same file
   740  	grabtest.WithTestServer(t, func(url string) {
   741  		req := mustNewRequest(filename, url)
   742  		req.NoResume = true
   743  		resp := mustDo(req)
   744  		if resp.Size() != smallSize {
   745  			t.Errorf("expected response size: %d, got: %d", smallSize, resp.Size())
   746  		}
   747  
   748  		// local file should have truncated and not resumed
   749  		if resp.DidResume {
   750  			t.Errorf("expected download to truncate, resumed instead")
   751  		}
   752  	}, grabtest.ContentLength(int(smallSize)))
   753  
   754  	fi, err := os.Stat(filename)
   755  	if err != nil {
   756  		t.Fatal(err)
   757  	}
   758  	if fi.Size() != int64(smallSize) {
   759  		t.Errorf("expected file size %d, got %d", smallSize, fi.Size())
   760  	}
   761  }
   762  
   763  // TestHeadBadStatus validates that HEAD requests that return non-200 can be
   764  // ignored and succeed if the GET requests succeeeds.
   765  //
   766  // Fixes: https://github.com/cavaliergopher/grab/v3/issues/43
   767  func TestHeadBadStatus(t *testing.T) {
   768  	expect := http.StatusOK
   769  	filename := ".testIssue43"
   770  
   771  	statusFunc := func(r *http.Request) int {
   772  		if r.Method == "HEAD" {
   773  			return http.StatusForbidden
   774  		}
   775  		return http.StatusOK
   776  	}
   777  
   778  	grabtest.WithTestServer(t, func(url string) {
   779  		testURL := fmt.Sprintf("%s/%s", url, filename)
   780  		resp := mustDo(mustNewRequest("", testURL))
   781  		if resp.HTTPResponse.StatusCode != expect {
   782  			t.Errorf(
   783  				"expected status code: %d, got:% d",
   784  				expect,
   785  				resp.HTTPResponse.StatusCode)
   786  		}
   787  	},
   788  		grabtest.StatusCode(statusFunc),
   789  	)
   790  }
   791  
   792  // TestMissingContentLength ensures that the Response.Size is correct for
   793  // transfers where the remote server does not send a Content-Length header.
   794  //
   795  // TestAutoResume also covers cases with checksum validation.
   796  //
   797  // Kudos to Setnička Jiří <Jiri.Setnicka@ysoft.com> for identifying and raising
   798  // a solution to this issue. Ref: https://github.com/cavaliergopher/grab/v3/pull/27
   799  func TestMissingContentLength(t *testing.T) {
   800  	// expectSize must be sufficiently large that DefaultClient.Do won't prefetch
   801  	// the entire body and compute ContentLength before returning a Response.
   802  	expectSize := 1048576
   803  	opts := []grabtest.HandlerOption{
   804  		grabtest.ContentLength(expectSize),
   805  		grabtest.HeaderBlacklist("Content-Length"),
   806  		grabtest.TimeToFirstByte(time.Millisecond * 100), // delay for initial read
   807  	}
   808  	grabtest.WithTestServer(t, func(url string) {
   809  		req := mustNewRequest(".testMissingContentLength", url)
   810  		req.SetChecksum(
   811  			md5.New(),
   812  			grabtest.DefaultHandlerMD5ChecksumBytes,
   813  			false)
   814  		resp := DefaultClient.Do(req)
   815  
   816  		// ensure remote server is not sending content-length header
   817  		if v := resp.HTTPResponse.Header.Get("Content-Length"); v != "" {
   818  			panic(fmt.Sprintf("http header content length must be empty, got: %s", v))
   819  		}
   820  		if v := resp.HTTPResponse.ContentLength; v != -1 {
   821  			panic(fmt.Sprintf("http response content length must be -1, got: %d", v))
   822  		}
   823  
   824  		// before completion, response size should be -1
   825  		if resp.Size() != -1 {
   826  			t.Errorf("expected response size: -1, got: %d", resp.Size())
   827  		}
   828  
   829  		// block for completion
   830  		if err := resp.Err(); err != nil {
   831  			panic(err)
   832  		}
   833  
   834  		// on completion, response size should be actual transfer size
   835  		if resp.Size() != int64(expectSize) {
   836  			t.Errorf("expected response size: %d, got: %d", expectSize, resp.Size())
   837  		}
   838  	}, opts...)
   839  }
   840  
   841  func TestNoStore(t *testing.T) {
   842  	filename := ".testSubdir/testNoStore"
   843  	t.Run("DefaultCase", func(t *testing.T) {
   844  		grabtest.WithTestServer(t, func(url string) {
   845  			req := mustNewRequest(filename, url)
   846  			req.NoStore = true
   847  			req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true)
   848  			resp := mustDo(req)
   849  
   850  			// ensure Response.Bytes is correct and can be reread
   851  			b, err := resp.Bytes()
   852  			if err != nil {
   853  				panic(err)
   854  			}
   855  			grabtest.AssertSHA256Sum(
   856  				t,
   857  				grabtest.DefaultHandlerSHA256ChecksumBytes,
   858  				bytes.NewReader(b),
   859  			)
   860  
   861  			// ensure Response.Open stream is correct and can be reread
   862  			r, err := resp.Open()
   863  			if err != nil {
   864  				panic(err)
   865  			}
   866  			defer r.Close()
   867  			grabtest.AssertSHA256Sum(
   868  				t,
   869  				grabtest.DefaultHandlerSHA256ChecksumBytes,
   870  				r,
   871  			)
   872  
   873  			// Response.Filename should still be set
   874  			if resp.Filename != filename {
   875  				t.Errorf("expected Response.Filename: %s, got: %s", filename, resp.Filename)
   876  			}
   877  
   878  			// ensure no files were written
   879  			paths := []string{
   880  				filename,
   881  				filepath.Base(filename),
   882  				filepath.Dir(filename),
   883  				resp.Filename,
   884  				filepath.Base(resp.Filename),
   885  				filepath.Dir(resp.Filename),
   886  			}
   887  			for _, path := range paths {
   888  				_, err := os.Stat(path)
   889  				if !os.IsNotExist(err) {
   890  					t.Errorf(
   891  						"expect error: %v, got: %v, for path: %s",
   892  						os.ErrNotExist,
   893  						err,
   894  						path)
   895  				}
   896  			}
   897  		})
   898  	})
   899  
   900  	t.Run("ChecksumValidation", func(t *testing.T) {
   901  		grabtest.WithTestServer(t, func(url string) {
   902  			req := mustNewRequest("", url)
   903  			req.NoStore = true
   904  			req.SetChecksum(
   905  				md5.New(),
   906  				grabtest.MustHexDecodeString("deadbeefcafebabe"),
   907  				true)
   908  			resp := DefaultClient.Do(req)
   909  			if err := resp.Err(); err != ErrBadChecksum {
   910  				t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
   911  			}
   912  		})
   913  	})
   914  }