github.com/aavshr/aws-sdk-go@v1.41.3/service/s3/s3manager/download_test.go (about)

     1  //go:build go1.7
     2  // +build go1.7
     3  
     4  package s3manager_test
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/xml"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"net/http"
    13  	"reflect"
    14  	"regexp"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/aavshr/aws-sdk-go/aws"
    23  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    24  	"github.com/aavshr/aws-sdk-go/aws/request"
    25  	"github.com/aavshr/aws-sdk-go/awstesting"
    26  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    27  	"github.com/aavshr/aws-sdk-go/internal/sdkio"
    28  	"github.com/aavshr/aws-sdk-go/service/s3"
    29  	"github.com/aavshr/aws-sdk-go/service/s3/internal/s3testing"
    30  	"github.com/aavshr/aws-sdk-go/service/s3/s3manager"
    31  )
    32  
    33  func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
    34  	var m sync.Mutex
    35  	names := []string{}
    36  	ranges := []string{}
    37  
    38  	svc := s3.New(unit.Session)
    39  	svc.Handlers.Send.Clear()
    40  	svc.Handlers.Send.PushBack(func(r *request.Request) {
    41  		m.Lock()
    42  		defer m.Unlock()
    43  
    44  		names = append(names, r.Operation.Name)
    45  		ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
    46  
    47  		rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
    48  		rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
    49  		start, _ := strconv.ParseInt(rng[1], 10, 64)
    50  		fin, _ := strconv.ParseInt(rng[2], 10, 64)
    51  		fin++
    52  
    53  		if fin > int64(len(data)) {
    54  			fin = int64(len(data))
    55  		}
    56  
    57  		bodyBytes := data[start:fin]
    58  		r.HTTPResponse = &http.Response{
    59  			StatusCode: 200,
    60  			Body:       ioutil.NopCloser(bytes.NewReader(bodyBytes)),
    61  			Header:     http.Header{},
    62  		}
    63  		r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
    64  			start, fin-1, len(data)))
    65  		r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
    66  	})
    67  
    68  	return svc, &names, &ranges
    69  }
    70  
    71  func dlLoggingSvcNoChunk(data []byte) (*s3.S3, *[]string) {
    72  	var m sync.Mutex
    73  	names := []string{}
    74  
    75  	svc := s3.New(unit.Session)
    76  	svc.Handlers.Send.Clear()
    77  	svc.Handlers.Send.PushBack(func(r *request.Request) {
    78  		m.Lock()
    79  		defer m.Unlock()
    80  
    81  		names = append(names, r.Operation.Name)
    82  
    83  		r.HTTPResponse = &http.Response{
    84  			StatusCode: 200,
    85  			Body:       ioutil.NopCloser(bytes.NewReader(data[:])),
    86  			Header:     http.Header{},
    87  		}
    88  		r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(data)))
    89  	})
    90  
    91  	return svc, &names
    92  }
    93  
    94  func dlLoggingSvcNoContentRangeLength(data []byte, states []int) (*s3.S3, *[]string) {
    95  	var m sync.Mutex
    96  	names := []string{}
    97  	var index int
    98  
    99  	svc := s3.New(unit.Session)
   100  	svc.Handlers.Send.Clear()
   101  	svc.Handlers.Send.PushBack(func(r *request.Request) {
   102  		m.Lock()
   103  		defer m.Unlock()
   104  
   105  		names = append(names, r.Operation.Name)
   106  
   107  		var body io.Reader
   108  		if states[index] < 400 {
   109  			body = bytes.NewReader(data[:])
   110  		} else {
   111  			var buffer bytes.Buffer
   112  			encoder := xml.NewEncoder(&buffer)
   113  			_ = encoder.Encode(&mockErrorResponse)
   114  			body = &buffer
   115  		}
   116  
   117  		r.HTTPResponse = &http.Response{
   118  			StatusCode: states[index],
   119  			Body:       ioutil.NopCloser(body),
   120  			Header:     http.Header{},
   121  		}
   122  		index++
   123  	})
   124  
   125  	return svc, &names
   126  }
   127  
   128  func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.S3, *[]string) {
   129  	var m sync.Mutex
   130  	names := []string{}
   131  	ranges := []string{}
   132  	var index int
   133  
   134  	svc := s3.New(unit.Session)
   135  	svc.Handlers.Send.Clear()
   136  	svc.Handlers.Send.PushBack(func(r *request.Request) {
   137  		m.Lock()
   138  		defer m.Unlock()
   139  
   140  		names = append(names, r.Operation.Name)
   141  		ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
   142  
   143  		rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
   144  		rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
   145  		start, _ := strconv.ParseInt(rng[1], 10, 64)
   146  		fin, _ := strconv.ParseInt(rng[2], 10, 64)
   147  		fin++
   148  
   149  		if fin >= int64(len(data)) {
   150  			fin = int64(len(data))
   151  		}
   152  
   153  		// Setting start and finish to 0 because this state of 1 is suppose to
   154  		// be an error state of 416
   155  		if index == len(states)-1 {
   156  			start = 0
   157  			fin = 0
   158  		}
   159  
   160  		bodyBytes := data[start:fin]
   161  
   162  		r.HTTPResponse = &http.Response{
   163  			StatusCode: states[index],
   164  			Body:       ioutil.NopCloser(bytes.NewReader(bodyBytes)),
   165  			Header:     http.Header{},
   166  		}
   167  		r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/*",
   168  			start, fin-1))
   169  		index++
   170  	})
   171  
   172  	return svc, &names
   173  }
   174  
   175  func dlLoggingSvcWithErrReader(cases []testErrReader) (*s3.S3, *[]string) {
   176  	var m sync.Mutex
   177  	names := []string{}
   178  	var index int
   179  
   180  	svc := s3.New(unit.Session, &aws.Config{
   181  		MaxRetries: aws.Int(len(cases) - 1),
   182  	})
   183  	svc.Handlers.Send.Clear()
   184  	svc.Handlers.Send.PushBack(func(r *request.Request) {
   185  		m.Lock()
   186  		defer m.Unlock()
   187  
   188  		names = append(names, r.Operation.Name)
   189  
   190  		c := cases[index]
   191  
   192  		r.HTTPResponse = &http.Response{
   193  			StatusCode: http.StatusOK,
   194  			Body:       ioutil.NopCloser(&c),
   195  			Header:     http.Header{},
   196  		}
   197  		r.HTTPResponse.Header.Set("Content-Range",
   198  			fmt.Sprintf("bytes %d-%d/%d", 0, c.Len-1, c.Len))
   199  		r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", c.Len))
   200  		index++
   201  	})
   202  
   203  	return svc, &names
   204  }
   205  
   206  func TestDownloadOrder(t *testing.T) {
   207  	s, names, ranges := dlLoggingSvc(buf12MB)
   208  
   209  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   210  		d.Concurrency = 1
   211  	})
   212  
   213  	w := aws.NewWriteAtBuffer(make([]byte, len(buf12MB)))
   214  	n, err := d.Download(w, &s3.GetObjectInput{
   215  		Bucket: aws.String("bucket"),
   216  		Key:    aws.String("key"),
   217  	})
   218  
   219  	if err != nil {
   220  		t.Fatalf("expect no error, got %v", err)
   221  	}
   222  	if e, a := int64(len(buf12MB)), n; e != a {
   223  		t.Errorf("expect %d buffer length, got %d", e, a)
   224  	}
   225  
   226  	expectCalls := []string{"GetObject", "GetObject", "GetObject"}
   227  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   228  		t.Errorf("expect %v API calls, got %v", e, a)
   229  	}
   230  
   231  	expectRngs := []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}
   232  	if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
   233  		t.Errorf("expect %v ranges, got %v", e, a)
   234  	}
   235  }
   236  
   237  func TestDownloadZero(t *testing.T) {
   238  	s, names, ranges := dlLoggingSvc([]byte{})
   239  
   240  	d := s3manager.NewDownloaderWithClient(s)
   241  	w := &aws.WriteAtBuffer{}
   242  	n, err := d.Download(w, &s3.GetObjectInput{
   243  		Bucket: aws.String("bucket"),
   244  		Key:    aws.String("key"),
   245  	})
   246  
   247  	if err != nil {
   248  		t.Fatalf("expect no error, got %v", err)
   249  	}
   250  	if n != 0 {
   251  		t.Errorf("expect 0 bytes read, got %d", n)
   252  	}
   253  	expectCalls := []string{"GetObject"}
   254  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   255  		t.Errorf("expect %v API calls, got %v", e, a)
   256  	}
   257  
   258  	expectRngs := []string{"bytes=0-5242879"}
   259  	if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
   260  		t.Errorf("expect %v ranges, got %v", e, a)
   261  	}
   262  }
   263  
   264  func TestDownloadSetPartSize(t *testing.T) {
   265  	s, names, ranges := dlLoggingSvc([]byte{1, 2, 3})
   266  
   267  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   268  		d.Concurrency = 1
   269  		d.PartSize = 1
   270  	})
   271  	w := &aws.WriteAtBuffer{}
   272  	n, err := d.Download(w, &s3.GetObjectInput{
   273  		Bucket: aws.String("bucket"),
   274  		Key:    aws.String("key"),
   275  	})
   276  
   277  	if err != nil {
   278  		t.Fatalf("expect no error, got %v", err)
   279  	}
   280  	if e, a := int64(3), n; e != a {
   281  		t.Errorf("expect %d bytes read, got %d", e, a)
   282  	}
   283  	expectCalls := []string{"GetObject", "GetObject", "GetObject"}
   284  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   285  		t.Errorf("expect %v API calls, got %v", e, a)
   286  	}
   287  	expectRngs := []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}
   288  	if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
   289  		t.Errorf("expect %v ranges, got %v", e, a)
   290  	}
   291  	expectBytes := []byte{1, 2, 3}
   292  	if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
   293  		t.Errorf("expect %v bytes, got %v", e, a)
   294  	}
   295  }
   296  
   297  func TestDownloadError(t *testing.T) {
   298  	s, names, _ := dlLoggingSvc([]byte{1, 2, 3})
   299  
   300  	num := 0
   301  	s.Handlers.Send.PushBack(func(r *request.Request) {
   302  		num++
   303  		if num > 1 {
   304  			r.HTTPResponse.StatusCode = 400
   305  			r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
   306  		}
   307  	})
   308  
   309  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   310  		d.Concurrency = 1
   311  		d.PartSize = 1
   312  	})
   313  	w := &aws.WriteAtBuffer{}
   314  	n, err := d.Download(w, &s3.GetObjectInput{
   315  		Bucket: aws.String("bucket"),
   316  		Key:    aws.String("key"),
   317  	})
   318  
   319  	if err == nil {
   320  		t.Fatalf("expect error, got none")
   321  	}
   322  	aerr := err.(awserr.Error)
   323  	if e, a := "BadRequest", aerr.Code(); e != a {
   324  		t.Errorf("expect %s error code, got %s", e, a)
   325  	}
   326  	if e, a := int64(1), n; e != a {
   327  		t.Errorf("expect %d bytes read, got %d", e, a)
   328  	}
   329  	expectCalls := []string{"GetObject", "GetObject"}
   330  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   331  		t.Errorf("expect %v API calls, got %v", e, a)
   332  	}
   333  	expectBytes := []byte{1}
   334  	if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
   335  		t.Errorf("expect %v bytes, got %v", e, a)
   336  	}
   337  }
   338  
   339  func TestDownloadNonChunk(t *testing.T) {
   340  	s, names := dlLoggingSvcNoChunk(buf2MB)
   341  
   342  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   343  		d.Concurrency = 1
   344  	})
   345  	w := &aws.WriteAtBuffer{}
   346  	n, err := d.Download(w, &s3.GetObjectInput{
   347  		Bucket: aws.String("bucket"),
   348  		Key:    aws.String("key"),
   349  	})
   350  
   351  	if err != nil {
   352  		t.Fatalf("expect no error, got %v", err)
   353  	}
   354  	if e, a := int64(len(buf2MB)), n; e != a {
   355  		t.Errorf("expect %d bytes read, got %d", e, a)
   356  	}
   357  	expectCalls := []string{"GetObject"}
   358  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   359  		t.Errorf("expect %v API calls, got %v", e, a)
   360  	}
   361  
   362  	count := 0
   363  	for _, b := range w.Bytes() {
   364  		count += int(b)
   365  	}
   366  	if count != 0 {
   367  		t.Errorf("expect 0 count, got %d", count)
   368  	}
   369  }
   370  
   371  func TestDownloadNoContentRangeLength(t *testing.T) {
   372  	s, names := dlLoggingSvcNoContentRangeLength(buf2MB, []int{200, 416})
   373  
   374  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   375  		d.Concurrency = 1
   376  	})
   377  	w := &aws.WriteAtBuffer{}
   378  	n, err := d.Download(w, &s3.GetObjectInput{
   379  		Bucket: aws.String("bucket"),
   380  		Key:    aws.String("key"),
   381  	})
   382  
   383  	if err != nil {
   384  		t.Fatalf("expect no error, got %v", err)
   385  	}
   386  	if e, a := int64(len(buf2MB)), n; e != a {
   387  		t.Errorf("expect %d bytes read, got %d", e, a)
   388  	}
   389  	expectCalls := []string{"GetObject", "GetObject"}
   390  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   391  		t.Errorf("expect %v API calls, got %v", e, a)
   392  	}
   393  
   394  	count := 0
   395  	for _, b := range w.Bytes() {
   396  		count += int(b)
   397  	}
   398  	if count != 0 {
   399  		t.Errorf("expect 0 count, got %d", count)
   400  	}
   401  }
   402  
   403  func TestDownloadContentRangeTotalAny(t *testing.T) {
   404  	s, names := dlLoggingSvcContentRangeTotalAny(buf2MB, []int{200, 416})
   405  
   406  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   407  		d.Concurrency = 1
   408  	})
   409  	w := &aws.WriteAtBuffer{}
   410  	n, err := d.Download(w, &s3.GetObjectInput{
   411  		Bucket: aws.String("bucket"),
   412  		Key:    aws.String("key"),
   413  	})
   414  
   415  	if err != nil {
   416  		t.Fatalf("expect no error, got %v", err)
   417  	}
   418  	if e, a := int64(len(buf2MB)), n; e != a {
   419  		t.Errorf("expect %d bytes read, got %d", e, a)
   420  	}
   421  	expectCalls := []string{"GetObject", "GetObject"}
   422  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   423  		t.Errorf("expect %v API calls, got %v", e, a)
   424  	}
   425  
   426  	count := 0
   427  	for _, b := range w.Bytes() {
   428  		count += int(b)
   429  	}
   430  	if count != 0 {
   431  		t.Errorf("expect 0 count, got %d", count)
   432  	}
   433  }
   434  
   435  func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) {
   436  	s, names := dlLoggingSvcWithErrReader([]testErrReader{
   437  		{Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
   438  		{Buf: []byte("123"), Len: 3, Err: io.EOF},
   439  	})
   440  
   441  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   442  		d.Concurrency = 1
   443  	})
   444  
   445  	w := &aws.WriteAtBuffer{}
   446  	n, err := d.Download(w, &s3.GetObjectInput{
   447  		Bucket: aws.String("bucket"),
   448  		Key:    aws.String("key"),
   449  	})
   450  
   451  	if err != nil {
   452  		t.Fatalf("expect no error, got %v", err)
   453  	}
   454  	if e, a := int64(3), n; e != a {
   455  		t.Errorf("expect %d bytes read, got %d", e, a)
   456  	}
   457  	expectCalls := []string{"GetObject", "GetObject"}
   458  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   459  		t.Errorf("expect %v API calls, got %v", e, a)
   460  	}
   461  	if e, a := "123", string(w.Bytes()); e != a {
   462  		t.Errorf("expect %q response, got %q", e, a)
   463  	}
   464  }
   465  
   466  func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) {
   467  	s, names := dlLoggingSvcWithErrReader([]testErrReader{
   468  		{Buf: []byte("abc"), Len: 3, Err: io.EOF},
   469  	})
   470  
   471  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   472  		d.Concurrency = 1
   473  	})
   474  
   475  	w := &aws.WriteAtBuffer{}
   476  	n, err := d.Download(w, &s3.GetObjectInput{
   477  		Bucket: aws.String("bucket"),
   478  		Key:    aws.String("key"),
   479  	})
   480  
   481  	if err != nil {
   482  		t.Fatalf("expect no error, got %v", err)
   483  	}
   484  	if e, a := int64(3), n; e != a {
   485  		t.Errorf("expect %d bytes read, got %d", e, a)
   486  	}
   487  	expectCalls := []string{"GetObject"}
   488  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   489  		t.Errorf("expect %v API calls, got %v", e, a)
   490  	}
   491  	if e, a := "abc", string(w.Bytes()); e != a {
   492  		t.Errorf("expect %q response, got %q", e, a)
   493  	}
   494  }
   495  
   496  func TestDownloadPartBodyRetry_FailRetry(t *testing.T) {
   497  	s, names := dlLoggingSvcWithErrReader([]testErrReader{
   498  		{Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
   499  	})
   500  
   501  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   502  		d.Concurrency = 1
   503  	})
   504  
   505  	w := &aws.WriteAtBuffer{}
   506  	n, err := d.Download(w, &s3.GetObjectInput{
   507  		Bucket: aws.String("bucket"),
   508  		Key:    aws.String("key"),
   509  	})
   510  
   511  	if err == nil {
   512  		t.Fatalf("expect error, got none")
   513  	}
   514  	if e, a := "unexpected EOF", err.Error(); !strings.Contains(a, e) {
   515  		t.Errorf("expect %q error message to be in %q", e, a)
   516  	}
   517  	if e, a := int64(2), n; e != a {
   518  		t.Errorf("expect %d bytes read, got %d", e, a)
   519  	}
   520  	expectCalls := []string{"GetObject"}
   521  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   522  		t.Errorf("expect %v API calls, got %v", e, a)
   523  	}
   524  	if e, a := "ab", string(w.Bytes()); e != a {
   525  		t.Errorf("expect %q response, got %q", e, a)
   526  	}
   527  }
   528  
   529  func TestDownloadWithContextCanceled(t *testing.T) {
   530  	d := s3manager.NewDownloader(unit.Session)
   531  
   532  	params := s3.GetObjectInput{
   533  		Bucket: aws.String("Bucket"),
   534  		Key:    aws.String("Key"),
   535  	}
   536  
   537  	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
   538  	ctx.Error = fmt.Errorf("context canceled")
   539  	close(ctx.DoneCh)
   540  
   541  	w := &aws.WriteAtBuffer{}
   542  
   543  	_, err := d.DownloadWithContext(ctx, w, &params)
   544  	if err == nil {
   545  		t.Fatalf("expected error, did not get one")
   546  	}
   547  	aerr := err.(awserr.Error)
   548  	if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
   549  		t.Errorf("expected error code %q, got %q", e, a)
   550  	}
   551  	if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) {
   552  		t.Errorf("expected error message to contain %q, but did not %q", e, a)
   553  	}
   554  }
   555  
   556  func TestDownload_WithRange(t *testing.T) {
   557  	s, names, ranges := dlLoggingSvc([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
   558  
   559  	d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
   560  		d.Concurrency = 10 // should be ignored
   561  		d.PartSize = 1     // should be ignored
   562  	})
   563  
   564  	w := &aws.WriteAtBuffer{}
   565  	n, err := d.Download(w, &s3.GetObjectInput{
   566  		Bucket: aws.String("bucket"),
   567  		Key:    aws.String("key"),
   568  		Range:  aws.String("bytes=2-6"),
   569  	})
   570  
   571  	if err != nil {
   572  		t.Fatalf("expect no error, got %v", err)
   573  	}
   574  	if e, a := int64(5), n; e != a {
   575  		t.Errorf("expect %d bytes read, got %d", e, a)
   576  	}
   577  	expectCalls := []string{"GetObject"}
   578  	if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
   579  		t.Errorf("expect %v API calls, got %v", e, a)
   580  	}
   581  	expectRngs := []string{"bytes=2-6"}
   582  	if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
   583  		t.Errorf("expect %v ranges, got %v", e, a)
   584  	}
   585  	expectBytes := []byte{2, 3, 4, 5, 6}
   586  	if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
   587  		t.Errorf("expect %v bytes, got %v", e, a)
   588  	}
   589  }
   590  
   591  func TestDownload_WithFailure(t *testing.T) {
   592  	svc := s3.New(unit.Session)
   593  	svc.Handlers.Send.Clear()
   594  
   595  	reqCount := int64(0)
   596  	startingByte := 0
   597  	svc.Handlers.Send.PushBack(func(r *request.Request) {
   598  		switch atomic.LoadInt64(&reqCount) {
   599  		case 1:
   600  			// Give a chance for the multipart chunks to be queued up
   601  			time.Sleep(1 * time.Second)
   602  
   603  			r.HTTPResponse = &http.Response{
   604  				Header: http.Header{},
   605  				Body:   ioutil.NopCloser(&bytes.Buffer{}),
   606  			}
   607  			r.Error = awserr.New("ConnectionError", "some connection error", nil)
   608  			r.Retryable = aws.Bool(false)
   609  
   610  		default:
   611  			body := bytes.NewReader(make([]byte, s3manager.DefaultDownloadPartSize))
   612  			r.HTTPResponse = &http.Response{
   613  				StatusCode:    http.StatusOK,
   614  				Status:        http.StatusText(http.StatusOK),
   615  				ContentLength: int64(body.Len()),
   616  				Body:          ioutil.NopCloser(body),
   617  				Header:        http.Header{},
   618  			}
   619  			r.HTTPResponse.Header.Set("Content-Length", strconv.Itoa(body.Len()))
   620  			r.HTTPResponse.Header.Set("Content-Range",
   621  				fmt.Sprintf("bytes %d-%d/%d", startingByte, body.Len()-1, body.Len()*10))
   622  
   623  			startingByte += body.Len()
   624  			if reqCount > 0 {
   625  				// sleep here to ensure context switching between goroutines
   626  				time.Sleep(25 * time.Millisecond)
   627  			}
   628  		}
   629  
   630  		atomic.AddInt64(&reqCount, 1)
   631  	})
   632  
   633  	d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
   634  		d.Concurrency = 2
   635  	})
   636  
   637  	w := &aws.WriteAtBuffer{}
   638  	params := s3.GetObjectInput{
   639  		Bucket: aws.String("Bucket"),
   640  		Key:    aws.String("Key"),
   641  	}
   642  
   643  	// Expect this request to exit quickly after failure
   644  	_, err := d.Download(w, &params)
   645  	if err == nil {
   646  		t.Fatalf("expect error, got none")
   647  	}
   648  
   649  	if atomic.LoadInt64(&reqCount) > 3 {
   650  		t.Errorf("expect no more than 3 requests, but received %d", reqCount)
   651  	}
   652  }
   653  
   654  func TestDownloadBufferStrategy(t *testing.T) {
   655  	cases := map[string]struct {
   656  		partSize     int64
   657  		strategy     *recordedWriterReadFromProvider
   658  		expectedSize int64
   659  	}{
   660  		"no strategy": {
   661  			partSize:     s3manager.DefaultDownloadPartSize,
   662  			expectedSize: 10 * sdkio.MebiByte,
   663  		},
   664  		"partSize modulo bufferSize == 0": {
   665  			partSize: 5 * sdkio.MebiByte,
   666  			strategy: &recordedWriterReadFromProvider{
   667  				WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(int(sdkio.MebiByte)), // 1 MiB
   668  			},
   669  			expectedSize: 10 * sdkio.MebiByte, // 10 MiB
   670  		},
   671  		"partSize modulo bufferSize > 0": {
   672  			partSize: 5 * 1024 * 1204, // 5 MiB
   673  			strategy: &recordedWriterReadFromProvider{
   674  				WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(2 * int(sdkio.MebiByte)), // 2 MiB
   675  			},
   676  			expectedSize: 10 * sdkio.MebiByte, // 10 MiB
   677  		},
   678  	}
   679  
   680  	for name, tCase := range cases {
   681  		t.Logf("starting case: %v", name)
   682  
   683  		expected := s3testing.GetTestBytes(int(tCase.expectedSize))
   684  
   685  		svc, _, _ := dlLoggingSvc(expected)
   686  
   687  		d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
   688  			d.PartSize = tCase.partSize
   689  			if tCase.strategy != nil {
   690  				d.BufferProvider = tCase.strategy
   691  			}
   692  		})
   693  
   694  		buffer := aws.NewWriteAtBuffer(make([]byte, len(expected)))
   695  
   696  		n, err := d.Download(buffer, &s3.GetObjectInput{
   697  			Bucket: aws.String("bucket"),
   698  			Key:    aws.String("key"),
   699  		})
   700  		if err != nil {
   701  			t.Errorf("failed to download: %v", err)
   702  		}
   703  
   704  		if e, a := len(expected), int(n); e != a {
   705  			t.Errorf("expected %v, got %v downloaded bytes", e, a)
   706  		}
   707  
   708  		if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) {
   709  			t.Errorf("downloaded bytes did not match expected")
   710  		}
   711  
   712  		if tCase.strategy != nil {
   713  			if e, a := tCase.strategy.callbacksVended, tCase.strategy.callbacksExecuted; e != a {
   714  				t.Errorf("expected %v, got %v", e, a)
   715  			}
   716  		}
   717  	}
   718  }
   719  
   720  type testErrReader struct {
   721  	Buf []byte
   722  	Err error
   723  	Len int64
   724  
   725  	off int
   726  }
   727  
   728  func (r *testErrReader) Read(p []byte) (int, error) {
   729  	to := len(r.Buf) - r.off
   730  
   731  	n := copy(p, r.Buf[r.off:to])
   732  	r.off += n
   733  
   734  	if n < len(p) {
   735  		return n, r.Err
   736  
   737  	}
   738  
   739  	return n, nil
   740  }
   741  
   742  func TestDownloadBufferStrategy_Errors(t *testing.T) {
   743  	expected := s3testing.GetTestBytes(int(10 * sdkio.MebiByte))
   744  
   745  	svc, _, _ := dlLoggingSvc(expected)
   746  	strat := &recordedWriterReadFromProvider{
   747  		WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(int(2 * sdkio.MebiByte)),
   748  	}
   749  
   750  	d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
   751  		d.PartSize = 5 * sdkio.MebiByte
   752  		d.BufferProvider = strat
   753  		d.Concurrency = 1
   754  	})
   755  
   756  	seenOps := make(map[string]struct{})
   757  	svc.Handlers.Send.PushFront(func(*request.Request) {})
   758  	svc.Handlers.Send.AfterEachFn = func(item request.HandlerListRunItem) bool {
   759  		r := item.Request
   760  
   761  		if r.Operation.Name != "GetObject" {
   762  			return true
   763  		}
   764  
   765  		input := r.Params.(*s3.GetObjectInput)
   766  
   767  		fingerPrint := fmt.Sprintf("%s/%s/%s/%s", r.Operation.Name, *input.Bucket, *input.Key, *input.Range)
   768  		if _, ok := seenOps[fingerPrint]; ok {
   769  			return true
   770  		}
   771  		seenOps[fingerPrint] = struct{}{}
   772  
   773  		regex := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
   774  		rng := regex.FindStringSubmatch(*input.Range)
   775  		start, _ := strconv.ParseInt(rng[1], 10, 64)
   776  		fin, _ := strconv.ParseInt(rng[2], 10, 64)
   777  
   778  		_, _ = io.Copy(ioutil.Discard, r.Body)
   779  		r.HTTPResponse = &http.Response{
   780  			StatusCode:    200,
   781  			Body:          aws.ReadSeekCloser(&badReader{err: io.ErrUnexpectedEOF}),
   782  			ContentLength: fin - start,
   783  		}
   784  
   785  		return false
   786  	}
   787  
   788  	buffer := aws.NewWriteAtBuffer(make([]byte, len(expected)))
   789  
   790  	n, err := d.Download(buffer, &s3.GetObjectInput{
   791  		Bucket: aws.String("bucket"),
   792  		Key:    aws.String("key"),
   793  	})
   794  	if err != nil {
   795  		t.Errorf("failed to download: %v", err)
   796  	}
   797  
   798  	if e, a := len(expected), int(n); e != a {
   799  		t.Errorf("expected %v, got %v downloaded bytes", e, a)
   800  	}
   801  
   802  	if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) {
   803  		t.Errorf("downloaded bytes did not match expected")
   804  	}
   805  
   806  	if e, a := strat.callbacksVended, strat.callbacksExecuted; e != a {
   807  		t.Errorf("expected %v, got %v", e, a)
   808  	}
   809  }
   810  
   811  func TestDownloaderValidARN(t *testing.T) {
   812  	cases := map[string]struct {
   813  		input   s3.GetObjectInput
   814  		wantErr bool
   815  	}{
   816  		"standard bucket": {
   817  			input: s3.GetObjectInput{
   818  				Bucket: aws.String("test-bucket"),
   819  				Key:    aws.String("test-key"),
   820  			},
   821  		},
   822  		"accesspoint": {
   823  			input: s3.GetObjectInput{
   824  				Bucket: aws.String("arn:aws:s3:us-west-2:123456789012:accesspoint/myap"),
   825  				Key:    aws.String("test-key"),
   826  			},
   827  		},
   828  		"outpost accesspoint": {
   829  			input: s3.GetObjectInput{
   830  				Bucket: aws.String("arn:aws:s3-outposts:us-west-2:012345678901:outpost/op-1234567890123456/accesspoint/myaccesspoint"),
   831  				Key:    aws.String("test-key"),
   832  			},
   833  		},
   834  		"s3-object-lambda accesspoint": {
   835  			input: s3.GetObjectInput{
   836  				Bucket: aws.String("arn:aws:s3-object-lambda:us-west-2:123456789012:accesspoint/myap"),
   837  			},
   838  			wantErr: true,
   839  		},
   840  	}
   841  
   842  	for name, tt := range cases {
   843  		t.Run(name, func(t *testing.T) {
   844  			client, _, _ := dlLoggingSvc(buf2MB)
   845  
   846  			client.Config.Region = aws.String("us-west-2")
   847  			client.ClientInfo.SigningRegion = "us-west-2"
   848  
   849  			downloader := s3manager.NewDownloaderWithClient(client, func(downloader *s3manager.Downloader) {
   850  				downloader.Concurrency = 1
   851  			})
   852  
   853  			_, err := downloader.Download(&awstesting.DiscardAt{}, &tt.input)
   854  			if (err != nil) != tt.wantErr {
   855  				t.Errorf("err: %v, wantErr: %v", err, tt.wantErr)
   856  			}
   857  		})
   858  	}
   859  }
   860  
   861  type recordedWriterReadFromProvider struct {
   862  	callbacksVended   uint32
   863  	callbacksExecuted uint32
   864  	s3manager.WriterReadFromProvider
   865  }
   866  
   867  func (r *recordedWriterReadFromProvider) GetReadFrom(writer io.Writer) (s3manager.WriterReadFrom, func()) {
   868  	w, cleanup := r.WriterReadFromProvider.GetReadFrom(writer)
   869  
   870  	atomic.AddUint32(&r.callbacksVended, 1)
   871  	return w, func() {
   872  		atomic.AddUint32(&r.callbacksExecuted, 1)
   873  		cleanup()
   874  	}
   875  }
   876  
   877  type badReader struct {
   878  	err error
   879  }
   880  
   881  func (b *badReader) Read(p []byte) (int, error) {
   882  	tb := s3testing.GetTestBytes(len(p))
   883  	copy(p, tb)
   884  
   885  	return len(p), b.err
   886  }
   887  
   888  var mockErrorResponse = struct {
   889  	XMLName xml.Name `xml:"Error"`
   890  	Code    string   `xml:"Code"`
   891  	Message string   `xml:"Message"`
   892  }{
   893  	Code:    "MOCK_S3_ERROR_CODE",
   894  	Message: "Mocked S3 Error Message",
   895  }