github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/block/s3/client_cache_test.go (about) 1 package s3_test 2 3 import ( 4 "context" 5 "errors" 6 "testing" 7 8 "github.com/aws/aws-sdk-go-v2/config" 9 awss3 "github.com/aws/aws-sdk-go-v2/service/s3" 10 "github.com/go-test/deep" 11 "github.com/treeverse/lakefs/pkg/block/params" 12 "github.com/treeverse/lakefs/pkg/block/s3" 13 "github.com/treeverse/lakefs/pkg/testutil" 14 ) 15 16 var errRegion = errors.New("failed to get region") 17 18 func TestClientCache(t *testing.T) { 19 const defaultRegion = "us-west-2" 20 ctx := context.Background() 21 cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(defaultRegion)) 22 testutil.Must(t, err) 23 24 tests := []struct { 25 name string 26 bucketToRegion map[string]string 27 bucketCalls []string 28 regionErrorsIndexes map[int]bool 29 }{ 30 { 31 name: "two_buckets_two_regions", 32 bucketToRegion: map[string]string{"us-bucket": "us-east-1", "eu-bucket": "eu-west-1"}, 33 bucketCalls: []string{"us-bucket", "us-bucket", "us-bucket", "eu-bucket", "eu-bucket", "eu-bucket"}, 34 }, 35 { 36 name: "multiple_buckets_two_regions", 37 bucketToRegion: map[string]string{"us-bucket-1": "us-east-1", "us-bucket-2": "us-east-1", "us-bucket-3": "us-east-1", "eu-bucket-1": "eu-west-1", "eu-bucket-2": "eu-west-1"}, 38 bucketCalls: []string{"us-bucket-1", "us-bucket-2", "us-bucket-3", "eu-bucket-1", "eu-bucket-2"}, 39 }, 40 { 41 name: "error_on_get_region", 42 bucketToRegion: map[string]string{"us-bucket": "us-east-1", "eu-bucket": "eu-west-1"}, 43 bucketCalls: []string{"us-bucket", "us-bucket", "us-bucket", "eu-bucket", "eu-bucket", "eu-bucket"}, 44 regionErrorsIndexes: map[int]bool{3: true}, 45 }, 46 { 47 name: "all_errors", 48 bucketToRegion: map[string]string{"us-bucket-1": "us-east-1", "us-bucket-2": "us-east-1", "us-bucket-3": "us-east-1", "eu-bucket-1": "eu-west-1", "eu-bucket-2": "eu-west-1"}, 49 bucketCalls: []string{"us-bucket-1", "us-bucket-2", "us-bucket-3", "eu-bucket-1", "eu-bucket-2"}, 50 regionErrorsIndexes: map[int]bool{0: true, 1: true, 2: true, 3: true, 4: true}, 51 }, 52 { 53 name: "alternating_regions", 54 bucketToRegion: map[string]string{"us-bucket-1": "us-east-1", "us-bucket-2": "us-east-1", "us-bucket-3": "us-east-1", "eu-bucket-1": "eu-west-1", "eu-bucket-2": "eu-west-1"}, 55 bucketCalls: []string{"us-bucket-1", "eu-bucket-1", "us-bucket-2", "eu-bucket-2", "us-bucket-3", "us-bucket-1", "eu-bucket-1", "us-bucket-2", "eu-bucket-2", "us-bucket-3"}, 56 }, 57 } 58 for _, test := range tests { 59 t.Run(test.name, func(t *testing.T) { 60 var callIdx int 61 var bucket string 62 actualClientsCreated := make(map[string]bool) 63 expectedClientsCreated := make(map[string]bool) 64 actualRegionFetch := make(map[string]bool) 65 expectedRegionFetch := make(map[string]bool) 66 67 c := s3.NewClientCache(cfg, params.S3{}) // params are ignored as we use custom client factory 68 69 c.SetClientFactory(func(region string) *awss3.Client { 70 if actualClientsCreated[region] { 71 t.Fatalf("client created more than once for a region") 72 } 73 actualClientsCreated[region] = true 74 return awss3.NewFromConfig(cfg, func(o *awss3.Options) { 75 o.Region = region 76 }) 77 }) 78 79 c.SetS3RegionGetter(func(ctx context.Context, bucket string) (string, error) { 80 if actualRegionFetch[bucket] { 81 t.Fatalf("region fetched more than once for bucket") 82 } 83 actualRegionFetch[bucket] = true 84 if test.regionErrorsIndexes[callIdx] { 85 return "", errRegion 86 } 87 return test.bucketToRegion[bucket], nil 88 }) 89 90 alreadyCalled := make(map[string]bool) 91 for callIdx, bucket = range test.bucketCalls { 92 expectedRegionFetch[bucket] = true // for every bucket, there should be exactly one region fetch 93 if _, ok := alreadyCalled[bucket]; !ok { 94 if test.regionErrorsIndexes[callIdx] { 95 // if there's an error, a client should be created for the default region 96 expectedClientsCreated[defaultRegion] = true 97 } else { 98 // for every region, a client is created exactly once 99 expectedClientsCreated[test.bucketToRegion[bucket]] = true 100 } 101 } 102 alreadyCalled[bucket] = true 103 c.Get(ctx, bucket) 104 } 105 if diff := deep.Equal(expectedClientsCreated, actualClientsCreated); diff != nil { 106 t.Fatal("unexpected client creation count: ", diff) 107 } 108 if diff := deep.Equal(expectedRegionFetch, actualRegionFetch); diff != nil { 109 t.Fatal("unexpected region fetch count. diff: ", diff) 110 } 111 }) 112 } 113 }