github.com/aavshr/aws-sdk-go@v1.41.3/service/s3/s3manager/download.go (about)

     1  package s3manager
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net/http"
     7  	"strconv"
     8  	"strings"
     9  	"sync"
    10  
    11  	"github.com/aavshr/aws-sdk-go/aws"
    12  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    13  	"github.com/aavshr/aws-sdk-go/aws/awsutil"
    14  	"github.com/aavshr/aws-sdk-go/aws/client"
    15  	"github.com/aavshr/aws-sdk-go/aws/request"
    16  	"github.com/aavshr/aws-sdk-go/service/s3"
    17  	"github.com/aavshr/aws-sdk-go/service/s3/s3iface"
    18  )
    19  
    20  // DefaultDownloadPartSize is the default range of bytes to get at a time when
    21  // using Download().
    22  const DefaultDownloadPartSize = 1024 * 1024 * 5
    23  
    24  // DefaultDownloadConcurrency is the default number of goroutines to spin up
    25  // when using Download().
    26  const DefaultDownloadConcurrency = 5
    27  
    28  type errReadingBody struct {
    29  	err error
    30  }
    31  
    32  func (e *errReadingBody) Error() string {
    33  	return fmt.Sprintf("failed to read part body: %v", e.err)
    34  }
    35  
    36  func (e *errReadingBody) Unwrap() error {
    37  	return e.err
    38  }
    39  
    40  // The Downloader structure that calls Download(). It is safe to call Download()
    41  // on this structure for multiple objects and across concurrent goroutines.
    42  // Mutating the Downloader's properties is not safe to be done concurrently.
    43  type Downloader struct {
    44  	// The size (in bytes) to request from S3 for each part.
    45  	// The minimum allowed part size is 5MB, and  if this value is set to zero,
    46  	// the DefaultDownloadPartSize value will be used.
    47  	//
    48  	// PartSize is ignored if the Range input parameter is provided.
    49  	PartSize int64
    50  
    51  	// The number of goroutines to spin up in parallel when sending parts.
    52  	// If this is set to zero, the DefaultDownloadConcurrency value will be used.
    53  	//
    54  	// Concurrency of 1 will download the parts sequentially.
    55  	//
    56  	// Concurrency is ignored if the Range input parameter is provided.
    57  	Concurrency int
    58  
    59  	// An S3 client to use when performing downloads.
    60  	S3 s3iface.S3API
    61  
    62  	// List of request options that will be passed down to individual API
    63  	// operation requests made by the downloader.
    64  	RequestOptions []request.Option
    65  
    66  	// Defines the buffer strategy used when downloading a part.
    67  	//
    68  	// If a WriterReadFromProvider is given the Download manager
    69  	// will pass the io.WriterAt of the Download request to the provider
    70  	// and will use the returned WriterReadFrom from the provider as the
    71  	// destination writer when copying from http response body.
    72  	BufferProvider WriterReadFromProvider
    73  }
    74  
    75  // WithDownloaderRequestOptions appends to the Downloader's API request options.
    76  func WithDownloaderRequestOptions(opts ...request.Option) func(*Downloader) {
    77  	return func(d *Downloader) {
    78  		d.RequestOptions = append(d.RequestOptions, opts...)
    79  	}
    80  }
    81  
    82  // NewDownloader creates a new Downloader instance to downloads objects from
    83  // S3 in concurrent chunks. Pass in additional functional options  to customize
    84  // the downloader behavior. Requires a client.ConfigProvider in order to create
    85  // a S3 service client. The session.Session satisfies the client.ConfigProvider
    86  // interface.
    87  //
    88  // Example:
    89  //     // The session the S3 Downloader will use
    90  //     sess := session.Must(session.NewSession())
    91  //
    92  //     // Create a downloader with the session and default options
    93  //     downloader := s3manager.NewDownloader(sess)
    94  //
    95  //     // Create a downloader with the session and custom options
    96  //     downloader := s3manager.NewDownloader(sess, func(d *s3manager.Downloader) {
    97  //          d.PartSize = 64 * 1024 * 1024 // 64MB per part
    98  //     })
    99  func NewDownloader(c client.ConfigProvider, options ...func(*Downloader)) *Downloader {
   100  	return newDownloader(s3.New(c), options...)
   101  }
   102  
   103  func newDownloader(client s3iface.S3API, options ...func(*Downloader)) *Downloader {
   104  	d := &Downloader{
   105  		S3:             client,
   106  		PartSize:       DefaultDownloadPartSize,
   107  		Concurrency:    DefaultDownloadConcurrency,
   108  		BufferProvider: defaultDownloadBufferProvider(),
   109  	}
   110  	for _, option := range options {
   111  		option(d)
   112  	}
   113  
   114  	return d
   115  }
   116  
   117  // NewDownloaderWithClient creates a new Downloader instance to downloads
   118  // objects from S3 in concurrent chunks. Pass in additional functional
   119  // options to customize the downloader behavior. Requires a S3 service client
   120  // to make S3 API calls.
   121  //
   122  // Example:
   123  //     // The session the S3 Downloader will use
   124  //     sess := session.Must(session.NewSession())
   125  //
   126  //     // The S3 client the S3 Downloader will use
   127  //     s3Svc := s3.New(sess)
   128  //
   129  //     // Create a downloader with the s3 client and default options
   130  //     downloader := s3manager.NewDownloaderWithClient(s3Svc)
   131  //
   132  //     // Create a downloader with the s3 client and custom options
   133  //     downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Downloader) {
   134  //          d.PartSize = 64 * 1024 * 1024 // 64MB per part
   135  //     })
   136  func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *Downloader {
   137  	return newDownloader(svc, options...)
   138  }
   139  
   140  type maxRetrier interface {
   141  	MaxRetries() int
   142  }
   143  
   144  // Download downloads an object in S3 and writes the payload into w using
   145  // concurrent GET requests. The n int64 returned is the size of the object downloaded
   146  // in bytes.
   147  //
   148  // Additional functional options can be provided to configure the individual
   149  // download. These options are copies of the Downloader instance Download is called from.
   150  // Modifying the options will not impact the original Downloader instance.
   151  //
   152  // It is safe to call this method concurrently across goroutines.
   153  //
   154  // The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
   155  // downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
   156  //
   157  // Specifying a Downloader.Concurrency of 1 will cause the Downloader to
   158  // download the parts from S3 sequentially.
   159  //
   160  // If the GetObjectInput's Range value is provided that will cause the downloader
   161  // to perform a single GetObjectInput request for that object's range. This will
   162  // caused the part size, and concurrency configurations to be ignored.
   163  func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
   164  	return d.DownloadWithContext(aws.BackgroundContext(), w, input, options...)
   165  }
   166  
   167  // DownloadWithContext downloads an object in S3 and writes the payload into w
   168  // using concurrent GET requests. The n int64 returned is the size of the object downloaded
   169  // in bytes.
   170  //
   171  // DownloadWithContext is the same as Download with the additional support for
   172  // Context input parameters. The Context must not be nil. A nil Context will
   173  // cause a panic. Use the Context to add deadlining, timeouts, etc. The
   174  // DownloadWithContext may create sub-contexts for individual underlying
   175  // requests.
   176  //
   177  // Additional functional options can be provided to configure the individual
   178  // download. These options are copies of the Downloader instance Download is
   179  // called from. Modifying the options will not impact the original Downloader
   180  // instance. Use the WithDownloaderRequestOptions helper function to pass in request
   181  // options that will be applied to all API operations made with this downloader.
   182  //
   183  // The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
   184  // downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
   185  //
   186  // Specifying a Downloader.Concurrency of 1 will cause the Downloader to
   187  // download the parts from S3 sequentially.
   188  //
   189  // It is safe to call this method concurrently across goroutines.
   190  //
   191  // If the GetObjectInput's Range value is provided that will cause the downloader
   192  // to perform a single GetObjectInput request for that object's range. This will
   193  // caused the part size, and concurrency configurations to be ignored.
   194  func (d Downloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
   195  	if err := validateSupportedARNType(aws.StringValue(input.Bucket)); err != nil {
   196  		return 0, err
   197  	}
   198  
   199  	impl := downloader{w: w, in: input, cfg: d, ctx: ctx}
   200  
   201  	for _, option := range options {
   202  		option(&impl.cfg)
   203  	}
   204  	impl.cfg.RequestOptions = append(impl.cfg.RequestOptions, request.WithAppendUserAgent("S3Manager"))
   205  
   206  	if s, ok := d.S3.(maxRetrier); ok {
   207  		impl.partBodyMaxRetries = s.MaxRetries()
   208  	}
   209  
   210  	impl.totalBytes = -1
   211  	if impl.cfg.Concurrency == 0 {
   212  		impl.cfg.Concurrency = DefaultDownloadConcurrency
   213  	}
   214  
   215  	if impl.cfg.PartSize == 0 {
   216  		impl.cfg.PartSize = DefaultDownloadPartSize
   217  	}
   218  
   219  	return impl.download()
   220  }
   221  
   222  // DownloadWithIterator will download a batched amount of objects in S3 and writes them
   223  // to the io.WriterAt specificed in the iterator.
   224  //
   225  // Example:
   226  //	svc := s3manager.NewDownloader(session)
   227  //
   228  //	fooFile, err := os.Open("/tmp/foo.file")
   229  //	if err != nil {
   230  //		return err
   231  //	}
   232  //
   233  //	barFile, err := os.Open("/tmp/bar.file")
   234  //	if err != nil {
   235  //		return err
   236  //	}
   237  //
   238  //	objects := []s3manager.BatchDownloadObject {
   239  //		{
   240  //			Object: &s3.GetObjectInput {
   241  //				Bucket: aws.String("bucket"),
   242  //				Key: aws.String("foo"),
   243  //			},
   244  //			Writer: fooFile,
   245  //		},
   246  //		{
   247  //			Object: &s3.GetObjectInput {
   248  //				Bucket: aws.String("bucket"),
   249  //				Key: aws.String("bar"),
   250  //			},
   251  //			Writer: barFile,
   252  //		},
   253  //	}
   254  //
   255  //	iter := &s3manager.DownloadObjectsIterator{Objects: objects}
   256  //	if err := svc.DownloadWithIterator(aws.BackgroundContext(), iter); err != nil {
   257  //		return err
   258  //	}
   259  func (d Downloader) DownloadWithIterator(ctx aws.Context, iter BatchDownloadIterator, opts ...func(*Downloader)) error {
   260  	var errs []Error
   261  	for iter.Next() {
   262  		object := iter.DownloadObject()
   263  		if _, err := d.DownloadWithContext(ctx, object.Writer, object.Object, opts...); err != nil {
   264  			errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
   265  		}
   266  
   267  		if object.After == nil {
   268  			continue
   269  		}
   270  
   271  		if err := object.After(); err != nil {
   272  			errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
   273  		}
   274  	}
   275  
   276  	if len(errs) > 0 {
   277  		return NewBatchError("BatchedDownloadIncomplete", "some objects have failed to download.", errs)
   278  	}
   279  	return nil
   280  }
   281  
   282  // downloader is the implementation structure used internally by Downloader.
   283  type downloader struct {
   284  	ctx aws.Context
   285  	cfg Downloader
   286  
   287  	in *s3.GetObjectInput
   288  	w  io.WriterAt
   289  
   290  	wg sync.WaitGroup
   291  	m  sync.Mutex
   292  
   293  	pos        int64
   294  	totalBytes int64
   295  	written    int64
   296  	err        error
   297  
   298  	partBodyMaxRetries int
   299  }
   300  
   301  // download performs the implementation of the object download across ranged
   302  // GETs.
   303  func (d *downloader) download() (n int64, err error) {
   304  	// If range is specified fall back to single download of that range
   305  	// this enables the functionality of ranged gets with the downloader but
   306  	// at the cost of no multipart downloads.
   307  	if rng := aws.StringValue(d.in.Range); len(rng) > 0 {
   308  		d.downloadRange(rng)
   309  		return d.written, d.err
   310  	}
   311  
   312  	// Spin off first worker to check additional header information
   313  	d.getChunk()
   314  
   315  	if total := d.getTotalBytes(); total >= 0 {
   316  		// Spin up workers
   317  		ch := make(chan dlchunk, d.cfg.Concurrency)
   318  
   319  		for i := 0; i < d.cfg.Concurrency; i++ {
   320  			d.wg.Add(1)
   321  			go d.downloadPart(ch)
   322  		}
   323  
   324  		// Assign work
   325  		for d.getErr() == nil {
   326  			if d.pos >= total {
   327  				break // We're finished queuing chunks
   328  			}
   329  
   330  			// Queue the next range of bytes to read.
   331  			ch <- dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
   332  			d.pos += d.cfg.PartSize
   333  		}
   334  
   335  		// Wait for completion
   336  		close(ch)
   337  		d.wg.Wait()
   338  	} else {
   339  		// Checking if we read anything new
   340  		for d.err == nil {
   341  			d.getChunk()
   342  		}
   343  
   344  		// We expect a 416 error letting us know we are done downloading the
   345  		// total bytes. Since we do not know the content's length, this will
   346  		// keep grabbing chunks of data until the range of bytes specified in
   347  		// the request is out of range of the content. Once, this happens, a
   348  		// 416 should occur.
   349  		e, ok := d.err.(awserr.RequestFailure)
   350  		if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
   351  			d.err = nil
   352  		}
   353  	}
   354  
   355  	// Return error
   356  	return d.written, d.err
   357  }
   358  
   359  // downloadPart is an individual goroutine worker reading from the ch channel
   360  // and performing a GetObject request on the data with a given byte range.
   361  //
   362  // If this is the first worker, this operation also resolves the total number
   363  // of bytes to be read so that the worker manager knows when it is finished.
   364  func (d *downloader) downloadPart(ch chan dlchunk) {
   365  	defer d.wg.Done()
   366  	for {
   367  		chunk, ok := <-ch
   368  		if !ok {
   369  			break
   370  		}
   371  		if d.getErr() != nil {
   372  			// Drain the channel if there is an error, to prevent deadlocking
   373  			// of download producer.
   374  			continue
   375  		}
   376  
   377  		if err := d.downloadChunk(chunk); err != nil {
   378  			d.setErr(err)
   379  		}
   380  	}
   381  }
   382  
   383  // getChunk grabs a chunk of data from the body.
   384  // Not thread safe. Should only used when grabbing data on a single thread.
   385  func (d *downloader) getChunk() {
   386  	if d.getErr() != nil {
   387  		return
   388  	}
   389  
   390  	chunk := dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
   391  	d.pos += d.cfg.PartSize
   392  
   393  	if err := d.downloadChunk(chunk); err != nil {
   394  		d.setErr(err)
   395  	}
   396  }
   397  
   398  // downloadRange downloads an Object given the passed in Byte-Range value.
   399  // The chunk used down download the range will be configured for that range.
   400  func (d *downloader) downloadRange(rng string) {
   401  	if d.getErr() != nil {
   402  		return
   403  	}
   404  
   405  	chunk := dlchunk{w: d.w, start: d.pos}
   406  	// Ranges specified will short circuit the multipart download
   407  	chunk.withRange = rng
   408  
   409  	if err := d.downloadChunk(chunk); err != nil {
   410  		d.setErr(err)
   411  	}
   412  
   413  	// Update the position based on the amount of data received.
   414  	d.pos = d.written
   415  }
   416  
   417  // downloadChunk downloads the chunk from s3
   418  func (d *downloader) downloadChunk(chunk dlchunk) error {
   419  	in := &s3.GetObjectInput{}
   420  	awsutil.Copy(in, d.in)
   421  
   422  	// Get the next byte range of data
   423  	in.Range = aws.String(chunk.ByteRange())
   424  
   425  	var n int64
   426  	var err error
   427  	for retry := 0; retry <= d.partBodyMaxRetries; retry++ {
   428  		n, err = d.tryDownloadChunk(in, &chunk)
   429  		if err == nil {
   430  			break
   431  		}
   432  		// Check if the returned error is an errReadingBody.
   433  		// If err is errReadingBody this indicates that an error
   434  		// occurred while copying the http response body.
   435  		// If this occurs we unwrap the err to set the underlying error
   436  		// and attempt any remaining retries.
   437  		if bodyErr, ok := err.(*errReadingBody); ok {
   438  			err = bodyErr.Unwrap()
   439  		} else {
   440  			return err
   441  		}
   442  
   443  		chunk.cur = 0
   444  		logMessage(d.cfg.S3, aws.LogDebugWithRequestRetries,
   445  			fmt.Sprintf("DEBUG: object part body download interrupted %s, err, %v, retrying attempt %d",
   446  				aws.StringValue(in.Key), err, retry))
   447  	}
   448  
   449  	d.incrWritten(n)
   450  
   451  	return err
   452  }
   453  
   454  func (d *downloader) tryDownloadChunk(in *s3.GetObjectInput, w io.Writer) (int64, error) {
   455  	cleanup := func() {}
   456  	if d.cfg.BufferProvider != nil {
   457  		w, cleanup = d.cfg.BufferProvider.GetReadFrom(w)
   458  	}
   459  	defer cleanup()
   460  
   461  	resp, err := d.cfg.S3.GetObjectWithContext(d.ctx, in, d.cfg.RequestOptions...)
   462  	if err != nil {
   463  		return 0, err
   464  	}
   465  	d.setTotalBytes(resp) // Set total if not yet set.
   466  
   467  	n, err := io.Copy(w, resp.Body)
   468  	resp.Body.Close()
   469  	if err != nil {
   470  		return n, &errReadingBody{err: err}
   471  	}
   472  
   473  	return n, nil
   474  }
   475  
   476  func logMessage(svc s3iface.S3API, level aws.LogLevelType, msg string) {
   477  	s, ok := svc.(*s3.S3)
   478  	if !ok {
   479  		return
   480  	}
   481  
   482  	if s.Config.Logger == nil {
   483  		return
   484  	}
   485  
   486  	if s.Config.LogLevel.Matches(level) {
   487  		s.Config.Logger.Log(msg)
   488  	}
   489  }
   490  
   491  // getTotalBytes is a thread-safe getter for retrieving the total byte status.
   492  func (d *downloader) getTotalBytes() int64 {
   493  	d.m.Lock()
   494  	defer d.m.Unlock()
   495  
   496  	return d.totalBytes
   497  }
   498  
   499  // setTotalBytes is a thread-safe setter for setting the total byte status.
   500  // Will extract the object's total bytes from the Content-Range if the file
   501  // will be chunked, or Content-Length. Content-Length is used when the response
   502  // does not include a Content-Range. Meaning the object was not chunked. This
   503  // occurs when the full file fits within the PartSize directive.
   504  func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
   505  	d.m.Lock()
   506  	defer d.m.Unlock()
   507  
   508  	if d.totalBytes >= 0 {
   509  		return
   510  	}
   511  
   512  	if resp.ContentRange == nil {
   513  		// ContentRange is nil when the full file contents is provided, and
   514  		// is not chunked. Use ContentLength instead.
   515  		if resp.ContentLength != nil {
   516  			d.totalBytes = *resp.ContentLength
   517  			return
   518  		}
   519  	} else {
   520  		parts := strings.Split(*resp.ContentRange, "/")
   521  
   522  		total := int64(-1)
   523  		var err error
   524  		// Checking for whether or not a numbered total exists
   525  		// If one does not exist, we will assume the total to be -1, undefined,
   526  		// and sequentially download each chunk until hitting a 416 error
   527  		totalStr := parts[len(parts)-1]
   528  		if totalStr != "*" {
   529  			total, err = strconv.ParseInt(totalStr, 10, 64)
   530  			if err != nil {
   531  				d.err = err
   532  				return
   533  			}
   534  		}
   535  
   536  		d.totalBytes = total
   537  	}
   538  }
   539  
   540  func (d *downloader) incrWritten(n int64) {
   541  	d.m.Lock()
   542  	defer d.m.Unlock()
   543  
   544  	d.written += n
   545  }
   546  
   547  // getErr is a thread-safe getter for the error object
   548  func (d *downloader) getErr() error {
   549  	d.m.Lock()
   550  	defer d.m.Unlock()
   551  
   552  	return d.err
   553  }
   554  
   555  // setErr is a thread-safe setter for the error object
   556  func (d *downloader) setErr(e error) {
   557  	d.m.Lock()
   558  	defer d.m.Unlock()
   559  
   560  	d.err = e
   561  }
   562  
   563  // dlchunk represents a single chunk of data to write by the worker routine.
   564  // This structure also implements an io.SectionReader style interface for
   565  // io.WriterAt, effectively making it an io.SectionWriter (which does not
   566  // exist).
   567  type dlchunk struct {
   568  	w     io.WriterAt
   569  	start int64
   570  	size  int64
   571  	cur   int64
   572  
   573  	// specifies the byte range the chunk should be downloaded with.
   574  	withRange string
   575  }
   576  
   577  // Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
   578  // position to its end (or EOF).
   579  //
   580  // If a range is specified on the dlchunk the size will be ignored when writing.
   581  // as the total size may not of be known ahead of time.
   582  func (c *dlchunk) Write(p []byte) (n int, err error) {
   583  	if c.cur >= c.size && len(c.withRange) == 0 {
   584  		return 0, io.EOF
   585  	}
   586  
   587  	n, err = c.w.WriteAt(p, c.start+c.cur)
   588  	c.cur += int64(n)
   589  
   590  	return
   591  }
   592  
   593  // ByteRange returns a HTTP Byte-Range header value that should be used by the
   594  // client to request the chunk's range.
   595  func (c *dlchunk) ByteRange() string {
   596  	if len(c.withRange) != 0 {
   597  		return c.withRange
   598  	}
   599  
   600  	return fmt.Sprintf("bytes=%d-%d", c.start, c.start+c.size-1)
   601  }