github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/block/s3/client_cache_test.go (about)

     1  package s3_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"testing"
     7  
     8  	"github.com/aws/aws-sdk-go-v2/config"
     9  	awss3 "github.com/aws/aws-sdk-go-v2/service/s3"
    10  	"github.com/go-test/deep"
    11  	"github.com/treeverse/lakefs/pkg/block/params"
    12  	"github.com/treeverse/lakefs/pkg/block/s3"
    13  	"github.com/treeverse/lakefs/pkg/testutil"
    14  )
    15  
    16  var errRegion = errors.New("failed to get region")
    17  
    18  func TestClientCache(t *testing.T) {
    19  	const defaultRegion = "us-west-2"
    20  	ctx := context.Background()
    21  	cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(defaultRegion))
    22  	testutil.Must(t, err)
    23  
    24  	tests := []struct {
    25  		name                string
    26  		bucketToRegion      map[string]string
    27  		bucketCalls         []string
    28  		regionErrorsIndexes map[int]bool
    29  	}{
    30  		{
    31  			name:           "two_buckets_two_regions",
    32  			bucketToRegion: map[string]string{"us-bucket": "us-east-1", "eu-bucket": "eu-west-1"},
    33  			bucketCalls:    []string{"us-bucket", "us-bucket", "us-bucket", "eu-bucket", "eu-bucket", "eu-bucket"},
    34  		},
    35  		{
    36  			name:           "multiple_buckets_two_regions",
    37  			bucketToRegion: map[string]string{"us-bucket-1": "us-east-1", "us-bucket-2": "us-east-1", "us-bucket-3": "us-east-1", "eu-bucket-1": "eu-west-1", "eu-bucket-2": "eu-west-1"},
    38  			bucketCalls:    []string{"us-bucket-1", "us-bucket-2", "us-bucket-3", "eu-bucket-1", "eu-bucket-2"},
    39  		},
    40  		{
    41  			name:                "error_on_get_region",
    42  			bucketToRegion:      map[string]string{"us-bucket": "us-east-1", "eu-bucket": "eu-west-1"},
    43  			bucketCalls:         []string{"us-bucket", "us-bucket", "us-bucket", "eu-bucket", "eu-bucket", "eu-bucket"},
    44  			regionErrorsIndexes: map[int]bool{3: true},
    45  		},
    46  		{
    47  			name:                "all_errors",
    48  			bucketToRegion:      map[string]string{"us-bucket-1": "us-east-1", "us-bucket-2": "us-east-1", "us-bucket-3": "us-east-1", "eu-bucket-1": "eu-west-1", "eu-bucket-2": "eu-west-1"},
    49  			bucketCalls:         []string{"us-bucket-1", "us-bucket-2", "us-bucket-3", "eu-bucket-1", "eu-bucket-2"},
    50  			regionErrorsIndexes: map[int]bool{0: true, 1: true, 2: true, 3: true, 4: true},
    51  		},
    52  		{
    53  			name:           "alternating_regions",
    54  			bucketToRegion: map[string]string{"us-bucket-1": "us-east-1", "us-bucket-2": "us-east-1", "us-bucket-3": "us-east-1", "eu-bucket-1": "eu-west-1", "eu-bucket-2": "eu-west-1"},
    55  			bucketCalls:    []string{"us-bucket-1", "eu-bucket-1", "us-bucket-2", "eu-bucket-2", "us-bucket-3", "us-bucket-1", "eu-bucket-1", "us-bucket-2", "eu-bucket-2", "us-bucket-3"},
    56  		},
    57  	}
    58  	for _, test := range tests {
    59  		t.Run(test.name, func(t *testing.T) {
    60  			var callIdx int
    61  			var bucket string
    62  			actualClientsCreated := make(map[string]bool)
    63  			expectedClientsCreated := make(map[string]bool)
    64  			actualRegionFetch := make(map[string]bool)
    65  			expectedRegionFetch := make(map[string]bool)
    66  
    67  			c := s3.NewClientCache(cfg, params.S3{}) // params are ignored as we use custom client factory
    68  
    69  			c.SetClientFactory(func(region string) *awss3.Client {
    70  				if actualClientsCreated[region] {
    71  					t.Fatalf("client created more than once for a region")
    72  				}
    73  				actualClientsCreated[region] = true
    74  				return awss3.NewFromConfig(cfg, func(o *awss3.Options) {
    75  					o.Region = region
    76  				})
    77  			})
    78  
    79  			c.SetS3RegionGetter(func(ctx context.Context, bucket string) (string, error) {
    80  				if actualRegionFetch[bucket] {
    81  					t.Fatalf("region fetched more than once for bucket")
    82  				}
    83  				actualRegionFetch[bucket] = true
    84  				if test.regionErrorsIndexes[callIdx] {
    85  					return "", errRegion
    86  				}
    87  				return test.bucketToRegion[bucket], nil
    88  			})
    89  
    90  			alreadyCalled := make(map[string]bool)
    91  			for callIdx, bucket = range test.bucketCalls {
    92  				expectedRegionFetch[bucket] = true // for every bucket, there should be exactly one region fetch
    93  				if _, ok := alreadyCalled[bucket]; !ok {
    94  					if test.regionErrorsIndexes[callIdx] {
    95  						// if there's an error, a client should be created for the default region
    96  						expectedClientsCreated[defaultRegion] = true
    97  					} else {
    98  						// for every region, a client is created exactly once
    99  						expectedClientsCreated[test.bucketToRegion[bucket]] = true
   100  					}
   101  				}
   102  				alreadyCalled[bucket] = true
   103  				c.Get(ctx, bucket)
   104  			}
   105  			if diff := deep.Equal(expectedClientsCreated, actualClientsCreated); diff != nil {
   106  				t.Fatal("unexpected client creation count: ", diff)
   107  			}
   108  			if diff := deep.Equal(expectedRegionFetch, actualRegionFetch); diff != nil {
   109  				t.Fatal("unexpected region fetch count. diff: ", diff)
   110  			}
   111  		})
   112  	}
   113  }