github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/s3util/s3copy_test.go (about) 1 package s3util 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io/ioutil" 8 "math/rand" 9 "testing" 10 "time" 11 12 "github.com/Schaudge/grailbase/retry" 13 "github.com/grailbio/testutil" 14 "github.com/grailbio/testutil/s3test" 15 16 "github.com/aws/aws-sdk-go/aws" 17 "github.com/aws/aws-sdk-go/aws/awserr" 18 "github.com/aws/aws-sdk-go/aws/request" 19 "github.com/aws/aws-sdk-go/service/s3" 20 ) 21 22 const testBucket = "test_bucket" 23 24 var ( 25 testKeys = map[string]*testutil.ByteContent{"test/x": content("some sample content")} 26 errorKeys = map[string]error{ 27 "key_awscanceled": awserr.New(request.CanceledErrorCode, "test", nil), 28 "key_nosuchkey": awserr.New(s3.ErrCodeNoSuchKey, "test", nil), 29 "key_badrequest": awserr.New("BadRequest", "test", nil), 30 "key_canceled": context.Canceled, 31 "key_deadlineexceeded": context.DeadlineExceeded, 32 "key_awsrequesttimeout": awserr.New("RequestTimeout", "test", nil), 33 "key_nestedEOFrequest": awserr.New("MultipartUpload", "test", awserr.New("SerializationError", "test2", fmt.Errorf("unexpected EOF"))), 34 "key_awsinternalerror": awserr.New("InternalError", "test", nil), 35 } 36 ) 37 38 func newTestClient(t *testing.T) *s3test.Client { 39 t.Helper() 40 client := s3test.NewClient(t, testBucket) 41 client.Region = "us-west-2" 42 for k, v := range testKeys { 43 client.SetFileContentAt(k, v, "") 44 } 45 return client 46 } 47 48 func newFailingTestClient(t *testing.T, fn *failN) *s3test.Client { 49 t.Helper() 50 client := newTestClient(t) 51 client.Err = func(api string, input interface{}) error { 52 switch api { 53 case "UploadPartCopyWithContext": 54 if upc, ok := input.(*s3.UploadPartCopyInput); ok { 55 // Possibly fail the first part with an error based on the key 56 if *upc.PartNumber == int64(1) && fn.fail() { 57 return errorKeys[*upc.Key] 58 } 59 } 60 case "CopyObjectRequest": 61 if req, ok := input.(*s3.CopyObjectInput); ok && fn.fail() { 62 return errorKeys[*req.Key] 63 } 64 } 65 return nil 66 } 67 return client 68 } 69 70 func TestBucketKey(t *testing.T) { 71 for _, tc := range []struct { 72 url, wantBucket, wantKey string 73 wantErr bool 74 }{ 75 {"s3://bucket/key", "bucket", "key", false}, 76 {"s3://some_other-bucket/very/long/key", "some_other-bucket", "very/long/key", false}, 77 } { 78 gotB, gotK, gotE := bucketKey(tc.url) 79 if tc.wantErr && gotE == nil { 80 t.Errorf("%s got no error, want error", tc.url) 81 continue 82 } 83 if got, want := gotB, tc.wantBucket; got != want { 84 t.Errorf("got %s want %s", got, want) 85 } 86 if got, want := gotK, tc.wantKey; got != want { 87 t.Errorf("got %s want %s", got, want) 88 } 89 } 90 } 91 92 func TestCopy(t *testing.T) { 93 client := newTestClient(t) 94 copier := NewCopier(client) 95 96 srcKey, srcSize, dstKey := "test/x", testKeys["test/x"].Size(), "test/x_copy" 97 srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, srcKey) 98 dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, dstKey) 99 100 checkObject(t, client, srcKey, testKeys[srcKey]) 101 if err := copier.Copy(context.Background(), srcUrl, dstUrl, srcSize, nil); err != nil { 102 t.Fatal(err) 103 } 104 checkObject(t, client, dstKey, testKeys[srcKey]) 105 } 106 107 func TestCopyWithRetry(t *testing.T) { 108 client := newFailingTestClient(t, &failN{n: 2}) 109 retrier := retry.MaxRetries(retry.Jitter(retry.Backoff(10*time.Millisecond, 50*time.Millisecond, 2), 0.25), 4) 110 copier := NewCopierWithParams(client, retrier, 1<<10, 1<<10, testDebugger{t}) 111 112 for _, tc := range []struct { 113 srcKey string 114 dstKey string 115 srcSize int64 116 }{ 117 { 118 srcKey: "test/x", 119 dstKey: "key_awsrequesttimeout", 120 srcSize: testKeys["test/x"].Size(), 121 }, 122 { 123 srcKey: "test/x", 124 dstKey: "key_awsinternalerror", 125 srcSize: testKeys["test/x"].Size(), 126 }, 127 } { 128 srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.srcKey) 129 dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.dstKey) 130 checkObject(t, client, tc.srcKey, testKeys[tc.srcKey]) 131 if err := copier.Copy(context.Background(), srcUrl, dstUrl, tc.srcSize, nil); err != nil { 132 t.Fatal(err) 133 } 134 checkObject(t, client, tc.dstKey, testKeys[tc.srcKey]) 135 } 136 137 } 138 139 func TestCopyMultipart(t *testing.T) { 140 bctx := context.Background() 141 for _, tc := range []struct { 142 client *s3test.Client 143 dstKey string 144 size, limit, partsize int64 145 useShortCtx, cancelCtx bool 146 wantErr bool 147 }{ 148 // 100KiB of data, multi-part limit 50KiB, part size 10KiB 149 {newTestClient(t), "dst1", 100 << 10, 50 << 10, 10 << 10, false, false, false}, 150 // 50KiB of data, multi-part limit 50KiB, part size 10KiB 151 {newTestClient(t), "dst2", 50 << 10, 50 << 10, 10 << 10, false, false, false}, 152 {newTestClient(t), "dst3", 100 << 10, 50 << 10, 10 << 10, true, false, true}, 153 {newTestClient(t), "dst4", 100 << 10, 50 << 10, 10 << 10, false, true, true}, 154 {newFailingTestClient(t, &failN{n: 2}), "key_badrequest", 100 << 10, 50 << 10, 10 << 10, false, false, false}, 155 {newFailingTestClient(t, &failN{n: 2}), "key_deadlineexceeded", 100 << 10, 50 << 10, 10 << 10, false, false, false}, 156 {newFailingTestClient(t, &failN{n: 2}), "key_awsrequesttimeout", 100 << 10, 50 << 10, 10 << 10, false, false, false}, 157 {newFailingTestClient(t, &failN{n: 2}), "key_nestedEOFrequest", 100 << 10, 50 << 10, 10 << 10, false, false, false}, 158 {newFailingTestClient(t, &failN{n: 2}), "key_canceled", 100 << 10, 50 << 10, 10 << 10, false, false, true}, 159 {newFailingTestClient(t, &failN{n: defaultMaxRetries + 1}), "key_badrequest", 100 << 10, 50 << 10, 10 << 10, false, false, true}, 160 } { 161 client := tc.client 162 b := make([]byte, tc.size) 163 if _, err := rand.Read(b); err != nil { 164 t.Fatal(err) 165 } 166 srcKey, srcContent := "src", &testutil.ByteContent{Data: b} 167 client.SetFileContentAt(srcKey, srcContent, "") 168 checkObject(t, client, srcKey, srcContent) 169 170 retrier := retry.MaxRetries(retry.Jitter(retry.Backoff(10*time.Millisecond, 50*time.Millisecond, 2), 0.25), defaultMaxRetries) 171 copier := NewCopierWithParams(client, retrier, tc.limit, tc.partsize, testDebugger{t}) 172 173 ctx := bctx 174 var cancel context.CancelFunc 175 if tc.useShortCtx { 176 ctx, cancel = context.WithTimeout(bctx, 10*time.Nanosecond) 177 } else if tc.cancelCtx { 178 ctx, cancel = context.WithCancel(bctx) 179 cancel() 180 } 181 srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, srcKey) 182 dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.dstKey) 183 184 err := copier.Copy(ctx, srcUrl, dstUrl, tc.size, nil) 185 if cancel != nil { 186 cancel() 187 } 188 if tc.wantErr { 189 if err == nil { 190 t.Errorf("%s got no error, want error", tc.dstKey) 191 } 192 continue 193 } 194 if err != nil { 195 t.Fatal(err) 196 } 197 checkObject(t, client, tc.dstKey, srcContent) 198 if t.Failed() { 199 t.Logf("case: %v", tc) 200 } 201 } 202 } 203 204 func content(s string) *testutil.ByteContent { 205 return &testutil.ByteContent{Data: []byte(s)} 206 } 207 208 func checkObject(t *testing.T, client *s3test.Client, key string, c *testutil.ByteContent) { 209 t.Helper() 210 out, err := client.GetObject(&s3.GetObjectInput{ 211 Bucket: aws.String(testBucket), 212 Key: aws.String(key), 213 }) 214 if err != nil { 215 t.Fatal(err) 216 } 217 p, err := ioutil.ReadAll(out.Body) 218 if err != nil { 219 t.Fatal(err) 220 } 221 if got, want := p, c.Data; !bytes.Equal(got, want) { 222 t.Errorf("got %v, want %v", got, want) 223 } 224 } 225 226 // failN returns true n times when fail() is called and then returns false, until its reset. 227 type failN struct { 228 n, i int 229 } 230 231 func (p *failN) fail() bool { 232 if p.i < p.n { 233 p.i++ 234 return true 235 } 236 return false 237 } 238 239 func (p *failN) reset() { 240 p.i = 0 241 } 242 243 type testDebugger struct{ *testing.T } 244 245 func (d testDebugger) Debugf(format string, args ...interface{}) { d.T.Logf(format, args...) }