github.com/fzfile/BaiduPCS-Go@v0.0.0-20200606205115-4408961cf336/requester/downloader/worker.go (about) 1 package downloader 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "github.com/fzfile/BaiduPCS-Go/pcsutil/cachepool" 8 "github.com/fzfile/BaiduPCS-Go/pcsverbose" 9 "github.com/fzfile/BaiduPCS-Go/requester" 10 "github.com/fzfile/BaiduPCS-Go/requester/rio/speeds" 11 "github.com/fzfile/BaiduPCS-Go/requester/transfer" 12 "io" 13 "net/http" 14 "sync" 15 ) 16 17 type ( 18 //Worker 工作单元 19 Worker struct { 20 totalSize int64 // 整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误 21 wrange *transfer.Range 22 speedsStat *speeds.Speeds 23 id int //id 24 url string //下载地址 25 referer string //来源地址 26 acceptRanges string 27 client *requester.HTTPClient 28 firstResp *http.Response // 第一个响应 29 writerAt io.WriterAt 30 writeMu *sync.Mutex 31 execMu sync.Mutex 32 33 pauseChan chan struct{} 34 workerCancelFunc context.CancelFunc 35 resetFunc context.CancelFunc 36 readRespBodyCancelFunc func() 37 err error //错误信息 38 status WorkerStatus 39 downloadStatus *transfer.DownloadStatus //总的下载状态 40 } 41 42 // WorkerList worker列表 43 WorkerList []*Worker 44 ) 45 46 // Duplicate 构造新的列表 47 func (wl WorkerList) Duplicate() WorkerList { 48 n := make(WorkerList, len(wl)) 49 copy(n, wl) 50 return n 51 } 52 53 //NewWorker 初始化Worker 54 func NewWorker(id int, durl string, writerAt io.WriterAt) *Worker { 55 return &Worker{ 56 id: id, 57 url: durl, 58 writerAt: writerAt, 59 } 60 } 61 62 //ID 返回worker ID 63 func (wer *Worker) ID() int { 64 return wer.id 65 } 66 67 func (wer *Worker) lazyInit() { 68 if wer.client == nil { 69 wer.client = requester.NewHTTPClient() 70 } 71 if wer.pauseChan == nil { 72 wer.pauseChan = make(chan struct{}) 73 } 74 if wer.wrange == nil { 75 wer.wrange = &transfer.Range{} 76 } 77 if wer.wrange.LoadBegin() == 0 && wer.wrange.LoadEnd() == 0 { 78 // 取消多线程下载 79 wer.acceptRanges = "" 80 wer.wrange.StoreEnd(-2) 81 } 82 if wer.speedsStat == nil { 83 wer.speedsStat = &speeds.Speeds{} 84 } 85 } 86 87 // SetTotalSize 设置整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误 88 func (wer *Worker) SetTotalSize(size int64) { 89 wer.totalSize = size 90 } 91 92 //SetClient 设置http客户端 93 func (wer *Worker) SetClient(c *requester.HTTPClient) { 94 wer.client = c 95 } 96 97 //SetAcceptRange 设置AcceptRange 98 func (wer *Worker) SetAcceptRange(acceptRanges string) { 99 wer.acceptRanges = acceptRanges 100 } 101 102 //SetRange 设置请求范围 103 func (wer *Worker) SetRange(r *transfer.Range) { 104 if wer.wrange == nil { 105 wer.wrange = r 106 return 107 } 108 wer.wrange.StoreBegin(r.LoadBegin()) 109 wer.wrange.StoreEnd(r.LoadEnd()) 110 } 111 112 //SetReferer 设置来源 113 func (wer *Worker) SetReferer(referer string) { 114 wer.referer = referer 115 } 116 117 //SetWriteMutex 设置数据写锁 118 func (wer *Worker) SetWriteMutex(mu *sync.Mutex) { 119 wer.writeMu = mu 120 } 121 122 //SetDownloadStatus 增加其他需要统计的数据 123 func (wer *Worker) SetDownloadStatus(downloadStatus *transfer.DownloadStatus) { 124 wer.downloadStatus = downloadStatus 125 } 126 127 //GetStatus 返回下载状态 128 func (wer *Worker) GetStatus() WorkerStatuser { 129 // 空接口与空指针不等价 130 return &wer.status 131 } 132 133 //GetRange 返回worker范围 134 func (wer *Worker) GetRange() *transfer.Range { 135 return wer.wrange 136 } 137 138 //GetSpeedsPerSecond 获取每秒的速度 139 func (wer *Worker) GetSpeedsPerSecond() int64 { 140 return wer.speedsStat.GetSpeeds() 141 } 142 143 //Pause 暂停下载 144 func (wer *Worker) Pause() { 145 wer.lazyInit() 146 if wer.acceptRanges == "" { 147 pcsverbose.Verbosef("WARNING: worker unsupport pause") 148 return 149 } 150 151 if wer.status.statusCode == StatusCodePaused { 152 return 153 } 154 wer.pauseChan <- struct{}{} 155 wer.status.statusCode = StatusCodePaused 156 } 157 158 //Resume 恢复下载 159 func (wer *Worker) Resume() { 160 if wer.status.statusCode != StatusCodePaused { 161 return 162 } 163 go wer.Execute() 164 } 165 166 //Cancel 取消下载 167 func (wer *Worker) Cancel() error { 168 if wer.workerCancelFunc == nil { 169 return errors.New("cancelFunc not set") 170 } 171 wer.workerCancelFunc() 172 if wer.readRespBodyCancelFunc != nil { 173 wer.readRespBodyCancelFunc() 174 } 175 return nil 176 } 177 178 //Reset 重设连接 179 func (wer *Worker) Reset() { 180 if wer.resetFunc == nil { 181 pcsverbose.Verbosef("DEBUG: worker: resetFunc not set") 182 return 183 } 184 wer.resetFunc() 185 if wer.readRespBodyCancelFunc != nil { 186 wer.readRespBodyCancelFunc() 187 } 188 wer.ClearStatus() 189 go wer.Execute() 190 } 191 192 // Canceled 是否已经取消 193 func (wer *Worker) Canceled() bool { 194 return wer.status.statusCode == StatusCodeCanceled 195 } 196 197 //Completed 是否已经完成 198 func (wer *Worker) Completed() bool { 199 switch wer.status.statusCode { 200 case StatusCodeSuccessed, StatusCodeCanceled: 201 return true 202 default: 203 return false 204 } 205 } 206 207 //Failed 是否失败 208 func (wer *Worker) Failed() bool { 209 switch wer.status.statusCode { 210 case StatusCodeFailed, StatusCodeInternalError, StatusCodeTooManyConnections, StatusCodeNetError: 211 return true 212 default: 213 return false 214 } 215 } 216 217 //ClearStatus 清空状态 218 func (wer *Worker) ClearStatus() { 219 wer.status.statusCode = StatusCodeInit 220 } 221 222 //Err 返回worker错误 223 func (wer *Worker) Err() error { 224 return wer.err 225 } 226 227 //Execute 执行任务 228 func (wer *Worker) Execute() { 229 wer.lazyInit() 230 231 wer.execMu.Lock() 232 defer wer.execMu.Unlock() 233 234 wer.status.statusCode = StatusCodeInit 235 single := wer.acceptRanges == "" 236 237 // 如果已暂停, 退出 238 if wer.status.statusCode == StatusCodePaused { 239 return 240 } 241 242 if !single { 243 // 已完成 244 if rlen := wer.wrange.Len(); rlen <= 0 { 245 if rlen < 0 { 246 pcsverbose.Verbosef("DEBUG: RangeLen is negative at begin: %v, %d\n", wer.wrange, wer.wrange.Len()) 247 } 248 wer.status.statusCode = StatusCodeSuccessed 249 return 250 } 251 } 252 253 workerCancelCtx, workerCancelFunc := context.WithCancel(context.Background()) 254 wer.workerCancelFunc = workerCancelFunc 255 resetCtx, resetFunc := context.WithCancel(context.Background()) 256 wer.resetFunc = resetFunc 257 258 header := map[string]string{} 259 if wer.referer != "" { 260 header["Referer"] = wer.referer 261 } 262 //检测是否支持range 263 if wer.acceptRanges != "" && wer.wrange.Len() >= 0 { 264 header["Range"] = fmt.Sprintf("%s=%d-%d", wer.acceptRanges, wer.wrange.LoadBegin(), wer.wrange.LoadEnd()-1) 265 } 266 267 wer.status.statusCode = StatusCodePending 268 269 var resp *http.Response 270 if wer.firstResp != nil { 271 resp = wer.firstResp // 使用第一个连接 272 } else { 273 resp, wer.err = wer.client.Req(http.MethodGet, wer.url, nil, header) 274 } 275 if resp != nil { 276 defer func() { 277 resp.Body.Close() 278 wer.firstResp = nil // 去掉第一个连接 279 }() 280 wer.readRespBodyCancelFunc = func() { 281 resp.Body.Close() 282 } 283 } 284 if wer.err != nil { 285 wer.status.statusCode = StatusCodeNetError 286 return 287 } 288 289 // 判断响应状态 290 switch resp.StatusCode { 291 case 200, 206: 292 // do nothing, continue 293 case 416: //Requested Range Not Satisfiable 294 fallthrough 295 case 403: // Forbidden 296 fallthrough 297 case 406: // Not Acceptable 298 wer.status.statusCode = StatusCodeNetError 299 wer.err = errors.New(resp.Status) 300 return 301 case 429, 509: // Too Many Requests 302 wer.status.SetStatusCode(StatusCodeTooManyConnections) 303 wer.err = errors.New(resp.Status) 304 return 305 default: 306 wer.status.statusCode = StatusCodeNetError 307 wer.err = fmt.Errorf("unexpected http status code, %d, %s", resp.StatusCode, resp.Status) 308 return 309 } 310 311 var ( 312 contentLength = resp.ContentLength 313 rangeLength = wer.wrange.Len() 314 ) 315 316 if !single { 317 // 检查请求长度 318 if contentLength != rangeLength && wer.firstResp == nil { // 跳过检查第一个连接 319 wer.status.statusCode = StatusCodeNetError 320 wer.err = fmt.Errorf("Content-Length is unexpected: %d, need %d", contentLength, rangeLength) 321 return 322 } 323 // 检查总大小 324 if wer.totalSize > 0 { 325 total := ParseContentRange(resp.Header.Get("Content-Range")) 326 if total > 0 { 327 if total != wer.totalSize { 328 wer.status.statusCode = StatusCodeInternalError // 这里设置为内部错误, 强制停止下载 329 wer.err = fmt.Errorf("Content-Range total length is unexpected: %d, need %d", total, wer.totalSize) 330 return 331 } 332 } 333 } 334 } 335 336 var ( 337 buf = cachepool.SyncPool.Get().([]byte) 338 n, nn int 339 n64, nn64 int64 340 ) 341 defer cachepool.SyncPool.Put(buf) 342 343 for { 344 select { 345 case <-workerCancelCtx.Done(): //取消 346 wer.status.statusCode = StatusCodeCanceled 347 return 348 case <-resetCtx.Done(): //重设连接 349 wer.status.statusCode = StatusCodeReseted 350 return 351 case <-wer.pauseChan: //暂停 352 return 353 default: 354 wer.status.statusCode = StatusCodeDownloading 355 356 // 初始化数据 357 var readErr error 358 n = 0 359 360 // 读取数据 361 for n < len(buf) && readErr == nil && (single || wer.wrange.Len() > 0) { 362 nn, readErr = resp.Body.Read(buf[n:]) 363 nn64 = int64(nn) 364 365 // 更新速度统计 366 if wer.downloadStatus != nil { 367 wer.downloadStatus.AddSpeedsDownloaded(nn64) // 限速在这里阻塞 368 } 369 wer.speedsStat.Add(nn64) 370 n += nn 371 } 372 373 if n > 0 && readErr == io.EOF { 374 readErr = io.ErrUnexpectedEOF 375 } 376 377 n64 = int64(n) 378 379 // 非单线程模式下 380 if !single { 381 rangeLength = wer.wrange.Len() 382 383 // 已完成 (未雨绸缪) 384 if rangeLength <= 0 { 385 wer.status.statusCode = StatusCodeCanceled 386 wer.err = errors.New("worker already complete") 387 return 388 } 389 390 if n64 > rangeLength { 391 // 数据大小不正常 392 n64 = rangeLength 393 n = int(rangeLength) 394 readErr = io.EOF 395 } 396 } 397 398 // 写入数据 399 if wer.writerAt != nil { 400 wer.status.statusCode = StatusCodeWaitToWrite 401 if wer.writeMu != nil { 402 wer.writeMu.Lock() // 加锁, 减轻硬盘的压力 403 } 404 _, wer.err = wer.writerAt.WriteAt(buf[:n], wer.wrange.Begin) // 写入数据 405 if wer.err != nil { 406 if wer.writeMu != nil { 407 wer.writeMu.Unlock() //解锁 408 } 409 wer.status.statusCode = StatusCodeInternalError 410 return 411 } 412 413 if wer.writeMu != nil { 414 wer.writeMu.Unlock() //解锁 415 } 416 wer.status.statusCode = StatusCodeDownloading 417 } 418 419 // 更新下载统计数据 420 wer.wrange.AddBegin(n64) 421 if wer.downloadStatus != nil { 422 wer.downloadStatus.AddDownloaded(n64) 423 if single { 424 wer.downloadStatus.AddTotalSize(n64) 425 } 426 } 427 428 if readErr != nil { 429 rlen := wer.wrange.Len() 430 switch { 431 case single && readErr == io.ErrUnexpectedEOF: 432 // 单线程判断下载成功 433 fallthrough 434 case readErr == io.EOF: 435 fallthrough 436 case rlen <= 0: 437 // 下载完成 438 // 小于0可能是因为 worker 被 duplicate 439 wer.status.statusCode = StatusCodeSuccessed 440 if rlen < 0 { 441 pcsverbose.Verbosef("DEBUG: RangeLen is negative at end: %v, %d\n", wer.wrange, wer.wrange.Len()) 442 } 443 return 444 default: 445 // 其他错误, 返回 446 wer.status.statusCode = StatusCodeFailed 447 wer.err = readErr 448 return 449 } 450 } 451 } 452 } 453 }