github.com/cavaliergopher/grab/v3@v3.0.1/client.go (about) 1 package grab 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "net/http" 9 "os" 10 "path/filepath" 11 "sync" 12 "sync/atomic" 13 "time" 14 ) 15 16 // HTTPClient provides an interface allowing us to perform HTTP requests. 17 type HTTPClient interface { 18 Do(req *http.Request) (*http.Response, error) 19 } 20 21 // truncater is a private interface allowing different response 22 // Writers to be truncated 23 type truncater interface { 24 Truncate(size int64) error 25 } 26 27 // A Client is a file download client. 28 // 29 // Clients are safe for concurrent use by multiple goroutines. 30 type Client struct { 31 // HTTPClient specifies the http.Client which will be used for communicating 32 // with the remote server during the file transfer. 33 HTTPClient HTTPClient 34 35 // UserAgent specifies the User-Agent string which will be set in the 36 // headers of all requests made by this client. 37 // 38 // The user agent string may be overridden in the headers of each request. 39 UserAgent string 40 41 // BufferSize specifies the size in bytes of the buffer that is used for 42 // transferring all requested files. Larger buffers may result in faster 43 // throughput but will use more memory and result in less frequent updates 44 // to the transfer progress statistics. The BufferSize of each request can 45 // be overridden on each Request object. Default: 32KB. 46 BufferSize int 47 } 48 49 // NewClient returns a new file download Client, using default configuration. 50 func NewClient() *Client { 51 return &Client{ 52 UserAgent: "grab", 53 HTTPClient: &http.Client{ 54 Transport: &http.Transport{ 55 Proxy: http.ProxyFromEnvironment, 56 }, 57 }, 58 } 59 } 60 61 // DefaultClient is the default client and is used by all Get convenience 62 // functions. 63 var DefaultClient = NewClient() 64 65 // Do sends a file transfer request and returns a file transfer response, 66 // following policy (e.g. redirects, cookies, auth) as configured on the 67 // client's HTTPClient. 68 // 69 // Like http.Get, Do blocks while the transfer is initiated, but returns as soon 70 // as the transfer has started transferring in a background goroutine, or if it 71 // failed early. 72 // 73 // An error is returned via Response.Err if caused by client policy (such as 74 // CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err 75 // will block the caller until the transfer is completed, successfully or 76 // otherwise. 77 func (c *Client) Do(req *Request) *Response { 78 // cancel will be called on all code-paths via closeResponse 79 ctx, cancel := context.WithCancel(req.Context()) 80 req = req.WithContext(ctx) 81 resp := &Response{ 82 Request: req, 83 Start: time.Now(), 84 Done: make(chan struct{}, 0), 85 Filename: req.Filename, 86 ctx: ctx, 87 cancel: cancel, 88 bufferSize: req.BufferSize, 89 } 90 if resp.bufferSize == 0 { 91 // default to Client.BufferSize 92 resp.bufferSize = c.BufferSize 93 } 94 95 // Run state-machine while caller is blocked to initialize the file transfer. 96 // Must never transition to the copyFile state - this happens next in another 97 // goroutine. 98 c.run(resp, c.statFileInfo) 99 100 // Run copyFile in a new goroutine. copyFile will no-op if the transfer is 101 // already complete or failed. 102 go c.run(resp, c.copyFile) 103 return resp 104 } 105 106 // DoChannel executes all requests sent through the given Request channel, one 107 // at a time, until it is closed by another goroutine. The caller is blocked 108 // until the Request channel is closed and all transfers have completed. All 109 // responses are sent through the given Response channel as soon as they are 110 // received from the remote servers and can be used to track the progress of 111 // each download. 112 // 113 // Slow Response receivers will cause a worker to block and therefore delay the 114 // start of the transfer for an already initiated connection - potentially 115 // causing a server timeout. It is the caller's responsibility to ensure a 116 // sufficient buffer size is used for the Response channel to prevent this. 117 // 118 // If an error occurs during any of the file transfers it will be accessible via 119 // the associated Response.Err function. 120 func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) { 121 // TODO: enable cancelling of batch jobs 122 for req := range reqch { 123 resp := c.Do(req) 124 respch <- resp 125 <-resp.Done 126 } 127 } 128 129 // DoBatch executes all the given requests using the given number of concurrent 130 // workers. Control is passed back to the caller as soon as the workers are 131 // initiated. 132 // 133 // If the requested number of workers is less than one, a worker will be created 134 // for every request. I.e. all requests will be executed concurrently. 135 // 136 // If an error occurs during any of the file transfers it will be accessible via 137 // call to the associated Response.Err. 138 // 139 // The returned Response channel is closed only after all of the given Requests 140 // have completed, successfully or otherwise. 141 func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response { 142 if workers < 1 { 143 workers = len(requests) 144 } 145 reqch := make(chan *Request, len(requests)) 146 respch := make(chan *Response, len(requests)) 147 wg := sync.WaitGroup{} 148 for i := 0; i < workers; i++ { 149 wg.Add(1) 150 go func() { 151 c.DoChannel(reqch, respch) 152 wg.Done() 153 }() 154 } 155 156 // queue requests 157 go func() { 158 for _, req := range requests { 159 reqch <- req 160 } 161 close(reqch) 162 wg.Wait() 163 close(respch) 164 }() 165 return respch 166 } 167 168 // An stateFunc is an action that mutates the state of a Response and returns 169 // the next stateFunc to be called. 170 type stateFunc func(*Response) stateFunc 171 172 // run calls the given stateFunc function and all subsequent returned stateFuncs 173 // until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc 174 // should mutate the state of the given Response until it has completed 175 // downloading or failed. 176 func (c *Client) run(resp *Response, f stateFunc) { 177 for { 178 select { 179 case <-resp.ctx.Done(): 180 if resp.IsComplete() { 181 return 182 } 183 resp.err = resp.ctx.Err() 184 f = c.closeResponse 185 186 default: 187 // keep working 188 } 189 if f = f(resp); f == nil { 190 return 191 } 192 } 193 } 194 195 // statFileInfo retrieves FileInfo for any local file matching 196 // Response.Filename. 197 // 198 // If the file does not exist, is a directory, or its name is unknown the next 199 // stateFunc is headRequest. 200 // 201 // If the file exists, Response.fi is set and the next stateFunc is 202 // validateLocal. 203 // 204 // If an error occurs, the next stateFunc is closeResponse. 205 func (c *Client) statFileInfo(resp *Response) stateFunc { 206 if resp.Request.NoStore || resp.Filename == "" { 207 return c.headRequest 208 } 209 fi, err := os.Stat(resp.Filename) 210 if err != nil { 211 if os.IsNotExist(err) { 212 return c.headRequest 213 } 214 resp.err = err 215 return c.closeResponse 216 } 217 if fi.IsDir() { 218 resp.Filename = "" 219 return c.headRequest 220 } 221 resp.fi = fi 222 return c.validateLocal 223 } 224 225 // validateLocal compares a local copy of the downloaded file to the remote 226 // file. 227 // 228 // An error is returned if the local file is larger than the remote file, or 229 // Request.SkipExisting is true. 230 // 231 // If the existing file matches the length of the remote file, the next 232 // stateFunc is checksumFile. 233 // 234 // If the local file is smaller than the remote file and the remote server is 235 // known to support ranged requests, the next stateFunc is getRequest. 236 func (c *Client) validateLocal(resp *Response) stateFunc { 237 if resp.Request.SkipExisting { 238 resp.err = ErrFileExists 239 return c.closeResponse 240 } 241 242 // determine target file size 243 expectedSize := resp.Request.Size 244 if expectedSize == 0 && resp.HTTPResponse != nil { 245 expectedSize = resp.HTTPResponse.ContentLength 246 } 247 248 if expectedSize == 0 { 249 // size is either actually 0 or unknown 250 // if unknown, we ask the remote server 251 // if known to be 0, we proceed with a GET 252 return c.headRequest 253 } 254 255 if expectedSize == resp.fi.Size() { 256 // local file matches remote file size - wrap it up 257 resp.DidResume = true 258 resp.bytesResumed = resp.fi.Size() 259 return c.checksumFile 260 } 261 262 if resp.Request.NoResume { 263 // local file should be overwritten 264 return c.getRequest 265 } 266 267 if expectedSize >= 0 && expectedSize < resp.fi.Size() { 268 // remote size is known, is smaller than local size and we want to resume 269 resp.err = ErrBadLength 270 return c.closeResponse 271 } 272 273 if resp.CanResume { 274 // set resume range on GET request 275 resp.Request.HTTPRequest.Header.Set( 276 "Range", 277 fmt.Sprintf("bytes=%d-", resp.fi.Size())) 278 resp.DidResume = true 279 resp.bytesResumed = resp.fi.Size() 280 return c.getRequest 281 } 282 return c.headRequest 283 } 284 285 func (c *Client) checksumFile(resp *Response) stateFunc { 286 if resp.Request.hash == nil { 287 return c.closeResponse 288 } 289 if resp.Filename == "" { 290 panic("grab: developer error: filename not set") 291 } 292 if resp.Size() < 0 { 293 panic("grab: developer error: size unknown") 294 } 295 req := resp.Request 296 297 // compute checksum 298 var sum []byte 299 sum, resp.err = resp.checksumUnsafe() 300 if resp.err != nil { 301 return c.closeResponse 302 } 303 304 // compare checksum 305 if !bytes.Equal(sum, req.checksum) { 306 resp.err = ErrBadChecksum 307 if !resp.Request.NoStore && req.deleteOnError { 308 if err := os.Remove(resp.Filename); err != nil { 309 // err should be os.PathError and include file path 310 resp.err = fmt.Errorf( 311 "cannot remove downloaded file with checksum mismatch: %v", 312 err) 313 } 314 } 315 } 316 return c.closeResponse 317 } 318 319 // doHTTPRequest sends a HTTP Request and returns the response 320 func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) { 321 if c.UserAgent != "" && req.Header.Get("User-Agent") == "" { 322 req.Header.Set("User-Agent", c.UserAgent) 323 } 324 return c.HTTPClient.Do(req) 325 } 326 327 func (c *Client) headRequest(resp *Response) stateFunc { 328 if resp.optionsKnown { 329 return c.getRequest 330 } 331 resp.optionsKnown = true 332 333 if resp.Request.NoResume { 334 return c.getRequest 335 } 336 337 if resp.Filename != "" && resp.fi == nil { 338 // destination path is already known and does not exist 339 return c.getRequest 340 } 341 342 hreq := new(http.Request) 343 *hreq = *resp.Request.HTTPRequest 344 hreq.Method = "HEAD" 345 346 resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq) 347 if resp.err != nil { 348 return c.closeResponse 349 } 350 resp.HTTPResponse.Body.Close() 351 352 if resp.HTTPResponse.StatusCode != http.StatusOK { 353 return c.getRequest 354 } 355 356 // In case of redirects during HEAD, record the final URL and use it 357 // instead of the original URL when sending future requests. 358 // This way we avoid sending potentially unsupported requests to 359 // the original URL, e.g. "Range", since it was the final URL 360 // that advertised its support. 361 resp.Request.HTTPRequest.URL = resp.HTTPResponse.Request.URL 362 resp.Request.HTTPRequest.Host = resp.HTTPResponse.Request.Host 363 364 return c.readResponse 365 } 366 367 func (c *Client) getRequest(resp *Response) stateFunc { 368 resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest) 369 if resp.err != nil { 370 return c.closeResponse 371 } 372 373 // TODO: check Content-Range 374 375 // check status code 376 if !resp.Request.IgnoreBadStatusCodes { 377 if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 { 378 resp.err = StatusCodeError(resp.HTTPResponse.StatusCode) 379 return c.closeResponse 380 } 381 } 382 383 return c.readResponse 384 } 385 386 func (c *Client) readResponse(resp *Response) stateFunc { 387 if resp.HTTPResponse == nil { 388 panic("grab: developer error: Response.HTTPResponse is nil") 389 } 390 391 // check expected size 392 resp.sizeUnsafe = resp.HTTPResponse.ContentLength 393 if resp.sizeUnsafe >= 0 { 394 // remote size is known 395 resp.sizeUnsafe += resp.bytesResumed 396 if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe { 397 resp.err = ErrBadLength 398 return c.closeResponse 399 } 400 } 401 402 // check filename 403 if resp.Filename == "" { 404 filename, err := guessFilename(resp.HTTPResponse) 405 if err != nil { 406 resp.err = err 407 return c.closeResponse 408 } 409 // Request.Filename will be empty or a directory 410 resp.Filename = filepath.Join(resp.Request.Filename, filename) 411 } 412 413 if !resp.Request.NoStore && resp.requestMethod() == "HEAD" { 414 if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" { 415 resp.CanResume = true 416 } 417 return c.statFileInfo 418 } 419 return c.openWriter 420 } 421 422 // openWriter opens the destination file for writing and seeks to the location 423 // from whence the file transfer will resume. 424 // 425 // Requires that Response.Filename and resp.DidResume are already be set. 426 func (c *Client) openWriter(resp *Response) stateFunc { 427 if !resp.Request.NoStore && !resp.Request.NoCreateDirectories { 428 resp.err = mkdirp(resp.Filename) 429 if resp.err != nil { 430 return c.closeResponse 431 } 432 } 433 434 if resp.Request.NoStore { 435 resp.writer = &resp.storeBuffer 436 } else { 437 // compute write flags 438 flag := os.O_CREATE | os.O_WRONLY 439 if resp.fi != nil { 440 if resp.DidResume { 441 flag = os.O_APPEND | os.O_WRONLY 442 } else { 443 // truncate later in copyFile, if not cancelled 444 // by BeforeCopy hook 445 flag = os.O_WRONLY 446 } 447 } 448 449 // open file 450 f, err := os.OpenFile(resp.Filename, flag, 0666) 451 if err != nil { 452 resp.err = err 453 return c.closeResponse 454 } 455 resp.writer = f 456 457 // seek to start or end 458 whence := os.SEEK_SET 459 if resp.bytesResumed > 0 { 460 whence = os.SEEK_END 461 } 462 _, resp.err = f.Seek(0, whence) 463 if resp.err != nil { 464 return c.closeResponse 465 } 466 } 467 468 // init transfer 469 if resp.bufferSize < 1 { 470 resp.bufferSize = 32 * 1024 471 } 472 b := make([]byte, resp.bufferSize) 473 resp.transfer = newTransfer( 474 resp.Request.Context(), 475 resp.Request.RateLimiter, 476 resp.writer, 477 resp.HTTPResponse.Body, 478 b) 479 480 // next step is copyFile, but this will be called later in another goroutine 481 return nil 482 } 483 484 // copy transfers content for a HTTP connection established via Client.do() 485 func (c *Client) copyFile(resp *Response) stateFunc { 486 if resp.IsComplete() { 487 return nil 488 } 489 490 // run BeforeCopy hook 491 if f := resp.Request.BeforeCopy; f != nil { 492 resp.err = f(resp) 493 if resp.err != nil { 494 return c.closeResponse 495 } 496 } 497 498 var bytesCopied int64 499 if resp.transfer == nil { 500 panic("grab: developer error: Response.transfer is nil") 501 } 502 503 // We waited to truncate the file in openWriter() to make sure 504 // the BeforeCopy didn't cancel the copy. If this was an existing 505 // file that is not going to be resumed, truncate the contents. 506 if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume { 507 t.Truncate(0) 508 } 509 510 bytesCopied, resp.err = resp.transfer.copy() 511 if resp.err != nil { 512 return c.closeResponse 513 } 514 closeWriter(resp) 515 516 // set file timestamp 517 if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime { 518 resp.err = setLastModified(resp.HTTPResponse, resp.Filename) 519 if resp.err != nil { 520 return c.closeResponse 521 } 522 } 523 524 // update transfer size if previously unknown 525 if resp.Size() < 0 { 526 discoveredSize := resp.bytesResumed + bytesCopied 527 atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize) 528 if resp.Request.Size > 0 && resp.Request.Size != discoveredSize { 529 resp.err = ErrBadLength 530 return c.closeResponse 531 } 532 } 533 534 // run AfterCopy hook 535 if f := resp.Request.AfterCopy; f != nil { 536 resp.err = f(resp) 537 if resp.err != nil { 538 return c.closeResponse 539 } 540 } 541 542 return c.checksumFile 543 } 544 545 func closeWriter(resp *Response) { 546 if closer, ok := resp.writer.(io.Closer); ok { 547 closer.Close() 548 } 549 resp.writer = nil 550 } 551 552 // close finalizes the Response 553 func (c *Client) closeResponse(resp *Response) stateFunc { 554 if resp.IsComplete() { 555 panic("grab: developer error: response already closed") 556 } 557 558 resp.fi = nil 559 closeWriter(resp) 560 resp.closeResponseBody() 561 562 resp.End = time.Now() 563 close(resp.Done) 564 if resp.cancel != nil { 565 resp.cancel() 566 } 567 568 return nil 569 }