github.com/qjfoidnh/BaiduPCS-Go@v0.0.0-20231011165705-caa18a3765f3/requester/downloader/downloader.go (about) 1 // Package downloader 多线程下载器, 重构版 2 package downloader 3 4 import ( 5 "context" 6 "errors" 7 "github.com/qjfoidnh/BaiduPCS-Go/pcsutil" 8 "github.com/qjfoidnh/BaiduPCS-Go/pcsutil/cachepool" 9 "github.com/qjfoidnh/BaiduPCS-Go/pcsutil/converter" 10 "github.com/qjfoidnh/BaiduPCS-Go/pcsutil/prealloc" 11 "github.com/qjfoidnh/BaiduPCS-Go/pcsutil/waitgroup" 12 "github.com/qjfoidnh/BaiduPCS-Go/pcsverbose" 13 "github.com/qjfoidnh/BaiduPCS-Go/requester" 14 "github.com/qjfoidnh/BaiduPCS-Go/requester/rio/speeds" 15 "github.com/qjfoidnh/BaiduPCS-Go/requester/transfer" 16 "io" 17 "net/http" 18 "sync" 19 "time" 20 ) 21 22 const ( 23 // DefaultAcceptRanges 默认的 Accept-Ranges 24 DefaultAcceptRanges = "bytes" 25 ) 26 27 var BlockSizeList = [6]int64{128*converter.KB, 256*converter.KB, 1024*converter.KB, 2*converter.MB, 4*converter.MB, 999*converter.GB} 28 29 type ( 30 // Downloader 下载 31 Downloader struct { 32 onExecuteEvent requester.Event //开始下载事件 33 onSuccessEvent requester.Event //成功下载事件 34 onFinishEvent requester.Event //结束下载事件 35 onPauseEvent requester.Event //暂停下载事件 36 onResumeEvent requester.Event //恢复下载事件 37 onCancelEvent requester.Event //取消下载事件 38 onDownloadStatusEvent DownloadStatusFunc //状态处理事件 39 40 monitorCancelFunc context.CancelFunc 41 42 firstInfo *DownloadFirstInfo // 初始信息 43 loadBalancerCompareFunc LoadBalancerCompareFunc // 负载均衡检测函数 44 durlCheckFunc DURLCheckFunc // 下载url检测函数 45 statusCodeBodyCheckFunc StatusCodeBodyCheckFunc 46 executeTime time.Time 47 durl string 48 loadBalansers []string 49 writer io.WriterAt 50 client *requester.HTTPClient 51 config *Config 52 monitor *Monitor 53 instanceState *InstanceState 54 } 55 56 // DURLCheckFunc 下载URL检测函数 57 DURLCheckFunc func(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) 58 // StatusCodeBodyCheckFunc 响应状态码出错的检查函数 59 StatusCodeBodyCheckFunc func(respBody io.Reader) error 60 ) 61 62 //NewDownloader 初始化Downloader 63 func NewDownloader(durl string, writer io.WriterAt, config *Config) (der *Downloader) { 64 der = &Downloader{ 65 durl: durl, 66 config: config, 67 writer: writer, 68 } 69 70 return 71 } 72 73 // SetFirstInfo 设置初始信息 74 // 如果设置了此值, 将忽略检测url 75 func (der *Downloader) SetFirstInfo(i *DownloadFirstInfo) { 76 der.firstInfo = i 77 } 78 79 //SetClient 设置http客户端 80 func (der *Downloader) SetClient(client *requester.HTTPClient) { 81 der.client = client 82 } 83 84 // SetDURLCheckFunc 设置下载URL检测函数 85 func (der *Downloader) SetDURLCheckFunc(f DURLCheckFunc) { 86 der.durlCheckFunc = f 87 } 88 89 func (der *Downloader) SetFileContentLength(length int64) { 90 if der.firstInfo == nil { 91 der.firstInfo = &DownloadFirstInfo{ 92 ContentLength: length, 93 AcceptRanges : DefaultAcceptRanges, 94 } 95 } 96 } 97 // SetLoadBalancerCompareFunc 设置负载均衡检测函数 98 func (der *Downloader) SetLoadBalancerCompareFunc(f LoadBalancerCompareFunc) { 99 der.loadBalancerCompareFunc = f 100 } 101 102 //SetStatusCodeBodyCheckFunc 设置响应状态码出错的检查函数, 当FirstCheckMethod不为HEAD时才有效 103 func (der *Downloader) SetStatusCodeBodyCheckFunc(f StatusCodeBodyCheckFunc) { 104 der.statusCodeBodyCheckFunc = f 105 } 106 107 func (der *Downloader) lazyInit() { 108 if der.config == nil { 109 der.config = NewConfig() 110 } 111 if der.client == nil { 112 der.client = requester.NewHTTPClient() 113 der.client.SetTimeout(5 * time.Minute) 114 } 115 if der.monitor == nil { 116 der.monitor = NewMonitor() 117 } 118 if der.durlCheckFunc == nil { 119 der.durlCheckFunc = DefaultDURLCheckFunc 120 } 121 if der.loadBalancerCompareFunc == nil { 122 der.loadBalancerCompareFunc = DefaultLoadBalancerCompareFunc 123 } 124 } 125 126 // SelectParallel 获取合适的 parallel 127 func (der *Downloader) SelectParallel(single bool, maxParallel int, totalSize int64, instanceRangeList transfer.RangeList) (parallel int) { 128 isRange := instanceRangeList != nil && len(instanceRangeList) > 0 129 if single { //不支持多线程 130 parallel = 1 131 } else if isRange { 132 parallel = len(instanceRangeList) 133 } else { 134 parallel = der.config.MaxParallel 135 if int64(parallel) > totalSize/int64(MinParallelSize) { 136 parallel = int(totalSize/int64(MinParallelSize)) + 1 137 } 138 } 139 140 if parallel < 1 { 141 parallel = 1 142 } 143 return 144 } 145 146 // SelectBlockSizeAndInitRangeGen 获取合适的 BlockSize, 和初始化 RangeGen 147 func (der *Downloader) SelectBlockSizeAndInitRangeGen(single bool, status *transfer.DownloadStatus, parallel int) (blockSize int64, initErr error) { 148 // Range 生成器 149 if single { // 单线程 150 blockSize = -1 151 return 152 } 153 gen := status.RangeListGen() 154 if gen == nil { 155 switch der.config.Mode { 156 case transfer.RangeGenMode_Default: 157 gen = transfer.NewRangeListGenDefault(status.TotalSize(), 0, 0, parallel) 158 blockSize = gen.LoadBlockSize() 159 case transfer.RangeGenMode_BlockSize: 160 //b2 := status.TotalSize()/int64(parallel) + 1 161 //if b2 > der.config.BlockSize { // 选小的BlockSize, 以更高并发 162 // blockSize = der.config.BlockSize 163 //} else { 164 // blockSize = b2 165 //} 166 totalSize := status.TotalSize() 167 if totalSize < 2 * converter.MB { 168 blockSize = BlockSizeList[1] 169 } else if totalSize < 10 * converter.MB { 170 blockSize = BlockSizeList[2] 171 } else if totalSize < 80 * converter.MB { 172 blockSize = BlockSizeList[3] 173 } else { 174 blockSize = BlockSizeList[4] 175 } 176 gen = transfer.NewRangeListGenBlockSize(totalSize, 0, blockSize) 177 default: 178 initErr = transfer.ErrUnknownRangeGenMode 179 return 180 } 181 } else { 182 blockSize = gen.LoadBlockSize() 183 } 184 status.SetRangeListGen(gen) 185 return 186 } 187 188 // SelectCacheSize 获取合适的 cacheSize 189 func (der *Downloader) SelectCacheSize(confCacheSize int, blockSize int64) (cacheSize int) { 190 if blockSize > 0 && int64(confCacheSize) > blockSize { 191 // 如果 cache size 过高, 则调低 192 cacheSize = int(blockSize) 193 } else { 194 cacheSize = confCacheSize 195 } 196 return 197 } 198 199 // DefaultDURLCheckFunc 默认的 DURLCheckFunc 200 func DefaultDURLCheckFunc(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) { 201 resp, err = client.Req(http.MethodGet, durl, nil, nil) 202 if err != nil { 203 if resp != nil { 204 resp.Body.Close() 205 } 206 return 0, nil, err 207 } 208 return resp.ContentLength, resp, nil 209 } 210 211 func (der *Downloader) checkLoadBalancers() *LoadBalancerResponseList { 212 var ( 213 loadBalancerResponses = make([]*LoadBalancerResponse, 0, len(der.loadBalansers)+1) 214 handleLoadBalancer = func(req *http.Request) { 215 if req == nil { 216 return 217 } 218 219 if der.config.TryHTTP { 220 req.URL.Scheme = "http" 221 } 222 223 loadBalancer := &LoadBalancerResponse{ 224 URL: req.URL.String(), 225 Referer: req.Referer(), 226 } 227 228 loadBalancerResponses = append(loadBalancerResponses, loadBalancer) 229 pcsverbose.Verbosef("DEBUG: load balance task: URL: %s, Referer: %s\n", loadBalancer.URL, loadBalancer.Referer) 230 } 231 ) 232 233 // 加入第一个 234 loadBalancerResponses = append(loadBalancerResponses, &LoadBalancerResponse{ 235 URL: der.durl, 236 }) 237 238 // 多下载服务器的负载均衡, 在本项目中无意义,因为百度只用单下载服务器 239 wg := waitgroup.NewWaitGroup(4) 240 privTimeout := der.client.Client.Timeout 241 der.client.SetTimeout(5 * time.Second) 242 for _, loadBalanser := range der.loadBalansers { // 这里服务器列表数量为0,所以逻辑实际未生效 243 wg.AddDelta() 244 go func(loadBalanser string) { 245 defer wg.Done() 246 247 subContentLength, subResp, subErr := der.durlCheckFunc(der.client, loadBalanser) 248 if subResp != nil { 249 subResp.Body.Close() // 不读Body, 马上关闭连接 250 } 251 if subErr != nil { 252 pcsverbose.Verbosef("DEBUG: loadBalanser Error: %s\n", subErr) 253 return 254 } 255 256 // 检测状态码 257 switch subResp.StatusCode / 100 { 258 case 2: // succeed 259 case 4, 5: // error 260 var err error 261 if der.statusCodeBodyCheckFunc != nil { 262 err = der.statusCodeBodyCheckFunc(subResp.Body) 263 } else { 264 err = errors.New(subResp.Status) 265 } 266 pcsverbose.Verbosef("DEBUG: loadBalanser Status Error: %s\n", err) 267 return 268 } 269 270 // 检测长度 271 if der.firstInfo.ContentLength != subContentLength { 272 pcsverbose.Verbosef("DEBUG: loadBalanser Content-Length not equal to main server\n") 273 return 274 } 275 276 if !der.loadBalancerCompareFunc(der.firstInfo.ToMap(), subResp) { 277 pcsverbose.Verbosef("DEBUG: loadBalanser not equal to main server\n") 278 return 279 } 280 281 handleLoadBalancer(subResp.Request) 282 }(loadBalanser) 283 } 284 wg.Wait() 285 der.client.SetTimeout(privTimeout) 286 287 loadBalancerResponseList := NewLoadBalancerResponseList(loadBalancerResponses) 288 return loadBalancerResponseList 289 } 290 291 //Execute 开始任务 292 func (der *Downloader) Execute() error { 293 der.lazyInit() 294 var ( 295 resp *http.Response 296 ) 297 if der.firstInfo == nil { 298 // 检测 299 contentLength, resp, err := der.durlCheckFunc(der.client, der.durl) 300 if err != nil { 301 return err 302 } 303 304 // 检测网络错误 305 switch resp.StatusCode / 100 { 306 case 2: // succeed 307 case 4, 5: // error 308 if der.statusCodeBodyCheckFunc != nil { 309 err = der.statusCodeBodyCheckFunc(resp.Body) 310 resp.Body.Close() // 关闭连接 311 if err != nil { 312 return err 313 } 314 } 315 return errors.New(resp.Status) 316 } 317 318 acceptRanges := resp.Header.Get("Accept-Ranges") 319 if contentLength < 0 { 320 acceptRanges = "" 321 } else { 322 acceptRanges = DefaultAcceptRanges 323 } 324 325 // 初始化firstInfo 326 der.firstInfo = &DownloadFirstInfo{ 327 ContentLength: contentLength, 328 ContentMD5: resp.Header.Get("Content-MD5"), 329 ContentCRC32: resp.Header.Get("x-bs-meta-crc32"), 330 AcceptRanges: acceptRanges, 331 Referer: resp.Header.Get("Referer"), 332 } 333 pcsverbose.Verbosef("DEBUG: download task: URL: %s, Referer: %s\n", resp.Request.URL, resp.Request.Referer()) 334 } else { 335 if der.firstInfo.AcceptRanges == "" { 336 der.firstInfo.AcceptRanges = DefaultAcceptRanges 337 } 338 } 339 340 var ( 341 loadBalancerResponseList = der.checkLoadBalancers() 342 single = der.firstInfo.AcceptRanges == "" 343 bii *transfer.DownloadInstanceInfo 344 ) 345 346 if !single { 347 //load breakpoint 348 //服务端不支持多线程时, 不记录断点 349 err := der.initInstanceState(der.config.InstanceStateStorageFormat) 350 if err != nil { 351 return err 352 } 353 bii = der.instanceState.Get() 354 } 355 356 var ( 357 isInstance = bii != nil // 是否存在断点信息 358 status *transfer.DownloadStatus 359 ) 360 if !isInstance { 361 bii = &transfer.DownloadInstanceInfo{} 362 } 363 364 if bii.DownloadStatus != nil { 365 // 使用断点信息的状态 366 status = bii.DownloadStatus 367 } else { 368 // 新建状态 369 status = transfer.NewDownloadStatus() 370 status.SetTotalSize(der.firstInfo.ContentLength) 371 } 372 373 // 设置限速 374 if der.config.MaxRate > 0 { 375 rl := speeds.NewRateLimit(der.config.MaxRate) 376 status.SetRateLimit(rl) 377 defer rl.Stop() 378 } 379 380 // 数据处理 381 parallel := der.SelectParallel(single, der.config.MaxParallel, status.TotalSize(), bii.Ranges) // 实际的下载并行量 382 blockSize, err := der.SelectBlockSizeAndInitRangeGen(single, status, parallel) // 实际的BlockSize 383 if err != nil { 384 return err 385 } 386 387 cacheSize := der.SelectCacheSize(der.config.CacheSize, blockSize) // 实际下载缓存 388 cachepool.SetSyncPoolSize(cacheSize) // 调整pool大小 389 390 pcsverbose.Verbosef("DEBUG: download task CREATED: parallel: %d, cache size: %d\n", parallel, cacheSize) 391 392 der.monitor.InitMonitorCapacity(parallel) 393 394 var writer Writer 395 if !der.config.IsTest { 396 // 尝试修剪文件 397 if fder, ok := der.writer.(Fder); ok { 398 err = prealloc.PreAlloc(fder.Fd(), status.TotalSize()) 399 if err != nil { 400 pcsverbose.Verbosef("DEBUG: truncate file error: %s\n", err) 401 } 402 } 403 writer = der.writer // 非测试模式, 赋值writer 404 } 405 406 // 数据平均分配给各个线程 407 isRange := bii.Ranges != nil && len(bii.Ranges) > 0 408 if !isRange { 409 // 没有使用断点续传 410 // 分配线程 411 bii.Ranges = make(transfer.RangeList, 0, parallel) 412 if single { // 单线程 413 bii.Ranges = append(bii.Ranges, &transfer.Range{}) 414 } else { 415 gen := status.RangeListGen() 416 for i := 0; i < cap(bii.Ranges); i++ { 417 _, r := gen.GenRange() 418 if r == nil { // 没有了(不正常) 419 break 420 } 421 bii.Ranges = append(bii.Ranges, r) 422 } 423 } 424 } 425 426 var ( 427 writeMu = &sync.Mutex{} 428 ) 429 for k, r := range bii.Ranges { 430 loadBalancer := loadBalancerResponseList.SequentialGet() 431 if loadBalancer == nil { 432 continue 433 } 434 435 worker := NewWorker(k, loadBalancer.URL, writer) 436 worker.SetClient(der.client) 437 worker.SetWriteMutex(writeMu) 438 worker.SetReferer(loadBalancer.Referer) 439 worker.SetTotalSize(der.firstInfo.ContentLength) 440 441 // 使用第一个连接 442 // 断点续传时不使用 443 if k == 0 && !isInstance { 444 worker.firstResp = resp 445 } 446 447 worker.SetAcceptRange(der.firstInfo.AcceptRanges) 448 worker.SetRange(r) // 分配Range 449 der.monitor.Append(worker) 450 } 451 452 der.monitor.SetStatus(status) 453 454 // 服务器不支持断点续传, 或者单线程下载, 都不重载worker 455 der.monitor.SetReloadWorker(parallel > 1) 456 457 moniterCtx, moniterCancelFunc := context.WithCancel(context.Background()) 458 der.monitorCancelFunc = moniterCancelFunc 459 460 der.monitor.SetInstanceState(der.instanceState) 461 462 // 开始执行 463 der.executeTime = time.Now() 464 pcsutil.Trigger(der.onExecuteEvent) 465 der.downloadStatusEvent() // 启动执行状态处理事件 466 der.monitor.Execute(moniterCtx) 467 468 // 检查错误 469 err = der.monitor.Err() 470 if err == nil { // 成功 471 pcsutil.Trigger(der.onSuccessEvent) 472 if !single { 473 der.removeInstanceState() // 移除断点续传文件 474 } 475 } 476 477 // 执行结束 478 pcsutil.Trigger(der.onFinishEvent) 479 return err 480 } 481 482 //downloadStatusEvent 执行状态处理事件 483 func (der *Downloader) downloadStatusEvent() { 484 if der.onDownloadStatusEvent == nil { 485 return 486 } 487 488 status := der.monitor.Status() 489 go func() { 490 ticker := time.NewTicker(1 * time.Second) 491 defer ticker.Stop() 492 for { 493 select { 494 case <-der.monitor.completed: 495 return 496 case <-ticker.C: 497 der.onDownloadStatusEvent(status, der.monitor.RangeWorker) 498 } 499 } 500 }() 501 } 502 503 //Pause 暂停 504 func (der *Downloader) Pause() { 505 if der.monitor == nil { 506 return 507 } 508 pcsutil.Trigger(der.onPauseEvent) 509 der.monitor.Pause() 510 } 511 512 //Resume 恢复 513 func (der *Downloader) Resume() { 514 if der.monitor == nil { 515 return 516 } 517 pcsutil.Trigger(der.onResumeEvent) 518 der.monitor.Resume() 519 } 520 521 //Cancel 取消 522 func (der *Downloader) Cancel() { 523 if der.monitor == nil { 524 return 525 } 526 pcsutil.Trigger(der.onCancelEvent) 527 pcsutil.Trigger(der.monitorCancelFunc) 528 } 529 530 //OnExecute 设置开始下载事件 531 func (der *Downloader) OnExecute(onExecuteEvent requester.Event) { 532 der.onExecuteEvent = onExecuteEvent 533 } 534 535 //OnSuccess 设置成功下载事件 536 func (der *Downloader) OnSuccess(onSuccessEvent requester.Event) { 537 der.onSuccessEvent = onSuccessEvent 538 } 539 540 //OnFinish 设置结束下载事件 541 func (der *Downloader) OnFinish(onFinishEvent requester.Event) { 542 der.onFinishEvent = onFinishEvent 543 } 544 545 //OnPause 设置暂停下载事件 546 func (der *Downloader) OnPause(onPauseEvent requester.Event) { 547 der.onPauseEvent = onPauseEvent 548 } 549 550 //OnResume 设置恢复下载事件 551 func (der *Downloader) OnResume(onResumeEvent requester.Event) { 552 der.onResumeEvent = onResumeEvent 553 } 554 555 //OnCancel 设置取消下载事件 556 func (der *Downloader) OnCancel(onCancelEvent requester.Event) { 557 der.onCancelEvent = onCancelEvent 558 } 559 560 //OnDownloadStatusEvent 设置状态处理函数 561 func (der *Downloader) OnDownloadStatusEvent(f DownloadStatusFunc) { 562 der.onDownloadStatusEvent = f 563 }