github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/internal/credproviders/imds_provider.go (about)

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package credproviders
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"net/http"
    15  	"net/url"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/internal"
    19  	"go.mongodb.org/mongo-driver/internal/aws/credentials"
    20  )
    21  
    22  const (
    23  	// AzureProviderName provides a name of Azure provider
    24  	AzureProviderName = "AzureProvider"
    25  
    26  	azureURI = "http://169.254.169.254/metadata/identity/oauth2/token"
    27  )
    28  
    29  // An AzureProvider retrieves credentials from Azure IMDS.
    30  type AzureProvider struct {
    31  	httpClient   *http.Client
    32  	expiration   time.Time
    33  	expiryWindow time.Duration
    34  }
    35  
    36  // NewAzureProvider returns a pointer to an Azure credential provider.
    37  func NewAzureProvider(httpClient *http.Client, expiryWindow time.Duration) *AzureProvider {
    38  	return &AzureProvider{
    39  		httpClient:   httpClient,
    40  		expiration:   time.Time{},
    41  		expiryWindow: expiryWindow,
    42  	}
    43  }
    44  
    45  // RetrieveWithContext retrieves the keys from the Azure service.
    46  func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
    47  	v := credentials.Value{ProviderName: AzureProviderName}
    48  	req, err := http.NewRequest(http.MethodGet, azureURI, nil)
    49  	if err != nil {
    50  		return v, internal.WrapErrorf(err, "unable to retrieve Azure credentials")
    51  	}
    52  	q := make(url.Values)
    53  	q.Set("api-version", "2018-02-01")
    54  	q.Set("resource", "https://vault.azure.net")
    55  	req.URL.RawQuery = q.Encode()
    56  	req.Header.Set("Metadata", "true")
    57  	req.Header.Set("Accept", "application/json")
    58  
    59  	resp, err := a.httpClient.Do(req.WithContext(ctx))
    60  	if err != nil {
    61  		return v, internal.WrapErrorf(err, "unable to retrieve Azure credentials")
    62  	}
    63  	defer resp.Body.Close()
    64  	body, err := ioutil.ReadAll(resp.Body)
    65  	if err != nil {
    66  		return v, internal.WrapErrorf(err, "unable to retrieve Azure credentials: error reading response body")
    67  	}
    68  	if resp.StatusCode != http.StatusOK {
    69  		return v, internal.WrapErrorf(err, "unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body)
    70  	}
    71  	var tokenResponse struct {
    72  		AccessToken string `json:"access_token"`
    73  		ExpiresIn   string `json:"expires_in"`
    74  	}
    75  	// Attempt to read body as JSON
    76  	err = json.Unmarshal(body, &tokenResponse)
    77  	if err != nil {
    78  		return v, internal.WrapErrorf(err, "unable to retrieve Azure credentials: error reading body JSON. Response body: %s", body)
    79  	}
    80  	if tokenResponse.AccessToken == "" {
    81  		return v, fmt.Errorf("unable to retrieve Azure credentials: got unexpected empty accessToken from Azure Metadata Server. Response body: %s", body)
    82  	}
    83  	v.SessionToken = tokenResponse.AccessToken
    84  
    85  	expiresIn, err := time.ParseDuration(tokenResponse.ExpiresIn + "s")
    86  	if err != nil {
    87  		return v, err
    88  	}
    89  	if expiration := expiresIn - a.expiryWindow; expiration > 0 {
    90  		a.expiration = time.Now().Add(expiration)
    91  	}
    92  
    93  	return v, err
    94  }
    95  
    96  // Retrieve retrieves the keys from the Azure service.
    97  func (a *AzureProvider) Retrieve() (credentials.Value, error) {
    98  	return a.RetrieveWithContext(context.Background())
    99  }
   100  
   101  // IsExpired returns if the credentials have been retrieved.
   102  func (a *AzureProvider) IsExpired() bool {
   103  	return a.expiration.Before(time.Now())
   104  }