github.com/fzfile/BaiduPCS-Go@v0.0.0-20200606205115-4408961cf336/requester/downloader/downloader.go (about) 1 // Package downloader 多线程下载器, 重构版 2 package downloader 3 4 import ( 5 "context" 6 "errors" 7 "github.com/fzfile/BaiduPCS-Go/pcsutil" 8 "github.com/fzfile/BaiduPCS-Go/pcsutil/cachepool" 9 "github.com/fzfile/BaiduPCS-Go/pcsutil/prealloc" 10 "github.com/fzfile/BaiduPCS-Go/pcsutil/waitgroup" 11 "github.com/fzfile/BaiduPCS-Go/pcsverbose" 12 "github.com/fzfile/BaiduPCS-Go/requester" 13 "github.com/fzfile/BaiduPCS-Go/requester/rio/speeds" 14 "github.com/fzfile/BaiduPCS-Go/requester/transfer" 15 "io" 16 "net/http" 17 "sync" 18 "time" 19 ) 20 21 const ( 22 // DefaultAcceptRanges 默认的 Accept-Ranges 23 DefaultAcceptRanges = "bytes" 24 ) 25 26 type ( 27 // Downloader 下载 28 Downloader struct { 29 onExecuteEvent requester.Event //开始下载事件 30 onSuccessEvent requester.Event //成功下载事件 31 onFinishEvent requester.Event //结束下载事件 32 onPauseEvent requester.Event //暂停下载事件 33 onResumeEvent requester.Event //恢复下载事件 34 onCancelEvent requester.Event //取消下载事件 35 onDownloadStatusEvent DownloadStatusFunc //状态处理事件 36 37 monitorCancelFunc context.CancelFunc 38 39 firstInfo *DownloadFirstInfo // 初始信息 40 loadBalancerCompareFunc LoadBalancerCompareFunc // 负载均衡检测函数 41 durlCheckFunc DURLCheckFunc // 下载url检测函数 42 statusCodeBodyCheckFunc StatusCodeBodyCheckFunc 43 executeTime time.Time 44 durl string 45 loadBalansers []string 46 writer io.WriterAt 47 client *requester.HTTPClient 48 config *Config 49 monitor *Monitor 50 instanceState *InstanceState 51 } 52 53 // DURLCheckFunc 下载URL检测函数 54 DURLCheckFunc func(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) 55 // StatusCodeBodyCheckFunc 响应状态码出错的检查函数 56 StatusCodeBodyCheckFunc func(respBody io.Reader) error 57 ) 58 59 //NewDownloader 初始化Downloader 60 func NewDownloader(durl string, writer io.WriterAt, config *Config) (der *Downloader) { 61 der = &Downloader{ 62 durl: durl, 63 config: config, 64 writer: writer, 65 } 66 67 return 68 } 69 70 // SetFirstInfo 设置初始信息 71 // 如果设置了此值, 将忽略检测url 72 func (der *Downloader) SetFirstInfo(i *DownloadFirstInfo) { 73 der.firstInfo = i 74 } 75 76 //SetClient 设置http客户端 77 func (der *Downloader) SetClient(client *requester.HTTPClient) { 78 der.client = client 79 } 80 81 // SetDURLCheckFunc 设置下载URL检测函数 82 func (der *Downloader) SetDURLCheckFunc(f DURLCheckFunc) { 83 der.durlCheckFunc = f 84 } 85 86 // SetLoadBalancerCompareFunc 设置负载均衡检测函数 87 func (der *Downloader) SetLoadBalancerCompareFunc(f LoadBalancerCompareFunc) { 88 der.loadBalancerCompareFunc = f 89 } 90 91 //SetStatusCodeBodyCheckFunc 设置响应状态码出错的检查函数, 当FirstCheckMethod不为HEAD时才有效 92 func (der *Downloader) SetStatusCodeBodyCheckFunc(f StatusCodeBodyCheckFunc) { 93 der.statusCodeBodyCheckFunc = f 94 } 95 96 func (der *Downloader) lazyInit() { 97 if der.config == nil { 98 der.config = NewConfig() 99 } 100 if der.client == nil { 101 der.client = requester.NewHTTPClient() 102 der.client.SetTimeout(20 * time.Minute) 103 } 104 if der.monitor == nil { 105 der.monitor = NewMonitor() 106 } 107 if der.durlCheckFunc == nil { 108 der.durlCheckFunc = DefaultDURLCheckFunc 109 } 110 if der.loadBalancerCompareFunc == nil { 111 der.loadBalancerCompareFunc = DefaultLoadBalancerCompareFunc 112 } 113 } 114 115 // SelectParallel 获取合适的 parallel 116 func (der *Downloader) SelectParallel(single bool, maxParallel int, totalSize int64, instanceRangeList transfer.RangeList) (parallel int) { 117 isRange := instanceRangeList != nil && len(instanceRangeList) > 0 118 if single { //不支持多线程 119 parallel = 1 120 } else if isRange { 121 parallel = len(instanceRangeList) 122 } else { 123 parallel = der.config.MaxParallel 124 if int64(parallel) > totalSize/int64(MinParallelSize) { 125 parallel = int(totalSize/int64(MinParallelSize)) + 1 126 } 127 } 128 129 if parallel < 1 { 130 parallel = 1 131 } 132 return 133 } 134 135 // SelectBlockSizeAndInitRangeGen 获取合适的 BlockSize, 和初始化 RangeGen 136 func (der *Downloader) SelectBlockSizeAndInitRangeGen(single bool, status *transfer.DownloadStatus, parallel int) (blockSize int64, initErr error) { 137 // Range 生成器 138 if single { // 单线程 139 blockSize = -1 140 return 141 } 142 gen := status.RangeListGen() 143 if gen == nil { 144 switch der.config.Mode { 145 case transfer.RangeGenMode_Default: 146 gen = transfer.NewRangeListGenDefault(status.TotalSize(), 0, 0, parallel) 147 blockSize = gen.LoadBlockSize() 148 case transfer.RangeGenMode_BlockSize: 149 b2 := status.TotalSize()/int64(parallel) + 1 150 if b2 > der.config.BlockSize { // 选小的BlockSize, 以更高并发 151 blockSize = der.config.BlockSize 152 } else { 153 blockSize = b2 154 } 155 156 gen = transfer.NewRangeListGenBlockSize(status.TotalSize(), 0, blockSize) 157 default: 158 initErr = transfer.ErrUnknownRangeGenMode 159 return 160 } 161 } else { 162 blockSize = gen.LoadBlockSize() 163 } 164 status.SetRangeListGen(gen) 165 return 166 } 167 168 // SelectCacheSize 获取合适的 cacheSize 169 func (der *Downloader) SelectCacheSize(confCacheSize int, blockSize int64) (cacheSize int) { 170 if blockSize > 0 && int64(confCacheSize) > blockSize { 171 // 如果 cache size 过高, 则调低 172 cacheSize = int(blockSize) 173 } else { 174 cacheSize = confCacheSize 175 } 176 return 177 } 178 179 // DefaultDURLCheckFunc 默认的 DURLCheckFunc 180 func DefaultDURLCheckFunc(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) { 181 resp, err = client.Req(http.MethodGet, durl, nil, nil) 182 if err != nil { 183 if resp != nil { 184 resp.Body.Close() 185 } 186 return 0, nil, err 187 } 188 return resp.ContentLength, resp, nil 189 } 190 191 func (der *Downloader) checkLoadBalancers() *LoadBalancerResponseList { 192 var ( 193 loadBalancerResponses = make([]*LoadBalancerResponse, 0, len(der.loadBalansers)+1) 194 handleLoadBalancer = func(req *http.Request) { 195 if req == nil { 196 return 197 } 198 199 if der.config.TryHTTP { 200 req.URL.Scheme = "http" 201 } 202 203 loadBalancer := &LoadBalancerResponse{ 204 URL: req.URL.String(), 205 Referer: req.Referer(), 206 } 207 208 loadBalancerResponses = append(loadBalancerResponses, loadBalancer) 209 pcsverbose.Verbosef("DEBUG: load balance task: URL: %s, Referer: %s\n", loadBalancer.URL, loadBalancer.Referer) 210 } 211 ) 212 213 // 加入第一个 214 loadBalancerResponses = append(loadBalancerResponses, &LoadBalancerResponse{ 215 URL: der.durl, 216 }) 217 218 // 负载均衡 219 wg := waitgroup.NewWaitGroup(10) 220 privTimeout := der.client.Client.Timeout 221 der.client.SetTimeout(5 * time.Second) 222 for _, loadBalanser := range der.loadBalansers { 223 wg.AddDelta() 224 go func(loadBalanser string) { 225 defer wg.Done() 226 227 subContentLength, subResp, subErr := der.durlCheckFunc(der.client, loadBalanser) 228 if subResp != nil { 229 subResp.Body.Close() // 不读Body, 马上关闭连接 230 } 231 if subErr != nil { 232 pcsverbose.Verbosef("DEBUG: loadBalanser Error: %s\n", subErr) 233 return 234 } 235 236 // 检测状态码 237 switch subResp.StatusCode / 100 { 238 case 2: // succeed 239 case 4, 5: // error 240 var err error 241 if der.statusCodeBodyCheckFunc != nil { 242 err = der.statusCodeBodyCheckFunc(subResp.Body) 243 } else { 244 err = errors.New(subResp.Status) 245 } 246 pcsverbose.Verbosef("DEBUG: loadBalanser Status Error: %s\n", err) 247 return 248 } 249 250 // 检测长度 251 if der.firstInfo.ContentLength != subContentLength { 252 pcsverbose.Verbosef("DEBUG: loadBalanser Content-Length not equal to main server\n") 253 return 254 } 255 256 if !der.loadBalancerCompareFunc(der.firstInfo.ToMap(), subResp) { 257 pcsverbose.Verbosef("DEBUG: loadBalanser not equal to main server\n") 258 return 259 } 260 261 handleLoadBalancer(subResp.Request) 262 }(loadBalanser) 263 } 264 wg.Wait() 265 der.client.SetTimeout(privTimeout) 266 267 loadBalancerResponseList := NewLoadBalancerResponseList(loadBalancerResponses) 268 return loadBalancerResponseList 269 } 270 271 //Execute 开始任务 272 func (der *Downloader) Execute() error { 273 der.lazyInit() 274 275 var ( 276 resp *http.Response 277 ) 278 if der.firstInfo == nil { 279 // 检测 280 contentLength, resp, err := der.durlCheckFunc(der.client, der.durl) 281 if err != nil { 282 return err 283 } 284 285 // 检测网络错误 286 switch resp.StatusCode / 100 { 287 case 2: // succeed 288 case 4, 5: // error 289 if der.statusCodeBodyCheckFunc != nil { 290 err = der.statusCodeBodyCheckFunc(resp.Body) 291 resp.Body.Close() // 关闭连接 292 if err != nil { 293 return err 294 } 295 } 296 return errors.New(resp.Status) 297 } 298 299 acceptRanges := resp.Header.Get("Accept-Ranges") 300 if contentLength < 0 { 301 acceptRanges = "" 302 } else { 303 acceptRanges = DefaultAcceptRanges 304 } 305 306 // 初始化firstInfo 307 der.firstInfo = &DownloadFirstInfo{ 308 ContentLength: contentLength, 309 ContentMD5: resp.Header.Get("Content-MD5"), 310 ContentCRC32: resp.Header.Get("x-bs-meta-crc32"), 311 AcceptRanges: acceptRanges, 312 Referer: resp.Header.Get("Referer"), 313 } 314 pcsverbose.Verbosef("DEBUG: download task: URL: %s, Referer: %s\n", resp.Request.URL, resp.Request.Referer()) 315 } else { 316 if der.firstInfo.AcceptRanges == "" { 317 der.firstInfo.AcceptRanges = DefaultAcceptRanges 318 } 319 } 320 321 var ( 322 loadBalancerResponseList = der.checkLoadBalancers() 323 single = der.firstInfo.AcceptRanges == "" 324 bii *transfer.DownloadInstanceInfo 325 ) 326 327 if !single { 328 //load breakpoint 329 //服务端不支持多线程时, 不记录断点 330 err := der.initInstanceState(der.config.InstanceStateStorageFormat) 331 if err != nil { 332 return err 333 } 334 bii = der.instanceState.Get() 335 } 336 337 var ( 338 isInstance = bii != nil // 是否存在断点信息 339 status *transfer.DownloadStatus 340 ) 341 if !isInstance { 342 bii = &transfer.DownloadInstanceInfo{} 343 } 344 345 if bii.DownloadStatus != nil { 346 // 使用断点信息的状态 347 status = bii.DownloadStatus 348 } else { 349 // 新建状态 350 status = transfer.NewDownloadStatus() 351 status.SetTotalSize(der.firstInfo.ContentLength) 352 } 353 354 // 设置限速 355 if der.config.MaxRate > 0 { 356 rl := speeds.NewRateLimit(der.config.MaxRate) 357 status.SetRateLimit(rl) 358 defer rl.Stop() 359 } 360 361 // 数据处理 362 parallel := der.SelectParallel(single, der.config.MaxParallel, status.TotalSize(), bii.Ranges) // 实际的下载并行量 363 blockSize, err := der.SelectBlockSizeAndInitRangeGen(single, status, parallel) // 实际的BlockSize 364 if err != nil { 365 return err 366 } 367 368 cacheSize := der.SelectCacheSize(der.config.CacheSize, blockSize) // 实际下载缓存 369 cachepool.SetSyncPoolSize(cacheSize) // 调整pool大小 370 371 pcsverbose.Verbosef("DEBUG: download task CREATED: parallel: %d, cache size: %d\n", parallel, cacheSize) 372 373 der.monitor.InitMonitorCapacity(parallel) 374 375 var writer Writer 376 if !der.config.IsTest { 377 // 尝试修剪文件 378 if fder, ok := der.writer.(Fder); ok { 379 err = prealloc.PreAlloc(fder.Fd(), status.TotalSize()) 380 if err != nil { 381 pcsverbose.Verbosef("DEBUG: truncate file error: %s\n", err) 382 } 383 } 384 writer = der.writer // 非测试模式, 赋值writer 385 } 386 387 // 数据平均分配给各个线程 388 isRange := bii.Ranges != nil && len(bii.Ranges) > 0 389 if !isRange { 390 // 没有使用断点续传 391 // 分配线程 392 bii.Ranges = make(transfer.RangeList, 0, parallel) 393 if single { // 单线程 394 bii.Ranges = append(bii.Ranges, &transfer.Range{}) 395 } else { 396 gen := status.RangeListGen() 397 for i := 0; i < cap(bii.Ranges); i++ { 398 _, r := gen.GenRange() 399 if r == nil { // 没有了(不正常) 400 break 401 } 402 bii.Ranges = append(bii.Ranges, r) 403 } 404 } 405 } 406 407 var ( 408 writeMu = &sync.Mutex{} 409 ) 410 for k, r := range bii.Ranges { 411 loadBalancer := loadBalancerResponseList.SequentialGet() 412 if loadBalancer == nil { 413 continue 414 } 415 416 worker := NewWorker(k, loadBalancer.URL, writer) 417 worker.SetClient(der.client) 418 worker.SetWriteMutex(writeMu) 419 worker.SetReferer(loadBalancer.Referer) 420 worker.SetTotalSize(der.firstInfo.ContentLength) 421 422 // 使用第一个连接 423 // 断点续传时不使用 424 if k == 0 && !isInstance { 425 worker.firstResp = resp 426 } 427 428 worker.SetAcceptRange(der.firstInfo.AcceptRanges) 429 worker.SetRange(r) // 分配Range 430 der.monitor.Append(worker) 431 } 432 433 der.monitor.SetStatus(status) 434 435 // 服务器不支持断点续传, 或者单线程下载, 都不重载worker 436 der.monitor.SetReloadWorker(parallel > 1) 437 438 moniterCtx, moniterCancelFunc := context.WithCancel(context.Background()) 439 der.monitorCancelFunc = moniterCancelFunc 440 441 der.monitor.SetInstanceState(der.instanceState) 442 443 // 开始执行 444 der.executeTime = time.Now() 445 pcsutil.Trigger(der.onExecuteEvent) 446 der.downloadStatusEvent() // 启动执行状态处理事件 447 der.monitor.Execute(moniterCtx) 448 449 // 检查错误 450 err = der.monitor.Err() 451 if err == nil { // 成功 452 pcsutil.Trigger(der.onSuccessEvent) 453 if !single { 454 der.removeInstanceState() // 移除断点续传文件 455 } 456 } 457 458 // 执行结束 459 pcsutil.Trigger(der.onFinishEvent) 460 return err 461 } 462 463 //downloadStatusEvent 执行状态处理事件 464 func (der *Downloader) downloadStatusEvent() { 465 if der.onDownloadStatusEvent == nil { 466 return 467 } 468 469 status := der.monitor.Status() 470 go func() { 471 ticker := time.NewTicker(1 * time.Second) 472 defer ticker.Stop() 473 for { 474 select { 475 case <-der.monitor.completed: 476 return 477 case <-ticker.C: 478 der.onDownloadStatusEvent(status, der.monitor.RangeWorker) 479 } 480 } 481 }() 482 } 483 484 //Pause 暂停 485 func (der *Downloader) Pause() { 486 if der.monitor == nil { 487 return 488 } 489 pcsutil.Trigger(der.onPauseEvent) 490 der.monitor.Pause() 491 } 492 493 //Resume 恢复 494 func (der *Downloader) Resume() { 495 if der.monitor == nil { 496 return 497 } 498 pcsutil.Trigger(der.onResumeEvent) 499 der.monitor.Resume() 500 } 501 502 //Cancel 取消 503 func (der *Downloader) Cancel() { 504 if der.monitor == nil { 505 return 506 } 507 pcsutil.Trigger(der.onCancelEvent) 508 pcsutil.Trigger(der.monitorCancelFunc) 509 } 510 511 //OnExecute 设置开始下载事件 512 func (der *Downloader) OnExecute(onExecuteEvent requester.Event) { 513 der.onExecuteEvent = onExecuteEvent 514 } 515 516 //OnSuccess 设置成功下载事件 517 func (der *Downloader) OnSuccess(onSuccessEvent requester.Event) { 518 der.onSuccessEvent = onSuccessEvent 519 } 520 521 //OnFinish 设置结束下载事件 522 func (der *Downloader) OnFinish(onFinishEvent requester.Event) { 523 der.onFinishEvent = onFinishEvent 524 } 525 526 //OnPause 设置暂停下载事件 527 func (der *Downloader) OnPause(onPauseEvent requester.Event) { 528 der.onPauseEvent = onPauseEvent 529 } 530 531 //OnResume 设置恢复下载事件 532 func (der *Downloader) OnResume(onResumeEvent requester.Event) { 533 der.onResumeEvent = onResumeEvent 534 } 535 536 //OnCancel 设置取消下载事件 537 func (der *Downloader) OnCancel(onCancelEvent requester.Event) { 538 der.onCancelEvent = onCancelEvent 539 } 540 541 //OnDownloadStatusEvent 设置状态处理函数 542 func (der *Downloader) OnDownloadStatusEvent(f DownloadStatusFunc) { 543 der.onDownloadStatusEvent = f 544 }