github.com/isi-lincoln/grab@v2.0.1-0.20200331080741-9f014744ee41+incompatible/client.go (about)

     1  package grab
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"os"
    10  	"path/filepath"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  )
    15  
    16  // HTTPClient provides an interface allowing us to perform HTTP requests.
    17  type HTTPClient interface {
    18  	Do(req *http.Request) (*http.Response, error)
    19  }
    20  
    21  // A Client is a file download client.
    22  //
    23  // Clients are safe for concurrent use by multiple goroutines.
    24  type Client struct {
    25  	// HTTPClient specifies the http.Client which will be used for communicating
    26  	// with the remote server during the file transfer.
    27  	HTTPClient HTTPClient
    28  
    29  	// UserAgent specifies the User-Agent string which will be set in the
    30  	// headers of all requests made by this client.
    31  	//
    32  	// The user agent string may be overridden in the headers of each request.
    33  	UserAgent string
    34  
    35  	// BufferSize specifies the size in bytes of the buffer that is used for
    36  	// transferring all requested files. Larger buffers may result in faster
    37  	// throughput but will use more memory and result in less frequent updates
    38  	// to the transfer progress statistics. The BufferSize of each request can
    39  	// be overridden on each Request object. Default: 32KB.
    40  	BufferSize int
    41  }
    42  
    43  // NewClient returns a new file download Client, using default configuration.
    44  func NewClient() *Client {
    45  	return &Client{
    46  		UserAgent: "grab",
    47  		HTTPClient: &http.Client{
    48  			Transport: &http.Transport{
    49  				Proxy: http.ProxyFromEnvironment,
    50  			},
    51  		},
    52  	}
    53  }
    54  
    55  // DefaultClient is the default client and is used by all Get convenience
    56  // functions.
    57  var DefaultClient = NewClient()
    58  
    59  // Do sends a file transfer request and returns a file transfer response,
    60  // following policy (e.g. redirects, cookies, auth) as configured on the
    61  // client's HTTPClient.
    62  //
    63  // Like http.Get, Do blocks while the transfer is initiated, but returns as soon
    64  // as the transfer has started transferring in a background goroutine, or if it
    65  // failed early.
    66  //
    67  // An error is returned via Response.Err if caused by client policy (such as
    68  // CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err
    69  // will block the caller until the transfer is completed, successfully or
    70  // otherwise.
    71  func (c *Client) Do(req *Request) *Response {
    72  	// cancel will be called on all code-paths via closeResponse
    73  	ctx, cancel := context.WithCancel(req.Context())
    74  	resp := &Response{
    75  		Request:    req,
    76  		Start:      time.Now(),
    77  		Done:       make(chan struct{}, 0),
    78  		Filename:   req.Filename,
    79  		ctx:        ctx,
    80  		cancel:     cancel,
    81  		bufferSize: req.BufferSize,
    82  	}
    83  	if resp.bufferSize == 0 {
    84  		// default to Client.BufferSize
    85  		resp.bufferSize = c.BufferSize
    86  	}
    87  
    88  	// Run state-machine while caller is blocked to initialize the file transfer.
    89  	// Must never transition to the copyFile state - this happens next in another
    90  	// goroutine.
    91  	c.run(resp, c.statFileInfo)
    92  
    93  	// Run copyFile in a new goroutine. copyFile will no-op if the transfer is
    94  	// already complete or failed.
    95  	go c.run(resp, c.copyFile)
    96  	return resp
    97  }
    98  
    99  // DoChannel executes all requests sent through the given Request channel, one
   100  // at a time, until it is closed by another goroutine. The caller is blocked
   101  // until the Request channel is closed and all transfers have completed. All
   102  // responses are sent through the given Response channel as soon as they are
   103  // received from the remote servers and can be used to track the progress of
   104  // each download.
   105  //
   106  // Slow Response receivers will cause a worker to block and therefore delay the
   107  // start of the transfer for an already initiated connection - potentially
   108  // causing a server timeout. It is the caller's responsibility to ensure a
   109  // sufficient buffer size is used for the Response channel to prevent this.
   110  //
   111  // If an error occurs during any of the file transfers it will be accessible via
   112  // the associated Response.Err function.
   113  func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) {
   114  	// TODO: enable cancelling of batch jobs
   115  	for req := range reqch {
   116  		resp := c.Do(req)
   117  		respch <- resp
   118  		<-resp.Done
   119  	}
   120  }
   121  
   122  // DoBatch executes all the given requests using the given number of concurrent
   123  // workers. Control is passed back to the caller as soon as the workers are
   124  // initiated.
   125  //
   126  // If the requested number of workers is less than one, a worker will be created
   127  // for every request. I.e. all requests will be executed concurrently.
   128  //
   129  // If an error occurs during any of the file transfers it will be accessible via
   130  // call to the associated Response.Err.
   131  //
   132  // The returned Response channel is closed only after all of the given Requests
   133  // have completed, successfully or otherwise.
   134  func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response {
   135  	if workers < 1 {
   136  		workers = len(requests)
   137  	}
   138  	reqch := make(chan *Request, len(requests))
   139  	respch := make(chan *Response, len(requests))
   140  	wg := sync.WaitGroup{}
   141  	for i := 0; i < workers; i++ {
   142  		wg.Add(1)
   143  		go func() {
   144  			c.DoChannel(reqch, respch)
   145  			wg.Done()
   146  		}()
   147  	}
   148  
   149  	// queue requests
   150  	go func() {
   151  		for _, req := range requests {
   152  			reqch <- req
   153  		}
   154  		close(reqch)
   155  		wg.Wait()
   156  		close(respch)
   157  	}()
   158  	return respch
   159  }
   160  
   161  // An stateFunc is an action that mutates the state of a Response and returns
   162  // the next stateFunc to be called.
   163  type stateFunc func(*Response) stateFunc
   164  
   165  // run calls the given stateFunc function and all subsequent returned stateFuncs
   166  // until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc
   167  // should mutate the state of the given Response until it has completed
   168  // downloading or failed.
   169  func (c *Client) run(resp *Response, f stateFunc) {
   170  	for {
   171  		select {
   172  		case <-resp.ctx.Done():
   173  			if resp.IsComplete() {
   174  				return
   175  			}
   176  			resp.err = resp.ctx.Err()
   177  			f = c.closeResponse
   178  
   179  		default:
   180  			// keep working
   181  		}
   182  		if f = f(resp); f == nil {
   183  			return
   184  		}
   185  	}
   186  }
   187  
   188  // statFileInfo retrieves FileInfo for any local file matching
   189  // Response.Filename.
   190  //
   191  // If the file does not exist, is a directory, or its name is unknown the next
   192  // stateFunc is headRequest.
   193  //
   194  // If the file exists, Response.fi is set and the next stateFunc is
   195  // validateLocal.
   196  //
   197  // If an error occurs, the next stateFunc is closeResponse.
   198  func (c *Client) statFileInfo(resp *Response) stateFunc {
   199  	if resp.Request.NoStore || resp.Filename == "" {
   200  		return c.headRequest
   201  	}
   202  	fi, err := os.Stat(resp.Filename)
   203  	if err != nil {
   204  		if os.IsNotExist(err) {
   205  			return c.headRequest
   206  		}
   207  		resp.err = err
   208  		return c.closeResponse
   209  	}
   210  	if fi.IsDir() {
   211  		resp.Filename = ""
   212  		return c.headRequest
   213  	}
   214  	resp.fi = fi
   215  	return c.validateLocal
   216  }
   217  
   218  // validateLocal compares a local copy of the downloaded file to the remote
   219  // file.
   220  //
   221  // An error is returned if the local file is larger than the remote file, or
   222  // Request.SkipExisting is true.
   223  //
   224  // If the existing file matches the length of the remote file, the next
   225  // stateFunc is checksumFile.
   226  //
   227  // If the local file is smaller than the remote file and the remote server is
   228  // known to support ranged requests, the next stateFunc is getRequest.
   229  func (c *Client) validateLocal(resp *Response) stateFunc {
   230  	if resp.Request.SkipExisting {
   231  		resp.err = ErrFileExists
   232  		return c.closeResponse
   233  	}
   234  
   235  	// determine target file size
   236  	expectedSize := resp.Request.Size
   237  	if expectedSize == 0 && resp.HTTPResponse != nil {
   238  		expectedSize = resp.HTTPResponse.ContentLength
   239  	}
   240  
   241  	if expectedSize == 0 {
   242  		// size is either actually 0 or unknown
   243  		// if unknown, we ask the remote server
   244  		// if known to be 0, we proceed with a GET
   245  		return c.headRequest
   246  	}
   247  
   248  	if expectedSize == resp.fi.Size() {
   249  		// local file matches remote file size - wrap it up
   250  		resp.DidResume = true
   251  		resp.bytesResumed = resp.fi.Size()
   252  		return c.checksumFile
   253  	}
   254  
   255  	if resp.Request.NoResume {
   256  		// local file should be overwritten
   257  		return c.getRequest
   258  	}
   259  
   260  	if expectedSize >= 0 && expectedSize < resp.fi.Size() {
   261  		// remote size is known, is smaller than local size and we want to resume
   262  		resp.err = ErrBadLength
   263  		return c.closeResponse
   264  	}
   265  
   266  	if resp.CanResume {
   267  		// set resume range on GET request
   268  		resp.Request.HTTPRequest.Header.Set(
   269  			"Range",
   270  			fmt.Sprintf("bytes=%d-", resp.fi.Size()))
   271  		resp.DidResume = true
   272  		resp.bytesResumed = resp.fi.Size()
   273  		return c.getRequest
   274  	}
   275  	return c.headRequest
   276  }
   277  
   278  func (c *Client) checksumFile(resp *Response) stateFunc {
   279  	if resp.Request.hash == nil {
   280  		return c.closeResponse
   281  	}
   282  	if resp.Filename == "" {
   283  		panic("grab: developer error: filename not set")
   284  	}
   285  	if resp.Size() < 0 {
   286  		panic("grab: developer error: size unknown")
   287  	}
   288  	req := resp.Request
   289  
   290  	// compute checksum
   291  	var sum []byte
   292  	sum, resp.err = resp.checksumUnsafe()
   293  	if resp.err != nil {
   294  		return c.closeResponse
   295  	}
   296  
   297  	// compare checksum
   298  	if !bytes.Equal(sum, req.checksum) {
   299  		resp.err = ErrBadChecksum
   300  		if !resp.Request.NoStore && req.deleteOnError {
   301  			if err := os.Remove(resp.Filename); err != nil {
   302  				// err should be os.PathError and include file path
   303  				resp.err = fmt.Errorf(
   304  					"cannot remove downloaded file with checksum mismatch: %v",
   305  					err)
   306  			}
   307  		}
   308  	}
   309  	return c.closeResponse
   310  }
   311  
   312  // doHTTPRequest sends a HTTP Request and returns the response
   313  func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) {
   314  	if c.UserAgent != "" && req.Header.Get("User-Agent") == "" {
   315  		req.Header.Set("User-Agent", c.UserAgent)
   316  	}
   317  	return c.HTTPClient.Do(req)
   318  }
   319  
   320  func (c *Client) headRequest(resp *Response) stateFunc {
   321  	if resp.optionsKnown {
   322  		return c.getRequest
   323  	}
   324  	resp.optionsKnown = true
   325  
   326  	if resp.Request.NoResume {
   327  		return c.getRequest
   328  	}
   329  
   330  	if resp.Filename != "" && resp.fi == nil {
   331  		// destination path is already known and does not exist
   332  		return c.getRequest
   333  	}
   334  
   335  	hreq := new(http.Request)
   336  	*hreq = *resp.Request.HTTPRequest
   337  	hreq.Method = "HEAD"
   338  
   339  	resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq)
   340  	if resp.err != nil {
   341  		return c.closeResponse
   342  	}
   343  	resp.HTTPResponse.Body.Close()
   344  
   345  	if resp.HTTPResponse.StatusCode != http.StatusOK {
   346  		return c.getRequest
   347  	}
   348  
   349  	return c.readResponse
   350  }
   351  
   352  func (c *Client) getRequest(resp *Response) stateFunc {
   353  	resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest)
   354  	if resp.err != nil {
   355  		return c.closeResponse
   356  	}
   357  
   358  	// TODO: check Content-Range
   359  
   360  	// check status code
   361  	if !resp.Request.IgnoreBadStatusCodes {
   362  		if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 {
   363  			resp.err = StatusCodeError(resp.HTTPResponse.StatusCode)
   364  			return c.closeResponse
   365  		}
   366  	}
   367  
   368  	return c.readResponse
   369  }
   370  
   371  func (c *Client) readResponse(resp *Response) stateFunc {
   372  	if resp.HTTPResponse == nil {
   373  		panic("grab: developer error: Response.HTTPResponse is nil")
   374  	}
   375  
   376  	// check expected size
   377  	resp.sizeUnsafe = resp.HTTPResponse.ContentLength
   378  	if resp.sizeUnsafe >= 0 {
   379  		// remote size is known
   380  		resp.sizeUnsafe += resp.bytesResumed
   381  		if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe {
   382  			resp.err = ErrBadLength
   383  			return c.closeResponse
   384  		}
   385  	}
   386  
   387  	// check filename
   388  	if resp.Filename == "" {
   389  		filename, err := guessFilename(resp.HTTPResponse)
   390  		if err != nil {
   391  			resp.err = err
   392  			return c.closeResponse
   393  		}
   394  		// Request.Filename will be empty or a directory
   395  		resp.Filename = filepath.Join(resp.Request.Filename, filename)
   396  	}
   397  
   398  	if !resp.Request.NoStore && resp.requestMethod() == "HEAD" {
   399  		if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" {
   400  			resp.CanResume = true
   401  		}
   402  		return c.statFileInfo
   403  	}
   404  	return c.openWriter
   405  }
   406  
   407  // openWriter opens the destination file for writing and seeks to the location
   408  // from whence the file transfer will resume.
   409  //
   410  // Requires that Response.Filename and resp.DidResume are already be set.
   411  func (c *Client) openWriter(resp *Response) stateFunc {
   412  	if !resp.Request.NoStore && !resp.Request.NoCreateDirectories {
   413  		resp.err = mkdirp(resp.Filename)
   414  		if resp.err != nil {
   415  			return c.closeResponse
   416  		}
   417  	}
   418  
   419  	if resp.Request.NoStore {
   420  		resp.writer = &resp.storeBuffer
   421  	} else {
   422  		// compute write flags
   423  		flag := os.O_CREATE | os.O_WRONLY
   424  		if resp.fi != nil {
   425  			if resp.DidResume {
   426  				flag = os.O_APPEND | os.O_WRONLY
   427  			} else {
   428  				flag = os.O_TRUNC | os.O_WRONLY
   429  			}
   430  		}
   431  
   432  		// open file
   433  		f, err := os.OpenFile(resp.Filename, flag, 0644)
   434  		if err != nil {
   435  			resp.err = err
   436  			return c.closeResponse
   437  		}
   438  		resp.writer = f
   439  
   440  		// seek to start or end
   441  		whence := os.SEEK_SET
   442  		if resp.bytesResumed > 0 {
   443  			whence = os.SEEK_END
   444  		}
   445  		_, resp.err = f.Seek(0, whence)
   446  		if resp.err != nil {
   447  			return c.closeResponse
   448  		}
   449  	}
   450  
   451  	// init transfer
   452  	if resp.bufferSize < 1 {
   453  		resp.bufferSize = 32 * 1024
   454  	}
   455  	b := make([]byte, resp.bufferSize)
   456  	resp.transfer = newTransfer(
   457  		resp.Request.Context(),
   458  		resp.Request.RateLimiter,
   459  		resp.writer,
   460  		resp.HTTPResponse.Body,
   461  		b)
   462  
   463  	// next step is copyFile, but this will be called later in another goroutine
   464  	return nil
   465  }
   466  
   467  // copy transfers content for a HTTP connection established via Client.do()
   468  func (c *Client) copyFile(resp *Response) stateFunc {
   469  	if resp.IsComplete() {
   470  		return nil
   471  	}
   472  
   473  	// run BeforeCopy hook
   474  	if f := resp.Request.BeforeCopy; f != nil {
   475  		resp.err = f(resp)
   476  		if resp.err != nil {
   477  			return c.closeResponse
   478  		}
   479  	}
   480  
   481  	var bytesCopied int64
   482  	if resp.transfer == nil {
   483  		panic("grab: developer error: Response.transfer is nil")
   484  	}
   485  	bytesCopied, resp.err = resp.transfer.copy()
   486  	if resp.err != nil {
   487  		return c.closeResponse
   488  	}
   489  	closeWriter(resp)
   490  
   491  	// set file timestamp
   492  	if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime {
   493  		resp.err = setLastModified(resp.HTTPResponse, resp.Filename)
   494  		if resp.err != nil {
   495  			return c.closeResponse
   496  		}
   497  	}
   498  
   499  	// update transfer size if previously unknown
   500  	if resp.Size() < 0 {
   501  		discoveredSize := resp.bytesResumed + bytesCopied
   502  		atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize)
   503  		if resp.Request.Size > 0 && resp.Request.Size != discoveredSize {
   504  			resp.err = ErrBadLength
   505  			return c.closeResponse
   506  		}
   507  	}
   508  
   509  	// run AfterCopy hook
   510  	if f := resp.Request.AfterCopy; f != nil {
   511  		resp.err = f(resp)
   512  		if resp.err != nil {
   513  			return c.closeResponse
   514  		}
   515  	}
   516  
   517  	return c.checksumFile
   518  }
   519  
   520  func closeWriter(resp *Response) {
   521  	if closer, ok := resp.writer.(io.Closer); ok {
   522  		closer.Close()
   523  	}
   524  	resp.writer = nil
   525  }
   526  
   527  // close finalizes the Response
   528  func (c *Client) closeResponse(resp *Response) stateFunc {
   529  	if resp.IsComplete() {
   530  		panic("grab: developer error: response already closed")
   531  	}
   532  
   533  	resp.fi = nil
   534  	closeWriter(resp)
   535  	resp.closeResponseBody()
   536  
   537  	resp.End = time.Now()
   538  	close(resp.Done)
   539  	if resp.cancel != nil {
   540  		resp.cancel()
   541  	}
   542  
   543  	return nil
   544  }