github.com/PDOK/gokoala@v0.50.6/internal/engine/downloader.go (about)

     1  package engine
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/url"
    10  	"os"
    11  	"time"
    12  
    13  	"github.com/failsafe-go/failsafe-go/failsafehttp"
    14  	"golang.org/x/sync/errgroup"
    15  )
    16  
    17  const bufferSize = 1 * 1024 * 1024 // 1MiB
    18  
    19  // Part piece of the file to download when HTTP Range Requests are supported
    20  type Part struct {
    21  	Start int64
    22  	End   int64
    23  	Size  int64
    24  }
    25  
    26  // Download downloads file from the given URL and stores the result in the given output location.
    27  // Will utilize multiple concurrent connections to increase transfer speed. The latter is only
    28  // possible when the remote server supports HTTP Range Requests, otherwise it falls back
    29  // to a regular/single connection download. Additionally, failed requests will be retried according
    30  // to the given settings.
    31  func Download(url url.URL, outputFilepath string, parallelism int, tlsSkipVerify bool, timeout time.Duration,
    32  	retryDelay time.Duration, retryMaxDelay time.Duration, maxRetries int) (*time.Duration, error) {
    33  
    34  	client := createHTTPClient(tlsSkipVerify, timeout, retryDelay, retryMaxDelay, maxRetries)
    35  	outputFile, err := os.OpenFile(outputFilepath, os.O_CREATE|os.O_RDWR, 0644)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	defer outputFile.Close()
    40  
    41  	start := time.Now()
    42  
    43  	supportRanges, contentLength, err := checkRemoteFile(url, client)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  	if supportRanges && parallelism > 1 {
    48  		err = downloadWithMultipleConnections(url, outputFile, contentLength, int64(parallelism), client)
    49  	} else {
    50  		err = downloadWithSingleConnection(url, outputFile, client)
    51  	}
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	err = assertFileValid(outputFile, contentLength)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	timeSpent := time.Since(start)
    61  	return &timeSpent, err
    62  }
    63  
    64  func checkRemoteFile(url url.URL, client *http.Client) (supportRanges bool, contentLength int64, err error) {
    65  	res, err := client.Head(url.String())
    66  	if err != nil {
    67  		return
    68  	}
    69  	defer res.Body.Close()
    70  
    71  	contentLength = res.ContentLength
    72  	supportRanges = res.Header.Get(HeaderAcceptRanges) == "bytes" && contentLength != 0
    73  	return
    74  }
    75  
    76  func downloadWithSingleConnection(url url.URL, outputFile *os.File, client *http.Client) error {
    77  	res, err := client.Get(url.String())
    78  	if err != nil {
    79  		return err
    80  	}
    81  	defer res.Body.Close()
    82  
    83  	buf := make([]byte, bufferSize)
    84  	_, err = io.CopyBuffer(outputFile, res.Body, buf)
    85  	return err
    86  }
    87  
    88  func downloadWithMultipleConnections(url url.URL, outputFile *os.File, contentLength int64, parallelism int64, client *http.Client) error {
    89  	parts := make([]Part, parallelism)
    90  	partSize := contentLength / parallelism
    91  	remainder := contentLength % parallelism
    92  
    93  	wg, _ := errgroup.WithContext(context.Background())
    94  	for i, part := range parts {
    95  		start := int64(i) * partSize
    96  		end := start + partSize
    97  		if remainder != 0 && i == len(parts)-1 {
    98  			end += remainder
    99  		}
   100  		part = Part{start, end, partSize}
   101  		wg.Go(func() error {
   102  			return downloadPart(client, url, outputFile.Name(), part)
   103  		})
   104  	}
   105  	return wg.Wait()
   106  }
   107  
   108  func downloadPart(client *http.Client, url url.URL, outputFilepath string, part Part) error {
   109  	outputFile, err := os.OpenFile(outputFilepath, os.O_RDWR, 0664)
   110  	if err != nil {
   111  		return err
   112  	}
   113  	defer outputFile.Close()
   114  	_, err = outputFile.Seek(part.Start, 0)
   115  	if err != nil {
   116  		return err
   117  	}
   118  
   119  	req, err := http.NewRequest(http.MethodGet, url.String(), nil)
   120  	if err != nil {
   121  		return err
   122  	}
   123  	req.Header.Set(HeaderRange, fmt.Sprintf("bytes=%d-%d", part.Start, part.End-1))
   124  	res, err := client.Do(req)
   125  	if err != nil {
   126  		return err
   127  	}
   128  	defer res.Body.Close()
   129  	if res.StatusCode != http.StatusPartialContent {
   130  		return fmt.Errorf("server advertises HTTP Range Request support "+
   131  			"but doesn't return status %d", http.StatusPartialContent)
   132  	}
   133  
   134  	buf := make([]byte, bufferSize)
   135  	_, err = io.CopyBuffer(outputFile, res.Body, buf)
   136  	return err
   137  }
   138  
   139  func assertFileValid(outputFile *os.File, contentLength int64) error {
   140  	fi, err := outputFile.Stat()
   141  	if err != nil {
   142  		return err
   143  	}
   144  	if fi.Size() != contentLength {
   145  		return fmt.Errorf("invalid file, content-length %d and file size %d mismatch", contentLength, fi.Size())
   146  	}
   147  	return nil
   148  }
   149  
   150  func createHTTPClient(tlsSkipVerify bool, timeout time.Duration, retryDelay time.Duration,
   151  	retryMaxDelay time.Duration, maxRetries int) *http.Client {
   152  
   153  	transport := &http.Transport{
   154  		TLSClientConfig: &tls.Config{
   155  			InsecureSkipVerify: tlsSkipVerify, //nolint:gosec // on purpose, default is false
   156  		},
   157  	}
   158  	//nolint:bodyclose // false positive
   159  	retryPolicy := failsafehttp.RetryPolicyBuilder().
   160  		WithBackoff(retryDelay, retryMaxDelay). //nolint:bodyclose // false positive
   161  		WithMaxRetries(maxRetries).             //nolint:bodyclose // false positive
   162  		Build()                                 //nolint:bodyclose // false positive
   163  	return &http.Client{
   164  		Timeout:   timeout,
   165  		Transport: failsafehttp.NewRoundTripper(transport, retryPolicy),
   166  	}
   167  }