github.com/blend/go-sdk@v1.20220411.3/vault/aws_auth.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package vault
     9  
    10  import (
    11  	"bytes"
    12  	"context"
    13  	"encoding/base64"
    14  	"encoding/json"
    15  	"io"
    16  	"net/http"
    17  	"net/url"
    18  
    19  	"github.com/aws/aws-sdk-go/aws"
    20  	"github.com/aws/aws-sdk-go/aws/credentials"
    21  	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
    22  	"github.com/aws/aws-sdk-go/aws/session"
    23  	"github.com/aws/aws-sdk-go/service/sts"
    24  
    25  	"github.com/blend/go-sdk/ex"
    26  )
    27  
    28  // AWSAuth defines vault aws auth methods
    29  type AWSAuth struct {
    30  	CredentialProvider CredentialProvider
    31  }
    32  
    33  // NewAWSAuth creates a new AWS struct
    34  func NewAWSAuth(opts ...AWSAuthOption) (*AWSAuth, error) {
    35  	auth := &AWSAuth{
    36  		CredentialProvider: GetIAMAuthCredentials,
    37  	}
    38  	var err error
    39  	for _, opt := range opts {
    40  		if err = opt(auth); err != nil {
    41  			return nil, err
    42  		}
    43  	}
    44  	return auth, nil
    45  }
    46  
    47  // AWSAuthOption mutates an AWSAuth instance
    48  type AWSAuthOption func(*AWSAuth) error
    49  
    50  // CredentialProvider defines the credential provider func interface
    51  type CredentialProvider func(roleARN string) (*credentials.Credentials, error)
    52  
    53  // OptAWSAuthCredentialProvider sets the credential provider
    54  func OptAWSAuthCredentialProvider(cp CredentialProvider) AWSAuthOption {
    55  	return func(a *AWSAuth) error {
    56  		a.CredentialProvider = cp
    57  		return nil
    58  	}
    59  }
    60  
    61  // AWSIAMLogin returns a vault token given the instance role which invokes this function
    62  func (a *AWSAuth) AWSIAMLogin(ctx context.Context, client HTTPClient, baseURL url.URL, roleName, roleARN, service, region string) (string, error) {
    63  	stsRequest, err := a.GetCallerIdentitySignedRequest(roleARN, service, region)
    64  	if err != nil {
    65  		return "", ex.New(err)
    66  	}
    67  
    68  	request, err := createVaultLoginRequest(roleName, baseURL, stsRequest)
    69  	if err != nil {
    70  		return "", ex.New(err)
    71  	}
    72  
    73  	res, err := client.Do(request)
    74  	if err != nil {
    75  		return "", ex.New(err)
    76  	}
    77  	defer res.Body.Close()
    78  
    79  	var response AWSAuthResponse
    80  	if err := json.NewDecoder(res.Body).Decode(&response); err != nil {
    81  		return "", ex.New(err)
    82  	}
    83  	if len(response.Errors) > 0 {
    84  		return "", ex.New("Error making aws get identity request", ex.OptMessagef("%+v", response.Errors))
    85  	}
    86  
    87  	return response.Auth.ClientToken, nil
    88  }
    89  
    90  // GetCallerIdentitySignedRequest gets a signed caller identity request
    91  func (a *AWSAuth) GetCallerIdentitySignedRequest(roleARN, service, region string) (*http.Request, error) {
    92  	credentials, err := a.CredentialProvider(roleARN)
    93  	if err != nil {
    94  		return nil, ex.New(err)
    95  	}
    96  
    97  	stsSession, err := session.NewSessionWithOptions(session.Options{
    98  		Config: aws.Config{
    99  			Credentials: credentials,
   100  			Region:      &region,
   101  		},
   102  	})
   103  	if err != nil {
   104  		return nil, ex.New(err)
   105  	}
   106  
   107  	svc := sts.New(stsSession)
   108  	stsRequest, _ := svc.GetCallerIdentityRequest(nil)
   109  	err = stsRequest.Sign()
   110  	if err != nil {
   111  		return nil, ex.New(err)
   112  	}
   113  
   114  	return stsRequest.HTTPRequest, nil
   115  }
   116  
   117  // GetIAMAuthCredentials is a credential provider to be passed in as input into the AWSAuth struct
   118  func GetIAMAuthCredentials(roleARN string) (*credentials.Credentials, error) {
   119  	session, err := session.NewSession()
   120  	if err != nil {
   121  		return nil, ex.New(err)
   122  	}
   123  	credentials := stscreds.NewCredentials(session, roleARN)
   124  	return credentials, nil
   125  }
   126  
   127  func createVaultLoginRequest(roleName string, baseURL url.URL, request *http.Request) (*http.Request, error) {
   128  	baseURL.Path = AWSAuthLoginPath
   129  	stsHeaders, err := json.Marshal(request.Header)
   130  	if err != nil {
   131  		return nil, ex.New(err)
   132  	}
   133  
   134  	body := map[string]string{
   135  		"role":                    roleName,
   136  		"iam_http_request_method": MethodPost,
   137  		"iam_request_url":         base64.StdEncoding.EncodeToString([]byte(request.URL.String())),
   138  		"iam_request_body":        base64.StdEncoding.EncodeToString([]byte(STSGetIdentityBody)),
   139  		"iam_request_headers":     base64.StdEncoding.EncodeToString(stsHeaders),
   140  	}
   141  
   142  	contents, err := json.Marshal(body)
   143  	if err != nil {
   144  		return nil, ex.New(err)
   145  	}
   146  
   147  	req := &http.Request{
   148  		URL:    &baseURL,
   149  		Method: MethodPost,
   150  		Body:   io.NopCloser(bytes.NewReader(contents)),
   151  	}
   152  
   153  	req.GetBody = func() (io.ReadCloser, error) {
   154  		r := bytes.NewReader(contents)
   155  		return io.NopCloser(r), nil
   156  	}
   157  
   158  	req.ContentLength = int64(len(contents))
   159  	if req.Header == nil {
   160  		req.Header = make(http.Header)
   161  	}
   162  	req.Header.Set(HeaderContentType, ContentTypeApplicationJSON)
   163  
   164  	return req, nil
   165  }