github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/api/helpers/download.go (about)

     1  package helpers
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"os"
     9  	"path/filepath"
    10  
    11  	"github.com/go-openapi/swag"
    12  	"github.com/treeverse/lakefs/pkg/api/apigen"
    13  	"github.com/treeverse/lakefs/pkg/uri"
    14  	"golang.org/x/sync/errgroup"
    15  )
    16  
    17  const (
    18  	MinDownloadPartSize        int64 = 1024 * 64       // 64KB
    19  	DefaultDownloadPartSize    int64 = 1024 * 1024 * 8 // 8MB
    20  	DefaultDownloadConcurrency       = 10
    21  )
    22  
    23  type Downloader struct {
    24  	Client     *apigen.ClientWithResponses
    25  	PreSign    bool
    26  	HTTPClient *http.Client
    27  	PartSize   int64
    28  }
    29  
    30  type downloadPart struct {
    31  	Number     int
    32  	RangeStart int64
    33  	PartSize   int64
    34  }
    35  
    36  func NewDownloader(client *apigen.ClientWithResponses, preSign bool) *Downloader {
    37  	// setup http client
    38  	transport := http.DefaultTransport.(*http.Transport).Clone()
    39  	transport.MaxIdleConnsPerHost = 10
    40  	httpClient := &http.Client{
    41  		Transport: transport,
    42  	}
    43  
    44  	return &Downloader{
    45  		Client:     client,
    46  		PreSign:    preSign,
    47  		HTTPClient: httpClient,
    48  		PartSize:   DefaultDownloadPartSize,
    49  	}
    50  }
    51  
    52  // Download downloads an object from lakeFS to a local file, create the destination directory if needed.
    53  func (d *Downloader) Download(ctx context.Context, src uri.URI, dst string) error {
    54  	// create destination dir if needed
    55  	dir := filepath.Dir(dst)
    56  	_ = os.MkdirAll(dir, os.ModePerm)
    57  
    58  	// download object
    59  	var err error
    60  	if d.PreSign {
    61  		// download using presigned multipart download, it will fall back to presign single object download if needed
    62  		err = d.downloadPresignMultipart(ctx, src, dst)
    63  	} else {
    64  		err = d.downloadObject(ctx, src, dst)
    65  	}
    66  	if err != nil {
    67  		return fmt.Errorf("download failed: %w", err)
    68  	}
    69  	return nil
    70  }
    71  
    72  func (d *Downloader) downloadPresignMultipart(ctx context.Context, src uri.URI, dst string) (err error) {
    73  	// get object metadata for size and physical address (presigned)
    74  	statResp, err := d.Client.StatObjectWithResponse(ctx, src.Repository, src.Ref, &apigen.StatObjectParams{
    75  		Path:    *src.Path,
    76  		Presign: swag.Bool(true),
    77  	})
    78  	if err != nil {
    79  		return err
    80  	}
    81  
    82  	// fallback to download if missing size
    83  	if statResp.JSON200 == nil || statResp.JSON200.SizeBytes == nil {
    84  		return d.downloadObject(ctx, src, dst)
    85  	}
    86  
    87  	// check if the object is small enough to download in one request
    88  	sizeBytes := *statResp.JSON200.SizeBytes
    89  	if sizeBytes < d.PartSize {
    90  		return d.downloadObject(ctx, src, dst)
    91  	}
    92  
    93  	f, err := os.Create(dst)
    94  	if err != nil {
    95  		return err
    96  	}
    97  	defer func() {
    98  		_ = f.Close()
    99  	}()
   100  
   101  	// make sure the destination file is in the right size
   102  	size := swag.Int64Value(statResp.JSON200.SizeBytes)
   103  	if err := f.Truncate(size); err != nil {
   104  		return fmt.Errorf("failed to truncate '%s' to size %d: %w", f.Name(), size, err)
   105  	}
   106  
   107  	// download the file using ranges and concurrency
   108  	physicalAddress := statResp.JSON200.PhysicalAddress
   109  
   110  	ch := make(chan downloadPart, DefaultDownloadConcurrency)
   111  	// start download workers
   112  	g, grpCtx := errgroup.WithContext(context.Background())
   113  	for i := 0; i < DefaultDownloadConcurrency; i++ {
   114  		g.Go(func() error {
   115  			buf := make([]byte, d.PartSize)
   116  			for part := range ch {
   117  				err := d.downloadPresignedPart(grpCtx, physicalAddress, part.RangeStart, part.PartSize, part.Number, f, buf)
   118  				if err != nil {
   119  					return err
   120  				}
   121  			}
   122  			return nil
   123  		})
   124  	}
   125  
   126  	// send parts to download to the channel
   127  	partNumber := 0
   128  	for off := int64(0); off < size; off += d.PartSize {
   129  		partNumber++ // part numbers start from 1
   130  		part := downloadPart{
   131  			Number:     partNumber,
   132  			RangeStart: off,
   133  			PartSize:   d.PartSize,
   134  		}
   135  		// adjust last part size
   136  		if part.RangeStart+part.PartSize > size {
   137  			part.PartSize = size - part.RangeStart
   138  		}
   139  		ch <- part
   140  	}
   141  	close(ch)
   142  
   143  	return g.Wait()
   144  }
   145  
   146  func (d *Downloader) downloadPresignedPart(ctx context.Context, physicalAddress string, rangeStart int64, partSize int64, partNumber int, f *os.File, buf []byte) error {
   147  	rangeEnd := rangeStart + partSize - 1
   148  	rangeHeader := fmt.Sprintf("bytes=%d-%d", rangeStart, rangeEnd)
   149  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, physicalAddress, nil)
   150  	if err != nil {
   151  		return err
   152  	}
   153  	req.Header.Set("Range", rangeHeader)
   154  	resp, err := d.HTTPClient.Do(req)
   155  	if err != nil {
   156  		return err
   157  	}
   158  	defer func() { _ = resp.Body.Close() }()
   159  
   160  	if resp.StatusCode != http.StatusPartialContent {
   161  		return fmt.Errorf("%w: %s", ErrRequestFailed, resp.Status)
   162  	}
   163  	if resp.ContentLength != partSize {
   164  		return fmt.Errorf("%w: part %d expected %d bytes, got %d", ErrRequestFailed, partNumber, partSize, resp.ContentLength)
   165  	}
   166  
   167  	// reuse buffer if possible
   168  	if buf == nil {
   169  		buf = make([]byte, partSize)
   170  	} else {
   171  		buf = buf[:partSize]
   172  	}
   173  
   174  	_, err = io.ReadFull(resp.Body, buf)
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	_, err = f.WriteAt(buf, rangeStart)
   180  	if err != nil {
   181  		return err
   182  	}
   183  	return nil
   184  }
   185  
   186  func (d *Downloader) downloadObject(ctx context.Context, src uri.URI, dst string) error {
   187  	// get object content
   188  	resp, err := d.Client.GetObject(ctx, src.Repository, src.Ref, &apigen.GetObjectParams{
   189  		Path:    *src.Path,
   190  		Presign: swag.Bool(d.PreSign),
   191  	})
   192  	if err != nil {
   193  		return err
   194  	}
   195  	defer func() { _ = resp.Body.Close() }()
   196  	if resp.StatusCode != http.StatusOK {
   197  		return fmt.Errorf("%w: %s", ErrRequestFailed, resp.Status)
   198  	}
   199  
   200  	// create and copy object content
   201  	f, err := os.Create(dst)
   202  	if err != nil {
   203  		return err
   204  	}
   205  	defer func() { _ = f.Close() }()
   206  	_, err = io.Copy(f, resp.Body)
   207  	return err
   208  }