github.com/iikira/iikira-go-utils@v0.0.0-20230610031953-f2cb11cde33a/requester/downloader/worker.go (about)

     1  package downloader
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/iikira/iikira-go-utils/pcsverbose"
     8  	"github.com/iikira/iikira-go-utils/requester"
     9  	"github.com/iikira/iikira-go-utils/requester/rio/speeds"
    10  	"github.com/iikira/iikira-go-utils/requester/transfer"
    11  	"github.com/iikira/iikira-go-utils/utils/cachepool"
    12  	"io"
    13  	"net/http"
    14  	"sync"
    15  )
    16  
    17  type (
    18  	//Worker 工作单元
    19  	Worker struct {
    20  		totalSize    int64 // 整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误
    21  		wrange       *transfer.Range
    22  		speedsStat   *speeds.Speeds
    23  		id           int    //id
    24  		url          string //下载地址
    25  		referer      string //来源地址
    26  		acceptRanges string
    27  		client       *requester.HTTPClient
    28  		firstResp    *http.Response // 第一个响应
    29  		writerAt     io.WriterAt
    30  		writeMu      *sync.Mutex
    31  		execMu       sync.Mutex
    32  
    33  		pauseChan              chan struct{}
    34  		workerCancelFunc       context.CancelFunc
    35  		resetFunc              context.CancelFunc
    36  		readRespBodyCancelFunc func()
    37  		err                    error //错误信息
    38  		status                 WorkerStatus
    39  		downloadStatus         *transfer.DownloadStatus //总的下载状态
    40  	}
    41  
    42  	// WorkerList worker列表
    43  	WorkerList []*Worker
    44  )
    45  
    46  // Duplicate 构造新的列表
    47  func (wl WorkerList) Duplicate() WorkerList {
    48  	n := make(WorkerList, len(wl))
    49  	copy(n, wl)
    50  	return n
    51  }
    52  
    53  //NewWorker 初始化Worker
    54  func NewWorker(id int, durl string, writerAt io.WriterAt) *Worker {
    55  	return &Worker{
    56  		id:       id,
    57  		url:      durl,
    58  		writerAt: writerAt,
    59  	}
    60  }
    61  
    62  //ID 返回worker ID
    63  func (wer *Worker) ID() int {
    64  	return wer.id
    65  }
    66  
    67  func (wer *Worker) lazyInit() {
    68  	if wer.client == nil {
    69  		wer.client = requester.NewHTTPClient()
    70  	}
    71  	if wer.pauseChan == nil {
    72  		wer.pauseChan = make(chan struct{})
    73  	}
    74  	if wer.wrange == nil {
    75  		wer.wrange = &transfer.Range{}
    76  	}
    77  	if wer.wrange.LoadBegin() == 0 && wer.wrange.LoadEnd() == 0 {
    78  		// 取消多线程下载
    79  		wer.acceptRanges = ""
    80  		wer.wrange.StoreEnd(-2)
    81  	}
    82  	if wer.speedsStat == nil {
    83  		wer.speedsStat = &speeds.Speeds{}
    84  	}
    85  }
    86  
    87  // SetTotalSize 设置整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误
    88  func (wer *Worker) SetTotalSize(size int64) {
    89  	wer.totalSize = size
    90  }
    91  
    92  //SetClient 设置http客户端
    93  func (wer *Worker) SetClient(c *requester.HTTPClient) {
    94  	wer.client = c
    95  }
    96  
    97  //SetAcceptRange 设置AcceptRange
    98  func (wer *Worker) SetAcceptRange(acceptRanges string) {
    99  	wer.acceptRanges = acceptRanges
   100  }
   101  
   102  //SetRange 设置请求范围
   103  func (wer *Worker) SetRange(r *transfer.Range) {
   104  	if wer.wrange == nil {
   105  		wer.wrange = r
   106  		return
   107  	}
   108  	wer.wrange.StoreBegin(r.LoadBegin())
   109  	wer.wrange.StoreEnd(r.LoadEnd())
   110  }
   111  
   112  //SetReferer 设置来源
   113  func (wer *Worker) SetReferer(referer string) {
   114  	wer.referer = referer
   115  }
   116  
   117  //SetWriteMutex 设置数据写锁
   118  func (wer *Worker) SetWriteMutex(mu *sync.Mutex) {
   119  	wer.writeMu = mu
   120  }
   121  
   122  //SetDownloadStatus 增加其他需要统计的数据
   123  func (wer *Worker) SetDownloadStatus(downloadStatus *transfer.DownloadStatus) {
   124  	wer.downloadStatus = downloadStatus
   125  }
   126  
   127  //GetStatus 返回下载状态
   128  func (wer *Worker) GetStatus() WorkerStatuser {
   129  	// 空接口与空指针不等价
   130  	return &wer.status
   131  }
   132  
   133  //GetRange 返回worker范围
   134  func (wer *Worker) GetRange() *transfer.Range {
   135  	return wer.wrange
   136  }
   137  
   138  //GetSpeedsPerSecond 获取每秒的速度
   139  func (wer *Worker) GetSpeedsPerSecond() int64 {
   140  	return wer.speedsStat.GetSpeeds()
   141  }
   142  
   143  //Pause 暂停下载
   144  func (wer *Worker) Pause() {
   145  	wer.lazyInit()
   146  	if wer.acceptRanges == "" {
   147  		pcsverbose.Verbosef("WARNING: worker unsupport pause")
   148  		return
   149  	}
   150  
   151  	if wer.status.statusCode == StatusCodePaused {
   152  		return
   153  	}
   154  	wer.pauseChan <- struct{}{}
   155  	wer.status.statusCode = StatusCodePaused
   156  }
   157  
   158  //Resume 恢复下载
   159  func (wer *Worker) Resume() {
   160  	if wer.status.statusCode != StatusCodePaused {
   161  		return
   162  	}
   163  	go wer.Execute()
   164  }
   165  
   166  //Cancel 取消下载
   167  func (wer *Worker) Cancel() error {
   168  	if wer.workerCancelFunc == nil {
   169  		return errors.New("cancelFunc not set")
   170  	}
   171  	wer.workerCancelFunc()
   172  	if wer.readRespBodyCancelFunc != nil {
   173  		wer.readRespBodyCancelFunc()
   174  	}
   175  	return nil
   176  }
   177  
   178  //Reset 重设连接
   179  func (wer *Worker) Reset() {
   180  	if wer.resetFunc == nil {
   181  		pcsverbose.Verbosef("DEBUG: worker: resetFunc not set")
   182  		return
   183  	}
   184  	wer.resetFunc()
   185  	if wer.readRespBodyCancelFunc != nil {
   186  		wer.readRespBodyCancelFunc()
   187  	}
   188  	wer.ClearStatus()
   189  	go wer.Execute()
   190  }
   191  
   192  // Canceled 是否已经取消
   193  func (wer *Worker) Canceled() bool {
   194  	return wer.status.statusCode == StatusCodeCanceled
   195  }
   196  
   197  //Completed 是否已经完成
   198  func (wer *Worker) Completed() bool {
   199  	switch wer.status.statusCode {
   200  	case StatusCodeSuccessed, StatusCodeCanceled:
   201  		return true
   202  	default:
   203  		return false
   204  	}
   205  }
   206  
   207  //Failed 是否失败
   208  func (wer *Worker) Failed() bool {
   209  	switch wer.status.statusCode {
   210  	case StatusCodeFailed, StatusCodeInternalError, StatusCodeTooManyConnections, StatusCodeNetError:
   211  		return true
   212  	default:
   213  		return false
   214  	}
   215  }
   216  
   217  //ClearStatus 清空状态
   218  func (wer *Worker) ClearStatus() {
   219  	wer.status.statusCode = StatusCodeInit
   220  }
   221  
   222  //Err 返回worker错误
   223  func (wer *Worker) Err() error {
   224  	return wer.err
   225  }
   226  
   227  //Execute 执行任务
   228  func (wer *Worker) Execute() {
   229  	wer.lazyInit()
   230  
   231  	wer.execMu.Lock()
   232  	defer wer.execMu.Unlock()
   233  
   234  	wer.status.statusCode = StatusCodeInit
   235  	single := wer.acceptRanges == ""
   236  
   237  	// 如果已暂停, 退出
   238  	if wer.status.statusCode == StatusCodePaused {
   239  		return
   240  	}
   241  
   242  	if !single {
   243  		// 已完成
   244  		if rlen := wer.wrange.Len(); rlen <= 0 {
   245  			if rlen < 0 {
   246  				pcsverbose.Verbosef("DEBUG: RangeLen is negative at begin: %v, %d\n", wer.wrange, wer.wrange.Len())
   247  			}
   248  			wer.status.statusCode = StatusCodeSuccessed
   249  			return
   250  		}
   251  	}
   252  
   253  	workerCancelCtx, workerCancelFunc := context.WithCancel(context.Background())
   254  	wer.workerCancelFunc = workerCancelFunc
   255  	resetCtx, resetFunc := context.WithCancel(context.Background())
   256  	wer.resetFunc = resetFunc
   257  
   258  	header := map[string]string{}
   259  	if wer.referer != "" {
   260  		header["Referer"] = wer.referer
   261  	}
   262  	//检测是否支持range
   263  	if wer.acceptRanges != "" && wer.wrange.Len() >= 0 {
   264  		header["Range"] = fmt.Sprintf("%s=%d-%d", wer.acceptRanges, wer.wrange.LoadBegin(), wer.wrange.LoadEnd()-1)
   265  	}
   266  
   267  	wer.status.statusCode = StatusCodePending
   268  
   269  	var resp *http.Response
   270  	if wer.firstResp != nil {
   271  		resp = wer.firstResp // 使用第一个连接
   272  	} else {
   273  		resp, wer.err = wer.client.Req(http.MethodGet, wer.url, nil, header)
   274  	}
   275  	if resp != nil {
   276  		defer func() {
   277  			resp.Body.Close()
   278  			wer.firstResp = nil // 去掉第一个连接
   279  		}()
   280  		wer.readRespBodyCancelFunc = func() {
   281  			resp.Body.Close()
   282  		}
   283  	}
   284  	if wer.err != nil {
   285  		wer.status.statusCode = StatusCodeNetError
   286  		return
   287  	}
   288  
   289  	// 判断响应状态
   290  	switch resp.StatusCode {
   291  	case 200, 206:
   292  		// do nothing, continue
   293  	case 416: //Requested Range Not Satisfiable
   294  		fallthrough
   295  	case 403: // Forbidden
   296  		fallthrough
   297  	case 406: // Not Acceptable
   298  		wer.status.statusCode = StatusCodeNetError
   299  		wer.err = errors.New(resp.Status)
   300  		return
   301  	case 429, 509: // Too Many Requests
   302  		wer.status.SetStatusCode(StatusCodeTooManyConnections)
   303  		wer.err = errors.New(resp.Status)
   304  		return
   305  	default:
   306  		wer.status.statusCode = StatusCodeNetError
   307  		wer.err = fmt.Errorf("unexpected http status code, %d, %s", resp.StatusCode, resp.Status)
   308  		return
   309  	}
   310  
   311  	var (
   312  		contentLength = resp.ContentLength
   313  		rangeLength   = wer.wrange.Len()
   314  	)
   315  
   316  	if !single {
   317  		// 检查请求长度
   318  		if contentLength != rangeLength && wer.firstResp == nil { // 跳过检查第一个连接
   319  			wer.status.statusCode = StatusCodeNetError
   320  			wer.err = fmt.Errorf("Content-Length is unexpected: %d, need %d", contentLength, rangeLength)
   321  			return
   322  		}
   323  		// 检查总大小
   324  		if wer.totalSize > 0 {
   325  			total := ParseContentRange(resp.Header.Get("Content-Range"))
   326  			if total > 0 {
   327  				if total != wer.totalSize {
   328  					wer.status.statusCode = StatusCodeInternalError // 这里设置为内部错误, 强制停止下载
   329  					wer.err = fmt.Errorf("Content-Range total length is unexpected: %d, need %d", total, wer.totalSize)
   330  					return
   331  				}
   332  			}
   333  		}
   334  	}
   335  
   336  	var (
   337  		buf       = cachepool.SyncPool.Get().([]byte)
   338  		n, nn     int
   339  		n64, nn64 int64
   340  	)
   341  	defer cachepool.SyncPool.Put(buf)
   342  
   343  	for {
   344  		select {
   345  		case <-workerCancelCtx.Done(): //取消
   346  			wer.status.statusCode = StatusCodeCanceled
   347  			return
   348  		case <-resetCtx.Done(): //重设连接
   349  			wer.status.statusCode = StatusCodeReseted
   350  			return
   351  		case <-wer.pauseChan: //暂停
   352  			return
   353  		default:
   354  			wer.status.statusCode = StatusCodeDownloading
   355  
   356  			// 初始化数据
   357  			var readErr error
   358  			n = 0
   359  
   360  			// 读取数据
   361  			for n < len(buf) && readErr == nil && (single || wer.wrange.Len() > 0) {
   362  				nn, readErr = resp.Body.Read(buf[n:])
   363  				nn64 = int64(nn)
   364  
   365  				// 更新速度统计
   366  				if wer.downloadStatus != nil {
   367  					wer.downloadStatus.AddSpeedsDownloaded(nn64) // 限速在这里阻塞
   368  				}
   369  				wer.speedsStat.Add(nn64)
   370  				n += nn
   371  			}
   372  
   373  			if n > 0 && readErr == io.EOF {
   374  				readErr = io.ErrUnexpectedEOF
   375  			}
   376  
   377  			n64 = int64(n)
   378  
   379  			// 非单线程模式下
   380  			if !single {
   381  				rangeLength = wer.wrange.Len()
   382  
   383  				// 已完成 (未雨绸缪)
   384  				if rangeLength <= 0 {
   385  					wer.status.statusCode = StatusCodeCanceled
   386  					wer.err = errors.New("worker already complete")
   387  					return
   388  				}
   389  
   390  				if n64 > rangeLength {
   391  					// 数据大小不正常
   392  					n64 = rangeLength
   393  					n = int(rangeLength)
   394  					readErr = io.EOF
   395  				}
   396  			}
   397  
   398  			// 写入数据
   399  			if wer.writerAt != nil {
   400  				wer.status.statusCode = StatusCodeWaitToWrite
   401  				if wer.writeMu != nil {
   402  					wer.writeMu.Lock() // 加锁, 减轻硬盘的压力
   403  				}
   404  				_, wer.err = wer.writerAt.WriteAt(buf[:n], wer.wrange.Begin) // 写入数据
   405  				if wer.err != nil {
   406  					if wer.writeMu != nil {
   407  						wer.writeMu.Unlock() //解锁
   408  					}
   409  					wer.status.statusCode = StatusCodeInternalError
   410  					return
   411  				}
   412  
   413  				if wer.writeMu != nil {
   414  					wer.writeMu.Unlock() //解锁
   415  				}
   416  				wer.status.statusCode = StatusCodeDownloading
   417  			}
   418  
   419  			// 更新下载统计数据
   420  			wer.wrange.AddBegin(n64)
   421  			if wer.downloadStatus != nil {
   422  				wer.downloadStatus.AddDownloaded(n64)
   423  				if single {
   424  					wer.downloadStatus.AddTotalSize(n64)
   425  				}
   426  			}
   427  
   428  			if readErr != nil {
   429  				rlen := wer.wrange.Len()
   430  				switch {
   431  				case single && readErr == io.ErrUnexpectedEOF:
   432  					// 单线程判断下载成功
   433  					fallthrough
   434  				case readErr == io.EOF:
   435  					fallthrough
   436  				case rlen <= 0:
   437  					// 下载完成
   438  					// 小于0可能是因为 worker 被 duplicate
   439  					wer.status.statusCode = StatusCodeSuccessed
   440  					if rlen < 0 {
   441  						pcsverbose.Verbosef("DEBUG: RangeLen is negative at end: %v, %d\n", wer.wrange, wer.wrange.Len())
   442  					}
   443  					return
   444  				default:
   445  					// 其他错误, 返回
   446  					wer.status.statusCode = StatusCodeFailed
   447  					wer.err = readErr
   448  					return
   449  				}
   450  			}
   451  		}
   452  	}
   453  }