github.com/PDOK/gokoala@v0.50.6/internal/engine/downloader.go (about) 1 package engine 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 "net/http" 9 "net/url" 10 "os" 11 "time" 12 13 "github.com/failsafe-go/failsafe-go/failsafehttp" 14 "golang.org/x/sync/errgroup" 15 ) 16 17 const bufferSize = 1 * 1024 * 1024 // 1MiB 18 19 // Part piece of the file to download when HTTP Range Requests are supported 20 type Part struct { 21 Start int64 22 End int64 23 Size int64 24 } 25 26 // Download downloads file from the given URL and stores the result in the given output location. 27 // Will utilize multiple concurrent connections to increase transfer speed. The latter is only 28 // possible when the remote server supports HTTP Range Requests, otherwise it falls back 29 // to a regular/single connection download. Additionally, failed requests will be retried according 30 // to the given settings. 31 func Download(url url.URL, outputFilepath string, parallelism int, tlsSkipVerify bool, timeout time.Duration, 32 retryDelay time.Duration, retryMaxDelay time.Duration, maxRetries int) (*time.Duration, error) { 33 34 client := createHTTPClient(tlsSkipVerify, timeout, retryDelay, retryMaxDelay, maxRetries) 35 outputFile, err := os.OpenFile(outputFilepath, os.O_CREATE|os.O_RDWR, 0644) 36 if err != nil { 37 return nil, err 38 } 39 defer outputFile.Close() 40 41 start := time.Now() 42 43 supportRanges, contentLength, err := checkRemoteFile(url, client) 44 if err != nil { 45 return nil, err 46 } 47 if supportRanges && parallelism > 1 { 48 err = downloadWithMultipleConnections(url, outputFile, contentLength, int64(parallelism), client) 49 } else { 50 err = downloadWithSingleConnection(url, outputFile, client) 51 } 52 if err != nil { 53 return nil, err 54 } 55 err = assertFileValid(outputFile, contentLength) 56 if err != nil { 57 return nil, err 58 } 59 60 timeSpent := time.Since(start) 61 return &timeSpent, err 62 } 63 64 func checkRemoteFile(url url.URL, client *http.Client) (supportRanges bool, contentLength int64, err error) { 65 res, err := client.Head(url.String()) 66 if err != nil { 67 return 68 } 69 defer res.Body.Close() 70 71 contentLength = res.ContentLength 72 supportRanges = res.Header.Get(HeaderAcceptRanges) == "bytes" && contentLength != 0 73 return 74 } 75 76 func downloadWithSingleConnection(url url.URL, outputFile *os.File, client *http.Client) error { 77 res, err := client.Get(url.String()) 78 if err != nil { 79 return err 80 } 81 defer res.Body.Close() 82 83 buf := make([]byte, bufferSize) 84 _, err = io.CopyBuffer(outputFile, res.Body, buf) 85 return err 86 } 87 88 func downloadWithMultipleConnections(url url.URL, outputFile *os.File, contentLength int64, parallelism int64, client *http.Client) error { 89 parts := make([]Part, parallelism) 90 partSize := contentLength / parallelism 91 remainder := contentLength % parallelism 92 93 wg, _ := errgroup.WithContext(context.Background()) 94 for i, part := range parts { 95 start := int64(i) * partSize 96 end := start + partSize 97 if remainder != 0 && i == len(parts)-1 { 98 end += remainder 99 } 100 part = Part{start, end, partSize} 101 wg.Go(func() error { 102 return downloadPart(client, url, outputFile.Name(), part) 103 }) 104 } 105 return wg.Wait() 106 } 107 108 func downloadPart(client *http.Client, url url.URL, outputFilepath string, part Part) error { 109 outputFile, err := os.OpenFile(outputFilepath, os.O_RDWR, 0664) 110 if err != nil { 111 return err 112 } 113 defer outputFile.Close() 114 _, err = outputFile.Seek(part.Start, 0) 115 if err != nil { 116 return err 117 } 118 119 req, err := http.NewRequest(http.MethodGet, url.String(), nil) 120 if err != nil { 121 return err 122 } 123 req.Header.Set(HeaderRange, fmt.Sprintf("bytes=%d-%d", part.Start, part.End-1)) 124 res, err := client.Do(req) 125 if err != nil { 126 return err 127 } 128 defer res.Body.Close() 129 if res.StatusCode != http.StatusPartialContent { 130 return fmt.Errorf("server advertises HTTP Range Request support "+ 131 "but doesn't return status %d", http.StatusPartialContent) 132 } 133 134 buf := make([]byte, bufferSize) 135 _, err = io.CopyBuffer(outputFile, res.Body, buf) 136 return err 137 } 138 139 func assertFileValid(outputFile *os.File, contentLength int64) error { 140 fi, err := outputFile.Stat() 141 if err != nil { 142 return err 143 } 144 if fi.Size() != contentLength { 145 return fmt.Errorf("invalid file, content-length %d and file size %d mismatch", contentLength, fi.Size()) 146 } 147 return nil 148 } 149 150 func createHTTPClient(tlsSkipVerify bool, timeout time.Duration, retryDelay time.Duration, 151 retryMaxDelay time.Duration, maxRetries int) *http.Client { 152 153 transport := &http.Transport{ 154 TLSClientConfig: &tls.Config{ 155 InsecureSkipVerify: tlsSkipVerify, //nolint:gosec // on purpose, default is false 156 }, 157 } 158 //nolint:bodyclose // false positive 159 retryPolicy := failsafehttp.RetryPolicyBuilder(). 160 WithBackoff(retryDelay, retryMaxDelay). //nolint:bodyclose // false positive 161 WithMaxRetries(maxRetries). //nolint:bodyclose // false positive 162 Build() //nolint:bodyclose // false positive 163 return &http.Client{ 164 Timeout: timeout, 165 Transport: failsafehttp.NewRoundTripper(transport, retryPolicy), 166 } 167 }