github.com/argoproj/argo-cd/v3@v3.2.1/util/helm/creds.go (about)

     1  package helm
     2  
     3  import (
     4  	"crypto/tls"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/golang-jwt/jwt/v5"
    15  	gocache "github.com/patrickmn/go-cache"
    16  	log "github.com/sirupsen/logrus"
    17  
    18  	argoutils "github.com/argoproj/argo-cd/v3/util"
    19  	"github.com/argoproj/argo-cd/v3/util/env"
    20  	"github.com/argoproj/argo-cd/v3/util/workloadidentity"
    21  )
    22  
    23  // In memory cache for storing Azure tokens
    24  var azureTokenCache *gocache.Cache
    25  
    26  func init() {
    27  	azureTokenCache = gocache.New(gocache.NoExpiration, 0)
    28  }
    29  
    30  // StoreToken stores a token in the cache
    31  func storeAzureToken(key, token string, expiration time.Duration) {
    32  	azureTokenCache.Set(key, token, expiration)
    33  }
    34  
    35  type Creds interface {
    36  	GetUsername() string
    37  	GetPassword() (string, error)
    38  	GetCAPath() string
    39  	GetCertData() []byte
    40  	GetKeyData() []byte
    41  	GetInsecureSkipVerify() bool
    42  }
    43  
    44  var _ Creds = HelmCreds{}
    45  
    46  type HelmCreds struct {
    47  	Username           string
    48  	Password           string
    49  	CAPath             string
    50  	CertData           []byte
    51  	KeyData            []byte
    52  	InsecureSkipVerify bool
    53  }
    54  
    55  func (creds HelmCreds) GetUsername() string {
    56  	return creds.Username
    57  }
    58  
    59  func (creds HelmCreds) GetPassword() (string, error) {
    60  	return creds.Password, nil
    61  }
    62  
    63  func (creds HelmCreds) GetCAPath() string {
    64  	return creds.CAPath
    65  }
    66  
    67  func (creds HelmCreds) GetCertData() []byte {
    68  	return creds.CertData
    69  }
    70  
    71  func (creds HelmCreds) GetKeyData() []byte {
    72  	return creds.KeyData
    73  }
    74  
    75  func (creds HelmCreds) GetInsecureSkipVerify() bool {
    76  	return creds.InsecureSkipVerify
    77  }
    78  
    79  var _ Creds = AzureWorkloadIdentityCreds{}
    80  
    81  type AzureWorkloadIdentityCreds struct {
    82  	repoURL            string
    83  	CAPath             string
    84  	CertData           []byte
    85  	KeyData            []byte
    86  	InsecureSkipVerify bool
    87  	tokenProvider      workloadidentity.TokenProvider
    88  }
    89  
    90  func (creds AzureWorkloadIdentityCreds) GetUsername() string {
    91  	return workloadidentity.EmptyGuid
    92  }
    93  
    94  func (creds AzureWorkloadIdentityCreds) GetPassword() (string, error) {
    95  	return creds.GetAccessToken()
    96  }
    97  
    98  func (creds AzureWorkloadIdentityCreds) GetCAPath() string {
    99  	return creds.CAPath
   100  }
   101  
   102  func (creds AzureWorkloadIdentityCreds) GetCertData() []byte {
   103  	return creds.CertData
   104  }
   105  
   106  func (creds AzureWorkloadIdentityCreds) GetKeyData() []byte {
   107  	return creds.KeyData
   108  }
   109  
   110  func (creds AzureWorkloadIdentityCreds) GetInsecureSkipVerify() bool {
   111  	return creds.InsecureSkipVerify
   112  }
   113  
   114  func NewAzureWorkloadIdentityCreds(repoURL string, caPath string, certData []byte, keyData []byte, insecureSkipVerify bool, tokenProvider workloadidentity.TokenProvider) AzureWorkloadIdentityCreds {
   115  	return AzureWorkloadIdentityCreds{
   116  		repoURL:            repoURL,
   117  		CAPath:             caPath,
   118  		CertData:           certData,
   119  		KeyData:            keyData,
   120  		InsecureSkipVerify: insecureSkipVerify,
   121  		tokenProvider:      tokenProvider,
   122  	}
   123  }
   124  
   125  func (creds AzureWorkloadIdentityCreds) GetAccessToken() (string, error) {
   126  	registryHost := strings.Split(creds.repoURL, "/")[0]
   127  
   128  	// Compute hash as key for refresh token in the cache
   129  	key, err := argoutils.GenerateCacheKey("accesstoken-%s", registryHost)
   130  	if err != nil {
   131  		return "", fmt.Errorf("failed to compute key for cache: %w", err)
   132  	}
   133  
   134  	// Check cache for GitHub transport which helps fetch an API token
   135  	t, found := azureTokenCache.Get(key)
   136  	if found {
   137  		fmt.Println("access token found token in cache")
   138  		return t.(string), nil
   139  	}
   140  
   141  	tokenParams, err := creds.challengeAzureContainerRegistry(registryHost)
   142  	if err != nil {
   143  		return "", fmt.Errorf("failed to challenge Azure Container Registry: %w", err)
   144  	}
   145  
   146  	token, err := creds.getAccessTokenAfterChallenge(tokenParams)
   147  	if err != nil {
   148  		return "", fmt.Errorf("failed to get Azure access token after challenge: %w", err)
   149  	}
   150  
   151  	tokenExpiry, err := getJWTExpiry(token)
   152  	if err != nil {
   153  		log.Warnf("failed to get token expiry from JWT: %v, using current time as fallback", err)
   154  		tokenExpiry = time.Now()
   155  	}
   156  
   157  	cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(tokenExpiry)
   158  	if cacheExpiry > 0 {
   159  		storeAzureToken(key, token, cacheExpiry)
   160  	}
   161  	return token, nil
   162  }
   163  
   164  func getJWTExpiry(token string) (time.Time, error) {
   165  	parser := jwt.NewParser()
   166  	claims := jwt.MapClaims{}
   167  	_, _, err := parser.ParseUnverified(token, claims)
   168  	if err != nil {
   169  		return time.Time{}, fmt.Errorf("failed to parse JWT: %w", err)
   170  	}
   171  	exp, err := claims.GetExpirationTime()
   172  	if err != nil {
   173  		return time.Time{}, fmt.Errorf("'exp' claim not found or invalid in token: %w", err)
   174  	}
   175  	return time.UnixMilli(exp.UnixMilli()), nil
   176  }
   177  
   178  func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams map[string]string) (string, error) {
   179  	realm := tokenParams["realm"]
   180  	service := tokenParams["service"]
   181  
   182  	armTokenScope := env.StringFromEnv("AZURE_ARM_TOKEN_RESOURCE", "https://management.core.windows.net")
   183  	armAccessToken, err := creds.tokenProvider.GetToken(armTokenScope + "/.default")
   184  	if err != nil {
   185  		return "", fmt.Errorf("failed to get Azure access token: %w", err)
   186  	}
   187  
   188  	parsedURL, _ := url.Parse(realm)
   189  	parsedURL.Path = "/oauth2/exchange"
   190  	refreshTokenURL := parsedURL.String()
   191  
   192  	client := &http.Client{
   193  		Timeout: 10 * time.Second,
   194  		Transport: &http.Transport{
   195  			TLSClientConfig: &tls.Config{
   196  				InsecureSkipVerify: creds.GetInsecureSkipVerify(),
   197  			},
   198  		},
   199  	}
   200  
   201  	formValues := url.Values{}
   202  	formValues.Add("grant_type", "access_token")
   203  	formValues.Add("service", service)
   204  	formValues.Add("access_token", armAccessToken.AccessToken)
   205  
   206  	resp, err := client.PostForm(refreshTokenURL, formValues)
   207  	if err != nil {
   208  		return "", fmt.Errorf("unable to connect to registry '%w'", err)
   209  	}
   210  
   211  	defer resp.Body.Close()
   212  	body, err := io.ReadAll(resp.Body)
   213  
   214  	if resp.StatusCode != http.StatusOK {
   215  		return "", fmt.Errorf("failed to get refresh token: %s", resp.Status)
   216  	}
   217  
   218  	if err != nil {
   219  		return "", fmt.Errorf("failed to read response body: %w", err)
   220  	}
   221  
   222  	type Response struct {
   223  		RefreshToken string `json:"refresh_token"`
   224  	}
   225  
   226  	var res Response
   227  	err = json.Unmarshal(body, &res)
   228  	if err != nil {
   229  		return "", fmt.Errorf("failed to unmarshal response body: %w", err)
   230  	}
   231  
   232  	return res.RefreshToken, nil
   233  }
   234  
   235  func (creds AzureWorkloadIdentityCreds) challengeAzureContainerRegistry(azureContainerRegistry string) (map[string]string, error) {
   236  	requestURL := fmt.Sprintf("https://%s/v2/", azureContainerRegistry)
   237  
   238  	client := &http.Client{
   239  		Timeout: 10 * time.Second,
   240  		Transport: &http.Transport{
   241  			TLSClientConfig: &tls.Config{
   242  				InsecureSkipVerify: creds.GetInsecureSkipVerify(),
   243  			},
   244  		},
   245  	}
   246  
   247  	req, err := http.NewRequest(http.MethodGet, requestURL, http.NoBody)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  
   252  	resp, err := client.Do(req)
   253  	if err != nil {
   254  		return nil, fmt.Errorf("unable to connect to registry '%w'", err)
   255  	}
   256  
   257  	defer resp.Body.Close()
   258  
   259  	if resp.StatusCode != http.StatusUnauthorized || resp.Header.Get("Www-Authenticate") == "" {
   260  		return nil, fmt.Errorf("registry '%s' did not issue a challenge", azureContainerRegistry)
   261  	}
   262  
   263  	authenticate := resp.Header.Get("Www-Authenticate")
   264  	tokens := strings.Split(authenticate, " ")
   265  
   266  	if !strings.EqualFold(tokens[0], "bearer") {
   267  		return nil, fmt.Errorf("registry does not allow 'Bearer' authentication, got '%s'", tokens[0])
   268  	}
   269  
   270  	tokenParams := make(map[string]string)
   271  
   272  	for _, token := range strings.Split(tokens[1], ",") {
   273  		kvPair := strings.Split(token, "=")
   274  		tokenParams[kvPair[0]] = strings.Trim(kvPair[1], "\"")
   275  	}
   276  
   277  	if _, realmExists := tokenParams["realm"]; !realmExists {
   278  		return nil, errors.New("realm parameter not found in challenge")
   279  	}
   280  
   281  	if _, serviceExists := tokenParams["service"]; !serviceExists {
   282  		return nil, errors.New("service parameter not found in challenge")
   283  	}
   284  
   285  	return tokenParams, nil
   286  }