github.com/fzfile/BaiduPCS-Go@v0.0.0-20200606205115-4408961cf336/requester/downloader/downloader.go (about)

     1  // Package downloader 多线程下载器, 重构版
     2  package downloader
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"github.com/fzfile/BaiduPCS-Go/pcsutil"
     8  	"github.com/fzfile/BaiduPCS-Go/pcsutil/cachepool"
     9  	"github.com/fzfile/BaiduPCS-Go/pcsutil/prealloc"
    10  	"github.com/fzfile/BaiduPCS-Go/pcsutil/waitgroup"
    11  	"github.com/fzfile/BaiduPCS-Go/pcsverbose"
    12  	"github.com/fzfile/BaiduPCS-Go/requester"
    13  	"github.com/fzfile/BaiduPCS-Go/requester/rio/speeds"
    14  	"github.com/fzfile/BaiduPCS-Go/requester/transfer"
    15  	"io"
    16  	"net/http"
    17  	"sync"
    18  	"time"
    19  )
    20  
    21  const (
    22  	// DefaultAcceptRanges 默认的 Accept-Ranges
    23  	DefaultAcceptRanges = "bytes"
    24  )
    25  
    26  type (
    27  	// Downloader 下载
    28  	Downloader struct {
    29  		onExecuteEvent        requester.Event    //开始下载事件
    30  		onSuccessEvent        requester.Event    //成功下载事件
    31  		onFinishEvent         requester.Event    //结束下载事件
    32  		onPauseEvent          requester.Event    //暂停下载事件
    33  		onResumeEvent         requester.Event    //恢复下载事件
    34  		onCancelEvent         requester.Event    //取消下载事件
    35  		onDownloadStatusEvent DownloadStatusFunc //状态处理事件
    36  
    37  		monitorCancelFunc context.CancelFunc
    38  
    39  		firstInfo               *DownloadFirstInfo      // 初始信息
    40  		loadBalancerCompareFunc LoadBalancerCompareFunc // 负载均衡检测函数
    41  		durlCheckFunc           DURLCheckFunc           // 下载url检测函数
    42  		statusCodeBodyCheckFunc StatusCodeBodyCheckFunc
    43  		executeTime             time.Time
    44  		durl                    string
    45  		loadBalansers           []string
    46  		writer                  io.WriterAt
    47  		client                  *requester.HTTPClient
    48  		config                  *Config
    49  		monitor                 *Monitor
    50  		instanceState           *InstanceState
    51  	}
    52  
    53  	// DURLCheckFunc 下载URL检测函数
    54  	DURLCheckFunc func(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error)
    55  	// StatusCodeBodyCheckFunc 响应状态码出错的检查函数
    56  	StatusCodeBodyCheckFunc func(respBody io.Reader) error
    57  )
    58  
    59  //NewDownloader 初始化Downloader
    60  func NewDownloader(durl string, writer io.WriterAt, config *Config) (der *Downloader) {
    61  	der = &Downloader{
    62  		durl:   durl,
    63  		config: config,
    64  		writer: writer,
    65  	}
    66  
    67  	return
    68  }
    69  
    70  // SetFirstInfo 设置初始信息
    71  // 如果设置了此值, 将忽略检测url
    72  func (der *Downloader) SetFirstInfo(i *DownloadFirstInfo) {
    73  	der.firstInfo = i
    74  }
    75  
    76  //SetClient 设置http客户端
    77  func (der *Downloader) SetClient(client *requester.HTTPClient) {
    78  	der.client = client
    79  }
    80  
    81  // SetDURLCheckFunc 设置下载URL检测函数
    82  func (der *Downloader) SetDURLCheckFunc(f DURLCheckFunc) {
    83  	der.durlCheckFunc = f
    84  }
    85  
    86  // SetLoadBalancerCompareFunc 设置负载均衡检测函数
    87  func (der *Downloader) SetLoadBalancerCompareFunc(f LoadBalancerCompareFunc) {
    88  	der.loadBalancerCompareFunc = f
    89  }
    90  
    91  //SetStatusCodeBodyCheckFunc 设置响应状态码出错的检查函数, 当FirstCheckMethod不为HEAD时才有效
    92  func (der *Downloader) SetStatusCodeBodyCheckFunc(f StatusCodeBodyCheckFunc) {
    93  	der.statusCodeBodyCheckFunc = f
    94  }
    95  
    96  func (der *Downloader) lazyInit() {
    97  	if der.config == nil {
    98  		der.config = NewConfig()
    99  	}
   100  	if der.client == nil {
   101  		der.client = requester.NewHTTPClient()
   102  		der.client.SetTimeout(20 * time.Minute)
   103  	}
   104  	if der.monitor == nil {
   105  		der.monitor = NewMonitor()
   106  	}
   107  	if der.durlCheckFunc == nil {
   108  		der.durlCheckFunc = DefaultDURLCheckFunc
   109  	}
   110  	if der.loadBalancerCompareFunc == nil {
   111  		der.loadBalancerCompareFunc = DefaultLoadBalancerCompareFunc
   112  	}
   113  }
   114  
   115  // SelectParallel 获取合适的 parallel
   116  func (der *Downloader) SelectParallel(single bool, maxParallel int, totalSize int64, instanceRangeList transfer.RangeList) (parallel int) {
   117  	isRange := instanceRangeList != nil && len(instanceRangeList) > 0
   118  	if single { //不支持多线程
   119  		parallel = 1
   120  	} else if isRange {
   121  		parallel = len(instanceRangeList)
   122  	} else {
   123  		parallel = der.config.MaxParallel
   124  		if int64(parallel) > totalSize/int64(MinParallelSize) {
   125  			parallel = int(totalSize/int64(MinParallelSize)) + 1
   126  		}
   127  	}
   128  
   129  	if parallel < 1 {
   130  		parallel = 1
   131  	}
   132  	return
   133  }
   134  
   135  // SelectBlockSizeAndInitRangeGen 获取合适的 BlockSize, 和初始化 RangeGen
   136  func (der *Downloader) SelectBlockSizeAndInitRangeGen(single bool, status *transfer.DownloadStatus, parallel int) (blockSize int64, initErr error) {
   137  	// Range 生成器
   138  	if single { // 单线程
   139  		blockSize = -1
   140  		return
   141  	}
   142  	gen := status.RangeListGen()
   143  	if gen == nil {
   144  		switch der.config.Mode {
   145  		case transfer.RangeGenMode_Default:
   146  			gen = transfer.NewRangeListGenDefault(status.TotalSize(), 0, 0, parallel)
   147  			blockSize = gen.LoadBlockSize()
   148  		case transfer.RangeGenMode_BlockSize:
   149  			b2 := status.TotalSize()/int64(parallel) + 1
   150  			if b2 > der.config.BlockSize { // 选小的BlockSize, 以更高并发
   151  				blockSize = der.config.BlockSize
   152  			} else {
   153  				blockSize = b2
   154  			}
   155  
   156  			gen = transfer.NewRangeListGenBlockSize(status.TotalSize(), 0, blockSize)
   157  		default:
   158  			initErr = transfer.ErrUnknownRangeGenMode
   159  			return
   160  		}
   161  	} else {
   162  		blockSize = gen.LoadBlockSize()
   163  	}
   164  	status.SetRangeListGen(gen)
   165  	return
   166  }
   167  
   168  // SelectCacheSize 获取合适的 cacheSize
   169  func (der *Downloader) SelectCacheSize(confCacheSize int, blockSize int64) (cacheSize int) {
   170  	if blockSize > 0 && int64(confCacheSize) > blockSize {
   171  		// 如果 cache size 过高, 则调低
   172  		cacheSize = int(blockSize)
   173  	} else {
   174  		cacheSize = confCacheSize
   175  	}
   176  	return
   177  }
   178  
   179  // DefaultDURLCheckFunc 默认的 DURLCheckFunc
   180  func DefaultDURLCheckFunc(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) {
   181  	resp, err = client.Req(http.MethodGet, durl, nil, nil)
   182  	if err != nil {
   183  		if resp != nil {
   184  			resp.Body.Close()
   185  		}
   186  		return 0, nil, err
   187  	}
   188  	return resp.ContentLength, resp, nil
   189  }
   190  
   191  func (der *Downloader) checkLoadBalancers() *LoadBalancerResponseList {
   192  	var (
   193  		loadBalancerResponses = make([]*LoadBalancerResponse, 0, len(der.loadBalansers)+1)
   194  		handleLoadBalancer    = func(req *http.Request) {
   195  			if req == nil {
   196  				return
   197  			}
   198  
   199  			if der.config.TryHTTP {
   200  				req.URL.Scheme = "http"
   201  			}
   202  
   203  			loadBalancer := &LoadBalancerResponse{
   204  				URL:     req.URL.String(),
   205  				Referer: req.Referer(),
   206  			}
   207  
   208  			loadBalancerResponses = append(loadBalancerResponses, loadBalancer)
   209  			pcsverbose.Verbosef("DEBUG: load balance task: URL: %s, Referer: %s\n", loadBalancer.URL, loadBalancer.Referer)
   210  		}
   211  	)
   212  
   213  	// 加入第一个
   214  	loadBalancerResponses = append(loadBalancerResponses, &LoadBalancerResponse{
   215  		URL: der.durl,
   216  	})
   217  
   218  	// 负载均衡
   219  	wg := waitgroup.NewWaitGroup(10)
   220  	privTimeout := der.client.Client.Timeout
   221  	der.client.SetTimeout(5 * time.Second)
   222  	for _, loadBalanser := range der.loadBalansers {
   223  		wg.AddDelta()
   224  		go func(loadBalanser string) {
   225  			defer wg.Done()
   226  
   227  			subContentLength, subResp, subErr := der.durlCheckFunc(der.client, loadBalanser)
   228  			if subResp != nil {
   229  				subResp.Body.Close() // 不读Body, 马上关闭连接
   230  			}
   231  			if subErr != nil {
   232  				pcsverbose.Verbosef("DEBUG: loadBalanser Error: %s\n", subErr)
   233  				return
   234  			}
   235  
   236  			// 检测状态码
   237  			switch subResp.StatusCode / 100 {
   238  			case 2: // succeed
   239  			case 4, 5: // error
   240  				var err error
   241  				if der.statusCodeBodyCheckFunc != nil {
   242  					err = der.statusCodeBodyCheckFunc(subResp.Body)
   243  				} else {
   244  					err = errors.New(subResp.Status)
   245  				}
   246  				pcsverbose.Verbosef("DEBUG: loadBalanser Status Error: %s\n", err)
   247  				return
   248  			}
   249  
   250  			// 检测长度
   251  			if der.firstInfo.ContentLength != subContentLength {
   252  				pcsverbose.Verbosef("DEBUG: loadBalanser Content-Length not equal to main server\n")
   253  				return
   254  			}
   255  
   256  			if !der.loadBalancerCompareFunc(der.firstInfo.ToMap(), subResp) {
   257  				pcsverbose.Verbosef("DEBUG: loadBalanser not equal to main server\n")
   258  				return
   259  			}
   260  
   261  			handleLoadBalancer(subResp.Request)
   262  		}(loadBalanser)
   263  	}
   264  	wg.Wait()
   265  	der.client.SetTimeout(privTimeout)
   266  
   267  	loadBalancerResponseList := NewLoadBalancerResponseList(loadBalancerResponses)
   268  	return loadBalancerResponseList
   269  }
   270  
   271  //Execute 开始任务
   272  func (der *Downloader) Execute() error {
   273  	der.lazyInit()
   274  
   275  	var (
   276  		resp *http.Response
   277  	)
   278  	if der.firstInfo == nil {
   279  		// 检测
   280  		contentLength, resp, err := der.durlCheckFunc(der.client, der.durl)
   281  		if err != nil {
   282  			return err
   283  		}
   284  
   285  		// 检测网络错误
   286  		switch resp.StatusCode / 100 {
   287  		case 2: // succeed
   288  		case 4, 5: // error
   289  			if der.statusCodeBodyCheckFunc != nil {
   290  				err = der.statusCodeBodyCheckFunc(resp.Body)
   291  				resp.Body.Close() // 关闭连接
   292  				if err != nil {
   293  					return err
   294  				}
   295  			}
   296  			return errors.New(resp.Status)
   297  		}
   298  
   299  		acceptRanges := resp.Header.Get("Accept-Ranges")
   300  		if contentLength < 0 {
   301  			acceptRanges = ""
   302  		} else {
   303  			acceptRanges = DefaultAcceptRanges
   304  		}
   305  
   306  		// 初始化firstInfo
   307  		der.firstInfo = &DownloadFirstInfo{
   308  			ContentLength: contentLength,
   309  			ContentMD5:    resp.Header.Get("Content-MD5"),
   310  			ContentCRC32:  resp.Header.Get("x-bs-meta-crc32"),
   311  			AcceptRanges:  acceptRanges,
   312  			Referer:       resp.Header.Get("Referer"),
   313  		}
   314  		pcsverbose.Verbosef("DEBUG: download task: URL: %s, Referer: %s\n", resp.Request.URL, resp.Request.Referer())
   315  	} else {
   316  		if der.firstInfo.AcceptRanges == "" {
   317  			der.firstInfo.AcceptRanges = DefaultAcceptRanges
   318  		}
   319  	}
   320  
   321  	var (
   322  		loadBalancerResponseList = der.checkLoadBalancers()
   323  		single                   = der.firstInfo.AcceptRanges == ""
   324  		bii                      *transfer.DownloadInstanceInfo
   325  	)
   326  
   327  	if !single {
   328  		//load breakpoint
   329  		//服务端不支持多线程时, 不记录断点
   330  		err := der.initInstanceState(der.config.InstanceStateStorageFormat)
   331  		if err != nil {
   332  			return err
   333  		}
   334  		bii = der.instanceState.Get()
   335  	}
   336  
   337  	var (
   338  		isInstance = bii != nil // 是否存在断点信息
   339  		status     *transfer.DownloadStatus
   340  	)
   341  	if !isInstance {
   342  		bii = &transfer.DownloadInstanceInfo{}
   343  	}
   344  
   345  	if bii.DownloadStatus != nil {
   346  		// 使用断点信息的状态
   347  		status = bii.DownloadStatus
   348  	} else {
   349  		// 新建状态
   350  		status = transfer.NewDownloadStatus()
   351  		status.SetTotalSize(der.firstInfo.ContentLength)
   352  	}
   353  
   354  	// 设置限速
   355  	if der.config.MaxRate > 0 {
   356  		rl := speeds.NewRateLimit(der.config.MaxRate)
   357  		status.SetRateLimit(rl)
   358  		defer rl.Stop()
   359  	}
   360  
   361  	// 数据处理
   362  	parallel := der.SelectParallel(single, der.config.MaxParallel, status.TotalSize(), bii.Ranges) // 实际的下载并行量
   363  	blockSize, err := der.SelectBlockSizeAndInitRangeGen(single, status, parallel)                 // 实际的BlockSize
   364  	if err != nil {
   365  		return err
   366  	}
   367  
   368  	cacheSize := der.SelectCacheSize(der.config.CacheSize, blockSize) // 实际下载缓存
   369  	cachepool.SetSyncPoolSize(cacheSize)                              // 调整pool大小
   370  
   371  	pcsverbose.Verbosef("DEBUG: download task CREATED: parallel: %d, cache size: %d\n", parallel, cacheSize)
   372  
   373  	der.monitor.InitMonitorCapacity(parallel)
   374  
   375  	var writer Writer
   376  	if !der.config.IsTest {
   377  		// 尝试修剪文件
   378  		if fder, ok := der.writer.(Fder); ok {
   379  			err = prealloc.PreAlloc(fder.Fd(), status.TotalSize())
   380  			if err != nil {
   381  				pcsverbose.Verbosef("DEBUG: truncate file error: %s\n", err)
   382  			}
   383  		}
   384  		writer = der.writer // 非测试模式, 赋值writer
   385  	}
   386  
   387  	// 数据平均分配给各个线程
   388  	isRange := bii.Ranges != nil && len(bii.Ranges) > 0
   389  	if !isRange {
   390  		// 没有使用断点续传
   391  		// 分配线程
   392  		bii.Ranges = make(transfer.RangeList, 0, parallel)
   393  		if single { // 单线程
   394  			bii.Ranges = append(bii.Ranges, &transfer.Range{})
   395  		} else {
   396  			gen := status.RangeListGen()
   397  			for i := 0; i < cap(bii.Ranges); i++ {
   398  				_, r := gen.GenRange()
   399  				if r == nil { // 没有了(不正常)
   400  					break
   401  				}
   402  				bii.Ranges = append(bii.Ranges, r)
   403  			}
   404  		}
   405  	}
   406  
   407  	var (
   408  		writeMu = &sync.Mutex{}
   409  	)
   410  	for k, r := range bii.Ranges {
   411  		loadBalancer := loadBalancerResponseList.SequentialGet()
   412  		if loadBalancer == nil {
   413  			continue
   414  		}
   415  
   416  		worker := NewWorker(k, loadBalancer.URL, writer)
   417  		worker.SetClient(der.client)
   418  		worker.SetWriteMutex(writeMu)
   419  		worker.SetReferer(loadBalancer.Referer)
   420  		worker.SetTotalSize(der.firstInfo.ContentLength)
   421  
   422  		// 使用第一个连接
   423  		// 断点续传时不使用
   424  		if k == 0 && !isInstance {
   425  			worker.firstResp = resp
   426  		}
   427  
   428  		worker.SetAcceptRange(der.firstInfo.AcceptRanges)
   429  		worker.SetRange(r) // 分配Range
   430  		der.monitor.Append(worker)
   431  	}
   432  
   433  	der.monitor.SetStatus(status)
   434  
   435  	// 服务器不支持断点续传, 或者单线程下载, 都不重载worker
   436  	der.monitor.SetReloadWorker(parallel > 1)
   437  
   438  	moniterCtx, moniterCancelFunc := context.WithCancel(context.Background())
   439  	der.monitorCancelFunc = moniterCancelFunc
   440  
   441  	der.monitor.SetInstanceState(der.instanceState)
   442  
   443  	// 开始执行
   444  	der.executeTime = time.Now()
   445  	pcsutil.Trigger(der.onExecuteEvent)
   446  	der.downloadStatusEvent() // 启动执行状态处理事件
   447  	der.monitor.Execute(moniterCtx)
   448  
   449  	// 检查错误
   450  	err = der.monitor.Err()
   451  	if err == nil { // 成功
   452  		pcsutil.Trigger(der.onSuccessEvent)
   453  		if !single {
   454  			der.removeInstanceState() // 移除断点续传文件
   455  		}
   456  	}
   457  
   458  	// 执行结束
   459  	pcsutil.Trigger(der.onFinishEvent)
   460  	return err
   461  }
   462  
   463  //downloadStatusEvent 执行状态处理事件
   464  func (der *Downloader) downloadStatusEvent() {
   465  	if der.onDownloadStatusEvent == nil {
   466  		return
   467  	}
   468  
   469  	status := der.monitor.Status()
   470  	go func() {
   471  		ticker := time.NewTicker(1 * time.Second)
   472  		defer ticker.Stop()
   473  		for {
   474  			select {
   475  			case <-der.monitor.completed:
   476  				return
   477  			case <-ticker.C:
   478  				der.onDownloadStatusEvent(status, der.monitor.RangeWorker)
   479  			}
   480  		}
   481  	}()
   482  }
   483  
   484  //Pause 暂停
   485  func (der *Downloader) Pause() {
   486  	if der.monitor == nil {
   487  		return
   488  	}
   489  	pcsutil.Trigger(der.onPauseEvent)
   490  	der.monitor.Pause()
   491  }
   492  
   493  //Resume 恢复
   494  func (der *Downloader) Resume() {
   495  	if der.monitor == nil {
   496  		return
   497  	}
   498  	pcsutil.Trigger(der.onResumeEvent)
   499  	der.monitor.Resume()
   500  }
   501  
   502  //Cancel 取消
   503  func (der *Downloader) Cancel() {
   504  	if der.monitor == nil {
   505  		return
   506  	}
   507  	pcsutil.Trigger(der.onCancelEvent)
   508  	pcsutil.Trigger(der.monitorCancelFunc)
   509  }
   510  
   511  //OnExecute 设置开始下载事件
   512  func (der *Downloader) OnExecute(onExecuteEvent requester.Event) {
   513  	der.onExecuteEvent = onExecuteEvent
   514  }
   515  
   516  //OnSuccess 设置成功下载事件
   517  func (der *Downloader) OnSuccess(onSuccessEvent requester.Event) {
   518  	der.onSuccessEvent = onSuccessEvent
   519  }
   520  
   521  //OnFinish 设置结束下载事件
   522  func (der *Downloader) OnFinish(onFinishEvent requester.Event) {
   523  	der.onFinishEvent = onFinishEvent
   524  }
   525  
   526  //OnPause 设置暂停下载事件
   527  func (der *Downloader) OnPause(onPauseEvent requester.Event) {
   528  	der.onPauseEvent = onPauseEvent
   529  }
   530  
   531  //OnResume 设置恢复下载事件
   532  func (der *Downloader) OnResume(onResumeEvent requester.Event) {
   533  	der.onResumeEvent = onResumeEvent
   534  }
   535  
   536  //OnCancel 设置取消下载事件
   537  func (der *Downloader) OnCancel(onCancelEvent requester.Event) {
   538  	der.onCancelEvent = onCancelEvent
   539  }
   540  
   541  //OnDownloadStatusEvent 设置状态处理函数
   542  func (der *Downloader) OnDownloadStatusEvent(f DownloadStatusFunc) {
   543  	der.onDownloadStatusEvent = f
   544  }