github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/endpointcreds/provider.go (about)

     1  // Package endpointcreds provides support for retrieving credentials from an
     2  // arbitrary HTTP endpoint.
     3  //
     4  // The credentials endpoint Provider can receive both static and refreshable
     5  // credentials that will expire. Credentials are static when an "Expiration"
     6  // value is not provided in the endpoint's response.
     7  //
     8  // Static credentials will never expire once they have been retrieved. The format
     9  // of the static credentials response:
    10  //    {
    11  //        "AccessKeyId" : "MUA...",
    12  //        "SecretAccessKey" : "/7PC5om....",
    13  //    }
    14  //
    15  // Refreshable credentials will expire within the "ExpiryWindow" of the Expiration
    16  // value in the response. The format of the refreshable credentials response:
    17  //    {
    18  //        "AccessKeyId" : "MUA...",
    19  //        "SecretAccessKey" : "/7PC5om....",
    20  //        "Token" : "AQoDY....=",
    21  //        "Expiration" : "2016-02-25T06:03:31Z"
    22  //    }
    23  //
    24  // Errors should be returned in the following format and only returned with 400
    25  // or 500 HTTP status codes.
    26  //    {
    27  //        "code": "ErrorCode",
    28  //        "message": "Helpful error message."
    29  //    }
    30  package endpointcreds
    31  
    32  import (
    33  	"encoding/json"
    34  	"time"
    35  
    36  	"github.com/aavshr/aws-sdk-go/aws"
    37  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    38  	"github.com/aavshr/aws-sdk-go/aws/client"
    39  	"github.com/aavshr/aws-sdk-go/aws/client/metadata"
    40  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    41  	"github.com/aavshr/aws-sdk-go/aws/request"
    42  	"github.com/aavshr/aws-sdk-go/private/protocol/json/jsonutil"
    43  )
    44  
    45  // ProviderName is the name of the credentials provider.
    46  const ProviderName = `CredentialsEndpointProvider`
    47  
    48  // Provider satisfies the credentials.Provider interface, and is a client to
    49  // retrieve credentials from an arbitrary endpoint.
    50  type Provider struct {
    51  	staticCreds bool
    52  	credentials.Expiry
    53  
    54  	// Requires a AWS Client to make HTTP requests to the endpoint with.
    55  	// the Endpoint the request will be made to is provided by the aws.Config's
    56  	// Endpoint value.
    57  	Client *client.Client
    58  
    59  	// ExpiryWindow will allow the credentials to trigger refreshing prior to
    60  	// the credentials actually expiring. This is beneficial so race conditions
    61  	// with expiring credentials do not cause request to fail unexpectedly
    62  	// due to ExpiredTokenException exceptions.
    63  	//
    64  	// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
    65  	// 10 seconds before the credentials are actually expired.
    66  	//
    67  	// If ExpiryWindow is 0 or less it will be ignored.
    68  	ExpiryWindow time.Duration
    69  
    70  	// Optional authorization token value if set will be used as the value of
    71  	// the Authorization header of the endpoint credential request.
    72  	AuthorizationToken string
    73  }
    74  
    75  // NewProviderClient returns a credentials Provider for retrieving AWS credentials
    76  // from arbitrary endpoint.
    77  func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) credentials.Provider {
    78  	p := &Provider{
    79  		Client: client.New(
    80  			cfg,
    81  			metadata.ClientInfo{
    82  				ServiceName: "CredentialsEndpoint",
    83  				Endpoint:    endpoint,
    84  			},
    85  			handlers,
    86  		),
    87  	}
    88  
    89  	p.Client.Handlers.Unmarshal.PushBack(unmarshalHandler)
    90  	p.Client.Handlers.UnmarshalError.PushBack(unmarshalError)
    91  	p.Client.Handlers.Validate.Clear()
    92  	p.Client.Handlers.Validate.PushBack(validateEndpointHandler)
    93  
    94  	for _, option := range options {
    95  		option(p)
    96  	}
    97  
    98  	return p
    99  }
   100  
   101  // NewCredentialsClient returns a pointer to a new Credentials object
   102  // wrapping the endpoint credentials Provider.
   103  func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
   104  	return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
   105  }
   106  
   107  // IsExpired returns true if the credentials retrieved are expired, or not yet
   108  // retrieved.
   109  func (p *Provider) IsExpired() bool {
   110  	if p.staticCreds {
   111  		return false
   112  	}
   113  	return p.Expiry.IsExpired()
   114  }
   115  
   116  // Retrieve will attempt to request the credentials from the endpoint the Provider
   117  // was configured for. And error will be returned if the retrieval fails.
   118  func (p *Provider) Retrieve() (credentials.Value, error) {
   119  	return p.RetrieveWithContext(aws.BackgroundContext())
   120  }
   121  
   122  // RetrieveWithContext will attempt to request the credentials from the endpoint the Provider
   123  // was configured for. And error will be returned if the retrieval fails.
   124  func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
   125  	resp, err := p.getCredentials(ctx)
   126  	if err != nil {
   127  		return credentials.Value{ProviderName: ProviderName},
   128  			awserr.New("CredentialsEndpointError", "failed to load credentials", err)
   129  	}
   130  
   131  	if resp.Expiration != nil {
   132  		p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
   133  	} else {
   134  		p.staticCreds = true
   135  	}
   136  
   137  	return credentials.Value{
   138  		AccessKeyID:     resp.AccessKeyID,
   139  		SecretAccessKey: resp.SecretAccessKey,
   140  		SessionToken:    resp.Token,
   141  		ProviderName:    ProviderName,
   142  	}, nil
   143  }
   144  
   145  type getCredentialsOutput struct {
   146  	Expiration      *time.Time
   147  	AccessKeyID     string
   148  	SecretAccessKey string
   149  	Token           string
   150  }
   151  
   152  type errorOutput struct {
   153  	Code    string `json:"code"`
   154  	Message string `json:"message"`
   155  }
   156  
   157  func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error) {
   158  	op := &request.Operation{
   159  		Name:       "GetCredentials",
   160  		HTTPMethod: "GET",
   161  	}
   162  
   163  	out := &getCredentialsOutput{}
   164  	req := p.Client.NewRequest(op, nil, out)
   165  	req.SetContext(ctx)
   166  	req.HTTPRequest.Header.Set("Accept", "application/json")
   167  	if authToken := p.AuthorizationToken; len(authToken) != 0 {
   168  		req.HTTPRequest.Header.Set("Authorization", authToken)
   169  	}
   170  
   171  	return out, req.Send()
   172  }
   173  
   174  func validateEndpointHandler(r *request.Request) {
   175  	if len(r.ClientInfo.Endpoint) == 0 {
   176  		r.Error = aws.ErrMissingEndpoint
   177  	}
   178  }
   179  
   180  func unmarshalHandler(r *request.Request) {
   181  	defer r.HTTPResponse.Body.Close()
   182  
   183  	out := r.Data.(*getCredentialsOutput)
   184  	if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
   185  		r.Error = awserr.New(request.ErrCodeSerialization,
   186  			"failed to decode endpoint credentials",
   187  			err,
   188  		)
   189  	}
   190  }
   191  
   192  func unmarshalError(r *request.Request) {
   193  	defer r.HTTPResponse.Body.Close()
   194  
   195  	var errOut errorOutput
   196  	err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body)
   197  	if err != nil {
   198  		r.Error = awserr.NewRequestFailure(
   199  			awserr.New(request.ErrCodeSerialization,
   200  				"failed to decode error message", err),
   201  			r.HTTPResponse.StatusCode,
   202  			r.RequestID,
   203  		)
   204  		return
   205  	}
   206  
   207  	// Response body format is not consistent between metadata endpoints.
   208  	// Grab the error message as a string and include that as the source error
   209  	r.Error = awserr.New(errOut.Code, errOut.Message, nil)
   210  }