gopkg.in/cavaliercoder/grab.v2@v2.0.0/client.go (about)

     1  package grab
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"net/http"
     8  	"os"
     9  	"path/filepath"
    10  	"sync"
    11  	"time"
    12  )
    13  
    14  // A Client is a file download client.
    15  //
    16  // Clients are safe for concurrent use by multiple goroutines.
    17  type Client struct {
    18  	// HTTPClient specifies the http.Client which will be used for communicating
    19  	// with the remote server during the file transfer.
    20  	HTTPClient *http.Client
    21  
    22  	// UserAgent specifies the User-Agent string which will be set in the
    23  	// headers of all requests made by this client.
    24  	//
    25  	// The user agent string may be overridden in the headers of each request.
    26  	UserAgent string
    27  
    28  	// BufferSize specifies the size in bytes of the buffer that is used for
    29  	// transferring all requested files. Larger buffers may result in faster
    30  	// throughput but will use more memory and result in less frequent updates
    31  	// to the transfer progress statistics. The BufferSize of each request can
    32  	// be overridden on each Request object. Default: 32KB.
    33  	BufferSize int
    34  }
    35  
    36  // NewClient returns a new file download Client, using default configuration.
    37  func NewClient() *Client {
    38  	return &Client{
    39  		UserAgent: "grab",
    40  		HTTPClient: &http.Client{
    41  			Transport: &http.Transport{
    42  				Proxy: http.ProxyFromEnvironment,
    43  			},
    44  		},
    45  	}
    46  }
    47  
    48  // DefaultClient is the default client and is used by all Get convenience
    49  // functions.
    50  var DefaultClient = NewClient()
    51  
    52  // Do sends a file transfer request and returns a file transfer response,
    53  // following policy (e.g. redirects, cookies, auth) as configured on the
    54  // client's HTTPClient.
    55  //
    56  // Like http.Get, Do blocks while the transfer is initiated, but returns as soon
    57  // as the transfer has started transferring in a background goroutine, or if it
    58  // failed early.
    59  //
    60  // An error is returned via Response.Err if caused by client policy (such as
    61  // CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err
    62  // will block the caller until the transfer is completed, successfully or
    63  // otherwise.
    64  func (c *Client) Do(req *Request) *Response {
    65  	// cancel will be called on all code-paths via closeResponse
    66  	ctx, cancel := context.WithCancel(req.Context())
    67  	resp := &Response{
    68  		Request:    req,
    69  		Start:      time.Now(),
    70  		Done:       make(chan struct{}, 0),
    71  		Filename:   req.Filename,
    72  		ctx:        ctx,
    73  		cancel:     cancel,
    74  		bufferSize: req.BufferSize,
    75  	}
    76  	if resp.bufferSize == 0 {
    77  		// default to Client.BufferSize
    78  		resp.bufferSize = c.BufferSize
    79  	}
    80  
    81  	// Run state-machine while caller is blocked to initialize the file transfer.
    82  	// Must never transition to the copyFile state - this happens next in another
    83  	// goroutine.
    84  	c.run(resp, c.statFileInfo)
    85  
    86  	// Run copyFile in a new goroutine. copyFile will no-op if the transfer is
    87  	// already complete or failed.
    88  	go c.run(resp, c.copyFile)
    89  	return resp
    90  }
    91  
    92  // DoChannel executes all requests sent through the given Request channel, one
    93  // at a time, until it is closed by another goroutine. The caller is blocked
    94  // until the Request channel is closed and all transfers have completed. All
    95  // responses are sent through the given Response channel as soon as they are
    96  // received from the remote servers and can be used to track the progress of
    97  // each download.
    98  //
    99  // Slow Response receivers will cause a worker to block and therefore delay the
   100  // start of the transfer for an already initiated connection - potentially
   101  // causing a server timeout. It is the caller's responsibility to ensure a
   102  // sufficient buffer size is used for the Response channel to prevent this.
   103  //
   104  // If an error occurs during any of the file transfers it will be accessible via
   105  // the associated Response.Err function.
   106  func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) {
   107  	// TODO: enable cancelling of batch jobs
   108  	for req := range reqch {
   109  		resp := c.Do(req)
   110  		respch <- resp
   111  		<-resp.Done
   112  	}
   113  }
   114  
   115  // DoBatch executes all the given requests using the given number of concurrent
   116  // workers. Control is passed back to the caller as soon as the workers are
   117  // initiated.
   118  //
   119  // If the requested number of workers is less than one, a worker will be created
   120  // for every request. I.e. all requests will be executed concurrently.
   121  //
   122  // If an error occurs during any of the file transfers it will be accessible via
   123  // call to the associated Response.Err.
   124  //
   125  // The returned Response channel is closed only after all of the given Requests
   126  // have completed, successfully or otherwise.
   127  func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response {
   128  	if workers < 1 {
   129  		workers = len(requests)
   130  	}
   131  	reqch := make(chan *Request, len(requests))
   132  	respch := make(chan *Response, len(requests))
   133  	wg := sync.WaitGroup{}
   134  	for i := 0; i < workers; i++ {
   135  		wg.Add(1)
   136  		go func() {
   137  			c.DoChannel(reqch, respch)
   138  			wg.Done()
   139  		}()
   140  	}
   141  
   142  	// queue requests
   143  	go func() {
   144  		for _, req := range requests {
   145  			reqch <- req
   146  		}
   147  		close(reqch)
   148  		wg.Wait()
   149  		close(respch)
   150  	}()
   151  	return respch
   152  }
   153  
   154  // An stateFunc is an action that mutates the state of a Response and returns
   155  // the next stateFunc to be called.
   156  type stateFunc func(*Response) stateFunc
   157  
   158  // run calls the given stateFunc function and all subsequent returned stateFuncs
   159  // until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc
   160  // should mutate the state of the given Response until it has completed
   161  // downloading or failed.
   162  func (c *Client) run(resp *Response, f stateFunc) {
   163  	for {
   164  		select {
   165  		case <-resp.ctx.Done():
   166  			if resp.IsComplete() {
   167  				return
   168  			}
   169  			resp.err = resp.ctx.Err()
   170  			f = c.closeResponse
   171  
   172  		default:
   173  			// keep working
   174  		}
   175  		if f = f(resp); f == nil {
   176  			return
   177  		}
   178  	}
   179  }
   180  
   181  // statFileInfo retrieves FileInfo for any local file matching
   182  // Response.Filename.
   183  //
   184  // If the file does not exist, is a directory, or its name is unknown the next
   185  // stateFunc is headRequest.
   186  //
   187  // If the file exists, Response.fi is set and the next stateFunc is
   188  // validateLocal.
   189  //
   190  // If an error occurs, the next stateFunc is closeResponse.
   191  func (c *Client) statFileInfo(resp *Response) stateFunc {
   192  	if resp.Filename == "" {
   193  		return c.headRequest
   194  	}
   195  	fi, err := os.Stat(resp.Filename)
   196  	if err != nil {
   197  		if os.IsNotExist(err) {
   198  			return c.headRequest
   199  		}
   200  		resp.err = err
   201  		return c.closeResponse
   202  	}
   203  	if fi.IsDir() {
   204  		resp.Filename = ""
   205  		return c.headRequest
   206  	}
   207  	resp.fi = fi
   208  	return c.validateLocal
   209  }
   210  
   211  // validateLocal compares a local copy of the downloaded file to the remote
   212  // file.
   213  //
   214  // An error is returned if the local file is larger than the remote file, or
   215  // Request.SkipExisting is true.
   216  //
   217  // If the existing file matches the length of the remote file, the next
   218  // stateFunc is checksumFile.
   219  //
   220  // If the local file is smaller than the remote file and the remote server is
   221  // known to support ranged requests, the next stateFunc is getRequest.
   222  func (c *Client) validateLocal(resp *Response) stateFunc {
   223  	if resp.Request.SkipExisting {
   224  		resp.err = ErrFileExists
   225  		return c.closeResponse
   226  	}
   227  
   228  	// determine expected file size
   229  	size := resp.Request.Size
   230  	if size == 0 && resp.HTTPResponse != nil {
   231  		size = resp.HTTPResponse.ContentLength
   232  	}
   233  	if size == 0 {
   234  		return c.headRequest
   235  	}
   236  
   237  	if size == resp.fi.Size() {
   238  		resp.DidResume = true
   239  		resp.bytesResumed = resp.fi.Size()
   240  		return c.checksumFile
   241  	}
   242  
   243  	if resp.Request.NoResume {
   244  		return c.getRequest
   245  	}
   246  
   247  	if size < resp.fi.Size() {
   248  		resp.err = ErrBadLength
   249  		return c.closeResponse
   250  	}
   251  
   252  	if resp.CanResume {
   253  		resp.Request.HTTPRequest.Header.Set(
   254  			"Range",
   255  			fmt.Sprintf("bytes=%d-", resp.fi.Size()))
   256  		resp.DidResume = true
   257  		resp.bytesResumed = resp.fi.Size()
   258  		return c.getRequest
   259  	}
   260  	return c.headRequest
   261  }
   262  
   263  func (c *Client) checksumFile(resp *Response) stateFunc {
   264  	if resp.Request.hash == nil {
   265  		return c.closeResponse
   266  	}
   267  	if resp.Filename == "" {
   268  		panic("filename not set")
   269  	}
   270  	req := resp.Request
   271  
   272  	// compare checksum
   273  	var sum []byte
   274  	sum, resp.err = checksum(req.Context(), resp.Filename, req.hash)
   275  	if resp.err != nil {
   276  		return c.closeResponse
   277  	}
   278  	if !bytes.Equal(sum, req.checksum) {
   279  		resp.err = ErrBadChecksum
   280  		if req.deleteOnError {
   281  			if err := os.Remove(resp.Filename); err != nil {
   282  				// err should be os.PathError and include file path
   283  				resp.err = fmt.Errorf(
   284  					"cannot remove downloaded file with checksum mismatch: %v",
   285  					err)
   286  			}
   287  		}
   288  	}
   289  	return c.closeResponse
   290  }
   291  
   292  // doHTTPRequest sends a HTTP Request and returns the response
   293  func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) {
   294  	if c.UserAgent != "" && req.Header.Get("User-Agent") == "" {
   295  		req.Header.Set("User-Agent", c.UserAgent)
   296  	}
   297  	return c.HTTPClient.Do(req)
   298  }
   299  
   300  func (c *Client) headRequest(resp *Response) stateFunc {
   301  	if resp.optionsKnown {
   302  		return c.getRequest
   303  	}
   304  	resp.optionsKnown = true
   305  
   306  	if resp.Request.NoResume {
   307  		return c.getRequest
   308  	}
   309  
   310  	if resp.Filename != "" && resp.fi == nil {
   311  		// destination path is already known and does not exist
   312  		return c.getRequest
   313  	}
   314  
   315  	hreq := new(http.Request)
   316  	*hreq = *resp.Request.HTTPRequest
   317  	hreq.Method = "HEAD"
   318  
   319  	resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq)
   320  	if resp.err != nil {
   321  		return c.closeResponse
   322  	}
   323  	resp.HTTPResponse.Body.Close()
   324  
   325  	if resp.HTTPResponse.StatusCode != http.StatusOK {
   326  		return c.getRequest
   327  	}
   328  
   329  	return c.readResponse
   330  }
   331  
   332  func (c *Client) getRequest(resp *Response) stateFunc {
   333  	resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest)
   334  	if resp.err != nil {
   335  		return c.closeResponse
   336  	}
   337  
   338  	// check status code
   339  	if !resp.Request.IgnoreBadStatusCodes {
   340  		if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 {
   341  			resp.err = StatusCodeError(resp.HTTPResponse.StatusCode)
   342  			return c.closeResponse
   343  		}
   344  	}
   345  
   346  	return c.readResponse
   347  }
   348  
   349  func (c *Client) readResponse(resp *Response) stateFunc {
   350  	if resp.HTTPResponse == nil {
   351  		panic("Response.HTTPResponse is not ready")
   352  	}
   353  
   354  	// check expected size
   355  	resp.Size = resp.bytesResumed + resp.HTTPResponse.ContentLength
   356  	if resp.HTTPResponse.ContentLength > 0 && resp.Request.Size > 0 {
   357  		if resp.Request.Size != resp.Size {
   358  			resp.err = ErrBadLength
   359  			return c.closeResponse
   360  		}
   361  	}
   362  
   363  	// check filename
   364  	if resp.Filename == "" {
   365  		filename, err := guessFilename(resp.HTTPResponse)
   366  		if err != nil {
   367  			resp.err = err
   368  			return c.closeResponse
   369  		}
   370  		// Request.Filename will be empty or a directory
   371  		resp.Filename = filepath.Join(resp.Request.Filename, filename)
   372  	}
   373  
   374  	if resp.requestMethod() == "HEAD" {
   375  		if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" {
   376  			resp.CanResume = true
   377  		}
   378  		return c.statFileInfo
   379  	}
   380  	return c.openWriter
   381  }
   382  
   383  // openWriter opens the destination file for writing and seeks to the location
   384  // from whence the file transfer will resume.
   385  //
   386  // Requires that Response.Filename and resp.DidResume are already be set.
   387  func (c *Client) openWriter(resp *Response) stateFunc {
   388  	if !resp.Request.NoCreateDirectories {
   389  		resp.err = mkdirp(resp.Filename)
   390  		if resp.err != nil {
   391  			return c.closeResponse
   392  		}
   393  	}
   394  
   395  	// compute write flags
   396  	flag := os.O_CREATE | os.O_WRONLY
   397  	if resp.fi != nil {
   398  		if resp.DidResume {
   399  			flag = os.O_APPEND | os.O_WRONLY
   400  		} else {
   401  			flag = os.O_TRUNC | os.O_WRONLY
   402  		}
   403  	}
   404  
   405  	// open file
   406  	f, err := os.OpenFile(resp.Filename, flag, 0644)
   407  	if err != nil {
   408  		resp.err = err
   409  		return c.closeResponse
   410  	}
   411  	resp.writer = f
   412  
   413  	// seek to start or end
   414  	whence := os.SEEK_SET
   415  	if resp.bytesResumed > 0 {
   416  		whence = os.SEEK_END
   417  	}
   418  	_, resp.err = f.Seek(0, whence)
   419  	if resp.err != nil {
   420  		return c.closeResponse
   421  	}
   422  
   423  	// init transfer
   424  	if resp.bufferSize < 1 {
   425  		resp.bufferSize = 32 * 1024
   426  	}
   427  	b := make([]byte, resp.bufferSize)
   428  	resp.transfer = newTransfer(
   429  		resp.Request.Context(),
   430  		resp.Request.RateLimiter,
   431  		resp.writer,
   432  		resp.HTTPResponse.Body,
   433  		b)
   434  
   435  	// next step is copyFile, but this will be called later in another goroutine
   436  	return nil
   437  }
   438  
   439  // copy transfers content for a HTTP connection established via Client.do()
   440  func (c *Client) copyFile(resp *Response) stateFunc {
   441  	if resp.IsComplete() {
   442  		return nil
   443  	}
   444  
   445  	// run BeforeCopy hook
   446  	if f := resp.Request.BeforeCopy; f != nil {
   447  		resp.err = f(resp)
   448  		if resp.err != nil {
   449  			return c.closeResponse
   450  		}
   451  	}
   452  
   453  	if resp.transfer == nil {
   454  		panic("developer error: Response.transfer is not initialized")
   455  	}
   456  	go resp.watchBps()
   457  	_, resp.err = resp.transfer.copy()
   458  	if resp.err != nil {
   459  		return c.closeResponse
   460  	}
   461  	closeWriter(resp)
   462  
   463  	// set timestamp
   464  	if !resp.Request.IgnoreRemoteTime {
   465  		resp.err = setLastModified(resp.HTTPResponse, resp.Filename)
   466  		if resp.err != nil {
   467  			return c.closeResponse
   468  		}
   469  	}
   470  
   471  	// run AfterCopy hook
   472  	if f := resp.Request.AfterCopy; f != nil {
   473  		resp.err = f(resp)
   474  		if resp.err != nil {
   475  			return c.closeResponse
   476  		}
   477  	}
   478  
   479  	return c.checksumFile
   480  }
   481  
   482  func closeWriter(resp *Response) {
   483  	if resp.writer != nil {
   484  		resp.writer.Close()
   485  		resp.writer = nil
   486  	}
   487  }
   488  
   489  // close finalizes the Response
   490  func (c *Client) closeResponse(resp *Response) stateFunc {
   491  	if resp.IsComplete() {
   492  		panic("response already closed")
   493  	}
   494  
   495  	resp.fi = nil
   496  	closeWriter(resp)
   497  	resp.closeResponseBody()
   498  
   499  	resp.End = time.Now()
   500  	close(resp.Done)
   501  	if resp.cancel != nil {
   502  		resp.cancel()
   503  	}
   504  
   505  	return nil
   506  }