github.com/cavaliergopher/grab/v3@v3.0.1/client.go (about)

     1  package grab
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"os"
    10  	"path/filepath"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  )
    15  
    16  // HTTPClient provides an interface allowing us to perform HTTP requests.
    17  type HTTPClient interface {
    18  	Do(req *http.Request) (*http.Response, error)
    19  }
    20  
    21  // truncater is a private interface allowing different response
    22  // Writers to be truncated
    23  type truncater interface {
    24  	Truncate(size int64) error
    25  }
    26  
    27  // A Client is a file download client.
    28  //
    29  // Clients are safe for concurrent use by multiple goroutines.
    30  type Client struct {
    31  	// HTTPClient specifies the http.Client which will be used for communicating
    32  	// with the remote server during the file transfer.
    33  	HTTPClient HTTPClient
    34  
    35  	// UserAgent specifies the User-Agent string which will be set in the
    36  	// headers of all requests made by this client.
    37  	//
    38  	// The user agent string may be overridden in the headers of each request.
    39  	UserAgent string
    40  
    41  	// BufferSize specifies the size in bytes of the buffer that is used for
    42  	// transferring all requested files. Larger buffers may result in faster
    43  	// throughput but will use more memory and result in less frequent updates
    44  	// to the transfer progress statistics. The BufferSize of each request can
    45  	// be overridden on each Request object. Default: 32KB.
    46  	BufferSize int
    47  }
    48  
    49  // NewClient returns a new file download Client, using default configuration.
    50  func NewClient() *Client {
    51  	return &Client{
    52  		UserAgent: "grab",
    53  		HTTPClient: &http.Client{
    54  			Transport: &http.Transport{
    55  				Proxy: http.ProxyFromEnvironment,
    56  			},
    57  		},
    58  	}
    59  }
    60  
    61  // DefaultClient is the default client and is used by all Get convenience
    62  // functions.
    63  var DefaultClient = NewClient()
    64  
    65  // Do sends a file transfer request and returns a file transfer response,
    66  // following policy (e.g. redirects, cookies, auth) as configured on the
    67  // client's HTTPClient.
    68  //
    69  // Like http.Get, Do blocks while the transfer is initiated, but returns as soon
    70  // as the transfer has started transferring in a background goroutine, or if it
    71  // failed early.
    72  //
    73  // An error is returned via Response.Err if caused by client policy (such as
    74  // CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err
    75  // will block the caller until the transfer is completed, successfully or
    76  // otherwise.
    77  func (c *Client) Do(req *Request) *Response {
    78  	// cancel will be called on all code-paths via closeResponse
    79  	ctx, cancel := context.WithCancel(req.Context())
    80  	req = req.WithContext(ctx)
    81  	resp := &Response{
    82  		Request:    req,
    83  		Start:      time.Now(),
    84  		Done:       make(chan struct{}, 0),
    85  		Filename:   req.Filename,
    86  		ctx:        ctx,
    87  		cancel:     cancel,
    88  		bufferSize: req.BufferSize,
    89  	}
    90  	if resp.bufferSize == 0 {
    91  		// default to Client.BufferSize
    92  		resp.bufferSize = c.BufferSize
    93  	}
    94  
    95  	// Run state-machine while caller is blocked to initialize the file transfer.
    96  	// Must never transition to the copyFile state - this happens next in another
    97  	// goroutine.
    98  	c.run(resp, c.statFileInfo)
    99  
   100  	// Run copyFile in a new goroutine. copyFile will no-op if the transfer is
   101  	// already complete or failed.
   102  	go c.run(resp, c.copyFile)
   103  	return resp
   104  }
   105  
   106  // DoChannel executes all requests sent through the given Request channel, one
   107  // at a time, until it is closed by another goroutine. The caller is blocked
   108  // until the Request channel is closed and all transfers have completed. All
   109  // responses are sent through the given Response channel as soon as they are
   110  // received from the remote servers and can be used to track the progress of
   111  // each download.
   112  //
   113  // Slow Response receivers will cause a worker to block and therefore delay the
   114  // start of the transfer for an already initiated connection - potentially
   115  // causing a server timeout. It is the caller's responsibility to ensure a
   116  // sufficient buffer size is used for the Response channel to prevent this.
   117  //
   118  // If an error occurs during any of the file transfers it will be accessible via
   119  // the associated Response.Err function.
   120  func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) {
   121  	// TODO: enable cancelling of batch jobs
   122  	for req := range reqch {
   123  		resp := c.Do(req)
   124  		respch <- resp
   125  		<-resp.Done
   126  	}
   127  }
   128  
   129  // DoBatch executes all the given requests using the given number of concurrent
   130  // workers. Control is passed back to the caller as soon as the workers are
   131  // initiated.
   132  //
   133  // If the requested number of workers is less than one, a worker will be created
   134  // for every request. I.e. all requests will be executed concurrently.
   135  //
   136  // If an error occurs during any of the file transfers it will be accessible via
   137  // call to the associated Response.Err.
   138  //
   139  // The returned Response channel is closed only after all of the given Requests
   140  // have completed, successfully or otherwise.
   141  func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response {
   142  	if workers < 1 {
   143  		workers = len(requests)
   144  	}
   145  	reqch := make(chan *Request, len(requests))
   146  	respch := make(chan *Response, len(requests))
   147  	wg := sync.WaitGroup{}
   148  	for i := 0; i < workers; i++ {
   149  		wg.Add(1)
   150  		go func() {
   151  			c.DoChannel(reqch, respch)
   152  			wg.Done()
   153  		}()
   154  	}
   155  
   156  	// queue requests
   157  	go func() {
   158  		for _, req := range requests {
   159  			reqch <- req
   160  		}
   161  		close(reqch)
   162  		wg.Wait()
   163  		close(respch)
   164  	}()
   165  	return respch
   166  }
   167  
   168  // An stateFunc is an action that mutates the state of a Response and returns
   169  // the next stateFunc to be called.
   170  type stateFunc func(*Response) stateFunc
   171  
   172  // run calls the given stateFunc function and all subsequent returned stateFuncs
   173  // until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc
   174  // should mutate the state of the given Response until it has completed
   175  // downloading or failed.
   176  func (c *Client) run(resp *Response, f stateFunc) {
   177  	for {
   178  		select {
   179  		case <-resp.ctx.Done():
   180  			if resp.IsComplete() {
   181  				return
   182  			}
   183  			resp.err = resp.ctx.Err()
   184  			f = c.closeResponse
   185  
   186  		default:
   187  			// keep working
   188  		}
   189  		if f = f(resp); f == nil {
   190  			return
   191  		}
   192  	}
   193  }
   194  
   195  // statFileInfo retrieves FileInfo for any local file matching
   196  // Response.Filename.
   197  //
   198  // If the file does not exist, is a directory, or its name is unknown the next
   199  // stateFunc is headRequest.
   200  //
   201  // If the file exists, Response.fi is set and the next stateFunc is
   202  // validateLocal.
   203  //
   204  // If an error occurs, the next stateFunc is closeResponse.
   205  func (c *Client) statFileInfo(resp *Response) stateFunc {
   206  	if resp.Request.NoStore || resp.Filename == "" {
   207  		return c.headRequest
   208  	}
   209  	fi, err := os.Stat(resp.Filename)
   210  	if err != nil {
   211  		if os.IsNotExist(err) {
   212  			return c.headRequest
   213  		}
   214  		resp.err = err
   215  		return c.closeResponse
   216  	}
   217  	if fi.IsDir() {
   218  		resp.Filename = ""
   219  		return c.headRequest
   220  	}
   221  	resp.fi = fi
   222  	return c.validateLocal
   223  }
   224  
   225  // validateLocal compares a local copy of the downloaded file to the remote
   226  // file.
   227  //
   228  // An error is returned if the local file is larger than the remote file, or
   229  // Request.SkipExisting is true.
   230  //
   231  // If the existing file matches the length of the remote file, the next
   232  // stateFunc is checksumFile.
   233  //
   234  // If the local file is smaller than the remote file and the remote server is
   235  // known to support ranged requests, the next stateFunc is getRequest.
   236  func (c *Client) validateLocal(resp *Response) stateFunc {
   237  	if resp.Request.SkipExisting {
   238  		resp.err = ErrFileExists
   239  		return c.closeResponse
   240  	}
   241  
   242  	// determine target file size
   243  	expectedSize := resp.Request.Size
   244  	if expectedSize == 0 && resp.HTTPResponse != nil {
   245  		expectedSize = resp.HTTPResponse.ContentLength
   246  	}
   247  
   248  	if expectedSize == 0 {
   249  		// size is either actually 0 or unknown
   250  		// if unknown, we ask the remote server
   251  		// if known to be 0, we proceed with a GET
   252  		return c.headRequest
   253  	}
   254  
   255  	if expectedSize == resp.fi.Size() {
   256  		// local file matches remote file size - wrap it up
   257  		resp.DidResume = true
   258  		resp.bytesResumed = resp.fi.Size()
   259  		return c.checksumFile
   260  	}
   261  
   262  	if resp.Request.NoResume {
   263  		// local file should be overwritten
   264  		return c.getRequest
   265  	}
   266  
   267  	if expectedSize >= 0 && expectedSize < resp.fi.Size() {
   268  		// remote size is known, is smaller than local size and we want to resume
   269  		resp.err = ErrBadLength
   270  		return c.closeResponse
   271  	}
   272  
   273  	if resp.CanResume {
   274  		// set resume range on GET request
   275  		resp.Request.HTTPRequest.Header.Set(
   276  			"Range",
   277  			fmt.Sprintf("bytes=%d-", resp.fi.Size()))
   278  		resp.DidResume = true
   279  		resp.bytesResumed = resp.fi.Size()
   280  		return c.getRequest
   281  	}
   282  	return c.headRequest
   283  }
   284  
   285  func (c *Client) checksumFile(resp *Response) stateFunc {
   286  	if resp.Request.hash == nil {
   287  		return c.closeResponse
   288  	}
   289  	if resp.Filename == "" {
   290  		panic("grab: developer error: filename not set")
   291  	}
   292  	if resp.Size() < 0 {
   293  		panic("grab: developer error: size unknown")
   294  	}
   295  	req := resp.Request
   296  
   297  	// compute checksum
   298  	var sum []byte
   299  	sum, resp.err = resp.checksumUnsafe()
   300  	if resp.err != nil {
   301  		return c.closeResponse
   302  	}
   303  
   304  	// compare checksum
   305  	if !bytes.Equal(sum, req.checksum) {
   306  		resp.err = ErrBadChecksum
   307  		if !resp.Request.NoStore && req.deleteOnError {
   308  			if err := os.Remove(resp.Filename); err != nil {
   309  				// err should be os.PathError and include file path
   310  				resp.err = fmt.Errorf(
   311  					"cannot remove downloaded file with checksum mismatch: %v",
   312  					err)
   313  			}
   314  		}
   315  	}
   316  	return c.closeResponse
   317  }
   318  
   319  // doHTTPRequest sends a HTTP Request and returns the response
   320  func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) {
   321  	if c.UserAgent != "" && req.Header.Get("User-Agent") == "" {
   322  		req.Header.Set("User-Agent", c.UserAgent)
   323  	}
   324  	return c.HTTPClient.Do(req)
   325  }
   326  
   327  func (c *Client) headRequest(resp *Response) stateFunc {
   328  	if resp.optionsKnown {
   329  		return c.getRequest
   330  	}
   331  	resp.optionsKnown = true
   332  
   333  	if resp.Request.NoResume {
   334  		return c.getRequest
   335  	}
   336  
   337  	if resp.Filename != "" && resp.fi == nil {
   338  		// destination path is already known and does not exist
   339  		return c.getRequest
   340  	}
   341  
   342  	hreq := new(http.Request)
   343  	*hreq = *resp.Request.HTTPRequest
   344  	hreq.Method = "HEAD"
   345  
   346  	resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq)
   347  	if resp.err != nil {
   348  		return c.closeResponse
   349  	}
   350  	resp.HTTPResponse.Body.Close()
   351  
   352  	if resp.HTTPResponse.StatusCode != http.StatusOK {
   353  		return c.getRequest
   354  	}
   355  
   356  	// In case of redirects during HEAD, record the final URL and use it
   357  	// instead of the original URL when sending future requests.
   358  	// This way we avoid sending potentially unsupported requests to
   359  	// the original URL, e.g. "Range", since it was the final URL
   360  	// that advertised its support.
   361  	resp.Request.HTTPRequest.URL = resp.HTTPResponse.Request.URL
   362  	resp.Request.HTTPRequest.Host = resp.HTTPResponse.Request.Host
   363  
   364  	return c.readResponse
   365  }
   366  
   367  func (c *Client) getRequest(resp *Response) stateFunc {
   368  	resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest)
   369  	if resp.err != nil {
   370  		return c.closeResponse
   371  	}
   372  
   373  	// TODO: check Content-Range
   374  
   375  	// check status code
   376  	if !resp.Request.IgnoreBadStatusCodes {
   377  		if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 {
   378  			resp.err = StatusCodeError(resp.HTTPResponse.StatusCode)
   379  			return c.closeResponse
   380  		}
   381  	}
   382  
   383  	return c.readResponse
   384  }
   385  
   386  func (c *Client) readResponse(resp *Response) stateFunc {
   387  	if resp.HTTPResponse == nil {
   388  		panic("grab: developer error: Response.HTTPResponse is nil")
   389  	}
   390  
   391  	// check expected size
   392  	resp.sizeUnsafe = resp.HTTPResponse.ContentLength
   393  	if resp.sizeUnsafe >= 0 {
   394  		// remote size is known
   395  		resp.sizeUnsafe += resp.bytesResumed
   396  		if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe {
   397  			resp.err = ErrBadLength
   398  			return c.closeResponse
   399  		}
   400  	}
   401  
   402  	// check filename
   403  	if resp.Filename == "" {
   404  		filename, err := guessFilename(resp.HTTPResponse)
   405  		if err != nil {
   406  			resp.err = err
   407  			return c.closeResponse
   408  		}
   409  		// Request.Filename will be empty or a directory
   410  		resp.Filename = filepath.Join(resp.Request.Filename, filename)
   411  	}
   412  
   413  	if !resp.Request.NoStore && resp.requestMethod() == "HEAD" {
   414  		if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" {
   415  			resp.CanResume = true
   416  		}
   417  		return c.statFileInfo
   418  	}
   419  	return c.openWriter
   420  }
   421  
   422  // openWriter opens the destination file for writing and seeks to the location
   423  // from whence the file transfer will resume.
   424  //
   425  // Requires that Response.Filename and resp.DidResume are already be set.
   426  func (c *Client) openWriter(resp *Response) stateFunc {
   427  	if !resp.Request.NoStore && !resp.Request.NoCreateDirectories {
   428  		resp.err = mkdirp(resp.Filename)
   429  		if resp.err != nil {
   430  			return c.closeResponse
   431  		}
   432  	}
   433  
   434  	if resp.Request.NoStore {
   435  		resp.writer = &resp.storeBuffer
   436  	} else {
   437  		// compute write flags
   438  		flag := os.O_CREATE | os.O_WRONLY
   439  		if resp.fi != nil {
   440  			if resp.DidResume {
   441  				flag = os.O_APPEND | os.O_WRONLY
   442  			} else {
   443  				// truncate later in copyFile, if not cancelled
   444  				// by BeforeCopy hook
   445  				flag = os.O_WRONLY
   446  			}
   447  		}
   448  
   449  		// open file
   450  		f, err := os.OpenFile(resp.Filename, flag, 0666)
   451  		if err != nil {
   452  			resp.err = err
   453  			return c.closeResponse
   454  		}
   455  		resp.writer = f
   456  
   457  		// seek to start or end
   458  		whence := os.SEEK_SET
   459  		if resp.bytesResumed > 0 {
   460  			whence = os.SEEK_END
   461  		}
   462  		_, resp.err = f.Seek(0, whence)
   463  		if resp.err != nil {
   464  			return c.closeResponse
   465  		}
   466  	}
   467  
   468  	// init transfer
   469  	if resp.bufferSize < 1 {
   470  		resp.bufferSize = 32 * 1024
   471  	}
   472  	b := make([]byte, resp.bufferSize)
   473  	resp.transfer = newTransfer(
   474  		resp.Request.Context(),
   475  		resp.Request.RateLimiter,
   476  		resp.writer,
   477  		resp.HTTPResponse.Body,
   478  		b)
   479  
   480  	// next step is copyFile, but this will be called later in another goroutine
   481  	return nil
   482  }
   483  
   484  // copy transfers content for a HTTP connection established via Client.do()
   485  func (c *Client) copyFile(resp *Response) stateFunc {
   486  	if resp.IsComplete() {
   487  		return nil
   488  	}
   489  
   490  	// run BeforeCopy hook
   491  	if f := resp.Request.BeforeCopy; f != nil {
   492  		resp.err = f(resp)
   493  		if resp.err != nil {
   494  			return c.closeResponse
   495  		}
   496  	}
   497  
   498  	var bytesCopied int64
   499  	if resp.transfer == nil {
   500  		panic("grab: developer error: Response.transfer is nil")
   501  	}
   502  
   503  	// We waited to truncate the file in openWriter() to make sure
   504  	// the BeforeCopy didn't cancel the copy. If this was an existing
   505  	// file that is not going to be resumed, truncate the contents.
   506  	if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume {
   507  		t.Truncate(0)
   508  	}
   509  
   510  	bytesCopied, resp.err = resp.transfer.copy()
   511  	if resp.err != nil {
   512  		return c.closeResponse
   513  	}
   514  	closeWriter(resp)
   515  
   516  	// set file timestamp
   517  	if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime {
   518  		resp.err = setLastModified(resp.HTTPResponse, resp.Filename)
   519  		if resp.err != nil {
   520  			return c.closeResponse
   521  		}
   522  	}
   523  
   524  	// update transfer size if previously unknown
   525  	if resp.Size() < 0 {
   526  		discoveredSize := resp.bytesResumed + bytesCopied
   527  		atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize)
   528  		if resp.Request.Size > 0 && resp.Request.Size != discoveredSize {
   529  			resp.err = ErrBadLength
   530  			return c.closeResponse
   531  		}
   532  	}
   533  
   534  	// run AfterCopy hook
   535  	if f := resp.Request.AfterCopy; f != nil {
   536  		resp.err = f(resp)
   537  		if resp.err != nil {
   538  			return c.closeResponse
   539  		}
   540  	}
   541  
   542  	return c.checksumFile
   543  }
   544  
   545  func closeWriter(resp *Response) {
   546  	if closer, ok := resp.writer.(io.Closer); ok {
   547  		closer.Close()
   548  	}
   549  	resp.writer = nil
   550  }
   551  
   552  // close finalizes the Response
   553  func (c *Client) closeResponse(resp *Response) stateFunc {
   554  	if resp.IsComplete() {
   555  		panic("grab: developer error: response already closed")
   556  	}
   557  
   558  	resp.fi = nil
   559  	closeWriter(resp)
   560  	resp.closeResponseBody()
   561  
   562  	resp.End = time.Now()
   563  	close(resp.Done)
   564  	if resp.cancel != nil {
   565  		resp.cancel()
   566  	}
   567  
   568  	return nil
   569  }