github.com/webx-top/com@v1.2.12/range_downloader.go (about)

     1  package com
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"log"
    11  	"math"
    12  	"net/http"
    13  	"os"
    14  	"strconv"
    15  	"sync"
    16  	"time"
    17  )
    18  
    19  var (
    20  	ErrNoHeaderContentLength = errors.New(`No Content-Length Provided`)
    21  	ErrMd5Unmatched          = errors.New("WARNING: MD5 Sums don't match")
    22  )
    23  
    24  func RangeDownload(url string, saveTo string, args ...int) error {
    25  	threads := 10
    26  	if len(args) > 0 {
    27  		threads = args[0]
    28  	}
    29  	defer timeTrack(time.Now(), "Full download")
    30  	resp, err := http.Get(url)
    31  	if err != nil {
    32  		return err
    33  	}
    34  	contentLength := resp.Header.Get("Content-Length")
    35  	if len(contentLength) < 1 {
    36  		return ErrNoHeaderContentLength
    37  	}
    38  	var startByte int64
    39  	outfile, err := os.OpenFile(saveTo, os.O_RDWR, 0666)
    40  	if err != nil {
    41  		if os.IsNotExist(err) {
    42  			outfile, err = os.Create(saveTo)
    43  		}
    44  	} else {
    45  		stat, err := outfile.Stat()
    46  		if err == nil {
    47  			outfile.Seek(stat.Size(), 0)
    48  			startByte = stat.Size()
    49  		}
    50  	}
    51  	if outfile != nil {
    52  		defer outfile.Close()
    53  	}
    54  	if err != nil {
    55  		return err
    56  	}
    57  	contentSize, _ := strconv.ParseInt(contentLength, 10, 64)
    58  	if resp.Header.Get("Accept-Ranges") == "bytes" {
    59  		var wg sync.WaitGroup
    60  		log.Println("Ranges Supported!")
    61  		log.Println("Content Size:", contentLength, `(`+FormatByte(contentSize)+`)`)
    62  		if contentSize <= startByte {
    63  			log.Println("Download Complete! Total Size:", contentSize, `(`+FormatByte(contentSize)+`)`)
    64  			return nil
    65  		}
    66  		contentSize -= startByte
    67  		calculatedChunksize := contentSize / int64(threads)
    68  		log.Println("Chunk Size: ", calculatedChunksize, `(`+FormatByte(calculatedChunksize)+`)`)
    69  		var endByte int64
    70  		chunks := 0
    71  		completedChunks := 0
    72  		totalChunks := threads
    73  		if math.Mod(float64(contentSize), float64(threads)) > 0 {
    74  			totalChunks++
    75  		}
    76  		lengthStr := strconv.Itoa(len(strconv.Itoa(totalChunks)))
    77  		completedChunkCallback := func() {
    78  			completedChunks++
    79  			log.Println(`Completed`, saveTo, `chunks:`, fmt.Sprintf(`%`+lengthStr+`d`, completedChunks), `/`, totalChunks)
    80  		}
    81  		for i := 0; i < threads; i++ {
    82  			wg.Add(1)
    83  			endByte = startByte + calculatedChunksize
    84  			go fetchChunk(startByte, endByte, url, outfile, &wg, completedChunkCallback)
    85  			startByte = endByte
    86  			chunks++
    87  		}
    88  		if endByte < contentSize {
    89  			wg.Add(1)
    90  			startByte = endByte
    91  			endByte = contentSize
    92  			go fetchChunk(startByte, endByte, url, outfile, &wg, completedChunkCallback)
    93  			chunks++
    94  		}
    95  		wg.Wait()
    96  		log.Println("Download Complete! Total Size:", contentSize, `(`+FormatByte(contentSize)+`)`)
    97  		log.Println("Building File...")
    98  		defer timeTrack(time.Now(), "File Assembled")
    99  		//Verify file size
   100  		filestats, err := outfile.Stat()
   101  		if err != nil {
   102  			return err
   103  		}
   104  		actualFileSize := filestats.Size()
   105  		if actualFileSize != contentSize {
   106  			return errors.New(fmt.Sprint("Actual Size: ", actualFileSize, " Expected: ", contentSize))
   107  		}
   108  		//Verify Md5
   109  		fileHash := resp.Header.Get("X-File-Hash")
   110  		if len(fileHash) == 0 {
   111  			if len(resp.Header["X-Goog-Hash"]) > 1 {
   112  				if len(resp.Header["X-Goog-Hash"][1]) > 4 {
   113  					fileHash = resp.Header["X-Goog-Hash"][1][4:]
   114  				}
   115  			}
   116  		}
   117  		if len(fileHash) > 0 {
   118  			contentMd5, err := hex.DecodeString(fileHash)
   119  			if err != nil {
   120  				return err
   121  			}
   122  			barray, _ := os.ReadFile(saveTo)
   123  			computedHash := md5.Sum(barray)
   124  			computedSlice := computedHash[0:]
   125  			if bytes.Compare(computedSlice, contentMd5) != 0 {
   126  				return ErrMd5Unmatched
   127  			}
   128  			//log.Println("File MD5 Matches!")
   129  		}
   130  		log.Println("File Build Complete!")
   131  		return nil
   132  	}
   133  	log.Println("Range Download unsupported")
   134  	log.Println("Beginning full download...")
   135  	err = fetchChunk(0, contentSize, url, outfile, nil, nil)
   136  	log.Println("Download Complete")
   137  	return err
   138  }
   139  
   140  func assembleChunk(filename string, outfile *os.File) error {
   141  	chunkFile, err := os.Open(filename)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	defer chunkFile.Close()
   146  	_, err = io.Copy(outfile, chunkFile)
   147  	if err != nil {
   148  		return err
   149  	}
   150  	return os.Remove(filename)
   151  }
   152  
   153  func fetchChunk(startByte, endByte int64, url string, outfile *os.File, wg *sync.WaitGroup, callback func()) error {
   154  	if wg != nil {
   155  		defer wg.Done()
   156  	}
   157  	client := new(http.Client)
   158  	req, err := http.NewRequest("GET", url, nil)
   159  	if err != nil {
   160  		return err
   161  	}
   162  	defer func() {
   163  		if err != nil {
   164  			log.Println(err)
   165  			return
   166  		}
   167  		//log.Println("Finished Downloading byte ", startByte, `(`+FormatByte(startByte)+`)`)
   168  		if callback != nil {
   169  			callback()
   170  		}
   171  	}()
   172  
   173  	req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", startByte, endByte-1))
   174  	res, err := client.Do(req)
   175  	/*
   176  		var retry int = 3
   177  		var res *http.Response
   178  		for i := retry; i > 0; i-- {
   179  			res, err = client.Do(req)
   180  			if res.StatusCode == 200 {
   181  				retry = 3
   182  				break
   183  			}
   184  			retry = i
   185  		}
   186  		if retry == 0 && res == nil {
   187  			log.Fatal(err)
   188  			return
   189  		}
   190  	*/
   191  	if err != nil {
   192  		return err
   193  	}
   194  	defer res.Body.Close()
   195  	ra, err := io.ReadAll(res.Body)
   196  	if err != nil {
   197  		return err
   198  	}
   199  	_, err = outfile.WriteAt(ra, startByte)
   200  	return err
   201  }
   202  
   203  func timeTrack(start time.Time, name string) {
   204  	elapsed := time.Since(start)
   205  	log.Printf("%s took %s", name, elapsed)
   206  }