github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/internal/credproviders/ec2_provider.go (about)

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package credproviders
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"errors"
    13  	"fmt"
    14  	"io/ioutil"
    15  	"net/http"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/internal/aws/credentials"
    19  )
    20  
    21  const (
    22  	// ec2ProviderName provides a name of EC2 provider
    23  	ec2ProviderName = "EC2Provider"
    24  
    25  	awsEC2URI       = "http://169.254.169.254/"
    26  	awsEC2RolePath  = "latest/meta-data/iam/security-credentials/"
    27  	awsEC2TokenPath = "latest/api/token"
    28  
    29  	defaultHTTPTimeout = 10 * time.Second
    30  )
    31  
    32  // An EC2Provider retrieves credentials from EC2 metadata.
    33  type EC2Provider struct {
    34  	httpClient *http.Client
    35  	expiration time.Time
    36  
    37  	// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
    38  	// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
    39  	//
    40  	// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
    41  	// 10 seconds before the credentials are actually expired.
    42  	expiryWindow time.Duration
    43  }
    44  
    45  // NewEC2Provider returns a pointer to an EC2 credential provider.
    46  func NewEC2Provider(httpClient *http.Client, expiryWindow time.Duration) *EC2Provider {
    47  	return &EC2Provider{
    48  		httpClient:   httpClient,
    49  		expiryWindow: expiryWindow,
    50  	}
    51  }
    52  
    53  func (e *EC2Provider) getToken(ctx context.Context) (string, error) {
    54  	req, err := http.NewRequest(http.MethodPut, awsEC2URI+awsEC2TokenPath, nil)
    55  	if err != nil {
    56  		return "", err
    57  	}
    58  	const defaultEC2TTLSeconds = "30"
    59  	req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", defaultEC2TTLSeconds)
    60  
    61  	ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
    62  	defer cancel()
    63  	resp, err := e.httpClient.Do(req.WithContext(ctx))
    64  	if err != nil {
    65  		return "", err
    66  	}
    67  	defer resp.Body.Close()
    68  	if resp.StatusCode != http.StatusOK {
    69  		return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
    70  	}
    71  
    72  	token, err := ioutil.ReadAll(resp.Body)
    73  	if err != nil {
    74  		return "", err
    75  	}
    76  	if len(token) == 0 {
    77  		return "", errors.New("unable to retrieve token from EC2 metadata")
    78  	}
    79  	return string(token), nil
    80  }
    81  
    82  func (e *EC2Provider) getRoleName(ctx context.Context, token string) (string, error) {
    83  	req, err := http.NewRequest(http.MethodGet, awsEC2URI+awsEC2RolePath, nil)
    84  	if err != nil {
    85  		return "", err
    86  	}
    87  	req.Header.Set("X-aws-ec2-metadata-token", token)
    88  
    89  	ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
    90  	defer cancel()
    91  	resp, err := e.httpClient.Do(req.WithContext(ctx))
    92  	if err != nil {
    93  		return "", err
    94  	}
    95  	defer resp.Body.Close()
    96  	if resp.StatusCode != http.StatusOK {
    97  		return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
    98  	}
    99  
   100  	role, err := ioutil.ReadAll(resp.Body)
   101  	if err != nil {
   102  		return "", err
   103  	}
   104  	if len(role) == 0 {
   105  		return "", errors.New("unable to retrieve role_name from EC2 metadata")
   106  	}
   107  	return string(role), nil
   108  }
   109  
   110  func (e *EC2Provider) getCredentials(ctx context.Context, token string, role string) (credentials.Value, time.Time, error) {
   111  	v := credentials.Value{ProviderName: ec2ProviderName}
   112  
   113  	pathWithRole := awsEC2URI + awsEC2RolePath + role
   114  	req, err := http.NewRequest(http.MethodGet, pathWithRole, nil)
   115  	if err != nil {
   116  		return v, time.Time{}, err
   117  	}
   118  	req.Header.Set("X-aws-ec2-metadata-token", token)
   119  	ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
   120  	defer cancel()
   121  	resp, err := e.httpClient.Do(req.WithContext(ctx))
   122  	if err != nil {
   123  		return v, time.Time{}, err
   124  	}
   125  	defer resp.Body.Close()
   126  	if resp.StatusCode != http.StatusOK {
   127  		return v, time.Time{}, fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
   128  	}
   129  
   130  	var ec2Resp struct {
   131  		AccessKeyID     string    `json:"AccessKeyId"`
   132  		SecretAccessKey string    `json:"SecretAccessKey"`
   133  		Token           string    `json:"Token"`
   134  		Expiration      time.Time `json:"Expiration"`
   135  	}
   136  
   137  	err = json.NewDecoder(resp.Body).Decode(&ec2Resp)
   138  	if err != nil {
   139  		return v, time.Time{}, err
   140  	}
   141  
   142  	v.AccessKeyID = ec2Resp.AccessKeyID
   143  	v.SecretAccessKey = ec2Resp.SecretAccessKey
   144  	v.SessionToken = ec2Resp.Token
   145  
   146  	return v, ec2Resp.Expiration, nil
   147  }
   148  
   149  // RetrieveWithContext retrieves the keys from the AWS service.
   150  func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
   151  	v := credentials.Value{ProviderName: ec2ProviderName}
   152  
   153  	token, err := e.getToken(ctx)
   154  	if err != nil {
   155  		return v, err
   156  	}
   157  
   158  	role, err := e.getRoleName(ctx, token)
   159  	if err != nil {
   160  		return v, err
   161  	}
   162  
   163  	v, exp, err := e.getCredentials(ctx, token, role)
   164  	if err != nil {
   165  		return v, err
   166  	}
   167  	if !v.HasKeys() {
   168  		return v, errors.New("failed to retrieve EC2 keys")
   169  	}
   170  	e.expiration = exp.Add(-e.expiryWindow)
   171  
   172  	return v, nil
   173  }
   174  
   175  // Retrieve retrieves the keys from the AWS service.
   176  func (e *EC2Provider) Retrieve() (credentials.Value, error) {
   177  	return e.RetrieveWithContext(context.Background())
   178  }
   179  
   180  // IsExpired returns true if the credentials are expired.
   181  func (e *EC2Provider) IsExpired() bool {
   182  	return e.expiration.Before(time.Now())
   183  }