github.com/aavshr/aws-sdk-go@v1.41.3/service/s3/s3manager/upload_internal_test.go (about)

     1  //go:build go1.7
     2  // +build go1.7
     3  
     4  package s3manager
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	random "math/rand"
    12  	"net/http"
    13  	"strconv"
    14  	"sync"
    15  	"sync/atomic"
    16  	"testing"
    17  
    18  	"github.com/aavshr/aws-sdk-go/aws"
    19  	"github.com/aavshr/aws-sdk-go/aws/request"
    20  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    21  	"github.com/aavshr/aws-sdk-go/internal/sdkio"
    22  	"github.com/aavshr/aws-sdk-go/service/s3"
    23  	"github.com/aavshr/aws-sdk-go/service/s3/internal/s3testing"
    24  )
    25  
    26  const respBody = `<?xml version="1.0" encoding="UTF-8"?>
    27  <CompleteMultipartUploadOutput>
    28     <Location>mockValue</Location>
    29     <Bucket>mockValue</Bucket>
    30     <Key>mockValue</Key>
    31     <ETag>mockValue</ETag>
    32  </CompleteMultipartUploadOutput>`
    33  
    34  type testReader struct {
    35  	br *bytes.Reader
    36  	m  sync.Mutex
    37  }
    38  
    39  func (r *testReader) Read(p []byte) (n int, err error) {
    40  	r.m.Lock()
    41  	defer r.m.Unlock()
    42  	return r.br.Read(p)
    43  }
    44  
    45  func TestUploadByteSlicePool(t *testing.T) {
    46  	cases := map[string]struct {
    47  		PartSize      int64
    48  		FileSize      int64
    49  		Concurrency   int
    50  		ExAllocations uint64
    51  	}{
    52  		"single part, single concurrency": {
    53  			PartSize:      sdkio.MebiByte * 5,
    54  			FileSize:      sdkio.MebiByte * 5,
    55  			ExAllocations: 2,
    56  			Concurrency:   1,
    57  		},
    58  		"multi-part, single concurrency": {
    59  			PartSize:      sdkio.MebiByte * 5,
    60  			FileSize:      sdkio.MebiByte * 10,
    61  			ExAllocations: 2,
    62  			Concurrency:   1,
    63  		},
    64  		"multi-part, multiple concurrency": {
    65  			PartSize:      sdkio.MebiByte * 5,
    66  			FileSize:      sdkio.MebiByte * 20,
    67  			ExAllocations: 3,
    68  			Concurrency:   2,
    69  		},
    70  	}
    71  
    72  	for name, tt := range cases {
    73  		t.Run(name, func(t *testing.T) {
    74  			var p *recordedPartPool
    75  
    76  			unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
    77  				p = newRecordedPartPool(sliceSize)
    78  				return p
    79  			})
    80  			defer unswap()
    81  
    82  			sess := unit.Session.Copy()
    83  			svc := s3.New(sess)
    84  			svc.Handlers.Unmarshal.Clear()
    85  			svc.Handlers.UnmarshalMeta.Clear()
    86  			svc.Handlers.UnmarshalError.Clear()
    87  			svc.Handlers.Send.Clear()
    88  			svc.Handlers.Send.PushFront(func(r *request.Request) {
    89  				if r.Body != nil {
    90  					io.Copy(ioutil.Discard, r.Body)
    91  				}
    92  
    93  				r.HTTPResponse = &http.Response{
    94  					StatusCode: 200,
    95  					Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
    96  				}
    97  
    98  				switch data := r.Data.(type) {
    99  				case *s3.CreateMultipartUploadOutput:
   100  					data.UploadId = aws.String("UPLOAD-ID")
   101  				case *s3.UploadPartOutput:
   102  					data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
   103  				case *s3.CompleteMultipartUploadOutput:
   104  					data.Location = aws.String("https://location")
   105  					data.VersionId = aws.String("VERSION-ID")
   106  				case *s3.PutObjectOutput:
   107  					data.VersionId = aws.String("VERSION-ID")
   108  				}
   109  			})
   110  
   111  			uploader := NewUploaderWithClient(svc, func(u *Uploader) {
   112  				u.PartSize = tt.PartSize
   113  				u.Concurrency = tt.Concurrency
   114  			})
   115  
   116  			expected := s3testing.GetTestBytes(int(tt.FileSize))
   117  			_, err := uploader.Upload(&UploadInput{
   118  				Bucket: aws.String("bucket"),
   119  				Key:    aws.String("key"),
   120  				Body:   &testReader{br: bytes.NewReader(expected)},
   121  			})
   122  			if err != nil {
   123  				t.Errorf("expected no error, but got %v", err)
   124  			}
   125  
   126  			if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
   127  				t.Fatalf("expected zero outsnatding pool parts, got %d", v)
   128  			}
   129  
   130  			gets, allocs := atomic.LoadUint64(&p.recordedGets), atomic.LoadUint64(&p.recordedAllocs)
   131  
   132  			t.Logf("total gets %v, total allocations %v", gets, allocs)
   133  			if e, a := tt.ExAllocations, allocs; a > e {
   134  				t.Errorf("expected %v allocations, got %v", e, a)
   135  			}
   136  		})
   137  	}
   138  }
   139  
   140  func TestUploadByteSlicePool_Failures(t *testing.T) {
   141  	cases := map[string]struct {
   142  		PartSize   int64
   143  		FileSize   int64
   144  		Operations []string
   145  	}{
   146  		"single part": {
   147  			PartSize: sdkio.MebiByte * 5,
   148  			FileSize: sdkio.MebiByte * 4,
   149  			Operations: []string{
   150  				"PutObject",
   151  			},
   152  		},
   153  		"multi-part": {
   154  			PartSize: sdkio.MebiByte * 5,
   155  			FileSize: sdkio.MebiByte * 10,
   156  			Operations: []string{
   157  				"CreateMultipartUpload",
   158  				"UploadPart",
   159  				"CompleteMultipartUpload",
   160  			},
   161  		},
   162  	}
   163  
   164  	for name, tt := range cases {
   165  		t.Run(name, func(t *testing.T) {
   166  			for _, operation := range tt.Operations {
   167  				t.Run(operation, func(t *testing.T) {
   168  					var p *recordedPartPool
   169  
   170  					unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
   171  						p = newRecordedPartPool(sliceSize)
   172  						return p
   173  					})
   174  					defer unswap()
   175  
   176  					sess := unit.Session.Copy()
   177  					svc := s3.New(sess)
   178  					svc.Handlers.Unmarshal.Clear()
   179  					svc.Handlers.UnmarshalMeta.Clear()
   180  					svc.Handlers.UnmarshalError.Clear()
   181  					svc.Handlers.Send.Clear()
   182  					svc.Handlers.Send.PushFront(func(r *request.Request) {
   183  						if r.Body != nil {
   184  							io.Copy(ioutil.Discard, r.Body)
   185  						}
   186  
   187  						if r.Operation.Name == operation {
   188  							r.Retryable = aws.Bool(false)
   189  							r.Error = fmt.Errorf("request error")
   190  							r.HTTPResponse = &http.Response{
   191  								StatusCode: 500,
   192  								Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
   193  							}
   194  							return
   195  						}
   196  
   197  						r.HTTPResponse = &http.Response{
   198  							StatusCode: 200,
   199  							Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
   200  						}
   201  
   202  						switch data := r.Data.(type) {
   203  						case *s3.CreateMultipartUploadOutput:
   204  							data.UploadId = aws.String("UPLOAD-ID")
   205  						case *s3.UploadPartOutput:
   206  							data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
   207  						case *s3.CompleteMultipartUploadOutput:
   208  							data.Location = aws.String("https://location")
   209  							data.VersionId = aws.String("VERSION-ID")
   210  						case *s3.PutObjectOutput:
   211  							data.VersionId = aws.String("VERSION-ID")
   212  						}
   213  					})
   214  
   215  					uploader := NewUploaderWithClient(svc, func(u *Uploader) {
   216  						u.Concurrency = 1
   217  						u.PartSize = tt.PartSize
   218  					})
   219  
   220  					expected := s3testing.GetTestBytes(int(tt.FileSize))
   221  					_, err := uploader.Upload(&UploadInput{
   222  						Bucket: aws.String("bucket"),
   223  						Key:    aws.String("key"),
   224  						Body:   &testReader{br: bytes.NewReader(expected)},
   225  					})
   226  					if err == nil {
   227  						t.Fatalf("expected error but got none")
   228  					}
   229  
   230  					if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
   231  						t.Fatalf("expected zero outsnatding pool parts, got %d", v)
   232  					}
   233  				})
   234  			}
   235  		})
   236  	}
   237  }
   238  
   239  func TestUploadByteSlicePoolConcurrentMultiPartSize(t *testing.T) {
   240  	var (
   241  		pools []*recordedPartPool
   242  		mtx   sync.Mutex
   243  	)
   244  
   245  	unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
   246  		mtx.Lock()
   247  		defer mtx.Unlock()
   248  		b := newRecordedPartPool(sliceSize)
   249  		pools = append(pools, b)
   250  		return b
   251  	})
   252  	defer unswap()
   253  
   254  	sess := unit.Session.Copy()
   255  	svc := s3.New(sess)
   256  	svc.Handlers.Unmarshal.Clear()
   257  	svc.Handlers.UnmarshalMeta.Clear()
   258  	svc.Handlers.UnmarshalError.Clear()
   259  	svc.Handlers.Send.Clear()
   260  	svc.Handlers.Send.PushFront(func(r *request.Request) {
   261  		if r.Body != nil {
   262  			io.Copy(ioutil.Discard, r.Body)
   263  		}
   264  
   265  		r.HTTPResponse = &http.Response{
   266  			StatusCode: 200,
   267  			Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
   268  		}
   269  
   270  		switch data := r.Data.(type) {
   271  		case *s3.CreateMultipartUploadOutput:
   272  			data.UploadId = aws.String("UPLOAD-ID")
   273  		case *s3.UploadPartOutput:
   274  			data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
   275  		case *s3.CompleteMultipartUploadOutput:
   276  			data.Location = aws.String("https://location")
   277  			data.VersionId = aws.String("VERSION-ID")
   278  		case *s3.PutObjectOutput:
   279  			data.VersionId = aws.String("VERSION-ID")
   280  		}
   281  	})
   282  
   283  	uploader := NewUploaderWithClient(svc, func(u *Uploader) {
   284  		u.PartSize = 5 * sdkio.MebiByte
   285  		u.Concurrency = 2
   286  	})
   287  
   288  	var wg sync.WaitGroup
   289  	for i := 0; i < 2; i++ {
   290  		wg.Add(2)
   291  		go func() {
   292  			defer wg.Done()
   293  			expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
   294  			_, err := uploader.Upload(&UploadInput{
   295  				Bucket: aws.String("bucket"),
   296  				Key:    aws.String("key"),
   297  				Body:   &testReader{br: bytes.NewReader(expected)},
   298  			})
   299  			if err != nil {
   300  				t.Errorf("expected no error, but got %v", err)
   301  			}
   302  		}()
   303  		go func() {
   304  			defer wg.Done()
   305  			expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
   306  			_, err := uploader.Upload(&UploadInput{
   307  				Bucket: aws.String("bucket"),
   308  				Key:    aws.String("key"),
   309  				Body:   &testReader{br: bytes.NewReader(expected)},
   310  			}, func(u *Uploader) {
   311  				u.PartSize = 6 * sdkio.MebiByte
   312  			})
   313  			if err != nil {
   314  				t.Errorf("expected no error, but got %v", err)
   315  			}
   316  		}()
   317  	}
   318  
   319  	wg.Wait()
   320  
   321  	if e, a := 3, len(pools); e != a {
   322  		t.Errorf("expected %v, got %v", e, a)
   323  	}
   324  
   325  	for _, p := range pools {
   326  		if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
   327  			t.Fatalf("expected zero outsnatding pool parts, got %d", v)
   328  		}
   329  
   330  		t.Logf("total gets %v, total allocations %v",
   331  			atomic.LoadUint64(&p.recordedGets),
   332  			atomic.LoadUint64(&p.recordedAllocs))
   333  	}
   334  }
   335  
   336  func BenchmarkPools(b *testing.B) {
   337  	cases := []struct {
   338  		PartSize      int64
   339  		FileSize      int64
   340  		Concurrency   int
   341  		ExAllocations uint64
   342  	}{
   343  		0: {
   344  			PartSize:    sdkio.MebiByte * 5,
   345  			FileSize:    sdkio.MebiByte * 5,
   346  			Concurrency: 1,
   347  		},
   348  		1: {
   349  			PartSize:    sdkio.MebiByte * 5,
   350  			FileSize:    sdkio.MebiByte * 10,
   351  			Concurrency: 1,
   352  		},
   353  		2: {
   354  			PartSize:    sdkio.MebiByte * 5,
   355  			FileSize:    sdkio.MebiByte * 20,
   356  			Concurrency: 2,
   357  		},
   358  		3: {
   359  			PartSize:    sdkio.MebiByte * 5,
   360  			FileSize:    sdkio.MebiByte * 250,
   361  			Concurrency: 10,
   362  		},
   363  	}
   364  
   365  	sess := unit.Session.Copy()
   366  	svc := s3.New(sess)
   367  	svc.Handlers.Unmarshal.Clear()
   368  	svc.Handlers.UnmarshalMeta.Clear()
   369  	svc.Handlers.UnmarshalError.Clear()
   370  	svc.Handlers.Send.Clear()
   371  	svc.Handlers.Send.PushFront(func(r *request.Request) {
   372  		if r.Body != nil {
   373  			io.Copy(ioutil.Discard, r.Body)
   374  		}
   375  
   376  		r.HTTPResponse = &http.Response{
   377  			StatusCode: 200,
   378  			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
   379  		}
   380  
   381  		switch data := r.Data.(type) {
   382  		case *s3.CreateMultipartUploadOutput:
   383  			data.UploadId = aws.String("UPLOAD-ID")
   384  		case *s3.UploadPartOutput:
   385  			data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
   386  		case *s3.CompleteMultipartUploadOutput:
   387  			data.Location = aws.String("https://location")
   388  			data.VersionId = aws.String("VERSION-ID")
   389  		case *s3.PutObjectOutput:
   390  			data.VersionId = aws.String("VERSION-ID")
   391  		}
   392  	})
   393  
   394  	pools := map[string]func(sliceSize int64) byteSlicePool{
   395  		"sync.Pool": func(sliceSize int64) byteSlicePool {
   396  			return newSyncSlicePool(sliceSize)
   397  		},
   398  		"custom": func(sliceSize int64) byteSlicePool {
   399  			return newMaxSlicePool(sliceSize)
   400  		},
   401  	}
   402  
   403  	for name, poolFunc := range pools {
   404  		b.Run(name, func(b *testing.B) {
   405  			unswap := swapByteSlicePool(poolFunc)
   406  			defer unswap()
   407  			for i, c := range cases {
   408  				b.Run(strconv.Itoa(i), func(b *testing.B) {
   409  					uploader := NewUploaderWithClient(svc, func(u *Uploader) {
   410  						u.PartSize = c.PartSize
   411  						u.Concurrency = c.Concurrency
   412  					})
   413  
   414  					expected := s3testing.GetTestBytes(int(c.FileSize))
   415  					b.ResetTimer()
   416  					_, err := uploader.Upload(&UploadInput{
   417  						Bucket: aws.String("bucket"),
   418  						Key:    aws.String("key"),
   419  						Body:   &testReader{br: bytes.NewReader(expected)},
   420  					})
   421  					if err != nil {
   422  						b.Fatalf("expected no error, but got %v", err)
   423  					}
   424  				})
   425  			}
   426  		})
   427  	}
   428  }