github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/s3util/s3copy_test.go (about)

     1  package s3util
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"math/rand"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/Schaudge/grailbase/retry"
    13  	"github.com/grailbio/testutil"
    14  	"github.com/grailbio/testutil/s3test"
    15  
    16  	"github.com/aws/aws-sdk-go/aws"
    17  	"github.com/aws/aws-sdk-go/aws/awserr"
    18  	"github.com/aws/aws-sdk-go/aws/request"
    19  	"github.com/aws/aws-sdk-go/service/s3"
    20  )
    21  
    22  const testBucket = "test_bucket"
    23  
    24  var (
    25  	testKeys  = map[string]*testutil.ByteContent{"test/x": content("some sample content")}
    26  	errorKeys = map[string]error{
    27  		"key_awscanceled":       awserr.New(request.CanceledErrorCode, "test", nil),
    28  		"key_nosuchkey":         awserr.New(s3.ErrCodeNoSuchKey, "test", nil),
    29  		"key_badrequest":        awserr.New("BadRequest", "test", nil),
    30  		"key_canceled":          context.Canceled,
    31  		"key_deadlineexceeded":  context.DeadlineExceeded,
    32  		"key_awsrequesttimeout": awserr.New("RequestTimeout", "test", nil),
    33  		"key_nestedEOFrequest":  awserr.New("MultipartUpload", "test", awserr.New("SerializationError", "test2", fmt.Errorf("unexpected EOF"))),
    34  		"key_awsinternalerror":  awserr.New("InternalError", "test", nil),
    35  	}
    36  )
    37  
    38  func newTestClient(t *testing.T) *s3test.Client {
    39  	t.Helper()
    40  	client := s3test.NewClient(t, testBucket)
    41  	client.Region = "us-west-2"
    42  	for k, v := range testKeys {
    43  		client.SetFileContentAt(k, v, "")
    44  	}
    45  	return client
    46  }
    47  
    48  func newFailingTestClient(t *testing.T, fn *failN) *s3test.Client {
    49  	t.Helper()
    50  	client := newTestClient(t)
    51  	client.Err = func(api string, input interface{}) error {
    52  		switch api {
    53  		case "UploadPartCopyWithContext":
    54  			if upc, ok := input.(*s3.UploadPartCopyInput); ok {
    55  				// Possibly fail the first part with an error based on the key
    56  				if *upc.PartNumber == int64(1) && fn.fail() {
    57  					return errorKeys[*upc.Key]
    58  				}
    59  			}
    60  		case "CopyObjectRequest":
    61  			if req, ok := input.(*s3.CopyObjectInput); ok && fn.fail() {
    62  				return errorKeys[*req.Key]
    63  			}
    64  		}
    65  		return nil
    66  	}
    67  	return client
    68  }
    69  
    70  func TestBucketKey(t *testing.T) {
    71  	for _, tc := range []struct {
    72  		url, wantBucket, wantKey string
    73  		wantErr                  bool
    74  	}{
    75  		{"s3://bucket/key", "bucket", "key", false},
    76  		{"s3://some_other-bucket/very/long/key", "some_other-bucket", "very/long/key", false},
    77  	} {
    78  		gotB, gotK, gotE := bucketKey(tc.url)
    79  		if tc.wantErr && gotE == nil {
    80  			t.Errorf("%s got no error, want error", tc.url)
    81  			continue
    82  		}
    83  		if got, want := gotB, tc.wantBucket; got != want {
    84  			t.Errorf("got %s want %s", got, want)
    85  		}
    86  		if got, want := gotK, tc.wantKey; got != want {
    87  			t.Errorf("got %s want %s", got, want)
    88  		}
    89  	}
    90  }
    91  
    92  func TestCopy(t *testing.T) {
    93  	client := newTestClient(t)
    94  	copier := NewCopier(client)
    95  
    96  	srcKey, srcSize, dstKey := "test/x", testKeys["test/x"].Size(), "test/x_copy"
    97  	srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, srcKey)
    98  	dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, dstKey)
    99  
   100  	checkObject(t, client, srcKey, testKeys[srcKey])
   101  	if err := copier.Copy(context.Background(), srcUrl, dstUrl, srcSize, nil); err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	checkObject(t, client, dstKey, testKeys[srcKey])
   105  }
   106  
   107  func TestCopyWithRetry(t *testing.T) {
   108  	client := newFailingTestClient(t, &failN{n: 2})
   109  	retrier := retry.MaxRetries(retry.Jitter(retry.Backoff(10*time.Millisecond, 50*time.Millisecond, 2), 0.25), 4)
   110  	copier := NewCopierWithParams(client, retrier, 1<<10, 1<<10, testDebugger{t})
   111  
   112  	for _, tc := range []struct {
   113  		srcKey  string
   114  		dstKey  string
   115  		srcSize int64
   116  	}{
   117  		{
   118  			srcKey:  "test/x",
   119  			dstKey:  "key_awsrequesttimeout",
   120  			srcSize: testKeys["test/x"].Size(),
   121  		},
   122  		{
   123  			srcKey:  "test/x",
   124  			dstKey:  "key_awsinternalerror",
   125  			srcSize: testKeys["test/x"].Size(),
   126  		},
   127  	} {
   128  		srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.srcKey)
   129  		dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.dstKey)
   130  		checkObject(t, client, tc.srcKey, testKeys[tc.srcKey])
   131  		if err := copier.Copy(context.Background(), srcUrl, dstUrl, tc.srcSize, nil); err != nil {
   132  			t.Fatal(err)
   133  		}
   134  		checkObject(t, client, tc.dstKey, testKeys[tc.srcKey])
   135  	}
   136  
   137  }
   138  
   139  func TestCopyMultipart(t *testing.T) {
   140  	bctx := context.Background()
   141  	for _, tc := range []struct {
   142  		client                 *s3test.Client
   143  		dstKey                 string
   144  		size, limit, partsize  int64
   145  		useShortCtx, cancelCtx bool
   146  		wantErr                bool
   147  	}{
   148  		// 100KiB of data, multi-part limit 50KiB, part size 10KiB
   149  		{newTestClient(t), "dst1", 100 << 10, 50 << 10, 10 << 10, false, false, false},
   150  		// 50KiB of data, multi-part limit 50KiB, part size 10KiB
   151  		{newTestClient(t), "dst2", 50 << 10, 50 << 10, 10 << 10, false, false, false},
   152  		{newTestClient(t), "dst3", 100 << 10, 50 << 10, 10 << 10, true, false, true},
   153  		{newTestClient(t), "dst4", 100 << 10, 50 << 10, 10 << 10, false, true, true},
   154  		{newFailingTestClient(t, &failN{n: 2}), "key_badrequest", 100 << 10, 50 << 10, 10 << 10, false, false, false},
   155  		{newFailingTestClient(t, &failN{n: 2}), "key_deadlineexceeded", 100 << 10, 50 << 10, 10 << 10, false, false, false},
   156  		{newFailingTestClient(t, &failN{n: 2}), "key_awsrequesttimeout", 100 << 10, 50 << 10, 10 << 10, false, false, false},
   157  		{newFailingTestClient(t, &failN{n: 2}), "key_nestedEOFrequest", 100 << 10, 50 << 10, 10 << 10, false, false, false},
   158  		{newFailingTestClient(t, &failN{n: 2}), "key_canceled", 100 << 10, 50 << 10, 10 << 10, false, false, true},
   159  		{newFailingTestClient(t, &failN{n: defaultMaxRetries + 1}), "key_badrequest", 100 << 10, 50 << 10, 10 << 10, false, false, true},
   160  	} {
   161  		client := tc.client
   162  		b := make([]byte, tc.size)
   163  		if _, err := rand.Read(b); err != nil {
   164  			t.Fatal(err)
   165  		}
   166  		srcKey, srcContent := "src", &testutil.ByteContent{Data: b}
   167  		client.SetFileContentAt(srcKey, srcContent, "")
   168  		checkObject(t, client, srcKey, srcContent)
   169  
   170  		retrier := retry.MaxRetries(retry.Jitter(retry.Backoff(10*time.Millisecond, 50*time.Millisecond, 2), 0.25), defaultMaxRetries)
   171  		copier := NewCopierWithParams(client, retrier, tc.limit, tc.partsize, testDebugger{t})
   172  
   173  		ctx := bctx
   174  		var cancel context.CancelFunc
   175  		if tc.useShortCtx {
   176  			ctx, cancel = context.WithTimeout(bctx, 10*time.Nanosecond)
   177  		} else if tc.cancelCtx {
   178  			ctx, cancel = context.WithCancel(bctx)
   179  			cancel()
   180  		}
   181  		srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, srcKey)
   182  		dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.dstKey)
   183  
   184  		err := copier.Copy(ctx, srcUrl, dstUrl, tc.size, nil)
   185  		if cancel != nil {
   186  			cancel()
   187  		}
   188  		if tc.wantErr {
   189  			if err == nil {
   190  				t.Errorf("%s got no error, want error", tc.dstKey)
   191  			}
   192  			continue
   193  		}
   194  		if err != nil {
   195  			t.Fatal(err)
   196  		}
   197  		checkObject(t, client, tc.dstKey, srcContent)
   198  		if t.Failed() {
   199  			t.Logf("case: %v", tc)
   200  		}
   201  	}
   202  }
   203  
   204  func content(s string) *testutil.ByteContent {
   205  	return &testutil.ByteContent{Data: []byte(s)}
   206  }
   207  
   208  func checkObject(t *testing.T, client *s3test.Client, key string, c *testutil.ByteContent) {
   209  	t.Helper()
   210  	out, err := client.GetObject(&s3.GetObjectInput{
   211  		Bucket: aws.String(testBucket),
   212  		Key:    aws.String(key),
   213  	})
   214  	if err != nil {
   215  		t.Fatal(err)
   216  	}
   217  	p, err := ioutil.ReadAll(out.Body)
   218  	if err != nil {
   219  		t.Fatal(err)
   220  	}
   221  	if got, want := p, c.Data; !bytes.Equal(got, want) {
   222  		t.Errorf("got %v, want %v", got, want)
   223  	}
   224  }
   225  
   226  // failN returns true n times when fail() is called and then returns false, until its reset.
   227  type failN struct {
   228  	n, i int
   229  }
   230  
   231  func (p *failN) fail() bool {
   232  	if p.i < p.n {
   233  		p.i++
   234  		return true
   235  	}
   236  	return false
   237  }
   238  
   239  func (p *failN) reset() {
   240  	p.i = 0
   241  }
   242  
   243  type testDebugger struct{ *testing.T }
   244  
   245  func (d testDebugger) Debugf(format string, args ...interface{}) { d.T.Logf(format, args...) }