github.com/aavshr/aws-sdk-go@v1.41.3/service/s3/s3manager/download_test.go (about) 1 //go:build go1.7 2 // +build go1.7 3 4 package s3manager_test 5 6 import ( 7 "bytes" 8 "encoding/xml" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net/http" 13 "reflect" 14 "regexp" 15 "strconv" 16 "strings" 17 "sync" 18 "sync/atomic" 19 "testing" 20 "time" 21 22 "github.com/aavshr/aws-sdk-go/aws" 23 "github.com/aavshr/aws-sdk-go/aws/awserr" 24 "github.com/aavshr/aws-sdk-go/aws/request" 25 "github.com/aavshr/aws-sdk-go/awstesting" 26 "github.com/aavshr/aws-sdk-go/awstesting/unit" 27 "github.com/aavshr/aws-sdk-go/internal/sdkio" 28 "github.com/aavshr/aws-sdk-go/service/s3" 29 "github.com/aavshr/aws-sdk-go/service/s3/internal/s3testing" 30 "github.com/aavshr/aws-sdk-go/service/s3/s3manager" 31 ) 32 33 func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) { 34 var m sync.Mutex 35 names := []string{} 36 ranges := []string{} 37 38 svc := s3.New(unit.Session) 39 svc.Handlers.Send.Clear() 40 svc.Handlers.Send.PushBack(func(r *request.Request) { 41 m.Lock() 42 defer m.Unlock() 43 44 names = append(names, r.Operation.Name) 45 ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range) 46 47 rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`) 48 rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range")) 49 start, _ := strconv.ParseInt(rng[1], 10, 64) 50 fin, _ := strconv.ParseInt(rng[2], 10, 64) 51 fin++ 52 53 if fin > int64(len(data)) { 54 fin = int64(len(data)) 55 } 56 57 bodyBytes := data[start:fin] 58 r.HTTPResponse = &http.Response{ 59 StatusCode: 200, 60 Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)), 61 Header: http.Header{}, 62 } 63 r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", 64 start, fin-1, len(data))) 65 r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes))) 66 }) 67 68 return svc, &names, &ranges 69 } 70 71 func dlLoggingSvcNoChunk(data []byte) (*s3.S3, *[]string) { 72 var m sync.Mutex 73 names := []string{} 74 75 svc := s3.New(unit.Session) 76 svc.Handlers.Send.Clear() 77 svc.Handlers.Send.PushBack(func(r *request.Request) { 78 m.Lock() 79 defer m.Unlock() 80 81 names = append(names, r.Operation.Name) 82 83 r.HTTPResponse = &http.Response{ 84 StatusCode: 200, 85 Body: ioutil.NopCloser(bytes.NewReader(data[:])), 86 Header: http.Header{}, 87 } 88 r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(data))) 89 }) 90 91 return svc, &names 92 } 93 94 func dlLoggingSvcNoContentRangeLength(data []byte, states []int) (*s3.S3, *[]string) { 95 var m sync.Mutex 96 names := []string{} 97 var index int 98 99 svc := s3.New(unit.Session) 100 svc.Handlers.Send.Clear() 101 svc.Handlers.Send.PushBack(func(r *request.Request) { 102 m.Lock() 103 defer m.Unlock() 104 105 names = append(names, r.Operation.Name) 106 107 var body io.Reader 108 if states[index] < 400 { 109 body = bytes.NewReader(data[:]) 110 } else { 111 var buffer bytes.Buffer 112 encoder := xml.NewEncoder(&buffer) 113 _ = encoder.Encode(&mockErrorResponse) 114 body = &buffer 115 } 116 117 r.HTTPResponse = &http.Response{ 118 StatusCode: states[index], 119 Body: ioutil.NopCloser(body), 120 Header: http.Header{}, 121 } 122 index++ 123 }) 124 125 return svc, &names 126 } 127 128 func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.S3, *[]string) { 129 var m sync.Mutex 130 names := []string{} 131 ranges := []string{} 132 var index int 133 134 svc := s3.New(unit.Session) 135 svc.Handlers.Send.Clear() 136 svc.Handlers.Send.PushBack(func(r *request.Request) { 137 m.Lock() 138 defer m.Unlock() 139 140 names = append(names, r.Operation.Name) 141 ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range) 142 143 rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`) 144 rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range")) 145 start, _ := strconv.ParseInt(rng[1], 10, 64) 146 fin, _ := strconv.ParseInt(rng[2], 10, 64) 147 fin++ 148 149 if fin >= int64(len(data)) { 150 fin = int64(len(data)) 151 } 152 153 // Setting start and finish to 0 because this state of 1 is suppose to 154 // be an error state of 416 155 if index == len(states)-1 { 156 start = 0 157 fin = 0 158 } 159 160 bodyBytes := data[start:fin] 161 162 r.HTTPResponse = &http.Response{ 163 StatusCode: states[index], 164 Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)), 165 Header: http.Header{}, 166 } 167 r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/*", 168 start, fin-1)) 169 index++ 170 }) 171 172 return svc, &names 173 } 174 175 func dlLoggingSvcWithErrReader(cases []testErrReader) (*s3.S3, *[]string) { 176 var m sync.Mutex 177 names := []string{} 178 var index int 179 180 svc := s3.New(unit.Session, &aws.Config{ 181 MaxRetries: aws.Int(len(cases) - 1), 182 }) 183 svc.Handlers.Send.Clear() 184 svc.Handlers.Send.PushBack(func(r *request.Request) { 185 m.Lock() 186 defer m.Unlock() 187 188 names = append(names, r.Operation.Name) 189 190 c := cases[index] 191 192 r.HTTPResponse = &http.Response{ 193 StatusCode: http.StatusOK, 194 Body: ioutil.NopCloser(&c), 195 Header: http.Header{}, 196 } 197 r.HTTPResponse.Header.Set("Content-Range", 198 fmt.Sprintf("bytes %d-%d/%d", 0, c.Len-1, c.Len)) 199 r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", c.Len)) 200 index++ 201 }) 202 203 return svc, &names 204 } 205 206 func TestDownloadOrder(t *testing.T) { 207 s, names, ranges := dlLoggingSvc(buf12MB) 208 209 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 210 d.Concurrency = 1 211 }) 212 213 w := aws.NewWriteAtBuffer(make([]byte, len(buf12MB))) 214 n, err := d.Download(w, &s3.GetObjectInput{ 215 Bucket: aws.String("bucket"), 216 Key: aws.String("key"), 217 }) 218 219 if err != nil { 220 t.Fatalf("expect no error, got %v", err) 221 } 222 if e, a := int64(len(buf12MB)), n; e != a { 223 t.Errorf("expect %d buffer length, got %d", e, a) 224 } 225 226 expectCalls := []string{"GetObject", "GetObject", "GetObject"} 227 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 228 t.Errorf("expect %v API calls, got %v", e, a) 229 } 230 231 expectRngs := []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"} 232 if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { 233 t.Errorf("expect %v ranges, got %v", e, a) 234 } 235 } 236 237 func TestDownloadZero(t *testing.T) { 238 s, names, ranges := dlLoggingSvc([]byte{}) 239 240 d := s3manager.NewDownloaderWithClient(s) 241 w := &aws.WriteAtBuffer{} 242 n, err := d.Download(w, &s3.GetObjectInput{ 243 Bucket: aws.String("bucket"), 244 Key: aws.String("key"), 245 }) 246 247 if err != nil { 248 t.Fatalf("expect no error, got %v", err) 249 } 250 if n != 0 { 251 t.Errorf("expect 0 bytes read, got %d", n) 252 } 253 expectCalls := []string{"GetObject"} 254 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 255 t.Errorf("expect %v API calls, got %v", e, a) 256 } 257 258 expectRngs := []string{"bytes=0-5242879"} 259 if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { 260 t.Errorf("expect %v ranges, got %v", e, a) 261 } 262 } 263 264 func TestDownloadSetPartSize(t *testing.T) { 265 s, names, ranges := dlLoggingSvc([]byte{1, 2, 3}) 266 267 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 268 d.Concurrency = 1 269 d.PartSize = 1 270 }) 271 w := &aws.WriteAtBuffer{} 272 n, err := d.Download(w, &s3.GetObjectInput{ 273 Bucket: aws.String("bucket"), 274 Key: aws.String("key"), 275 }) 276 277 if err != nil { 278 t.Fatalf("expect no error, got %v", err) 279 } 280 if e, a := int64(3), n; e != a { 281 t.Errorf("expect %d bytes read, got %d", e, a) 282 } 283 expectCalls := []string{"GetObject", "GetObject", "GetObject"} 284 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 285 t.Errorf("expect %v API calls, got %v", e, a) 286 } 287 expectRngs := []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"} 288 if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { 289 t.Errorf("expect %v ranges, got %v", e, a) 290 } 291 expectBytes := []byte{1, 2, 3} 292 if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { 293 t.Errorf("expect %v bytes, got %v", e, a) 294 } 295 } 296 297 func TestDownloadError(t *testing.T) { 298 s, names, _ := dlLoggingSvc([]byte{1, 2, 3}) 299 300 num := 0 301 s.Handlers.Send.PushBack(func(r *request.Request) { 302 num++ 303 if num > 1 { 304 r.HTTPResponse.StatusCode = 400 305 r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) 306 } 307 }) 308 309 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 310 d.Concurrency = 1 311 d.PartSize = 1 312 }) 313 w := &aws.WriteAtBuffer{} 314 n, err := d.Download(w, &s3.GetObjectInput{ 315 Bucket: aws.String("bucket"), 316 Key: aws.String("key"), 317 }) 318 319 if err == nil { 320 t.Fatalf("expect error, got none") 321 } 322 aerr := err.(awserr.Error) 323 if e, a := "BadRequest", aerr.Code(); e != a { 324 t.Errorf("expect %s error code, got %s", e, a) 325 } 326 if e, a := int64(1), n; e != a { 327 t.Errorf("expect %d bytes read, got %d", e, a) 328 } 329 expectCalls := []string{"GetObject", "GetObject"} 330 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 331 t.Errorf("expect %v API calls, got %v", e, a) 332 } 333 expectBytes := []byte{1} 334 if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { 335 t.Errorf("expect %v bytes, got %v", e, a) 336 } 337 } 338 339 func TestDownloadNonChunk(t *testing.T) { 340 s, names := dlLoggingSvcNoChunk(buf2MB) 341 342 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 343 d.Concurrency = 1 344 }) 345 w := &aws.WriteAtBuffer{} 346 n, err := d.Download(w, &s3.GetObjectInput{ 347 Bucket: aws.String("bucket"), 348 Key: aws.String("key"), 349 }) 350 351 if err != nil { 352 t.Fatalf("expect no error, got %v", err) 353 } 354 if e, a := int64(len(buf2MB)), n; e != a { 355 t.Errorf("expect %d bytes read, got %d", e, a) 356 } 357 expectCalls := []string{"GetObject"} 358 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 359 t.Errorf("expect %v API calls, got %v", e, a) 360 } 361 362 count := 0 363 for _, b := range w.Bytes() { 364 count += int(b) 365 } 366 if count != 0 { 367 t.Errorf("expect 0 count, got %d", count) 368 } 369 } 370 371 func TestDownloadNoContentRangeLength(t *testing.T) { 372 s, names := dlLoggingSvcNoContentRangeLength(buf2MB, []int{200, 416}) 373 374 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 375 d.Concurrency = 1 376 }) 377 w := &aws.WriteAtBuffer{} 378 n, err := d.Download(w, &s3.GetObjectInput{ 379 Bucket: aws.String("bucket"), 380 Key: aws.String("key"), 381 }) 382 383 if err != nil { 384 t.Fatalf("expect no error, got %v", err) 385 } 386 if e, a := int64(len(buf2MB)), n; e != a { 387 t.Errorf("expect %d bytes read, got %d", e, a) 388 } 389 expectCalls := []string{"GetObject", "GetObject"} 390 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 391 t.Errorf("expect %v API calls, got %v", e, a) 392 } 393 394 count := 0 395 for _, b := range w.Bytes() { 396 count += int(b) 397 } 398 if count != 0 { 399 t.Errorf("expect 0 count, got %d", count) 400 } 401 } 402 403 func TestDownloadContentRangeTotalAny(t *testing.T) { 404 s, names := dlLoggingSvcContentRangeTotalAny(buf2MB, []int{200, 416}) 405 406 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 407 d.Concurrency = 1 408 }) 409 w := &aws.WriteAtBuffer{} 410 n, err := d.Download(w, &s3.GetObjectInput{ 411 Bucket: aws.String("bucket"), 412 Key: aws.String("key"), 413 }) 414 415 if err != nil { 416 t.Fatalf("expect no error, got %v", err) 417 } 418 if e, a := int64(len(buf2MB)), n; e != a { 419 t.Errorf("expect %d bytes read, got %d", e, a) 420 } 421 expectCalls := []string{"GetObject", "GetObject"} 422 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 423 t.Errorf("expect %v API calls, got %v", e, a) 424 } 425 426 count := 0 427 for _, b := range w.Bytes() { 428 count += int(b) 429 } 430 if count != 0 { 431 t.Errorf("expect 0 count, got %d", count) 432 } 433 } 434 435 func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) { 436 s, names := dlLoggingSvcWithErrReader([]testErrReader{ 437 {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, 438 {Buf: []byte("123"), Len: 3, Err: io.EOF}, 439 }) 440 441 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 442 d.Concurrency = 1 443 }) 444 445 w := &aws.WriteAtBuffer{} 446 n, err := d.Download(w, &s3.GetObjectInput{ 447 Bucket: aws.String("bucket"), 448 Key: aws.String("key"), 449 }) 450 451 if err != nil { 452 t.Fatalf("expect no error, got %v", err) 453 } 454 if e, a := int64(3), n; e != a { 455 t.Errorf("expect %d bytes read, got %d", e, a) 456 } 457 expectCalls := []string{"GetObject", "GetObject"} 458 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 459 t.Errorf("expect %v API calls, got %v", e, a) 460 } 461 if e, a := "123", string(w.Bytes()); e != a { 462 t.Errorf("expect %q response, got %q", e, a) 463 } 464 } 465 466 func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) { 467 s, names := dlLoggingSvcWithErrReader([]testErrReader{ 468 {Buf: []byte("abc"), Len: 3, Err: io.EOF}, 469 }) 470 471 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 472 d.Concurrency = 1 473 }) 474 475 w := &aws.WriteAtBuffer{} 476 n, err := d.Download(w, &s3.GetObjectInput{ 477 Bucket: aws.String("bucket"), 478 Key: aws.String("key"), 479 }) 480 481 if err != nil { 482 t.Fatalf("expect no error, got %v", err) 483 } 484 if e, a := int64(3), n; e != a { 485 t.Errorf("expect %d bytes read, got %d", e, a) 486 } 487 expectCalls := []string{"GetObject"} 488 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 489 t.Errorf("expect %v API calls, got %v", e, a) 490 } 491 if e, a := "abc", string(w.Bytes()); e != a { 492 t.Errorf("expect %q response, got %q", e, a) 493 } 494 } 495 496 func TestDownloadPartBodyRetry_FailRetry(t *testing.T) { 497 s, names := dlLoggingSvcWithErrReader([]testErrReader{ 498 {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, 499 }) 500 501 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 502 d.Concurrency = 1 503 }) 504 505 w := &aws.WriteAtBuffer{} 506 n, err := d.Download(w, &s3.GetObjectInput{ 507 Bucket: aws.String("bucket"), 508 Key: aws.String("key"), 509 }) 510 511 if err == nil { 512 t.Fatalf("expect error, got none") 513 } 514 if e, a := "unexpected EOF", err.Error(); !strings.Contains(a, e) { 515 t.Errorf("expect %q error message to be in %q", e, a) 516 } 517 if e, a := int64(2), n; e != a { 518 t.Errorf("expect %d bytes read, got %d", e, a) 519 } 520 expectCalls := []string{"GetObject"} 521 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 522 t.Errorf("expect %v API calls, got %v", e, a) 523 } 524 if e, a := "ab", string(w.Bytes()); e != a { 525 t.Errorf("expect %q response, got %q", e, a) 526 } 527 } 528 529 func TestDownloadWithContextCanceled(t *testing.T) { 530 d := s3manager.NewDownloader(unit.Session) 531 532 params := s3.GetObjectInput{ 533 Bucket: aws.String("Bucket"), 534 Key: aws.String("Key"), 535 } 536 537 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 538 ctx.Error = fmt.Errorf("context canceled") 539 close(ctx.DoneCh) 540 541 w := &aws.WriteAtBuffer{} 542 543 _, err := d.DownloadWithContext(ctx, w, ¶ms) 544 if err == nil { 545 t.Fatalf("expected error, did not get one") 546 } 547 aerr := err.(awserr.Error) 548 if e, a := request.CanceledErrorCode, aerr.Code(); e != a { 549 t.Errorf("expected error code %q, got %q", e, a) 550 } 551 if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) { 552 t.Errorf("expected error message to contain %q, but did not %q", e, a) 553 } 554 } 555 556 func TestDownload_WithRange(t *testing.T) { 557 s, names, ranges := dlLoggingSvc([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) 558 559 d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) { 560 d.Concurrency = 10 // should be ignored 561 d.PartSize = 1 // should be ignored 562 }) 563 564 w := &aws.WriteAtBuffer{} 565 n, err := d.Download(w, &s3.GetObjectInput{ 566 Bucket: aws.String("bucket"), 567 Key: aws.String("key"), 568 Range: aws.String("bytes=2-6"), 569 }) 570 571 if err != nil { 572 t.Fatalf("expect no error, got %v", err) 573 } 574 if e, a := int64(5), n; e != a { 575 t.Errorf("expect %d bytes read, got %d", e, a) 576 } 577 expectCalls := []string{"GetObject"} 578 if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) { 579 t.Errorf("expect %v API calls, got %v", e, a) 580 } 581 expectRngs := []string{"bytes=2-6"} 582 if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { 583 t.Errorf("expect %v ranges, got %v", e, a) 584 } 585 expectBytes := []byte{2, 3, 4, 5, 6} 586 if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { 587 t.Errorf("expect %v bytes, got %v", e, a) 588 } 589 } 590 591 func TestDownload_WithFailure(t *testing.T) { 592 svc := s3.New(unit.Session) 593 svc.Handlers.Send.Clear() 594 595 reqCount := int64(0) 596 startingByte := 0 597 svc.Handlers.Send.PushBack(func(r *request.Request) { 598 switch atomic.LoadInt64(&reqCount) { 599 case 1: 600 // Give a chance for the multipart chunks to be queued up 601 time.Sleep(1 * time.Second) 602 603 r.HTTPResponse = &http.Response{ 604 Header: http.Header{}, 605 Body: ioutil.NopCloser(&bytes.Buffer{}), 606 } 607 r.Error = awserr.New("ConnectionError", "some connection error", nil) 608 r.Retryable = aws.Bool(false) 609 610 default: 611 body := bytes.NewReader(make([]byte, s3manager.DefaultDownloadPartSize)) 612 r.HTTPResponse = &http.Response{ 613 StatusCode: http.StatusOK, 614 Status: http.StatusText(http.StatusOK), 615 ContentLength: int64(body.Len()), 616 Body: ioutil.NopCloser(body), 617 Header: http.Header{}, 618 } 619 r.HTTPResponse.Header.Set("Content-Length", strconv.Itoa(body.Len())) 620 r.HTTPResponse.Header.Set("Content-Range", 621 fmt.Sprintf("bytes %d-%d/%d", startingByte, body.Len()-1, body.Len()*10)) 622 623 startingByte += body.Len() 624 if reqCount > 0 { 625 // sleep here to ensure context switching between goroutines 626 time.Sleep(25 * time.Millisecond) 627 } 628 } 629 630 atomic.AddInt64(&reqCount, 1) 631 }) 632 633 d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) { 634 d.Concurrency = 2 635 }) 636 637 w := &aws.WriteAtBuffer{} 638 params := s3.GetObjectInput{ 639 Bucket: aws.String("Bucket"), 640 Key: aws.String("Key"), 641 } 642 643 // Expect this request to exit quickly after failure 644 _, err := d.Download(w, ¶ms) 645 if err == nil { 646 t.Fatalf("expect error, got none") 647 } 648 649 if atomic.LoadInt64(&reqCount) > 3 { 650 t.Errorf("expect no more than 3 requests, but received %d", reqCount) 651 } 652 } 653 654 func TestDownloadBufferStrategy(t *testing.T) { 655 cases := map[string]struct { 656 partSize int64 657 strategy *recordedWriterReadFromProvider 658 expectedSize int64 659 }{ 660 "no strategy": { 661 partSize: s3manager.DefaultDownloadPartSize, 662 expectedSize: 10 * sdkio.MebiByte, 663 }, 664 "partSize modulo bufferSize == 0": { 665 partSize: 5 * sdkio.MebiByte, 666 strategy: &recordedWriterReadFromProvider{ 667 WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(int(sdkio.MebiByte)), // 1 MiB 668 }, 669 expectedSize: 10 * sdkio.MebiByte, // 10 MiB 670 }, 671 "partSize modulo bufferSize > 0": { 672 partSize: 5 * 1024 * 1204, // 5 MiB 673 strategy: &recordedWriterReadFromProvider{ 674 WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(2 * int(sdkio.MebiByte)), // 2 MiB 675 }, 676 expectedSize: 10 * sdkio.MebiByte, // 10 MiB 677 }, 678 } 679 680 for name, tCase := range cases { 681 t.Logf("starting case: %v", name) 682 683 expected := s3testing.GetTestBytes(int(tCase.expectedSize)) 684 685 svc, _, _ := dlLoggingSvc(expected) 686 687 d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) { 688 d.PartSize = tCase.partSize 689 if tCase.strategy != nil { 690 d.BufferProvider = tCase.strategy 691 } 692 }) 693 694 buffer := aws.NewWriteAtBuffer(make([]byte, len(expected))) 695 696 n, err := d.Download(buffer, &s3.GetObjectInput{ 697 Bucket: aws.String("bucket"), 698 Key: aws.String("key"), 699 }) 700 if err != nil { 701 t.Errorf("failed to download: %v", err) 702 } 703 704 if e, a := len(expected), int(n); e != a { 705 t.Errorf("expected %v, got %v downloaded bytes", e, a) 706 } 707 708 if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) { 709 t.Errorf("downloaded bytes did not match expected") 710 } 711 712 if tCase.strategy != nil { 713 if e, a := tCase.strategy.callbacksVended, tCase.strategy.callbacksExecuted; e != a { 714 t.Errorf("expected %v, got %v", e, a) 715 } 716 } 717 } 718 } 719 720 type testErrReader struct { 721 Buf []byte 722 Err error 723 Len int64 724 725 off int 726 } 727 728 func (r *testErrReader) Read(p []byte) (int, error) { 729 to := len(r.Buf) - r.off 730 731 n := copy(p, r.Buf[r.off:to]) 732 r.off += n 733 734 if n < len(p) { 735 return n, r.Err 736 737 } 738 739 return n, nil 740 } 741 742 func TestDownloadBufferStrategy_Errors(t *testing.T) { 743 expected := s3testing.GetTestBytes(int(10 * sdkio.MebiByte)) 744 745 svc, _, _ := dlLoggingSvc(expected) 746 strat := &recordedWriterReadFromProvider{ 747 WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(int(2 * sdkio.MebiByte)), 748 } 749 750 d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) { 751 d.PartSize = 5 * sdkio.MebiByte 752 d.BufferProvider = strat 753 d.Concurrency = 1 754 }) 755 756 seenOps := make(map[string]struct{}) 757 svc.Handlers.Send.PushFront(func(*request.Request) {}) 758 svc.Handlers.Send.AfterEachFn = func(item request.HandlerListRunItem) bool { 759 r := item.Request 760 761 if r.Operation.Name != "GetObject" { 762 return true 763 } 764 765 input := r.Params.(*s3.GetObjectInput) 766 767 fingerPrint := fmt.Sprintf("%s/%s/%s/%s", r.Operation.Name, *input.Bucket, *input.Key, *input.Range) 768 if _, ok := seenOps[fingerPrint]; ok { 769 return true 770 } 771 seenOps[fingerPrint] = struct{}{} 772 773 regex := regexp.MustCompile(`bytes=(\d+)-(\d+)`) 774 rng := regex.FindStringSubmatch(*input.Range) 775 start, _ := strconv.ParseInt(rng[1], 10, 64) 776 fin, _ := strconv.ParseInt(rng[2], 10, 64) 777 778 _, _ = io.Copy(ioutil.Discard, r.Body) 779 r.HTTPResponse = &http.Response{ 780 StatusCode: 200, 781 Body: aws.ReadSeekCloser(&badReader{err: io.ErrUnexpectedEOF}), 782 ContentLength: fin - start, 783 } 784 785 return false 786 } 787 788 buffer := aws.NewWriteAtBuffer(make([]byte, len(expected))) 789 790 n, err := d.Download(buffer, &s3.GetObjectInput{ 791 Bucket: aws.String("bucket"), 792 Key: aws.String("key"), 793 }) 794 if err != nil { 795 t.Errorf("failed to download: %v", err) 796 } 797 798 if e, a := len(expected), int(n); e != a { 799 t.Errorf("expected %v, got %v downloaded bytes", e, a) 800 } 801 802 if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) { 803 t.Errorf("downloaded bytes did not match expected") 804 } 805 806 if e, a := strat.callbacksVended, strat.callbacksExecuted; e != a { 807 t.Errorf("expected %v, got %v", e, a) 808 } 809 } 810 811 func TestDownloaderValidARN(t *testing.T) { 812 cases := map[string]struct { 813 input s3.GetObjectInput 814 wantErr bool 815 }{ 816 "standard bucket": { 817 input: s3.GetObjectInput{ 818 Bucket: aws.String("test-bucket"), 819 Key: aws.String("test-key"), 820 }, 821 }, 822 "accesspoint": { 823 input: s3.GetObjectInput{ 824 Bucket: aws.String("arn:aws:s3:us-west-2:123456789012:accesspoint/myap"), 825 Key: aws.String("test-key"), 826 }, 827 }, 828 "outpost accesspoint": { 829 input: s3.GetObjectInput{ 830 Bucket: aws.String("arn:aws:s3-outposts:us-west-2:012345678901:outpost/op-1234567890123456/accesspoint/myaccesspoint"), 831 Key: aws.String("test-key"), 832 }, 833 }, 834 "s3-object-lambda accesspoint": { 835 input: s3.GetObjectInput{ 836 Bucket: aws.String("arn:aws:s3-object-lambda:us-west-2:123456789012:accesspoint/myap"), 837 }, 838 wantErr: true, 839 }, 840 } 841 842 for name, tt := range cases { 843 t.Run(name, func(t *testing.T) { 844 client, _, _ := dlLoggingSvc(buf2MB) 845 846 client.Config.Region = aws.String("us-west-2") 847 client.ClientInfo.SigningRegion = "us-west-2" 848 849 downloader := s3manager.NewDownloaderWithClient(client, func(downloader *s3manager.Downloader) { 850 downloader.Concurrency = 1 851 }) 852 853 _, err := downloader.Download(&awstesting.DiscardAt{}, &tt.input) 854 if (err != nil) != tt.wantErr { 855 t.Errorf("err: %v, wantErr: %v", err, tt.wantErr) 856 } 857 }) 858 } 859 } 860 861 type recordedWriterReadFromProvider struct { 862 callbacksVended uint32 863 callbacksExecuted uint32 864 s3manager.WriterReadFromProvider 865 } 866 867 func (r *recordedWriterReadFromProvider) GetReadFrom(writer io.Writer) (s3manager.WriterReadFrom, func()) { 868 w, cleanup := r.WriterReadFromProvider.GetReadFrom(writer) 869 870 atomic.AddUint32(&r.callbacksVended, 1) 871 return w, func() { 872 atomic.AddUint32(&r.callbacksExecuted, 1) 873 cleanup() 874 } 875 } 876 877 type badReader struct { 878 err error 879 } 880 881 func (b *badReader) Read(p []byte) (int, error) { 882 tb := s3testing.GetTestBytes(len(p)) 883 copy(p, tb) 884 885 return len(p), b.err 886 } 887 888 var mockErrorResponse = struct { 889 XMLName xml.Name `xml:"Error"` 890 Code string `xml:"Code"` 891 Message string `xml:"Message"` 892 }{ 893 Code: "MOCK_S3_ERROR_CODE", 894 Message: "Mocked S3 Error Message", 895 }