github.com/qjfoidnh/BaiduPCS-Go@v0.0.0-20231011165705-caa18a3765f3/requester/downloader/downloader.go (about)

     1  // Package downloader 多线程下载器, 重构版
     2  package downloader
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"github.com/qjfoidnh/BaiduPCS-Go/pcsutil"
     8  	"github.com/qjfoidnh/BaiduPCS-Go/pcsutil/cachepool"
     9  	"github.com/qjfoidnh/BaiduPCS-Go/pcsutil/converter"
    10  	"github.com/qjfoidnh/BaiduPCS-Go/pcsutil/prealloc"
    11  	"github.com/qjfoidnh/BaiduPCS-Go/pcsutil/waitgroup"
    12  	"github.com/qjfoidnh/BaiduPCS-Go/pcsverbose"
    13  	"github.com/qjfoidnh/BaiduPCS-Go/requester"
    14  	"github.com/qjfoidnh/BaiduPCS-Go/requester/rio/speeds"
    15  	"github.com/qjfoidnh/BaiduPCS-Go/requester/transfer"
    16  	"io"
    17  	"net/http"
    18  	"sync"
    19  	"time"
    20  )
    21  
    22  const (
    23  	// DefaultAcceptRanges 默认的 Accept-Ranges
    24  	DefaultAcceptRanges = "bytes"
    25  )
    26  
    27  var BlockSizeList = [6]int64{128*converter.KB, 256*converter.KB, 1024*converter.KB, 2*converter.MB, 4*converter.MB, 999*converter.GB}
    28  
    29  type (
    30  	// Downloader 下载
    31  	Downloader struct {
    32  		onExecuteEvent        requester.Event    //开始下载事件
    33  		onSuccessEvent        requester.Event    //成功下载事件
    34  		onFinishEvent         requester.Event    //结束下载事件
    35  		onPauseEvent          requester.Event    //暂停下载事件
    36  		onResumeEvent         requester.Event    //恢复下载事件
    37  		onCancelEvent         requester.Event    //取消下载事件
    38  		onDownloadStatusEvent DownloadStatusFunc //状态处理事件
    39  
    40  		monitorCancelFunc context.CancelFunc
    41  
    42  		firstInfo               *DownloadFirstInfo      // 初始信息
    43  		loadBalancerCompareFunc LoadBalancerCompareFunc // 负载均衡检测函数
    44  		durlCheckFunc           DURLCheckFunc           // 下载url检测函数
    45  		statusCodeBodyCheckFunc StatusCodeBodyCheckFunc
    46  		executeTime             time.Time
    47  		durl                    string
    48  		loadBalansers           []string
    49  		writer                  io.WriterAt
    50  		client                  *requester.HTTPClient
    51  		config                  *Config
    52  		monitor                 *Monitor
    53  		instanceState           *InstanceState
    54  	}
    55  
    56  	// DURLCheckFunc 下载URL检测函数
    57  	DURLCheckFunc func(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error)
    58  	// StatusCodeBodyCheckFunc 响应状态码出错的检查函数
    59  	StatusCodeBodyCheckFunc func(respBody io.Reader) error
    60  )
    61  
    62  //NewDownloader 初始化Downloader
    63  func NewDownloader(durl string, writer io.WriterAt, config *Config) (der *Downloader) {
    64  	der = &Downloader{
    65  		durl:   durl,
    66  		config: config,
    67  		writer: writer,
    68  	}
    69  
    70  	return
    71  }
    72  
    73  // SetFirstInfo 设置初始信息
    74  // 如果设置了此值, 将忽略检测url
    75  func (der *Downloader) SetFirstInfo(i *DownloadFirstInfo) {
    76  	der.firstInfo = i
    77  }
    78  
    79  //SetClient 设置http客户端
    80  func (der *Downloader) SetClient(client *requester.HTTPClient) {
    81  	der.client = client
    82  }
    83  
    84  // SetDURLCheckFunc 设置下载URL检测函数
    85  func (der *Downloader) SetDURLCheckFunc(f DURLCheckFunc) {
    86  	der.durlCheckFunc = f
    87  }
    88  
    89  func (der *Downloader) SetFileContentLength(length int64) {
    90  	if der.firstInfo == nil {
    91  		der.firstInfo = &DownloadFirstInfo{
    92  			ContentLength: length,
    93  			AcceptRanges : DefaultAcceptRanges,
    94  		}
    95  	}
    96  }
    97  // SetLoadBalancerCompareFunc 设置负载均衡检测函数
    98  func (der *Downloader) SetLoadBalancerCompareFunc(f LoadBalancerCompareFunc) {
    99  	der.loadBalancerCompareFunc = f
   100  }
   101  
   102  //SetStatusCodeBodyCheckFunc 设置响应状态码出错的检查函数, 当FirstCheckMethod不为HEAD时才有效
   103  func (der *Downloader) SetStatusCodeBodyCheckFunc(f StatusCodeBodyCheckFunc) {
   104  	der.statusCodeBodyCheckFunc = f
   105  }
   106  
   107  func (der *Downloader) lazyInit() {
   108  	if der.config == nil {
   109  		der.config = NewConfig()
   110  	}
   111  	if der.client == nil {
   112  		der.client = requester.NewHTTPClient()
   113  		der.client.SetTimeout(5 * time.Minute)
   114  	}
   115  	if der.monitor == nil {
   116  		der.monitor = NewMonitor()
   117  	}
   118  	if der.durlCheckFunc == nil {
   119  		der.durlCheckFunc = DefaultDURLCheckFunc
   120  	}
   121  	if der.loadBalancerCompareFunc == nil {
   122  		der.loadBalancerCompareFunc = DefaultLoadBalancerCompareFunc
   123  	}
   124  }
   125  
   126  // SelectParallel 获取合适的 parallel
   127  func (der *Downloader) SelectParallel(single bool, maxParallel int, totalSize int64, instanceRangeList transfer.RangeList) (parallel int) {
   128  	isRange := instanceRangeList != nil && len(instanceRangeList) > 0
   129  	if single { //不支持多线程
   130  		parallel = 1
   131  	} else if isRange {
   132  		parallel = len(instanceRangeList)
   133  	} else {
   134  		parallel = der.config.MaxParallel
   135  		if int64(parallel) > totalSize/int64(MinParallelSize) {
   136  			parallel = int(totalSize/int64(MinParallelSize)) + 1
   137  		}
   138  	}
   139  
   140  	if parallel < 1 {
   141  		parallel = 1
   142  	}
   143  	return
   144  }
   145  
   146  // SelectBlockSizeAndInitRangeGen 获取合适的 BlockSize, 和初始化 RangeGen
   147  func (der *Downloader) SelectBlockSizeAndInitRangeGen(single bool, status *transfer.DownloadStatus, parallel int) (blockSize int64, initErr error) {
   148  	// Range 生成器
   149  	if single { // 单线程
   150  		blockSize = -1
   151  		return
   152  	}
   153  	gen := status.RangeListGen()
   154  	if gen == nil {
   155  		switch der.config.Mode {
   156  		case transfer.RangeGenMode_Default:
   157  			gen = transfer.NewRangeListGenDefault(status.TotalSize(), 0, 0, parallel)
   158  			blockSize = gen.LoadBlockSize()
   159  		case transfer.RangeGenMode_BlockSize:
   160  			//b2 := status.TotalSize()/int64(parallel) + 1
   161  			//if b2 > der.config.BlockSize { // 选小的BlockSize, 以更高并发
   162  			//	blockSize = der.config.BlockSize
   163  			//} else {
   164  			//	blockSize = b2
   165  			//}
   166  			totalSize := status.TotalSize()
   167  			if totalSize < 2 * converter.MB {
   168  				blockSize = BlockSizeList[1]
   169  			} else if totalSize < 10 * converter.MB {
   170  				blockSize = BlockSizeList[2]
   171  			} else if totalSize < 80 * converter.MB {
   172  				blockSize = BlockSizeList[3]
   173  			} else {
   174  				blockSize = BlockSizeList[4]
   175  			}
   176  			gen = transfer.NewRangeListGenBlockSize(totalSize, 0, blockSize)
   177  		default:
   178  			initErr = transfer.ErrUnknownRangeGenMode
   179  			return
   180  		}
   181  	} else {
   182  		blockSize = gen.LoadBlockSize()
   183  	}
   184  	status.SetRangeListGen(gen)
   185  	return
   186  }
   187  
   188  // SelectCacheSize 获取合适的 cacheSize
   189  func (der *Downloader) SelectCacheSize(confCacheSize int, blockSize int64) (cacheSize int) {
   190  	if blockSize > 0 && int64(confCacheSize) > blockSize {
   191  		// 如果 cache size 过高, 则调低
   192  		cacheSize = int(blockSize)
   193  	} else {
   194  		cacheSize = confCacheSize
   195  	}
   196  	return
   197  }
   198  
   199  // DefaultDURLCheckFunc 默认的 DURLCheckFunc
   200  func DefaultDURLCheckFunc(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) {
   201  	resp, err = client.Req(http.MethodGet, durl, nil, nil)
   202  	if err != nil {
   203  		if resp != nil {
   204  			resp.Body.Close()
   205  		}
   206  		return 0, nil, err
   207  	}
   208  	return resp.ContentLength, resp, nil
   209  }
   210  
   211  func (der *Downloader) checkLoadBalancers() *LoadBalancerResponseList {
   212  	var (
   213  		loadBalancerResponses = make([]*LoadBalancerResponse, 0, len(der.loadBalansers)+1)
   214  		handleLoadBalancer    = func(req *http.Request) {
   215  			if req == nil {
   216  				return
   217  			}
   218  
   219  			if der.config.TryHTTP {
   220  				req.URL.Scheme = "http"
   221  			}
   222  
   223  			loadBalancer := &LoadBalancerResponse{
   224  				URL:     req.URL.String(),
   225  				Referer: req.Referer(),
   226  			}
   227  
   228  			loadBalancerResponses = append(loadBalancerResponses, loadBalancer)
   229  			pcsverbose.Verbosef("DEBUG: load balance task: URL: %s, Referer: %s\n", loadBalancer.URL, loadBalancer.Referer)
   230  		}
   231  	)
   232  
   233  	// 加入第一个
   234  	loadBalancerResponses = append(loadBalancerResponses, &LoadBalancerResponse{
   235  		URL: der.durl,
   236  	})
   237  
   238  	// 多下载服务器的负载均衡, 在本项目中无意义,因为百度只用单下载服务器
   239  	wg := waitgroup.NewWaitGroup(4)
   240  	privTimeout := der.client.Client.Timeout
   241  	der.client.SetTimeout(5 * time.Second)
   242  	for _, loadBalanser := range der.loadBalansers { // 这里服务器列表数量为0,所以逻辑实际未生效
   243  		wg.AddDelta()
   244  		go func(loadBalanser string) {
   245  			defer wg.Done()
   246  
   247  			subContentLength, subResp, subErr := der.durlCheckFunc(der.client, loadBalanser)
   248  			if subResp != nil {
   249  				subResp.Body.Close() // 不读Body, 马上关闭连接
   250  			}
   251  			if subErr != nil {
   252  				pcsverbose.Verbosef("DEBUG: loadBalanser Error: %s\n", subErr)
   253  				return
   254  			}
   255  
   256  			// 检测状态码
   257  			switch subResp.StatusCode / 100 {
   258  			case 2: // succeed
   259  			case 4, 5: // error
   260  				var err error
   261  				if der.statusCodeBodyCheckFunc != nil {
   262  					err = der.statusCodeBodyCheckFunc(subResp.Body)
   263  				} else {
   264  					err = errors.New(subResp.Status)
   265  				}
   266  				pcsverbose.Verbosef("DEBUG: loadBalanser Status Error: %s\n", err)
   267  				return
   268  			}
   269  
   270  			// 检测长度
   271  			if der.firstInfo.ContentLength != subContentLength {
   272  				pcsverbose.Verbosef("DEBUG: loadBalanser Content-Length not equal to main server\n")
   273  				return
   274  			}
   275  
   276  			if !der.loadBalancerCompareFunc(der.firstInfo.ToMap(), subResp) {
   277  				pcsverbose.Verbosef("DEBUG: loadBalanser not equal to main server\n")
   278  				return
   279  			}
   280  
   281  			handleLoadBalancer(subResp.Request)
   282  		}(loadBalanser)
   283  	}
   284  	wg.Wait()
   285  	der.client.SetTimeout(privTimeout)
   286  
   287  	loadBalancerResponseList := NewLoadBalancerResponseList(loadBalancerResponses)
   288  	return loadBalancerResponseList
   289  }
   290  
   291  //Execute 开始任务
   292  func (der *Downloader) Execute() error {
   293  	der.lazyInit()
   294  	var (
   295  		resp *http.Response
   296  	)
   297  	if der.firstInfo == nil {
   298  		// 检测
   299  		contentLength, resp, err := der.durlCheckFunc(der.client, der.durl)
   300  		if err != nil {
   301  			return err
   302  		}
   303  
   304  		// 检测网络错误
   305  		switch resp.StatusCode / 100 {
   306  		case 2: // succeed
   307  		case 4, 5: // error
   308  			if der.statusCodeBodyCheckFunc != nil {
   309  				err = der.statusCodeBodyCheckFunc(resp.Body)
   310  				resp.Body.Close() // 关闭连接
   311  				if err != nil {
   312  					return err
   313  				}
   314  			}
   315  			return errors.New(resp.Status)
   316  		}
   317  
   318  		acceptRanges := resp.Header.Get("Accept-Ranges")
   319  		if contentLength < 0 {
   320  			acceptRanges = ""
   321  		} else {
   322  			acceptRanges = DefaultAcceptRanges
   323  		}
   324  
   325  		// 初始化firstInfo
   326  		der.firstInfo = &DownloadFirstInfo{
   327  			ContentLength: contentLength,
   328  			ContentMD5:    resp.Header.Get("Content-MD5"),
   329  			ContentCRC32:  resp.Header.Get("x-bs-meta-crc32"),
   330  			AcceptRanges:  acceptRanges,
   331  			Referer:       resp.Header.Get("Referer"),
   332  		}
   333  		pcsverbose.Verbosef("DEBUG: download task: URL: %s, Referer: %s\n", resp.Request.URL, resp.Request.Referer())
   334  	} else {
   335  		if der.firstInfo.AcceptRanges == "" {
   336  			der.firstInfo.AcceptRanges = DefaultAcceptRanges
   337  		}
   338  	}
   339  
   340  	var (
   341  		loadBalancerResponseList = der.checkLoadBalancers()
   342  		single                   = der.firstInfo.AcceptRanges == ""
   343  		bii                      *transfer.DownloadInstanceInfo
   344  	)
   345  
   346  	if !single {
   347  		//load breakpoint
   348  		//服务端不支持多线程时, 不记录断点
   349  		err := der.initInstanceState(der.config.InstanceStateStorageFormat)
   350  		if err != nil {
   351  			return err
   352  		}
   353  		bii = der.instanceState.Get()
   354  	}
   355  
   356  	var (
   357  		isInstance = bii != nil // 是否存在断点信息
   358  		status     *transfer.DownloadStatus
   359  	)
   360  	if !isInstance {
   361  		bii = &transfer.DownloadInstanceInfo{}
   362  	}
   363  
   364  	if bii.DownloadStatus != nil {
   365  		// 使用断点信息的状态
   366  		status = bii.DownloadStatus
   367  	} else {
   368  		// 新建状态
   369  		status = transfer.NewDownloadStatus()
   370  		status.SetTotalSize(der.firstInfo.ContentLength)
   371  	}
   372  
   373  	// 设置限速
   374  	if der.config.MaxRate > 0 {
   375  		rl := speeds.NewRateLimit(der.config.MaxRate)
   376  		status.SetRateLimit(rl)
   377  		defer rl.Stop()
   378  	}
   379  
   380  	// 数据处理
   381  	parallel := der.SelectParallel(single, der.config.MaxParallel, status.TotalSize(), bii.Ranges) // 实际的下载并行量
   382  	blockSize, err := der.SelectBlockSizeAndInitRangeGen(single, status, parallel)                 // 实际的BlockSize
   383  	if err != nil {
   384  		return err
   385  	}
   386  
   387  	cacheSize := der.SelectCacheSize(der.config.CacheSize, blockSize) // 实际下载缓存
   388  	cachepool.SetSyncPoolSize(cacheSize)                              // 调整pool大小
   389  
   390  	pcsverbose.Verbosef("DEBUG: download task CREATED: parallel: %d, cache size: %d\n", parallel, cacheSize)
   391  
   392  	der.monitor.InitMonitorCapacity(parallel)
   393  
   394  	var writer Writer
   395  	if !der.config.IsTest {
   396  		// 尝试修剪文件
   397  		if fder, ok := der.writer.(Fder); ok {
   398  			err = prealloc.PreAlloc(fder.Fd(), status.TotalSize())
   399  			if err != nil {
   400  				pcsverbose.Verbosef("DEBUG: truncate file error: %s\n", err)
   401  			}
   402  		}
   403  		writer = der.writer // 非测试模式, 赋值writer
   404  	}
   405  
   406  	// 数据平均分配给各个线程
   407  	isRange := bii.Ranges != nil && len(bii.Ranges) > 0
   408  	if !isRange {
   409  		// 没有使用断点续传
   410  		// 分配线程
   411  		bii.Ranges = make(transfer.RangeList, 0, parallel)
   412  		if single { // 单线程
   413  			bii.Ranges = append(bii.Ranges, &transfer.Range{})
   414  		} else {
   415  			gen := status.RangeListGen()
   416  			for i := 0; i < cap(bii.Ranges); i++ {
   417  				_, r := gen.GenRange()
   418  				if r == nil { // 没有了(不正常)
   419  					break
   420  				}
   421  				bii.Ranges = append(bii.Ranges, r)
   422  			}
   423  		}
   424  	}
   425  
   426  	var (
   427  		writeMu = &sync.Mutex{}
   428  	)
   429  	for k, r := range bii.Ranges {
   430  		loadBalancer := loadBalancerResponseList.SequentialGet()
   431  		if loadBalancer == nil {
   432  			continue
   433  		}
   434  
   435  		worker := NewWorker(k, loadBalancer.URL, writer)
   436  		worker.SetClient(der.client)
   437  		worker.SetWriteMutex(writeMu)
   438  		worker.SetReferer(loadBalancer.Referer)
   439  		worker.SetTotalSize(der.firstInfo.ContentLength)
   440  
   441  		// 使用第一个连接
   442  		// 断点续传时不使用
   443  		if k == 0 && !isInstance {
   444  			worker.firstResp = resp
   445  		}
   446  
   447  		worker.SetAcceptRange(der.firstInfo.AcceptRanges)
   448  		worker.SetRange(r) // 分配Range
   449  		der.monitor.Append(worker)
   450  	}
   451  
   452  	der.monitor.SetStatus(status)
   453  
   454  	// 服务器不支持断点续传, 或者单线程下载, 都不重载worker
   455  	der.monitor.SetReloadWorker(parallel > 1)
   456  
   457  	moniterCtx, moniterCancelFunc := context.WithCancel(context.Background())
   458  	der.monitorCancelFunc = moniterCancelFunc
   459  
   460  	der.monitor.SetInstanceState(der.instanceState)
   461  
   462  	// 开始执行
   463  	der.executeTime = time.Now()
   464  	pcsutil.Trigger(der.onExecuteEvent)
   465  	der.downloadStatusEvent() // 启动执行状态处理事件
   466  	der.monitor.Execute(moniterCtx)
   467  
   468  	// 检查错误
   469  	err = der.monitor.Err()
   470  	if err == nil { // 成功
   471  		pcsutil.Trigger(der.onSuccessEvent)
   472  		if !single {
   473  			der.removeInstanceState() // 移除断点续传文件
   474  		}
   475  	}
   476  
   477  	// 执行结束
   478  	pcsutil.Trigger(der.onFinishEvent)
   479  	return err
   480  }
   481  
   482  //downloadStatusEvent 执行状态处理事件
   483  func (der *Downloader) downloadStatusEvent() {
   484  	if der.onDownloadStatusEvent == nil {
   485  		return
   486  	}
   487  
   488  	status := der.monitor.Status()
   489  	go func() {
   490  		ticker := time.NewTicker(1 * time.Second)
   491  		defer ticker.Stop()
   492  		for {
   493  			select {
   494  			case <-der.monitor.completed:
   495  				return
   496  			case <-ticker.C:
   497  				der.onDownloadStatusEvent(status, der.monitor.RangeWorker)
   498  			}
   499  		}
   500  	}()
   501  }
   502  
   503  //Pause 暂停
   504  func (der *Downloader) Pause() {
   505  	if der.monitor == nil {
   506  		return
   507  	}
   508  	pcsutil.Trigger(der.onPauseEvent)
   509  	der.monitor.Pause()
   510  }
   511  
   512  //Resume 恢复
   513  func (der *Downloader) Resume() {
   514  	if der.monitor == nil {
   515  		return
   516  	}
   517  	pcsutil.Trigger(der.onResumeEvent)
   518  	der.monitor.Resume()
   519  }
   520  
   521  //Cancel 取消
   522  func (der *Downloader) Cancel() {
   523  	if der.monitor == nil {
   524  		return
   525  	}
   526  	pcsutil.Trigger(der.onCancelEvent)
   527  	pcsutil.Trigger(der.monitorCancelFunc)
   528  }
   529  
   530  //OnExecute 设置开始下载事件
   531  func (der *Downloader) OnExecute(onExecuteEvent requester.Event) {
   532  	der.onExecuteEvent = onExecuteEvent
   533  }
   534  
   535  //OnSuccess 设置成功下载事件
   536  func (der *Downloader) OnSuccess(onSuccessEvent requester.Event) {
   537  	der.onSuccessEvent = onSuccessEvent
   538  }
   539  
   540  //OnFinish 设置结束下载事件
   541  func (der *Downloader) OnFinish(onFinishEvent requester.Event) {
   542  	der.onFinishEvent = onFinishEvent
   543  }
   544  
   545  //OnPause 设置暂停下载事件
   546  func (der *Downloader) OnPause(onPauseEvent requester.Event) {
   547  	der.onPauseEvent = onPauseEvent
   548  }
   549  
   550  //OnResume 设置恢复下载事件
   551  func (der *Downloader) OnResume(onResumeEvent requester.Event) {
   552  	der.onResumeEvent = onResumeEvent
   553  }
   554  
   555  //OnCancel 设置取消下载事件
   556  func (der *Downloader) OnCancel(onCancelEvent requester.Event) {
   557  	der.onCancelEvent = onCancelEvent
   558  }
   559  
   560  //OnDownloadStatusEvent 设置状态处理函数
   561  func (der *Downloader) OnDownloadStatusEvent(f DownloadStatusFunc) {
   562  	der.onDownloadStatusEvent = f
   563  }