github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/internal/credproviders/assume_role_provider.go (about)

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package credproviders
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"errors"
    13  	"fmt"
    14  	"io/ioutil"
    15  	"net/http"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/internal/aws/credentials"
    19  	"go.mongodb.org/mongo-driver/internal/uuid"
    20  )
    21  
    22  const (
    23  	// assumeRoleProviderName provides a name of assume role provider
    24  	assumeRoleProviderName = "AssumeRoleProvider"
    25  
    26  	stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
    27  )
    28  
    29  // An AssumeRoleProvider retrieves credentials for assume role with web identity.
    30  type AssumeRoleProvider struct {
    31  	AwsRoleArnEnv              EnvVar
    32  	AwsWebIdentityTokenFileEnv EnvVar
    33  	AwsRoleSessionNameEnv      EnvVar
    34  
    35  	httpClient *http.Client
    36  	expiration time.Time
    37  
    38  	// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
    39  	// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
    40  	//
    41  	// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
    42  	// 10 seconds before the credentials are actually expired.
    43  	expiryWindow time.Duration
    44  }
    45  
    46  // NewAssumeRoleProvider returns a pointer to an assume role provider.
    47  func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
    48  	return &AssumeRoleProvider{
    49  		// AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN
    50  		AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
    51  		// AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE
    52  		AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
    53  		// AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME
    54  		AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
    55  		httpClient:            httpClient,
    56  		expiryWindow:          expiryWindow,
    57  	}
    58  }
    59  
    60  // RetrieveWithContext retrieves the keys from the AWS service.
    61  func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
    62  	const defaultHTTPTimeout = 10 * time.Second
    63  
    64  	v := credentials.Value{ProviderName: assumeRoleProviderName}
    65  
    66  	roleArn := a.AwsRoleArnEnv.Get()
    67  	tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
    68  	if tokenFile == "" && roleArn == "" {
    69  		return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
    70  	}
    71  	if tokenFile != "" && roleArn == "" {
    72  		return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
    73  	}
    74  	if tokenFile == "" && roleArn != "" {
    75  		return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
    76  	}
    77  	token, err := ioutil.ReadFile(tokenFile)
    78  	if err != nil {
    79  		return v, err
    80  	}
    81  
    82  	sessionName := a.AwsRoleSessionNameEnv.Get()
    83  	if sessionName == "" {
    84  		// Use a UUID if the RoleSessionName is not given.
    85  		id, err := uuid.New()
    86  		if err != nil {
    87  			return v, err
    88  		}
    89  		sessionName = id.String()
    90  	}
    91  
    92  	fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))
    93  
    94  	req, err := http.NewRequest(http.MethodPost, fullURI, nil)
    95  	if err != nil {
    96  		return v, err
    97  	}
    98  	req.Header.Set("Accept", "application/json")
    99  
   100  	ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
   101  	defer cancel()
   102  	resp, err := a.httpClient.Do(req.WithContext(ctx))
   103  	if err != nil {
   104  		return v, err
   105  	}
   106  	defer resp.Body.Close()
   107  	if resp.StatusCode != http.StatusOK {
   108  		return v, fmt.Errorf("response failure: %s", resp.Status)
   109  	}
   110  
   111  	var stsResp struct {
   112  		Response struct {
   113  			Result struct {
   114  				Credentials struct {
   115  					AccessKeyID     string  `json:"AccessKeyId"`
   116  					SecretAccessKey string  `json:"SecretAccessKey"`
   117  					Token           string  `json:"SessionToken"`
   118  					Expiration      float64 `json:"Expiration"`
   119  				} `json:"Credentials"`
   120  			} `json:"AssumeRoleWithWebIdentityResult"`
   121  		} `json:"AssumeRoleWithWebIdentityResponse"`
   122  	}
   123  
   124  	err = json.NewDecoder(resp.Body).Decode(&stsResp)
   125  	if err != nil {
   126  		return v, err
   127  	}
   128  	v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
   129  	v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
   130  	v.SessionToken = stsResp.Response.Result.Credentials.Token
   131  	if !v.HasKeys() {
   132  		return v, errors.New("failed to retrieve web identity keys")
   133  	}
   134  	sec := int64(stsResp.Response.Result.Credentials.Expiration)
   135  	a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)
   136  
   137  	return v, nil
   138  }
   139  
   140  // Retrieve retrieves the keys from the AWS service.
   141  func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
   142  	return a.RetrieveWithContext(context.Background())
   143  }
   144  
   145  // IsExpired returns true if the credentials are expired.
   146  func (a *AssumeRoleProvider) IsExpired() bool {
   147  	return a.expiration.Before(time.Now())
   148  }