github.com/qjfoidnh/BaiduPCS-Go@v0.0.0-20231011165705-caa18a3765f3/requester/downloader/worker.go (about) 1 package downloader 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "github.com/qjfoidnh/BaiduPCS-Go/pcsutil/cachepool" 8 "github.com/qjfoidnh/BaiduPCS-Go/pcsverbose" 9 "github.com/qjfoidnh/BaiduPCS-Go/requester" 10 "github.com/qjfoidnh/BaiduPCS-Go/requester/rio/speeds" 11 "github.com/qjfoidnh/BaiduPCS-Go/requester/transfer" 12 "io" 13 "net/http" 14 "sync" 15 ) 16 17 type ( 18 //Worker 工作单元 19 Worker struct { 20 totalSize int64 // 整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误 21 wrange *transfer.Range 22 speedsStat *speeds.Speeds 23 id int //id 24 url string //下载地址 25 referer string //来源地址 26 acceptRanges string 27 client *requester.HTTPClient 28 firstResp *http.Response // 第一个响应 29 writerAt io.WriterAt 30 writeMu *sync.Mutex 31 execMu sync.Mutex 32 33 pauseChan chan struct{} 34 workerCancelFunc context.CancelFunc 35 resetFunc context.CancelFunc 36 readRespBodyCancelFunc func() 37 err error //错误信息 38 status WorkerStatus 39 downloadStatus *transfer.DownloadStatus //总的下载状态 40 } 41 42 // WorkerList worker列表 43 WorkerList []*Worker 44 ) 45 46 // Duplicate 构造新的列表 47 func (wl WorkerList) Duplicate() WorkerList { 48 n := make(WorkerList, len(wl)) 49 copy(n, wl) 50 return n 51 } 52 53 //NewWorker 初始化Worker 54 func NewWorker(id int, durl string, writerAt io.WriterAt) *Worker { 55 return &Worker{ 56 id: id, 57 url: durl, 58 writerAt: writerAt, 59 } 60 } 61 62 //ID 返回worker ID 63 func (wer *Worker) ID() int { 64 return wer.id 65 } 66 67 func (wer *Worker) lazyInit() { 68 if wer.client == nil { 69 wer.client = requester.NewHTTPClient() 70 } 71 if wer.pauseChan == nil { 72 wer.pauseChan = make(chan struct{}) 73 } 74 if wer.wrange == nil { 75 wer.wrange = &transfer.Range{} 76 } 77 if wer.wrange.LoadBegin() == 0 && wer.wrange.LoadEnd() == 0 { 78 // 取消多线程下载 79 wer.acceptRanges = "" 80 wer.wrange.StoreEnd(-2) 81 } 82 if wer.speedsStat == nil { 83 wer.speedsStat = &speeds.Speeds{} 84 } 85 } 86 87 // SetTotalSize 设置整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误 88 func (wer *Worker) SetTotalSize(size int64) { 89 wer.totalSize = size 90 } 91 92 //SetClient 设置http客户端 93 func (wer *Worker) SetClient(c *requester.HTTPClient) { 94 wer.client = c 95 } 96 97 //SetAcceptRange 设置AcceptRange 98 func (wer *Worker) SetAcceptRange(acceptRanges string) { 99 wer.acceptRanges = acceptRanges 100 } 101 102 //SetRange 设置请求范围 103 func (wer *Worker) SetRange(r *transfer.Range) { 104 if wer.wrange == nil { 105 wer.wrange = r 106 return 107 } 108 wer.wrange.StoreBegin(r.LoadBegin()) 109 wer.wrange.StoreEnd(r.LoadEnd()) 110 } 111 112 //SetReferer 设置来源 113 func (wer *Worker) SetReferer(referer string) { 114 wer.referer = referer 115 } 116 117 //SetWriteMutex 设置数据写锁 118 func (wer *Worker) SetWriteMutex(mu *sync.Mutex) { 119 wer.writeMu = mu 120 } 121 122 //SetDownloadStatus 增加其他需要统计的数据 123 func (wer *Worker) SetDownloadStatus(downloadStatus *transfer.DownloadStatus) { 124 wer.downloadStatus = downloadStatus 125 } 126 127 //GetStatus 返回下载状态 128 func (wer *Worker) GetStatus() WorkerStatuser { 129 // 空接口与空指针不等价 130 return &wer.status 131 } 132 133 //GetRange 返回worker范围 134 func (wer *Worker) GetRange() *transfer.Range { 135 return wer.wrange 136 } 137 138 //GetSpeedsPerSecond 获取每秒的速度 139 func (wer *Worker) GetSpeedsPerSecond() int64 { 140 return wer.speedsStat.GetSpeeds() 141 } 142 143 //Pause 暂停下载 144 func (wer *Worker) Pause() { 145 wer.lazyInit() 146 if wer.acceptRanges == "" { 147 pcsverbose.Verbosef("WARNING: worker unsupport pause") 148 return 149 } 150 151 if wer.status.statusCode == StatusCodePaused { 152 return 153 } 154 wer.pauseChan <- struct{}{} 155 wer.status.statusCode = StatusCodePaused 156 } 157 158 //Resume 恢复下载 159 func (wer *Worker) Resume() { 160 if wer.status.statusCode != StatusCodePaused { 161 return 162 } 163 go wer.Execute() 164 } 165 166 //Cancel 取消下载 167 func (wer *Worker) Cancel() error { 168 if wer.workerCancelFunc == nil { 169 return errors.New("cancelFunc not set") 170 } 171 wer.workerCancelFunc() 172 if wer.readRespBodyCancelFunc != nil { 173 wer.readRespBodyCancelFunc() 174 } 175 return nil 176 } 177 178 //Reset 重设连接 179 func (wer *Worker) Reset() { 180 if wer.resetFunc == nil { 181 pcsverbose.Verbosef("DEBUG: worker: resetFunc not set") 182 return 183 } 184 wer.resetFunc() 185 if wer.readRespBodyCancelFunc != nil { 186 wer.readRespBodyCancelFunc() 187 } 188 wer.ClearStatus() 189 go wer.Execute() 190 } 191 192 // Canceled 是否已经取消 193 func (wer *Worker) Canceled() bool { 194 return wer.status.statusCode == StatusCodeCanceled 195 } 196 197 //Completed 是否已经完成 198 func (wer *Worker) Completed() bool { 199 switch wer.status.statusCode { 200 case StatusCodeSuccessed, StatusCodeCanceled: 201 return true 202 default: 203 return false 204 } 205 } 206 207 //Failed 是否失败 208 func (wer *Worker) Failed() bool { 209 switch wer.status.statusCode { 210 case StatusCodeFailed, StatusCodeInternalError, StatusCodeTooManyConnections, StatusCodeNetError: 211 return true 212 default: 213 return false 214 } 215 } 216 217 //ClearStatus 清空状态 218 func (wer *Worker) ClearStatus() { 219 wer.status.statusCode = StatusCodeInit 220 } 221 222 //Err 返回worker错误 223 func (wer *Worker) Err() error { 224 return wer.err 225 } 226 227 //Execute 执行任务 228 func (wer *Worker) Execute() { 229 wer.lazyInit() 230 231 wer.execMu.Lock() 232 defer wer.execMu.Unlock() 233 234 wer.status.statusCode = StatusCodeInit 235 single := wer.acceptRanges == "" 236 237 // 如果已暂停, 退出 238 if wer.status.statusCode == StatusCodePaused { 239 return 240 } 241 242 if !single { 243 // 已完成 244 if rlen := wer.wrange.Len(); rlen <= 0 { 245 if rlen < 0 { 246 pcsverbose.Verbosef("DEBUG: RangeLen is negative at begin: %v, %d\n", wer.wrange, wer.wrange.Len()) 247 } 248 wer.status.statusCode = StatusCodeSuccessed 249 return 250 } 251 } 252 253 workerCancelCtx, workerCancelFunc := context.WithCancel(context.Background()) 254 wer.workerCancelFunc = workerCancelFunc 255 resetCtx, resetFunc := context.WithCancel(context.Background()) 256 wer.resetFunc = resetFunc 257 258 header := map[string]string{} 259 if wer.referer != "" { 260 header["Referer"] = wer.referer 261 } 262 //检测是否支持range 263 if wer.acceptRanges != "" && wer.wrange.Len() >= 0 { 264 header["Range"] = fmt.Sprintf("%s=%d-%d", wer.acceptRanges, wer.wrange.LoadBegin(), wer.wrange.LoadEnd()-1) 265 } 266 267 wer.status.statusCode = StatusCodePending 268 269 var resp *http.Response 270 if wer.firstResp != nil { 271 resp = wer.firstResp // 使用第一个连接 272 } else { 273 resp, wer.err = wer.client.Req(http.MethodGet, wer.url, nil, header) 274 } 275 if resp != nil { 276 defer func() { 277 resp.Body.Close() 278 wer.firstResp = nil // 去掉第一个连接 279 }() 280 wer.readRespBodyCancelFunc = func() { 281 resp.Body.Close() 282 } 283 } 284 if wer.err != nil { 285 wer.status.statusCode = StatusCodeNetError 286 return 287 } 288 289 // 判断响应状态 290 switch resp.StatusCode { 291 case 200, 206: 292 // do nothing, continue 293 case 416: //Requested Range Not Satisfiable 294 fallthrough 295 case 403: // Forbidden 296 fallthrough 297 case 404: // file block not exists 298 wer.status.statusCode = StatusCodeInternalError 299 wer.err = errors.New(resp.Status) 300 return 301 case 406: // Not Acceptable 302 wer.status.statusCode = StatusCodeNetError 303 wer.err = errors.New(resp.Status) 304 return 305 case 429, 509: // Too Many Requests 306 wer.status.SetStatusCode(StatusCodeTooManyConnections) 307 wer.err = errors.New(resp.Status) 308 return 309 default: 310 wer.status.statusCode = StatusCodeNetError 311 wer.err = fmt.Errorf("unexpected http status code, %d, %s", resp.StatusCode, resp.Status) 312 return 313 } 314 315 var ( 316 contentLength = resp.ContentLength 317 rangeLength = wer.wrange.Len() 318 ) 319 320 if !single { 321 // 检查请求长度 322 if contentLength != rangeLength && wer.firstResp == nil { // 跳过检查第一个连接 323 wer.status.statusCode = StatusCodeNetError 324 wer.err = fmt.Errorf("Content-Length is unexpected: %d, need %d", contentLength, rangeLength) 325 return 326 } 327 // 检查总大小 328 if wer.totalSize > 0 { 329 total := ParseContentRange(resp.Header.Get("Content-Range")) 330 if total > 0 { 331 if total != wer.totalSize { 332 wer.status.statusCode = StatusCodeInternalError // 这里设置为内部错误, 强制停止下载 333 wer.err = fmt.Errorf("Content-Range total length is unexpected: %d, need %d", total, wer.totalSize) 334 return 335 } 336 } 337 } 338 } 339 340 var ( 341 buf = cachepool.SyncPool.Get().([]byte) 342 n, nn int 343 n64, nn64 int64 344 ) 345 defer cachepool.SyncPool.Put(buf) 346 347 for { 348 select { 349 case <-workerCancelCtx.Done(): //取消 350 wer.status.statusCode = StatusCodeCanceled 351 return 352 case <-resetCtx.Done(): //重设连接 353 wer.status.statusCode = StatusCodeReseted 354 return 355 case <-wer.pauseChan: //暂停 356 return 357 default: 358 wer.status.statusCode = StatusCodeDownloading 359 360 // 初始化数据 361 var readErr error 362 n = 0 363 364 // 读取数据 365 for n < len(buf) && readErr == nil && (single || wer.wrange.Len() > 0) { 366 nn, readErr = resp.Body.Read(buf[n:]) 367 nn64 = int64(nn) 368 369 // 更新速度统计 370 if wer.downloadStatus != nil { 371 wer.downloadStatus.AddSpeedsDownloaded(nn64) // 限速在这里阻塞 372 } 373 wer.speedsStat.Add(nn64) 374 n += nn 375 } 376 377 if n > 0 && readErr == io.EOF { 378 readErr = io.ErrUnexpectedEOF 379 } 380 381 n64 = int64(n) 382 383 // 非单线程模式下 384 if !single { 385 rangeLength = wer.wrange.Len() 386 387 // 已完成 (未雨绸缪) 388 if rangeLength <= 0 { 389 wer.status.statusCode = StatusCodeCanceled 390 wer.err = errors.New("worker already complete") 391 return 392 } 393 394 if n64 > rangeLength { 395 // 数据大小不正常 396 n64 = rangeLength 397 n = int(rangeLength) 398 readErr = io.EOF 399 } 400 } 401 402 // 写入数据 403 if wer.writerAt != nil { 404 wer.status.statusCode = StatusCodeWaitToWrite 405 if wer.writeMu != nil { 406 wer.writeMu.Lock() // 加锁, 减轻硬盘的压力 407 } 408 _, wer.err = wer.writerAt.WriteAt(buf[:n], wer.wrange.Begin) // 写入数据 409 if wer.err != nil { 410 if wer.writeMu != nil { 411 wer.writeMu.Unlock() //解锁 412 } 413 wer.status.statusCode = StatusCodeInternalError 414 return 415 } 416 417 if wer.writeMu != nil { 418 wer.writeMu.Unlock() //解锁 419 } 420 wer.status.statusCode = StatusCodeDownloading 421 } 422 423 // 更新下载统计数据 424 wer.wrange.AddBegin(n64) 425 if wer.downloadStatus != nil { 426 wer.downloadStatus.AddDownloaded(n64) 427 if single { 428 wer.downloadStatus.AddTotalSize(n64) 429 } 430 } 431 432 if readErr != nil { 433 rlen := wer.wrange.Len() 434 switch { 435 case single && readErr == io.ErrUnexpectedEOF: 436 // 单线程判断下载成功 437 fallthrough 438 case readErr == io.EOF: 439 fallthrough 440 case rlen <= 0: 441 // 下载完成 442 // 小于0可能是因为 worker 被 duplicate 443 wer.status.statusCode = StatusCodeSuccessed 444 if rlen < 0 { 445 pcsverbose.Verbosef("DEBUG: RangeLen is negative at end: %v, %d\n", wer.wrange, wer.wrange.Len()) 446 } 447 return 448 default: 449 // 其他错误, 返回 450 wer.status.statusCode = StatusCodeFailed 451 wer.err = readErr 452 return 453 } 454 } 455 } 456 } 457 }