github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/file/s3file/s3file_test.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 //go:build !unit 6 // +build !unit 7 8 package s3file 9 10 import ( 11 "context" 12 "crypto/md5" 13 "crypto/sha256" 14 "flag" 15 "fmt" 16 "io" 17 "io/ioutil" 18 "math/rand" 19 "net/http" 20 "runtime/debug" 21 "strings" 22 "sync" 23 "sync/atomic" 24 "testing" 25 "time" 26 27 "github.com/aws/aws-sdk-go/aws" 28 "github.com/aws/aws-sdk-go/aws/awserr" 29 awsrequest "github.com/aws/aws-sdk-go/aws/request" 30 "github.com/aws/aws-sdk-go/service/s3/s3iface" 31 "github.com/Schaudge/grailbase/errors" 32 "github.com/Schaudge/grailbase/file" 33 "github.com/Schaudge/grailbase/file/internal/s3bufpool" 34 "github.com/Schaudge/grailbase/file/internal/testutil" 35 "github.com/Schaudge/grailbase/file/s3file/s3transport" 36 "github.com/Schaudge/grailbase/log" 37 "github.com/Schaudge/grailbase/retry" 38 "github.com/grailbio/testutil/assert" 39 "github.com/grailbio/testutil/s3test" 40 ) 41 42 var ( 43 s3BucketFlag = flag.String("s3-bucket", "", "If set, run a unittest against a real S3 bucket named in this flag") 44 s3DirFlag = flag.String("s3-dir", "", "S3 directory under -s3-bucket used by some unittests") 45 ) 46 47 type failingContentAt struct { 48 prob float64 // probability of failing requests 49 content []byte 50 failWithErr error 51 52 randMu sync.Mutex 53 rand *rand.Rand 54 } 55 56 func doReadAt(src []byte, off64 int64, dest []byte) (int, error) { 57 off := int(off64) 58 remaining := len(src) - off 59 if remaining <= 0 { 60 return 0, io.EOF 61 } 62 if len(dest) < remaining { 63 remaining = len(dest) 64 } 65 copy(dest, src[off:]) 66 return remaining, nil 67 } 68 69 func doWriteAt(src []byte, off64 int64, dest *[]byte) (int, error) { 70 off := int(off64) 71 if len(*dest) < off+len(src) { 72 tmp := make([]byte, off+len(src)) 73 copy(tmp, *dest) 74 *dest = tmp 75 } 76 copy((*dest)[off:], src) 77 return len(src), nil 78 } 79 80 func (c *failingContentAt) ReadAt(p []byte, off64 int64) (int, error) { 81 c.randMu.Lock() 82 pr := c.rand.Float64() 83 c.randMu.Unlock() 84 if pr < c.prob { 85 return 0, c.failWithErr 86 } 87 n := len(p) 88 if n > 1 { 89 c.randMu.Lock() 90 n = 1 + c.rand.Intn(n-1) 91 c.randMu.Unlock() 92 } 93 return doReadAt(c.content, off64, p[:n]) 94 } 95 96 func (c *failingContentAt) WriteAt(p []byte, off64 int64) (int, error) { 97 return doWriteAt(p, off64, &c.content) 98 } 99 100 func (c *failingContentAt) Size() int64 { 101 return int64(len(c.content)) 102 } 103 104 func (c *failingContentAt) Checksum() string { 105 return fmt.Sprintf("%x", md5.Sum(c.content)) 106 } 107 108 type pausingContentAt struct { 109 ready chan bool 110 content []byte 111 } 112 113 // ReadAt implements io.ReaderAt. 114 func (c *pausingContentAt) ReadAt(p []byte, off64 int64) (int, error) { 115 <-c.ready 116 return doReadAt(c.content, off64, p) 117 } 118 119 // WriteAt implements io.WriterAt 120 func (c *pausingContentAt) WriteAt(p []byte, off64 int64) (int, error) { 121 return doWriteAt(p, off64, &c.content) 122 } 123 124 // Size returns the size of the fake content. 125 func (c *pausingContentAt) Size() int64 { 126 return int64(len(c.content)) 127 } 128 129 func (c *pausingContentAt) Checksum() string { 130 return fmt.Sprintf("%x", md5.Sum(c.content)) 131 } 132 133 func newImpl(clients ...s3iface.S3API) *s3Impl { 134 return &s3Impl{ 135 clientsForAction: func(_ context.Context, _, _, _ string) ([]s3iface.S3API, error) { 136 return clients, nil 137 }, 138 } 139 } 140 141 func newClient(t *testing.T) *s3test.Client { return s3test.NewClient(t, "b") } 142 func errorClient(t *testing.T, err error) s3iface.S3API { 143 c := s3test.NewClient(t, "b") 144 c.Err = func(api string, input interface{}) error { 145 return err 146 } 147 return c 148 } 149 150 func TestS3(t *testing.T) { 151 ctx := context.Background() 152 impl := newImpl( 153 errorClient(t, awserr.New( 154 "", // TODO(swami): Use an AWS error code that represents a permission error. 155 "test permission error", 156 nil, 157 )), 158 newClient(t), 159 ) 160 testutil.TestStandard(ctx, t, impl, "s3://b/dir") 161 t.Run("readat", func(t *testing.T) { 162 testutil.TestConcurrentOffsetReads(ctx, t, impl, "s3://b/dir/readats.txt") 163 }) 164 } 165 166 func TestS3WithRetries(t *testing.T) { 167 tearDown := setZeroBackoffPolicy() 168 defer tearDown() 169 170 ctx := context.Background() 171 for iter := 0; iter < 50; iter++ { 172 randIntsC := make(chan int) 173 go func() { 174 r := rand.New(rand.NewSource(int64(iter))) 175 for { 176 randIntsC <- r.Intn(20) 177 } 178 }() 179 client := newClient(t) 180 client.Err = func(api string, input interface{}) error { 181 switch <-randIntsC { 182 case 0: 183 return awserr.New(awsrequest.ErrCodeSerialization, "injected serialization failure", nil) 184 case 1: 185 return awserr.New("RequestError", "send request failed", readConnResetError{}) 186 } 187 return nil 188 } 189 impl := newImpl(client) 190 testutil.TestStandard(ctx, t, impl, "s3://b/dir") 191 t.Run("readat", func(t *testing.T) { 192 testutil.TestConcurrentOffsetReads(ctx, t, impl, "s3://b/dir/readats.txt") 193 }) 194 } 195 } 196 197 // WriteFile creates a file with the given contents. Path should be of form 198 // s3://bucket/key. 199 func writeFile(ctx context.Context, t *testing.T, impl file.Implementation, path, data string) { 200 f, err := impl.Create(ctx, path) 201 assert.NoError(t, err) 202 _, err = f.Writer(ctx).Write([]byte(data)) 203 assert.NoError(t, err) 204 assert.NoError(t, f.Close(ctx)) 205 } 206 func TestListBucketRoot(t *testing.T) { 207 ctx := context.Background() 208 impl := newImpl(newClient(t)) 209 writeFile(ctx, t, impl, "s3://b/0.txt", "data") 210 211 l := impl.List(ctx, "s3://b", true) 212 assert.True(t, l.Scan(), "err: %v", l.Err()) 213 assert.EQ(t, "s3://b/0.txt", l.Path()) 214 assert.False(t, l.Scan()) 215 assert.NoError(t, l.Err()) 216 } 217 218 type readConnResetError struct{} 219 220 func (c readConnResetError) Temporary() bool { return false } 221 func (c readConnResetError) Error() string { return "read: connection reset" } 222 223 func TestErrors(t *testing.T) { 224 ctx := context.Background() 225 impl := newImpl( 226 errorClient(t, 227 awserr.New("", // TODO(swami): Use an AWS error code that represents a permission error. 228 fmt.Sprintf("test permission error: %s", string(debug.Stack())), 229 nil, 230 ), 231 ), 232 ) 233 234 _, err := impl.Create(ctx, "s3://b/junk0.txt") 235 assert.Regexp(t, err, "test permission error") 236 237 _, err = impl.Stat(ctx, "s3://b/junk0.txt") 238 assert.Regexp(t, err, "test permission error") 239 240 l := impl.List(ctx, "s3://b/foo", true) 241 assert.False(t, l.Scan()) 242 assert.Regexp(t, l.Err(), "test permission error") 243 } 244 245 func TestTransientErrors(t *testing.T) { 246 impl := newImpl(errorClient(t, awserr.New("RequestError", "send request failed", readConnResetError{}))) 247 ctx, cancel := context.WithCancel(context.Background()) 248 cancel() 249 _, err := impl.Stat(ctx, "s3://b/junk0.txt") 250 assert.True(t, errors.Is(errors.Canceled, err), "expected cancellation") 251 252 ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) 253 defer cancel() 254 _, err = impl.Stat(ctx, "s3://b/junk0.txt") 255 assert.Regexp(t, err, "ran out of time while waiting") 256 } 257 258 func TestWriteRetryAfterError(t *testing.T) { 259 tearDown := setZeroBackoffPolicy() 260 defer tearDown() 261 262 client := newClient(t) 263 impl := newImpl(client) 264 ctx := context.Background() 265 for i := 0; i < 10; i++ { 266 r := rand.New(rand.NewSource(0)) 267 client.Err = func(api string, input interface{}) error { 268 if r.Intn(3) == 0 { 269 fmt.Printf("write: api %s\n", api) 270 return awserr.New(awsrequest.ErrCodeSerialization, "test failure", nil) 271 } 272 return nil 273 } 274 writeFile(ctx, t, impl, "s3://b/0.txt", "data") 275 } 276 } 277 278 func TestReadRetryAfterError(t *testing.T) { 279 for errIdx, failWithErr := range []error{ 280 fmt.Errorf("failingContentAt synthetic error"), 281 readConnResetError{}, 282 } { 283 t.Run(fmt.Sprintf("error_%d", errIdx), func(t *testing.T) { 284 tearDown := setZeroBackoffPolicy() 285 defer tearDown() 286 287 client := newClient(t) 288 setContent := func(path string, prob float64, data string) { 289 c := &failingContentAt{ 290 prob: prob, 291 rand: rand.New(rand.NewSource(0)), 292 content: []byte(data), 293 failWithErr: failWithErr, 294 } 295 checksum := sha256.Sum256(c.content) 296 client.SetFileContentAt(path, c, fmt.Sprintf("%x", checksum[:])) 297 } 298 299 var contents string 300 { 301 l := []string{} 302 for i := 0; i < 1000; i++ { 303 l = append(l, fmt.Sprintf("D%d", i)) 304 } 305 contents = strings.Join(l, ",") 306 } 307 // Exercise parallel reading including partial last chunk. 308 tearDownRCB := setReadChunkBytes() 309 defer tearDownRCB() 310 311 assert.GT(t, len(contents)%ReadChunkBytes(), 0) 312 313 impl := newImpl(client) 314 ctx := context.Background() 315 316 setContent("junk0.txt", 0.3, contents) 317 for i := 0; i < 10; i++ { 318 f, err := impl.Open(ctx, "b/junk0.txt") 319 assert.NoError(t, err) 320 r := f.Reader(ctx) 321 data, err := ioutil.ReadAll(r) 322 assert.NoError(t, err) 323 assert.EQ(t, contents, string(data)) 324 assert.NoError(t, f.Close(ctx)) 325 } 326 327 // Simulate exhausting all allowed retries. Since the number of retries is unrestricted, 328 // the request is capped by MaxRetryDuration. To avoid a flaky time dependency, instead 329 // of using an actual deadline we just cancel the context. 330 tearDown = setFakeWithDeadline() 331 defer tearDown() 332 setContent("junk1.txt", 1.0 /*fail everything*/, contents) 333 { 334 f, err := impl.Open(ctx, "b/junk1.txt") 335 assert.NoError(t, err) 336 r := f.Reader(ctx) 337 _, err = ioutil.ReadAll(r) 338 assert.Regexp(t, err, failWithErr.Error()) 339 assert.NoError(t, f.Close(ctx)) 340 } 341 }) 342 } 343 } 344 345 func TestRetryWhenNotFound(t *testing.T) { 346 client := s3test.NewClient(t, "b") 347 348 impl := newImpl(client) 349 350 ctx := context.Background() 351 // By default, there is no retry. 352 _, err := impl.Open(ctx, "s3://b/file.txt") 353 assert.Regexp(t, err, "NoSuchKey") 354 355 doneCh := make(chan bool) 356 go func() { 357 _, err := impl.Open(ctx, "s3://b/file.txt", file.Opts{RetryWhenNotFound: true}) 358 assert.NoError(t, err) 359 doneCh <- true 360 }() 361 time.Sleep(1 * time.Second) 362 select { 363 case <-doneCh: 364 t.Fatal("should not reach here") 365 default: 366 } 367 writeFile(ctx, t, impl, "s3://b/file.txt", "data") 368 fmt.Println("wrote file") 369 <-doneCh 370 } 371 372 func TestCancellation(t *testing.T) { 373 client := s3test.NewClient(t, "b") 374 375 setContent := func(path, data string) *pausingContentAt { 376 c := &pausingContentAt{ready: make(chan bool, 1), content: []byte(data)} 377 checksum := sha256.Sum256(c.content) 378 client.SetFileContentAt(path, c, fmt.Sprintf("%x", checksum[:])) 379 return c 380 } 381 c0 := setContent("test0.txt", "hello") 382 _ = setContent("test1.txt", "goodbye") 383 384 impl := newImpl(client) 385 { 386 c0.ready <- true 387 // Reading c0 completes immediately. 388 ctx := context.Background() 389 f, err := impl.Open(ctx, "s3://b/test0.txt") 390 assert.NoError(t, err) 391 r := f.Reader(ctx) 392 data, err := ioutil.ReadAll(r) 393 assert.NoError(t, err) 394 assert.EQ(t, "hello", string(data)) 395 assert.NoError(t, f.Close(ctx)) 396 } 397 { 398 // Reading c1 will block. 399 f, err := impl.Open(context.Background(), "s3://b/test1.txt") 400 assert.NoError(t, err) 401 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 402 defer cancel() 403 r := f.Reader(ctx) 404 _, err = ioutil.ReadAll(r) 405 assert.True(t, errors.Is(errors.Canceled, err), "expected cancellation") 406 assert.True(t, errors.Is(errors.Canceled, f.Close(ctx)), "expected cancellation") 407 } 408 } 409 410 func testOverwriteWhileReading(t *testing.T, impl file.Implementation, pathPrefix string) { 411 ctx := context.Background() 412 path := pathPrefix + "/test.txt" 413 writeFile(ctx, t, impl, path, "test0") 414 f, err := impl.Open(ctx, path) 415 assert.NoError(t, err) 416 417 r := f.Reader(ctx) 418 data, err := ioutil.ReadAll(r) 419 assert.NoError(t, err) 420 assert.EQ(t, "test0", string(data)) 421 422 _, err = r.Seek(0, io.SeekStart) 423 assert.NoError(t, err) 424 425 writeFile(ctx, t, impl, path, "test0") 426 427 data, err = ioutil.ReadAll(r) 428 assert.NoError(t, err) 429 assert.EQ(t, "test0", string(data)) 430 431 _, err = r.Seek(0, io.SeekStart) 432 assert.NoError(t, err) 433 writeFile(ctx, t, impl, path, "test1") 434 _, err = ioutil.ReadAll(r) 435 assert.True(t, errors.Is(errors.Precondition, err), "err=%v", err) 436 } 437 438 func TestWriteLargeFile(t *testing.T) { 439 // Reduce the upload chunk size to issue concurrent upload requests to S3. 440 oldUploadPartSize := UploadPartSize 441 UploadPartSize = 128 442 defer func() { 443 UploadPartSize = oldUploadPartSize 444 }() 445 446 ctx := context.Background() 447 impl := newImpl(s3test.NewClient(t, "b")) 448 path := "s3://b/test.txt" 449 f, err := impl.Create(ctx, path) 450 assert.NoError(t, err) 451 r := rand.New(rand.NewSource(0)) 452 var want []byte 453 const iters = 400 454 for i := 0; i < iters; i++ { 455 n := r.Intn(1024) + 100 456 data := make([]byte, n) 457 n, err := r.Read(data) 458 assert.EQ(t, n, len(data)) 459 assert.NoError(t, err) 460 n, err = f.Writer(ctx).Write(data) 461 assert.EQ(t, n, len(data)) 462 assert.NoError(t, err) 463 want = append(want, data...) 464 } 465 assert.NoError(t, f.Close(ctx)) 466 467 // Read the file back and verify contents. 468 f, err = impl.Open(ctx, path) 469 assert.NoError(t, err) 470 got := make([]byte, len(want)) 471 n, _ := f.Reader(ctx).Read(got) 472 assert.EQ(t, n, len(want)) 473 assert.EQ(t, got, want) 474 assert.NoError(t, f.Close(ctx)) 475 } 476 477 func TestOverwriteWhileReading(t *testing.T) { 478 impl := newImpl(s3test.NewClient(t, "b")) 479 testOverwriteWhileReading(t, impl, "s3://b/test") 480 } 481 482 func TestNotExist(t *testing.T) { 483 impl := newImpl(s3test.NewClient(t, "b")) 484 ctx := context.Background() 485 // The s3test client fails tests for requests that attempt to 486 // access buckets other than the one specified, so we can 487 // test only missing keys here. 488 _, err := impl.Open(ctx, "b/notexist") 489 assert.True(t, errors.Is(errors.NotExist, err)) 490 } 491 492 func realBucketProviderOrSkip(t *testing.T) SessionProvider { 493 if *s3BucketFlag == "" { 494 t.Skip("Skipping. Set -s3-bucket to run the test.") 495 } 496 return NewDefaultProvider( 497 aws.NewConfig().WithHTTPClient(s3transport.DefaultClient()), 498 ) 499 } 500 501 func TestOverwriteWhileReadingAWS(t *testing.T) { 502 provider := realBucketProviderOrSkip(t) 503 impl := NewImplementation(provider, Options{}) 504 testOverwriteWhileReading(t, impl, fmt.Sprintf("s3://%s/tmp/testoverwrite", *s3BucketFlag)) 505 } 506 507 func TestPresignRequestsAWS(t *testing.T) { 508 provider := realBucketProviderOrSkip(t) 509 impl := NewImplementation(provider, Options{}) 510 ctx := context.Background() 511 const content = "file for testing presigned URLs\n" 512 path := fmt.Sprintf("s3://%s/tmp/testpresigned", *s3BucketFlag) 513 514 // Write the dummy file. 515 url, err := impl.Presign(ctx, path, "PUT", time.Minute) 516 if err != nil { 517 t.Fatal(err) 518 } 519 req, err := http.NewRequest(http.MethodPut, url, strings.NewReader(content)) 520 if err != nil { 521 t.Fatal(err) 522 } 523 resp, err := http.DefaultClient.Do(req) 524 if err != nil { 525 t.Fatal(err) 526 } 527 resp.Body.Close() 528 529 // Read the dummy file. 530 url, err = impl.Presign(ctx, path, "GET", time.Minute) 531 if err != nil { 532 t.Fatal(err) 533 } 534 resp, err = http.Get(url) 535 if err != nil { 536 t.Fatal(err) 537 } 538 defer resp.Body.Close() 539 respBytes, err := ioutil.ReadAll(resp.Body) 540 if err != nil { 541 t.Fatal(err) 542 } 543 if content != string(respBytes) { 544 t.Errorf("got: %q, want: %q", string(respBytes), content) 545 } 546 547 // Delete the dummy file. 548 url, err = impl.Presign(ctx, path, "DELETE", time.Minute) 549 if err != nil { 550 t.Fatal(err) 551 } 552 req, err = http.NewRequest(http.MethodDelete, url, strings.NewReader("")) 553 if err != nil { 554 t.Fatal(err) 555 } 556 resp, err = http.DefaultClient.Do(req) 557 if err != nil { 558 t.Fatal(err) 559 } 560 resp.Body.Close() 561 if _, err := impl.Stat(ctx, path); !errors.Is(errors.NotExist, err) { 562 t.Errorf("got: %v\nwant an error of kind NotExist", err) 563 } 564 } 565 566 func TestAWS(t *testing.T) { 567 provider := realBucketProviderOrSkip(t) 568 ctx := context.Background() 569 impl := NewImplementation(provider, Options{}) 570 testutil.TestStandard(ctx, t, impl, "s3://"+*s3BucketFlag+"/tmp") 571 t.Run("readat", func(t *testing.T) { 572 testutil.TestConcurrentOffsetReads(ctx, t, impl, "s3://"+*s3BucketFlag+"/tmp") 573 }) 574 } 575 576 func TestConcurrentUploadsAWS(t *testing.T) { 577 provider := realBucketProviderOrSkip(t) 578 impl := NewImplementation(provider, Options{}) 579 580 if *s3DirFlag == "" { 581 t.Skip("Skipping. Set -s3-bucket and -s3-dir to run the test.") 582 } 583 path := fmt.Sprintf("s3://%s/%s/test.txt", *s3BucketFlag, *s3DirFlag) 584 ctx := context.Background() 585 586 upload := func() { 587 f, err := impl.Create(ctx, path, file.Opts{IgnoreNoSuchUpload: true}) 588 if err != nil { 589 log.Panic(err) 590 } 591 _, err = f.Writer(ctx).Write([]byte("hello")) 592 if err != nil { 593 log.Panic(err) 594 } 595 if err := f.Close(ctx); err != nil { 596 log.Panic(err) 597 } 598 } 599 600 wg := sync.WaitGroup{} 601 n := uint64(0) 602 for i := 0; i < 4000; i++ { 603 wg.Add(1) 604 go func() { 605 upload() 606 if x := atomic.AddUint64(&n, 1); x%100 == 0 { 607 log.Printf("%d done", x) 608 } 609 wg.Done() 610 }() 611 } 612 wg.Wait() 613 } 614 615 func ExampleParseURL() { 616 scheme, bucket, key, err := ParseURL("s3://grail-bucket/dir/file") 617 fmt.Printf("scheme: %s, bucket: %s, key: %s, err: %v\n", scheme, bucket, key, err) 618 scheme, bucket, key, err = ParseURL("s3://grail-bucket/dir/") 619 fmt.Printf("scheme: %s, bucket: %s, key: %s, err: %v\n", scheme, bucket, key, err) 620 scheme, bucket, key, err = ParseURL("s3://grail-bucket") 621 fmt.Printf("scheme: %s, bucket: %s, key: %s, err: %v\n", scheme, bucket, key, err) 622 // Output: 623 // scheme: s3, bucket: grail-bucket, key: dir/file, err: <nil> 624 // scheme: s3, bucket: grail-bucket, key: dir/, err: <nil> 625 // scheme: s3, bucket: grail-bucket, key: , err: <nil> 626 } 627 628 func setZeroBackoffPolicy() (tearDown func()) { 629 oldPolicy := BackoffPolicy 630 BackoffPolicy = retry.Backoff(0, 0, 1.0) 631 return func() { BackoffPolicy = oldPolicy } 632 } 633 634 func setReadChunkBytes() (tearDown func()) { 635 old := s3bufpool.BufBytes 636 s3bufpool.SetBufSize(100) 637 return func() { s3bufpool.SetBufSize(old) } 638 } 639 640 func setFakeWithDeadline() (tearDown func()) { 641 old := WithDeadline 642 WithDeadline = func(ctx context.Context, deadline time.Time) (context.Context, context.CancelFunc) { 643 ctx, cancel := context.WithDeadline(ctx, deadline) 644 cancel() 645 return ctx, cancel 646 } 647 return func() { WithDeadline = old } 648 }