github.com/tyler-smith/grab@v2.0.1-0.20190224022517-abcee96e61b1+incompatible/client.go (about)

     1  package grab
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"net/http"
     8  	"os"
     9  	"path/filepath"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/btcsuite/go-socks/socks"
    14  )
    15  
    16  // A Client is a file download client.
    17  //
    18  // Clients are safe for concurrent use by multiple goroutines.
    19  type Client struct {
    20  	// HTTPClient specifies the http.Client which will be used for communicating
    21  	// with the remote server during the file transfer.
    22  	HTTPClient *http.Client
    23  
    24  	// UserAgent specifies the User-Agent string which will be set in the
    25  	// headers of all requests made by this client.
    26  	//
    27  	// The user agent string may be overridden in the headers of each request.
    28  	UserAgent string
    29  
    30  	// BufferSize specifies the size in bytes of the buffer that is used for
    31  	// transferring all requested files. Larger buffers may result in faster
    32  	// throughput but will use more memory and result in less frequent updates
    33  	// to the transfer progress statistics. The BufferSize of each request can
    34  	// be overridden on each Request object. Default: 32KB.
    35  	BufferSize int
    36  }
    37  
    38  // NewClient returns a new file download Client, using default configuration.
    39  func NewClient() *Client {
    40  	return &Client{
    41  		UserAgent: "grab",
    42  		HTTPClient: &http.Client{
    43  			Transport: &http.Transport{
    44  				Proxy: http.ProxyFromEnvironment,
    45  			},
    46  		},
    47  	}
    48  }
    49  
    50  // NewClientWithProxy returns a new file download Client
    51  // with proxy support.
    52  func NewClientWithProxy(proxy *socks.Proxy) *Client {
    53  	transport := &http.Transport{
    54  		Proxy: http.ProxyFromEnvironment,
    55  	}
    56  
    57  	if proxy != nil {
    58  		transport.Dial = proxy.Dial
    59  	}
    60  
    61  	client := &Client{
    62  		UserAgent: "grab",
    63  		HTTPClient: &http.Client{
    64  			Transport: transport,
    65  		},
    66  	}
    67  
    68  	return client
    69  }
    70  
    71  // DefaultClient is the default client and is used by all Get convenience
    72  // functions.
    73  var DefaultClient = NewClient()
    74  
    75  // SetDefaultClientWithProxy allows you to setup the default client with
    76  // a socks5 proxy.
    77  func SetDefaultClientWithProxy(proxy *socks.Proxy) {
    78  	DefaultClient = NewClientWithProxy(proxy)
    79  }
    80  
    81  // Do sends a file transfer request and returns a file transfer response,
    82  // following policy (e.g. redirects, cookies, auth) as configured on the
    83  // client's HTTPClient.
    84  //
    85  // Like http.Get, Do blocks while the transfer is initiated, but returns as soon
    86  // as the transfer has started transferring in a background goroutine, or if it
    87  // failed early.
    88  //
    89  // An error is returned via Response.Err if caused by client policy (such as
    90  // CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err
    91  // will block the caller until the transfer is completed, successfully or
    92  // otherwise.
    93  func (c *Client) Do(req *Request) *Response {
    94  	// cancel will be called on all code-paths via closeResponse
    95  	ctx, cancel := context.WithCancel(req.Context())
    96  	resp := &Response{
    97  		Request:    req,
    98  		Start:      time.Now(),
    99  		Done:       make(chan struct{}, 0),
   100  		Filename:   req.Filename,
   101  		ctx:        ctx,
   102  		cancel:     cancel,
   103  		bufferSize: req.BufferSize,
   104  	}
   105  	if resp.bufferSize == 0 {
   106  		// default to Client.BufferSize
   107  		resp.bufferSize = c.BufferSize
   108  	}
   109  
   110  	// Run state-machine while caller is blocked to initialize the file transfer.
   111  	// Must never transition to the copyFile state - this happens next in another
   112  	// goroutine.
   113  	c.run(resp, c.statFileInfo)
   114  
   115  	// Run copyFile in a new goroutine. copyFile will no-op if the transfer is
   116  	// already complete or failed.
   117  	go c.run(resp, c.copyFile)
   118  	return resp
   119  }
   120  
   121  // DoChannel executes all requests sent through the given Request channel, one
   122  // at a time, until it is closed by another goroutine. The caller is blocked
   123  // until the Request channel is closed and all transfers have completed. All
   124  // responses are sent through the given Response channel as soon as they are
   125  // received from the remote servers and can be used to track the progress of
   126  // each download.
   127  //
   128  // Slow Response receivers will cause a worker to block and therefore delay the
   129  // start of the transfer for an already initiated connection - potentially
   130  // causing a server timeout. It is the caller's responsibility to ensure a
   131  // sufficient buffer size is used for the Response channel to prevent this.
   132  //
   133  // If an error occurs during any of the file transfers it will be accessible via
   134  // the associated Response.Err function.
   135  func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) {
   136  	// TODO: enable cancelling of batch jobs
   137  	for req := range reqch {
   138  		resp := c.Do(req)
   139  		respch <- resp
   140  		<-resp.Done
   141  	}
   142  }
   143  
   144  // DoBatch executes all the given requests using the given number of concurrent
   145  // workers. Control is passed back to the caller as soon as the workers are
   146  // initiated.
   147  //
   148  // If the requested number of workers is less than one, a worker will be created
   149  // for every request. I.e. all requests will be executed concurrently.
   150  //
   151  // If an error occurs during any of the file transfers it will be accessible via
   152  // call to the associated Response.Err.
   153  //
   154  // The returned Response channel is closed only after all of the given Requests
   155  // have completed, successfully or otherwise.
   156  func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response {
   157  	if workers < 1 {
   158  		workers = len(requests)
   159  	}
   160  	reqch := make(chan *Request, len(requests))
   161  	respch := make(chan *Response, len(requests))
   162  	wg := sync.WaitGroup{}
   163  	for i := 0; i < workers; i++ {
   164  		wg.Add(1)
   165  		go func() {
   166  			c.DoChannel(reqch, respch)
   167  			wg.Done()
   168  		}()
   169  	}
   170  
   171  	// queue requests
   172  	go func() {
   173  		for _, req := range requests {
   174  			reqch <- req
   175  		}
   176  		close(reqch)
   177  		wg.Wait()
   178  		close(respch)
   179  	}()
   180  	return respch
   181  }
   182  
   183  // An stateFunc is an action that mutates the state of a Response and returns
   184  // the next stateFunc to be called.
   185  type stateFunc func(*Response) stateFunc
   186  
   187  // run calls the given stateFunc function and all subsequent returned stateFuncs
   188  // until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc
   189  // should mutate the state of the given Response until it has completed
   190  // downloading or failed.
   191  func (c *Client) run(resp *Response, f stateFunc) {
   192  	for {
   193  		select {
   194  		case <-resp.ctx.Done():
   195  			if resp.IsComplete() {
   196  				return
   197  			}
   198  			resp.err = resp.ctx.Err()
   199  			f = c.closeResponse
   200  
   201  		default:
   202  			// keep working
   203  		}
   204  		if f = f(resp); f == nil {
   205  			return
   206  		}
   207  	}
   208  }
   209  
   210  // statFileInfo retrieves FileInfo for any local file matching
   211  // Response.Filename.
   212  //
   213  // If the file does not exist, is a directory, or its name is unknown the next
   214  // stateFunc is headRequest.
   215  //
   216  // If the file exists, Response.fi is set and the next stateFunc is
   217  // validateLocal.
   218  //
   219  // If an error occurs, the next stateFunc is closeResponse.
   220  func (c *Client) statFileInfo(resp *Response) stateFunc {
   221  	if resp.Filename == "" {
   222  		return c.headRequest
   223  	}
   224  	fi, err := os.Stat(resp.Filename)
   225  	if err != nil {
   226  		if os.IsNotExist(err) {
   227  			return c.headRequest
   228  		}
   229  		resp.err = err
   230  		return c.closeResponse
   231  	}
   232  	if fi.IsDir() {
   233  		resp.Filename = ""
   234  		return c.headRequest
   235  	}
   236  	resp.fi = fi
   237  	return c.validateLocal
   238  }
   239  
   240  // validateLocal compares a local copy of the downloaded file to the remote
   241  // file.
   242  //
   243  // An error is returned if the local file is larger than the remote file, or
   244  // Request.SkipExisting is true.
   245  //
   246  // If the existing file matches the length of the remote file, the next
   247  // stateFunc is checksumFile.
   248  //
   249  // If the local file is smaller than the remote file and the remote server is
   250  // known to support ranged requests, the next stateFunc is getRequest.
   251  func (c *Client) validateLocal(resp *Response) stateFunc {
   252  	if resp.Request.SkipExisting {
   253  		resp.err = ErrFileExists
   254  		return c.closeResponse
   255  	}
   256  
   257  	// determine expected file size
   258  	size := resp.Request.Size
   259  	if size == 0 && resp.HTTPResponse != nil {
   260  		size = resp.HTTPResponse.ContentLength
   261  	}
   262  	if size == 0 {
   263  		return c.headRequest
   264  	}
   265  
   266  	if size == resp.fi.Size() {
   267  		resp.DidResume = true
   268  		resp.bytesResumed = resp.fi.Size()
   269  		return c.checksumFile
   270  	}
   271  
   272  	if resp.Request.NoResume {
   273  		return c.getRequest
   274  	}
   275  
   276  	if size < resp.fi.Size() {
   277  		resp.err = ErrBadLength
   278  		return c.closeResponse
   279  	}
   280  
   281  	if resp.CanResume {
   282  		resp.Request.HTTPRequest.Header.Set(
   283  			"Range",
   284  			fmt.Sprintf("bytes=%d-", resp.fi.Size()))
   285  		resp.DidResume = true
   286  		resp.bytesResumed = resp.fi.Size()
   287  		return c.getRequest
   288  	}
   289  	return c.headRequest
   290  }
   291  
   292  func (c *Client) checksumFile(resp *Response) stateFunc {
   293  	if resp.Request.hash == nil {
   294  		return c.closeResponse
   295  	}
   296  	if resp.Filename == "" {
   297  		panic("filename not set")
   298  	}
   299  	req := resp.Request
   300  
   301  	// compare checksum
   302  	var sum []byte
   303  	sum, resp.err = checksum(req.Context(), resp.Filename, req.hash)
   304  	if resp.err != nil {
   305  		return c.closeResponse
   306  	}
   307  	if !bytes.Equal(sum, req.checksum) {
   308  		resp.err = ErrBadChecksum
   309  		if req.deleteOnError {
   310  			if err := os.Remove(resp.Filename); err != nil {
   311  				// err should be os.PathError and include file path
   312  				resp.err = fmt.Errorf(
   313  					"cannot remove downloaded file with checksum mismatch: %v",
   314  					err)
   315  			}
   316  		}
   317  	}
   318  	return c.closeResponse
   319  }
   320  
   321  // doHTTPRequest sends a HTTP Request and returns the response
   322  func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) {
   323  	if c.UserAgent != "" && req.Header.Get("User-Agent") == "" {
   324  		req.Header.Set("User-Agent", c.UserAgent)
   325  	}
   326  	return c.HTTPClient.Do(req)
   327  }
   328  
   329  func (c *Client) headRequest(resp *Response) stateFunc {
   330  	if resp.optionsKnown {
   331  		return c.getRequest
   332  	}
   333  	resp.optionsKnown = true
   334  
   335  	if resp.Request.NoResume {
   336  		return c.getRequest
   337  	}
   338  
   339  	if resp.Filename != "" && resp.fi == nil {
   340  		// destination path is already known and does not exist
   341  		return c.getRequest
   342  	}
   343  
   344  	hreq := new(http.Request)
   345  	*hreq = *resp.Request.HTTPRequest
   346  	hreq.Method = "HEAD"
   347  
   348  	resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq)
   349  	if resp.err != nil {
   350  		return c.closeResponse
   351  	}
   352  	resp.HTTPResponse.Body.Close()
   353  
   354  	if resp.HTTPResponse.StatusCode != http.StatusOK {
   355  		return c.getRequest
   356  	}
   357  
   358  	return c.readResponse
   359  }
   360  
   361  func (c *Client) getRequest(resp *Response) stateFunc {
   362  	resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest)
   363  	if resp.err != nil {
   364  		return c.closeResponse
   365  	}
   366  
   367  	// check status code
   368  	if !resp.Request.IgnoreBadStatusCodes {
   369  		if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 {
   370  			resp.err = StatusCodeError(resp.HTTPResponse.StatusCode)
   371  			return c.closeResponse
   372  		}
   373  	}
   374  
   375  	return c.readResponse
   376  }
   377  
   378  func (c *Client) readResponse(resp *Response) stateFunc {
   379  	if resp.HTTPResponse == nil {
   380  		panic("Response.HTTPResponse is not ready")
   381  	}
   382  
   383  	// check expected size
   384  	resp.Size = resp.bytesResumed + resp.HTTPResponse.ContentLength
   385  	if resp.HTTPResponse.ContentLength > 0 && resp.Request.Size > 0 {
   386  		if resp.Request.Size != resp.Size {
   387  			resp.err = ErrBadLength
   388  			return c.closeResponse
   389  		}
   390  	}
   391  
   392  	// check filename
   393  	if resp.Filename == "" {
   394  		filename, err := guessFilename(resp.HTTPResponse)
   395  		if err != nil {
   396  			resp.err = err
   397  			return c.closeResponse
   398  		}
   399  		// Request.Filename will be empty or a directory
   400  		resp.Filename = filepath.Join(resp.Request.Filename, filename)
   401  	}
   402  
   403  	if resp.requestMethod() == "HEAD" {
   404  		if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" {
   405  			resp.CanResume = true
   406  		}
   407  		return c.statFileInfo
   408  	}
   409  	return c.openWriter
   410  }
   411  
   412  // openWriter opens the destination file for writing and seeks to the location
   413  // from whence the file transfer will resume.
   414  //
   415  // Requires that Response.Filename and resp.DidResume are already be set.
   416  func (c *Client) openWriter(resp *Response) stateFunc {
   417  	if !resp.Request.NoCreateDirectories {
   418  		resp.err = mkdirp(resp.Filename)
   419  		if resp.err != nil {
   420  			return c.closeResponse
   421  		}
   422  	}
   423  
   424  	// compute write flags
   425  	flag := os.O_CREATE | os.O_WRONLY
   426  	if resp.fi != nil {
   427  		if resp.DidResume {
   428  			flag = os.O_APPEND | os.O_WRONLY
   429  		} else {
   430  			flag = os.O_TRUNC | os.O_WRONLY
   431  		}
   432  	}
   433  
   434  	// open file
   435  	f, err := os.OpenFile(resp.Filename, flag, 0644)
   436  	if err != nil {
   437  		resp.err = err
   438  		return c.closeResponse
   439  	}
   440  	resp.writer = f
   441  
   442  	// seek to start or end
   443  	whence := os.SEEK_SET
   444  	if resp.bytesResumed > 0 {
   445  		whence = os.SEEK_END
   446  	}
   447  	_, resp.err = f.Seek(0, whence)
   448  	if resp.err != nil {
   449  		return c.closeResponse
   450  	}
   451  
   452  	// init transfer
   453  	if resp.bufferSize < 1 {
   454  		resp.bufferSize = 32 * 1024
   455  	}
   456  	b := make([]byte, resp.bufferSize)
   457  	resp.transfer = newTransfer(
   458  		resp.Request.Context(),
   459  		resp.Request.RateLimiter,
   460  		resp.writer,
   461  		resp.HTTPResponse.Body,
   462  		b)
   463  
   464  	// next step is copyFile, but this will be called later in another goroutine
   465  	return nil
   466  }
   467  
   468  // copy transfers content for a HTTP connection established via Client.do()
   469  func (c *Client) copyFile(resp *Response) stateFunc {
   470  	if resp.IsComplete() {
   471  		return nil
   472  	}
   473  
   474  	// run BeforeCopy hook
   475  	if f := resp.Request.BeforeCopy; f != nil {
   476  		resp.err = f(resp)
   477  		if resp.err != nil {
   478  			return c.closeResponse
   479  		}
   480  	}
   481  
   482  	if resp.transfer == nil {
   483  		panic("developer error: Response.transfer is not initialized")
   484  	}
   485  	go resp.watchBps()
   486  	_, resp.err = resp.transfer.copy()
   487  	if resp.err != nil {
   488  		return c.closeResponse
   489  	}
   490  	closeWriter(resp)
   491  
   492  	// set timestamp
   493  	if !resp.Request.IgnoreRemoteTime {
   494  		resp.err = setLastModified(resp.HTTPResponse, resp.Filename)
   495  		if resp.err != nil {
   496  			return c.closeResponse
   497  		}
   498  	}
   499  
   500  	// run AfterCopy hook
   501  	if f := resp.Request.AfterCopy; f != nil {
   502  		resp.err = f(resp)
   503  		if resp.err != nil {
   504  			return c.closeResponse
   505  		}
   506  	}
   507  
   508  	return c.checksumFile
   509  }
   510  
   511  func closeWriter(resp *Response) {
   512  	if resp.writer != nil {
   513  		resp.writer.Close()
   514  		resp.writer = nil
   515  	}
   516  }
   517  
   518  // close finalizes the Response
   519  func (c *Client) closeResponse(resp *Response) stateFunc {
   520  	if resp.IsComplete() {
   521  		panic("response already closed")
   522  	}
   523  
   524  	resp.fi = nil
   525  	closeWriter(resp)
   526  	resp.closeResponseBody()
   527  
   528  	resp.End = time.Now()
   529  	close(resp.Done)
   530  	if resp.cancel != nil {
   531  		resp.cancel()
   532  	}
   533  
   534  	return nil
   535  }