github.com/iikira/iikira-go-utils@v0.0.0-20230610031953-f2cb11cde33a/requester/downloader/downloader.go (about)

     1  // Package downloader 多线程下载器, 重构版
     2  package downloader
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"github.com/iikira/iikira-go-utils/pcsverbose"
     8  	"github.com/iikira/iikira-go-utils/requester"
     9  	"github.com/iikira/iikira-go-utils/requester/rio/speeds"
    10  	"github.com/iikira/iikira-go-utils/requester/transfer"
    11  	"github.com/iikira/iikira-go-utils/utils"
    12  	"github.com/iikira/iikira-go-utils/utils/cachepool"
    13  	"github.com/iikira/iikira-go-utils/utils/prealloc"
    14  	"github.com/iikira/iikira-go-utils/utils/waitgroup"
    15  	"io"
    16  	"net/http"
    17  	"sync"
    18  	"time"
    19  )
    20  
    21  const (
    22  	// DefaultAcceptRanges 默认的 Accept-Ranges
    23  	DefaultAcceptRanges = "bytes"
    24  )
    25  
    26  type (
    27  	// Downloader 下载
    28  	Downloader struct {
    29  		onExecuteEvent        requester.Event    //开始下载事件
    30  		onSuccessEvent        requester.Event    //成功下载事件
    31  		onFinishEvent         requester.Event    //结束下载事件
    32  		onPauseEvent          requester.Event    //暂停下载事件
    33  		onResumeEvent         requester.Event    //恢复下载事件
    34  		onCancelEvent         requester.Event    //取消下载事件
    35  		onDownloadStatusEvent DownloadStatusFunc //状态处理事件
    36  
    37  		monitorCancelFunc context.CancelFunc
    38  
    39  		firstInfo               *DownloadFirstInfo      // 初始信息
    40  		loadBalancerCompareFunc LoadBalancerCompareFunc // 负载均衡检测函数
    41  		durlCheckFunc           DURLCheckFunc           // 下载url检测函数
    42  		statusCodeBodyCheckFunc StatusCodeBodyCheckFunc
    43  		executeTime             time.Time
    44  		durl                    string
    45  		loadBalansers           []string
    46  		writer                  io.WriterAt
    47  		client                  *requester.HTTPClient
    48  		config                  *Config
    49  		monitor                 *Monitor
    50  		instanceState           *InstanceState
    51  	}
    52  
    53  	// DURLCheckFunc 下载URL检测函数
    54  	DURLCheckFunc func(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error)
    55  	// StatusCodeBodyCheckFunc 响应状态码出错的检查函数
    56  	StatusCodeBodyCheckFunc func(respBody io.Reader) error
    57  )
    58  
    59  //NewDownloader 初始化Downloader
    60  func NewDownloader(durl string, writer io.WriterAt, config *Config) (der *Downloader) {
    61  	der = &Downloader{
    62  		durl:   durl,
    63  		config: config,
    64  		writer: writer,
    65  	}
    66  
    67  	return
    68  }
    69  
    70  // SetFirstInfo 设置初始信息
    71  // 如果设置了此值, 将忽略检测url
    72  func (der *Downloader) SetFirstInfo(i *DownloadFirstInfo) {
    73  	der.firstInfo = i
    74  }
    75  
    76  //SetClient 设置http客户端
    77  func (der *Downloader) SetClient(client *requester.HTTPClient) {
    78  	der.client = client
    79  }
    80  
    81  // SetDURLCheckFunc 设置下载URL检测函数
    82  func (der *Downloader) SetDURLCheckFunc(f DURLCheckFunc) {
    83  	der.durlCheckFunc = f
    84  }
    85  
    86  // SetLoadBalancerCompareFunc 设置负载均衡检测函数
    87  func (der *Downloader) SetLoadBalancerCompareFunc(f LoadBalancerCompareFunc) {
    88  	der.loadBalancerCompareFunc = f
    89  }
    90  
    91  //SetStatusCodeBodyCheckFunc 设置响应状态码出错的检查函数, 当FirstCheckMethod不为HEAD时才有效
    92  func (der *Downloader) SetStatusCodeBodyCheckFunc(f StatusCodeBodyCheckFunc) {
    93  	der.statusCodeBodyCheckFunc = f
    94  }
    95  
    96  func (der *Downloader) lazyInit() {
    97  	if der.config == nil {
    98  		der.config = NewConfig()
    99  	}
   100  	if der.client == nil {
   101  		der.client = requester.NewHTTPClient()
   102  		der.client.SetTimeout(20 * time.Minute)
   103  	}
   104  	if der.monitor == nil {
   105  		der.monitor = NewMonitor()
   106  	}
   107  	if der.durlCheckFunc == nil {
   108  		der.durlCheckFunc = DefaultDURLCheckFunc
   109  	}
   110  	if der.loadBalancerCompareFunc == nil {
   111  		der.loadBalancerCompareFunc = DefaultLoadBalancerCompareFunc
   112  	}
   113  }
   114  
   115  // SelectParallel 获取合适的 parallel
   116  func (der *Downloader) SelectParallel(single bool, maxParallel int, totalSize int64, instanceRangeList transfer.RangeList) (parallel int) {
   117  	isRange := instanceRangeList != nil && len(instanceRangeList) > 0
   118  	if single { //不支持多线程
   119  		parallel = 1
   120  	} else if isRange {
   121  		parallel = len(instanceRangeList)
   122  	} else {
   123  		parallel = der.config.MaxParallel
   124  		if int64(parallel) > totalSize/int64(MinParallelSize) {
   125  			parallel = int(totalSize/int64(MinParallelSize)) + 1
   126  		}
   127  	}
   128  
   129  	if parallel < 1 {
   130  		parallel = 1
   131  	}
   132  	return
   133  }
   134  
   135  // SelectBlockSizeAndInitRangeGen 获取合适的 BlockSize, 和初始化 RangeGen
   136  func (der *Downloader) SelectBlockSizeAndInitRangeGen(single bool, status *transfer.DownloadStatus, parallel int) (blockSize int64, initErr error) {
   137  	// Range 生成器
   138  	if single { // 单线程
   139  		blockSize = -1
   140  		return
   141  	}
   142  	gen := status.RangeListGen()
   143  	if gen == nil {
   144  		switch der.config.Mode {
   145  		case transfer.RangeGenMode_Default:
   146  			gen = transfer.NewRangeListGenDefault(status.TotalSize(), 0, 0, parallel)
   147  			blockSize = gen.LoadBlockSize()
   148  		case transfer.RangeGenMode_BlockSize:
   149  			b2 := status.TotalSize()/int64(parallel) + 1
   150  			if b2 > der.config.BlockSize { // 选小的BlockSize, 以更高并发
   151  				blockSize = der.config.BlockSize
   152  			} else {
   153  				blockSize = b2
   154  			}
   155  
   156  			gen = transfer.NewRangeListGenBlockSize(status.TotalSize(), 0, blockSize)
   157  		default:
   158  			initErr = transfer.ErrUnknownRangeGenMode
   159  			return
   160  		}
   161  	} else {
   162  		blockSize = gen.LoadBlockSize()
   163  	}
   164  	status.SetRangeListGen(gen)
   165  	return
   166  }
   167  
   168  // SelectCacheSize 获取合适的 cacheSize
   169  func (der *Downloader) SelectCacheSize(confCacheSize int, blockSize int64) (cacheSize int) {
   170  	if blockSize > 0 && int64(confCacheSize) > blockSize {
   171  		// 如果 cache size 过高, 则调低
   172  		cacheSize = int(blockSize)
   173  	} else {
   174  		cacheSize = confCacheSize
   175  	}
   176  	return
   177  }
   178  
   179  // DefaultDURLCheckFunc 默认的 DURLCheckFunc
   180  func DefaultDURLCheckFunc(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) {
   181  	resp, err = client.Req(http.MethodGet, durl, nil, nil)
   182  	if err != nil {
   183  		if resp != nil {
   184  			resp.Body.Close()
   185  		}
   186  		return 0, nil, err
   187  	}
   188  	return resp.ContentLength, resp, nil
   189  }
   190  
   191  func (der *Downloader) checkLoadBalancers() *LoadBalancerResponseList {
   192  	var (
   193  		loadBalancerResponses = make([]*LoadBalancerResponse, 0, len(der.loadBalansers)+1)
   194  		handleLoadBalancer    = func(req *http.Request) {
   195  			if req == nil {
   196  				return
   197  			}
   198  
   199  			if der.config.TryHTTP {
   200  				req.URL.Scheme = "http"
   201  			}
   202  
   203  			loadBalancer := &LoadBalancerResponse{
   204  				URL:     req.URL.String(),
   205  				Referer: req.Referer(),
   206  			}
   207  
   208  			loadBalancerResponses = append(loadBalancerResponses, loadBalancer)
   209  			pcsverbose.Verbosef("DEBUG: load balance task: URL: %s, Referer: %s\n", loadBalancer.URL, loadBalancer.Referer)
   210  		}
   211  	)
   212  
   213  	// 加入第一个
   214  	loadBalancerResponses = append(loadBalancerResponses, &LoadBalancerResponse{
   215  		URL: der.durl,
   216  	})
   217  
   218  	// 负载均衡
   219  	wg := waitgroup.NewWaitGroup(10)
   220  	privTimeout := der.client.Client.Timeout
   221  	der.client.SetTimeout(5 * time.Second)
   222  	for _, loadBalanser := range der.loadBalansers {
   223  		wg.AddDelta()
   224  		go func(loadBalanser string) {
   225  			defer wg.Done()
   226  
   227  			subContentLength, subResp, subErr := der.durlCheckFunc(der.client, loadBalanser)
   228  			if subResp != nil {
   229  				subResp.Body.Close() // 不读Body, 马上关闭连接
   230  			}
   231  			if subErr != nil {
   232  				pcsverbose.Verbosef("DEBUG: loadBalanser Error: %s\n", subErr)
   233  				return
   234  			}
   235  
   236  			// 检测状态码
   237  			switch subResp.StatusCode / 100 {
   238  			case 2: // succeed
   239  			case 4, 5: // error
   240  				var err error
   241  				if der.statusCodeBodyCheckFunc != nil {
   242  					err = der.statusCodeBodyCheckFunc(subResp.Body)
   243  				} else {
   244  					err = errors.New(subResp.Status)
   245  				}
   246  				pcsverbose.Verbosef("DEBUG: loadBalanser Status Error: %s\n", err)
   247  				return
   248  			}
   249  
   250  			// 检测长度
   251  			if der.firstInfo.ContentLength != subContentLength {
   252  				pcsverbose.Verbosef("DEBUG: loadBalanser Content-Length not equal to main server\n")
   253  				return
   254  			}
   255  
   256  			if !der.loadBalancerCompareFunc(der.firstInfo.ToMap(), subResp) {
   257  				pcsverbose.Verbosef("DEBUG: loadBalanser not equal to main server\n")
   258  				return
   259  			}
   260  
   261  			handleLoadBalancer(subResp.Request)
   262  		}(loadBalanser)
   263  	}
   264  	wg.Wait()
   265  	der.client.SetTimeout(privTimeout)
   266  
   267  	loadBalancerResponseList := NewLoadBalancerResponseList(loadBalancerResponses)
   268  	return loadBalancerResponseList
   269  }
   270  
   271  //Execute 开始任务
   272  func (der *Downloader) Execute() error {
   273  	der.lazyInit()
   274  
   275  	var (
   276  		resp *http.Response
   277  	)
   278  	if der.firstInfo == nil {
   279  		// 检测
   280  		var (
   281  			// 这个声明不得去掉, 否则会覆盖resp
   282  			contentLength int64
   283  			err error
   284  		)
   285  		contentLength, resp, err = der.durlCheckFunc(der.client, der.durl)
   286  		if err != nil {
   287  			return err
   288  		}
   289  
   290  		// 检测网络错误
   291  		switch resp.StatusCode / 100 {
   292  		case 2: // succeed
   293  		case 4, 5: // error
   294  			if der.statusCodeBodyCheckFunc != nil {
   295  				err = der.statusCodeBodyCheckFunc(resp.Body)
   296  				resp.Body.Close() // 关闭连接
   297  				if err != nil {
   298  					return err
   299  				}
   300  			}
   301  			return errors.New(resp.Status)
   302  		}
   303  
   304  		acceptRanges := resp.Header.Get("Accept-Ranges")
   305  		if contentLength < 0 {
   306  			acceptRanges = ""
   307  		} else {
   308  			acceptRanges = DefaultAcceptRanges
   309  		}
   310  
   311  		// 初始化firstInfo
   312  		der.firstInfo = &DownloadFirstInfo{
   313  			ContentLength: contentLength,
   314  			ContentMD5:    resp.Header.Get("Content-MD5"),
   315  			ContentCRC32:  resp.Header.Get("x-bs-meta-crc32"),
   316  			AcceptRanges:  acceptRanges,
   317  			Referer:       resp.Header.Get("Referer"),
   318  		}
   319  		pcsverbose.Verbosef("DEBUG: download task: URL: %s, Referer: %s\n", resp.Request.URL, resp.Request.Referer())
   320  	} else {
   321  		if der.firstInfo.AcceptRanges == "" {
   322  			der.firstInfo.AcceptRanges = DefaultAcceptRanges
   323  		}
   324  	}
   325  
   326  	var (
   327  		loadBalancerResponseList = der.checkLoadBalancers()
   328  		single                   = der.firstInfo.AcceptRanges == ""
   329  		bii                      *transfer.DownloadInstanceInfo
   330  	)
   331  
   332  	if !single {
   333  		//load breakpoint
   334  		//服务端不支持多线程时, 不记录断点
   335  		err := der.initInstanceState(der.config.InstanceStateStorageFormat)
   336  		if err != nil {
   337  			return err
   338  		}
   339  		bii = der.instanceState.Get()
   340  	}
   341  
   342  	var (
   343  		isInstance = bii != nil // 是否存在断点信息
   344  		status     *transfer.DownloadStatus
   345  	)
   346  	if !isInstance {
   347  		bii = &transfer.DownloadInstanceInfo{}
   348  	}
   349  	if isInstance {
   350  		// 使用断点续传时,不启用firstResp
   351  		resp.Body.Close()
   352  	}
   353  
   354  	if bii.DownloadStatus != nil {
   355  		// 使用断点信息的状态
   356  		status = bii.DownloadStatus
   357  	} else {
   358  		// 新建状态
   359  		status = transfer.NewDownloadStatus()
   360  		status.SetTotalSize(der.firstInfo.ContentLength)
   361  	}
   362  
   363  	// 设置限速
   364  	if der.config.MaxRate > 0 {
   365  		rl := speeds.NewRateLimit(der.config.MaxRate)
   366  		status.SetRateLimit(rl)
   367  		defer rl.Stop()
   368  	}
   369  
   370  	// 数据处理
   371  	parallel := der.SelectParallel(single, der.config.MaxParallel, status.TotalSize(), bii.Ranges) // 实际的下载并行量
   372  	blockSize, err := der.SelectBlockSizeAndInitRangeGen(single, status, parallel)                 // 实际的BlockSize
   373  	if err != nil {
   374  		return err
   375  	}
   376  
   377  	cacheSize := der.SelectCacheSize(der.config.CacheSize, blockSize) // 实际下载缓存
   378  	cachepool.SetSyncPoolSize(cacheSize)                              // 调整pool大小
   379  
   380  	pcsverbose.Verbosef("DEBUG: download task CREATED: parallel: %d, cache size: %d\n", parallel, cacheSize)
   381  
   382  	der.monitor.InitMonitorCapacity(parallel)
   383  
   384  	var writer Writer
   385  	if !der.config.IsTest {
   386  		// 尝试修剪文件
   387  		if fder, ok := der.writer.(Fder); ok {
   388  			err = prealloc.PreAlloc(fder.Fd(), status.TotalSize())
   389  			if err != nil {
   390  				pcsverbose.Verbosef("DEBUG: truncate file error: %s\n", err)
   391  			}
   392  		}
   393  		writer = der.writer // 非测试模式, 赋值writer
   394  	}
   395  
   396  	// 数据平均分配给各个线程
   397  	isRange := bii.Ranges != nil && len(bii.Ranges) > 0
   398  	if !isRange {
   399  		// 没有使用断点续传
   400  		// 分配线程
   401  		bii.Ranges = make(transfer.RangeList, 0, parallel)
   402  		if single { // 单线程
   403  			bii.Ranges = append(bii.Ranges, &transfer.Range{})
   404  		} else {
   405  			gen := status.RangeListGen()
   406  			for i := 0; i < cap(bii.Ranges); i++ {
   407  				_, r := gen.GenRange()
   408  				if r == nil { // 没有了(不正常)
   409  					break
   410  				}
   411  				bii.Ranges = append(bii.Ranges, r)
   412  			}
   413  		}
   414  	}
   415  
   416  	var (
   417  		writeMu = &sync.Mutex{}
   418  	)
   419  	for k, r := range bii.Ranges {
   420  		loadBalancer := loadBalancerResponseList.SequentialGet()
   421  		if loadBalancer == nil {
   422  			continue
   423  		}
   424  
   425  		worker := NewWorker(k, loadBalancer.URL, writer)
   426  		worker.SetClient(der.client)
   427  		worker.SetWriteMutex(writeMu)
   428  		worker.SetReferer(loadBalancer.Referer)
   429  		worker.SetTotalSize(der.firstInfo.ContentLength)
   430  
   431  		// 使用第一个连接
   432  		// 断点续传时不使用
   433  		if k == 0 && !isInstance {
   434  			worker.firstResp = resp
   435  		}
   436  
   437  		worker.SetAcceptRange(der.firstInfo.AcceptRanges)
   438  		worker.SetRange(r) // 分配Range
   439  		der.monitor.Append(worker)
   440  	}
   441  
   442  	der.monitor.SetStatus(status)
   443  
   444  	// 服务器不支持断点续传, 或者单线程下载, 都不重载worker
   445  	der.monitor.SetReloadWorker(parallel > 1)
   446  
   447  	moniterCtx, moniterCancelFunc := context.WithCancel(context.Background())
   448  	der.monitorCancelFunc = moniterCancelFunc
   449  
   450  	der.monitor.SetInstanceState(der.instanceState)
   451  
   452  	// 开始执行
   453  	der.executeTime = time.Now()
   454  	utils.Trigger(der.onExecuteEvent)
   455  	der.downloadStatusEvent() // 启动执行状态处理事件
   456  	der.monitor.Execute(moniterCtx)
   457  
   458  	// 检查错误
   459  	err = der.monitor.Err()
   460  	if err == nil { // 成功
   461  		utils.Trigger(der.onSuccessEvent)
   462  		if !single {
   463  			der.removeInstanceState() // 移除断点续传文件
   464  		}
   465  	}
   466  
   467  	// 执行结束
   468  	utils.Trigger(der.onFinishEvent)
   469  	return err
   470  }
   471  
   472  //downloadStatusEvent 执行状态处理事件
   473  func (der *Downloader) downloadStatusEvent() {
   474  	if der.onDownloadStatusEvent == nil {
   475  		return
   476  	}
   477  
   478  	status := der.monitor.Status()
   479  	go func() {
   480  		ticker := time.NewTicker(1 * time.Second)
   481  		defer ticker.Stop()
   482  		for {
   483  			select {
   484  			case <-der.monitor.completed:
   485  				return
   486  			case <-ticker.C:
   487  				der.onDownloadStatusEvent(status, der.monitor.RangeWorker)
   488  			}
   489  		}
   490  	}()
   491  }
   492  
   493  //Pause 暂停
   494  func (der *Downloader) Pause() {
   495  	if der.monitor == nil {
   496  		return
   497  	}
   498  	utils.Trigger(der.onPauseEvent)
   499  	der.monitor.Pause()
   500  }
   501  
   502  //Resume 恢复
   503  func (der *Downloader) Resume() {
   504  	if der.monitor == nil {
   505  		return
   506  	}
   507  	utils.Trigger(der.onResumeEvent)
   508  	der.monitor.Resume()
   509  }
   510  
   511  //Cancel 取消
   512  func (der *Downloader) Cancel() {
   513  	if der.monitor == nil {
   514  		return
   515  	}
   516  	utils.Trigger(der.onCancelEvent)
   517  	utils.Trigger(der.monitorCancelFunc)
   518  }
   519  
   520  //OnExecute 设置开始下载事件
   521  func (der *Downloader) OnExecute(onExecuteEvent requester.Event) {
   522  	der.onExecuteEvent = onExecuteEvent
   523  }
   524  
   525  //OnSuccess 设置成功下载事件
   526  func (der *Downloader) OnSuccess(onSuccessEvent requester.Event) {
   527  	der.onSuccessEvent = onSuccessEvent
   528  }
   529  
   530  //OnFinish 设置结束下载事件
   531  func (der *Downloader) OnFinish(onFinishEvent requester.Event) {
   532  	der.onFinishEvent = onFinishEvent
   533  }
   534  
   535  //OnPause 设置暂停下载事件
   536  func (der *Downloader) OnPause(onPauseEvent requester.Event) {
   537  	der.onPauseEvent = onPauseEvent
   538  }
   539  
   540  //OnResume 设置恢复下载事件
   541  func (der *Downloader) OnResume(onResumeEvent requester.Event) {
   542  	der.onResumeEvent = onResumeEvent
   543  }
   544  
   545  //OnCancel 设置取消下载事件
   546  func (der *Downloader) OnCancel(onCancelEvent requester.Event) {
   547  	der.onCancelEvent = onCancelEvent
   548  }
   549  
   550  //OnDownloadStatusEvent 设置状态处理函数
   551  func (der *Downloader) OnDownloadStatusEvent(f DownloadStatusFunc) {
   552  	der.onDownloadStatusEvent = f
   553  }