github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/cloud/aws/metadata.go (about)

     1  package aws
     2  
     3  import (
     4  	"context"
     5  	"crypto/md5" //nolint:gosec
     6  	"fmt"
     7  	"time"
     8  
     9  	awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
    10  	"github.com/aws/aws-sdk-go-v2/config"
    11  	"github.com/aws/aws-sdk-go-v2/credentials"
    12  	"github.com/aws/aws-sdk-go-v2/service/sts"
    13  	"github.com/cenkalti/backoff/v4"
    14  	"github.com/treeverse/lakefs/pkg/block/params"
    15  	"github.com/treeverse/lakefs/pkg/cloud"
    16  	"github.com/treeverse/lakefs/pkg/logging"
    17  )
    18  
    19  type MetadataProvider struct {
    20  	logger    logging.Logger
    21  	stsClient *sts.Client
    22  }
    23  
    24  func NewMetadataProvider(logger logging.Logger, params params.S3) (*MetadataProvider, error) {
    25  	// set up a session with a shorter timeout and no retries
    26  	const sessionMaxRetries = 0 // max number of retries on the client operation
    27  
    28  	// use a shorter timeout than default
    29  	// because the service can be inaccessible from networks
    30  	// which don't have an internet connection
    31  	const sessionTimeout = 5 * time.Second
    32  	/// params
    33  	var opts []func(*config.LoadOptions) error
    34  	if params.Region != "" {
    35  		opts = append(opts, config.WithRegion(params.Region))
    36  	}
    37  	if params.Profile != "" {
    38  		opts = append(opts, config.WithSharedConfigProfile(params.Profile))
    39  	}
    40  	if params.CredentialsFile != "" {
    41  		opts = append(opts, config.WithSharedCredentialsFiles([]string{params.CredentialsFile}))
    42  	}
    43  	if params.Credentials.AccessKeyID != "" {
    44  		opts = append(opts, config.WithCredentialsProvider(
    45  			credentials.NewStaticCredentialsProvider(
    46  				params.Credentials.AccessKeyID,
    47  				params.Credentials.SecretAccessKey,
    48  				params.Credentials.SessionToken,
    49  			),
    50  		))
    51  	}
    52  
    53  	cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	stsClient := sts.NewFromConfig(cfg, func(options *sts.Options) {
    59  		options.RetryMaxAttempts = sessionMaxRetries
    60  		options.HTTPClient = awshttp.NewBuildableClient().
    61  			WithTimeout(sessionTimeout)
    62  	})
    63  	return &MetadataProvider{logger: logger, stsClient: stsClient}, nil
    64  }
    65  
    66  func (m *MetadataProvider) GetMetadata() map[string]string {
    67  	if m.stsClient == nil {
    68  		return nil
    69  	}
    70  	const (
    71  		maxInterval    = 200 * time.Millisecond
    72  		maxElapsedTime = 3 * time.Second
    73  	)
    74  	bo := backoff.NewExponentialBackOff()
    75  	bo.MaxInterval = maxInterval
    76  	bo.MaxElapsedTime = maxElapsedTime
    77  	ctx := context.Background()
    78  	identity, err := backoff.RetryWithData(func() (*sts.GetCallerIdentityOutput, error) {
    79  		identity, err := m.stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
    80  		if err != nil {
    81  			m.logger.WithError(err).Warn("Tried to to get AWS account ID for BI")
    82  			return nil, err
    83  		}
    84  		return identity, nil
    85  	}, bo)
    86  	if err != nil {
    87  		m.logger.WithError(err).Warn("Failed to to get AWS account ID for BI")
    88  		return nil
    89  	}
    90  
    91  	return map[string]string{
    92  		cloud.IDKey:     fmt.Sprintf("%x", md5.Sum([]byte(*identity.Account))), //nolint:gosec
    93  		cloud.IDTypeKey: "aws_account_id",
    94  	}
    95  }