github.com/qjfoidnh/BaiduPCS-Go@v0.0.0-20231011165705-caa18a3765f3/requester/downloader/worker.go (about)

     1  package downloader
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/qjfoidnh/BaiduPCS-Go/pcsutil/cachepool"
     8  	"github.com/qjfoidnh/BaiduPCS-Go/pcsverbose"
     9  	"github.com/qjfoidnh/BaiduPCS-Go/requester"
    10  	"github.com/qjfoidnh/BaiduPCS-Go/requester/rio/speeds"
    11  	"github.com/qjfoidnh/BaiduPCS-Go/requester/transfer"
    12  	"io"
    13  	"net/http"
    14  	"sync"
    15  )
    16  
    17  type (
    18  	//Worker 工作单元
    19  	Worker struct {
    20  		totalSize    int64 // 整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误
    21  		wrange       *transfer.Range
    22  		speedsStat   *speeds.Speeds
    23  		id           int    //id
    24  		url          string //下载地址
    25  		referer      string //来源地址
    26  		acceptRanges string
    27  		client       *requester.HTTPClient
    28  		firstResp    *http.Response // 第一个响应
    29  		writerAt     io.WriterAt
    30  		writeMu      *sync.Mutex
    31  		execMu       sync.Mutex
    32  
    33  		pauseChan              chan struct{}
    34  		workerCancelFunc       context.CancelFunc
    35  		resetFunc              context.CancelFunc
    36  		readRespBodyCancelFunc func()
    37  		err                    error //错误信息
    38  		status                 WorkerStatus
    39  		downloadStatus         *transfer.DownloadStatus //总的下载状态
    40  	}
    41  
    42  	// WorkerList worker列表
    43  	WorkerList []*Worker
    44  )
    45  
    46  // Duplicate 构造新的列表
    47  func (wl WorkerList) Duplicate() WorkerList {
    48  	n := make(WorkerList, len(wl))
    49  	copy(n, wl)
    50  	return n
    51  }
    52  
    53  //NewWorker 初始化Worker
    54  func NewWorker(id int, durl string, writerAt io.WriterAt) *Worker {
    55  	return &Worker{
    56  		id:       id,
    57  		url:      durl,
    58  		writerAt: writerAt,
    59  	}
    60  }
    61  
    62  //ID 返回worker ID
    63  func (wer *Worker) ID() int {
    64  	return wer.id
    65  }
    66  
    67  func (wer *Worker) lazyInit() {
    68  	if wer.client == nil {
    69  		wer.client = requester.NewHTTPClient()
    70  	}
    71  	if wer.pauseChan == nil {
    72  		wer.pauseChan = make(chan struct{})
    73  	}
    74  	if wer.wrange == nil {
    75  		wer.wrange = &transfer.Range{}
    76  	}
    77  	if wer.wrange.LoadBegin() == 0 && wer.wrange.LoadEnd() == 0 {
    78  		// 取消多线程下载
    79  		wer.acceptRanges = ""
    80  		wer.wrange.StoreEnd(-2)
    81  	}
    82  	if wer.speedsStat == nil {
    83  		wer.speedsStat = &speeds.Speeds{}
    84  	}
    85  }
    86  
    87  // SetTotalSize 设置整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误
    88  func (wer *Worker) SetTotalSize(size int64) {
    89  	wer.totalSize = size
    90  }
    91  
    92  //SetClient 设置http客户端
    93  func (wer *Worker) SetClient(c *requester.HTTPClient) {
    94  	wer.client = c
    95  }
    96  
    97  //SetAcceptRange 设置AcceptRange
    98  func (wer *Worker) SetAcceptRange(acceptRanges string) {
    99  	wer.acceptRanges = acceptRanges
   100  }
   101  
   102  //SetRange 设置请求范围
   103  func (wer *Worker) SetRange(r *transfer.Range) {
   104  	if wer.wrange == nil {
   105  		wer.wrange = r
   106  		return
   107  	}
   108  	wer.wrange.StoreBegin(r.LoadBegin())
   109  	wer.wrange.StoreEnd(r.LoadEnd())
   110  }
   111  
   112  //SetReferer 设置来源
   113  func (wer *Worker) SetReferer(referer string) {
   114  	wer.referer = referer
   115  }
   116  
   117  //SetWriteMutex 设置数据写锁
   118  func (wer *Worker) SetWriteMutex(mu *sync.Mutex) {
   119  	wer.writeMu = mu
   120  }
   121  
   122  //SetDownloadStatus 增加其他需要统计的数据
   123  func (wer *Worker) SetDownloadStatus(downloadStatus *transfer.DownloadStatus) {
   124  	wer.downloadStatus = downloadStatus
   125  }
   126  
   127  //GetStatus 返回下载状态
   128  func (wer *Worker) GetStatus() WorkerStatuser {
   129  	// 空接口与空指针不等价
   130  	return &wer.status
   131  }
   132  
   133  //GetRange 返回worker范围
   134  func (wer *Worker) GetRange() *transfer.Range {
   135  	return wer.wrange
   136  }
   137  
   138  //GetSpeedsPerSecond 获取每秒的速度
   139  func (wer *Worker) GetSpeedsPerSecond() int64 {
   140  	return wer.speedsStat.GetSpeeds()
   141  }
   142  
   143  //Pause 暂停下载
   144  func (wer *Worker) Pause() {
   145  	wer.lazyInit()
   146  	if wer.acceptRanges == "" {
   147  		pcsverbose.Verbosef("WARNING: worker unsupport pause")
   148  		return
   149  	}
   150  
   151  	if wer.status.statusCode == StatusCodePaused {
   152  		return
   153  	}
   154  	wer.pauseChan <- struct{}{}
   155  	wer.status.statusCode = StatusCodePaused
   156  }
   157  
   158  //Resume 恢复下载
   159  func (wer *Worker) Resume() {
   160  	if wer.status.statusCode != StatusCodePaused {
   161  		return
   162  	}
   163  	go wer.Execute()
   164  }
   165  
   166  //Cancel 取消下载
   167  func (wer *Worker) Cancel() error {
   168  	if wer.workerCancelFunc == nil {
   169  		return errors.New("cancelFunc not set")
   170  	}
   171  	wer.workerCancelFunc()
   172  	if wer.readRespBodyCancelFunc != nil {
   173  		wer.readRespBodyCancelFunc()
   174  	}
   175  	return nil
   176  }
   177  
   178  //Reset 重设连接
   179  func (wer *Worker) Reset() {
   180  	if wer.resetFunc == nil {
   181  		pcsverbose.Verbosef("DEBUG: worker: resetFunc not set")
   182  		return
   183  	}
   184  	wer.resetFunc()
   185  	if wer.readRespBodyCancelFunc != nil {
   186  		wer.readRespBodyCancelFunc()
   187  	}
   188  	wer.ClearStatus()
   189  	go wer.Execute()
   190  }
   191  
   192  // Canceled 是否已经取消
   193  func (wer *Worker) Canceled() bool {
   194  	return wer.status.statusCode == StatusCodeCanceled
   195  }
   196  
   197  //Completed 是否已经完成
   198  func (wer *Worker) Completed() bool {
   199  	switch wer.status.statusCode {
   200  	case StatusCodeSuccessed, StatusCodeCanceled:
   201  		return true
   202  	default:
   203  		return false
   204  	}
   205  }
   206  
   207  //Failed 是否失败
   208  func (wer *Worker) Failed() bool {
   209  	switch wer.status.statusCode {
   210  	case StatusCodeFailed, StatusCodeInternalError, StatusCodeTooManyConnections, StatusCodeNetError:
   211  		return true
   212  	default:
   213  		return false
   214  	}
   215  }
   216  
   217  //ClearStatus 清空状态
   218  func (wer *Worker) ClearStatus() {
   219  	wer.status.statusCode = StatusCodeInit
   220  }
   221  
   222  //Err 返回worker错误
   223  func (wer *Worker) Err() error {
   224  	return wer.err
   225  }
   226  
   227  //Execute 执行任务
   228  func (wer *Worker) Execute() {
   229  	wer.lazyInit()
   230  
   231  	wer.execMu.Lock()
   232  	defer wer.execMu.Unlock()
   233  
   234  	wer.status.statusCode = StatusCodeInit
   235  	single := wer.acceptRanges == ""
   236  
   237  	// 如果已暂停, 退出
   238  	if wer.status.statusCode == StatusCodePaused {
   239  		return
   240  	}
   241  
   242  	if !single {
   243  		// 已完成
   244  		if rlen := wer.wrange.Len(); rlen <= 0 {
   245  			if rlen < 0 {
   246  				pcsverbose.Verbosef("DEBUG: RangeLen is negative at begin: %v, %d\n", wer.wrange, wer.wrange.Len())
   247  			}
   248  			wer.status.statusCode = StatusCodeSuccessed
   249  			return
   250  		}
   251  	}
   252  
   253  	workerCancelCtx, workerCancelFunc := context.WithCancel(context.Background())
   254  	wer.workerCancelFunc = workerCancelFunc
   255  	resetCtx, resetFunc := context.WithCancel(context.Background())
   256  	wer.resetFunc = resetFunc
   257  
   258  	header := map[string]string{}
   259  	if wer.referer != "" {
   260  		header["Referer"] = wer.referer
   261  	}
   262  	//检测是否支持range
   263  	if wer.acceptRanges != "" && wer.wrange.Len() >= 0 {
   264  		header["Range"] = fmt.Sprintf("%s=%d-%d", wer.acceptRanges, wer.wrange.LoadBegin(), wer.wrange.LoadEnd()-1)
   265  	}
   266  
   267  	wer.status.statusCode = StatusCodePending
   268  
   269  	var resp *http.Response
   270  	if wer.firstResp != nil {
   271  		resp = wer.firstResp // 使用第一个连接
   272  	} else {
   273  		resp, wer.err = wer.client.Req(http.MethodGet, wer.url, nil, header)
   274  	}
   275  	if resp != nil {
   276  		defer func() {
   277  			resp.Body.Close()
   278  			wer.firstResp = nil // 去掉第一个连接
   279  		}()
   280  		wer.readRespBodyCancelFunc = func() {
   281  			resp.Body.Close()
   282  		}
   283  	}
   284  	if wer.err != nil {
   285  		wer.status.statusCode = StatusCodeNetError
   286  		return
   287  	}
   288  
   289  	// 判断响应状态
   290  	switch resp.StatusCode {
   291  	case 200, 206:
   292  		// do nothing, continue
   293  	case 416: //Requested Range Not Satisfiable
   294  		fallthrough
   295  	case 403: // Forbidden
   296  		fallthrough
   297  	case 404: // file block not exists
   298  		wer.status.statusCode = StatusCodeInternalError
   299  		wer.err = errors.New(resp.Status)
   300  		return
   301  	case 406: // Not Acceptable
   302  		wer.status.statusCode = StatusCodeNetError
   303  		wer.err = errors.New(resp.Status)
   304  		return
   305  	case 429, 509: // Too Many Requests
   306  		wer.status.SetStatusCode(StatusCodeTooManyConnections)
   307  		wer.err = errors.New(resp.Status)
   308  		return
   309  	default:
   310  		wer.status.statusCode = StatusCodeNetError
   311  		wer.err = fmt.Errorf("unexpected http status code, %d, %s", resp.StatusCode, resp.Status)
   312  		return
   313  	}
   314  
   315  	var (
   316  		contentLength = resp.ContentLength
   317  		rangeLength   = wer.wrange.Len()
   318  	)
   319  
   320  	if !single {
   321  		// 检查请求长度
   322  		if contentLength != rangeLength && wer.firstResp == nil { // 跳过检查第一个连接
   323  			wer.status.statusCode = StatusCodeNetError
   324  			wer.err = fmt.Errorf("Content-Length is unexpected: %d, need %d", contentLength, rangeLength)
   325  			return
   326  		}
   327  		// 检查总大小
   328  		if wer.totalSize > 0 {
   329  			total := ParseContentRange(resp.Header.Get("Content-Range"))
   330  			if total > 0 {
   331  				if total != wer.totalSize {
   332  					wer.status.statusCode = StatusCodeInternalError // 这里设置为内部错误, 强制停止下载
   333  					wer.err = fmt.Errorf("Content-Range total length is unexpected: %d, need %d", total, wer.totalSize)
   334  					return
   335  				}
   336  			}
   337  		}
   338  	}
   339  
   340  	var (
   341  		buf       = cachepool.SyncPool.Get().([]byte)
   342  		n, nn     int
   343  		n64, nn64 int64
   344  	)
   345  	defer cachepool.SyncPool.Put(buf)
   346  
   347  	for {
   348  		select {
   349  		case <-workerCancelCtx.Done(): //取消
   350  			wer.status.statusCode = StatusCodeCanceled
   351  			return
   352  		case <-resetCtx.Done(): //重设连接
   353  			wer.status.statusCode = StatusCodeReseted
   354  			return
   355  		case <-wer.pauseChan: //暂停
   356  			return
   357  		default:
   358  			wer.status.statusCode = StatusCodeDownloading
   359  
   360  			// 初始化数据
   361  			var readErr error
   362  			n = 0
   363  
   364  			// 读取数据
   365  			for n < len(buf) && readErr == nil && (single || wer.wrange.Len() > 0) {
   366  				nn, readErr = resp.Body.Read(buf[n:])
   367  				nn64 = int64(nn)
   368  
   369  				// 更新速度统计
   370  				if wer.downloadStatus != nil {
   371  					wer.downloadStatus.AddSpeedsDownloaded(nn64) // 限速在这里阻塞
   372  				}
   373  				wer.speedsStat.Add(nn64)
   374  				n += nn
   375  			}
   376  
   377  			if n > 0 && readErr == io.EOF {
   378  				readErr = io.ErrUnexpectedEOF
   379  			}
   380  
   381  			n64 = int64(n)
   382  
   383  			// 非单线程模式下
   384  			if !single {
   385  				rangeLength = wer.wrange.Len()
   386  
   387  				// 已完成 (未雨绸缪)
   388  				if rangeLength <= 0 {
   389  					wer.status.statusCode = StatusCodeCanceled
   390  					wer.err = errors.New("worker already complete")
   391  					return
   392  				}
   393  
   394  				if n64 > rangeLength {
   395  					// 数据大小不正常
   396  					n64 = rangeLength
   397  					n = int(rangeLength)
   398  					readErr = io.EOF
   399  				}
   400  			}
   401  
   402  			// 写入数据
   403  			if wer.writerAt != nil {
   404  				wer.status.statusCode = StatusCodeWaitToWrite
   405  				if wer.writeMu != nil {
   406  					wer.writeMu.Lock() // 加锁, 减轻硬盘的压力
   407  				}
   408  				_, wer.err = wer.writerAt.WriteAt(buf[:n], wer.wrange.Begin) // 写入数据
   409  				if wer.err != nil {
   410  					if wer.writeMu != nil {
   411  						wer.writeMu.Unlock() //解锁
   412  					}
   413  					wer.status.statusCode = StatusCodeInternalError
   414  					return
   415  				}
   416  
   417  				if wer.writeMu != nil {
   418  					wer.writeMu.Unlock() //解锁
   419  				}
   420  				wer.status.statusCode = StatusCodeDownloading
   421  			}
   422  
   423  			// 更新下载统计数据
   424  			wer.wrange.AddBegin(n64)
   425  			if wer.downloadStatus != nil {
   426  				wer.downloadStatus.AddDownloaded(n64)
   427  				if single {
   428  					wer.downloadStatus.AddTotalSize(n64)
   429  				}
   430  			}
   431  
   432  			if readErr != nil {
   433  				rlen := wer.wrange.Len()
   434  				switch {
   435  				case single && readErr == io.ErrUnexpectedEOF:
   436  					// 单线程判断下载成功
   437  					fallthrough
   438  				case readErr == io.EOF:
   439  					fallthrough
   440  				case rlen <= 0:
   441  					// 下载完成
   442  					// 小于0可能是因为 worker 被 duplicate
   443  					wer.status.statusCode = StatusCodeSuccessed
   444  					if rlen < 0 {
   445  						pcsverbose.Verbosef("DEBUG: RangeLen is negative at end: %v, %d\n", wer.wrange, wer.wrange.Len())
   446  					}
   447  					return
   448  				default:
   449  					// 其他错误, 返回
   450  					wer.status.statusCode = StatusCodeFailed
   451  					wer.err = readErr
   452  					return
   453  				}
   454  			}
   455  		}
   456  	}
   457  }