github.com/blend/go-sdk@v1.20220411.3/vault/aws_auth.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package vault 9 10 import ( 11 "bytes" 12 "context" 13 "encoding/base64" 14 "encoding/json" 15 "io" 16 "net/http" 17 "net/url" 18 19 "github.com/aws/aws-sdk-go/aws" 20 "github.com/aws/aws-sdk-go/aws/credentials" 21 "github.com/aws/aws-sdk-go/aws/credentials/stscreds" 22 "github.com/aws/aws-sdk-go/aws/session" 23 "github.com/aws/aws-sdk-go/service/sts" 24 25 "github.com/blend/go-sdk/ex" 26 ) 27 28 // AWSAuth defines vault aws auth methods 29 type AWSAuth struct { 30 CredentialProvider CredentialProvider 31 } 32 33 // NewAWSAuth creates a new AWS struct 34 func NewAWSAuth(opts ...AWSAuthOption) (*AWSAuth, error) { 35 auth := &AWSAuth{ 36 CredentialProvider: GetIAMAuthCredentials, 37 } 38 var err error 39 for _, opt := range opts { 40 if err = opt(auth); err != nil { 41 return nil, err 42 } 43 } 44 return auth, nil 45 } 46 47 // AWSAuthOption mutates an AWSAuth instance 48 type AWSAuthOption func(*AWSAuth) error 49 50 // CredentialProvider defines the credential provider func interface 51 type CredentialProvider func(roleARN string) (*credentials.Credentials, error) 52 53 // OptAWSAuthCredentialProvider sets the credential provider 54 func OptAWSAuthCredentialProvider(cp CredentialProvider) AWSAuthOption { 55 return func(a *AWSAuth) error { 56 a.CredentialProvider = cp 57 return nil 58 } 59 } 60 61 // AWSIAMLogin returns a vault token given the instance role which invokes this function 62 func (a *AWSAuth) AWSIAMLogin(ctx context.Context, client HTTPClient, baseURL url.URL, roleName, roleARN, service, region string) (string, error) { 63 stsRequest, err := a.GetCallerIdentitySignedRequest(roleARN, service, region) 64 if err != nil { 65 return "", ex.New(err) 66 } 67 68 request, err := createVaultLoginRequest(roleName, baseURL, stsRequest) 69 if err != nil { 70 return "", ex.New(err) 71 } 72 73 res, err := client.Do(request) 74 if err != nil { 75 return "", ex.New(err) 76 } 77 defer res.Body.Close() 78 79 var response AWSAuthResponse 80 if err := json.NewDecoder(res.Body).Decode(&response); err != nil { 81 return "", ex.New(err) 82 } 83 if len(response.Errors) > 0 { 84 return "", ex.New("Error making aws get identity request", ex.OptMessagef("%+v", response.Errors)) 85 } 86 87 return response.Auth.ClientToken, nil 88 } 89 90 // GetCallerIdentitySignedRequest gets a signed caller identity request 91 func (a *AWSAuth) GetCallerIdentitySignedRequest(roleARN, service, region string) (*http.Request, error) { 92 credentials, err := a.CredentialProvider(roleARN) 93 if err != nil { 94 return nil, ex.New(err) 95 } 96 97 stsSession, err := session.NewSessionWithOptions(session.Options{ 98 Config: aws.Config{ 99 Credentials: credentials, 100 Region: ®ion, 101 }, 102 }) 103 if err != nil { 104 return nil, ex.New(err) 105 } 106 107 svc := sts.New(stsSession) 108 stsRequest, _ := svc.GetCallerIdentityRequest(nil) 109 err = stsRequest.Sign() 110 if err != nil { 111 return nil, ex.New(err) 112 } 113 114 return stsRequest.HTTPRequest, nil 115 } 116 117 // GetIAMAuthCredentials is a credential provider to be passed in as input into the AWSAuth struct 118 func GetIAMAuthCredentials(roleARN string) (*credentials.Credentials, error) { 119 session, err := session.NewSession() 120 if err != nil { 121 return nil, ex.New(err) 122 } 123 credentials := stscreds.NewCredentials(session, roleARN) 124 return credentials, nil 125 } 126 127 func createVaultLoginRequest(roleName string, baseURL url.URL, request *http.Request) (*http.Request, error) { 128 baseURL.Path = AWSAuthLoginPath 129 stsHeaders, err := json.Marshal(request.Header) 130 if err != nil { 131 return nil, ex.New(err) 132 } 133 134 body := map[string]string{ 135 "role": roleName, 136 "iam_http_request_method": MethodPost, 137 "iam_request_url": base64.StdEncoding.EncodeToString([]byte(request.URL.String())), 138 "iam_request_body": base64.StdEncoding.EncodeToString([]byte(STSGetIdentityBody)), 139 "iam_request_headers": base64.StdEncoding.EncodeToString(stsHeaders), 140 } 141 142 contents, err := json.Marshal(body) 143 if err != nil { 144 return nil, ex.New(err) 145 } 146 147 req := &http.Request{ 148 URL: &baseURL, 149 Method: MethodPost, 150 Body: io.NopCloser(bytes.NewReader(contents)), 151 } 152 153 req.GetBody = func() (io.ReadCloser, error) { 154 r := bytes.NewReader(contents) 155 return io.NopCloser(r), nil 156 } 157 158 req.ContentLength = int64(len(contents)) 159 if req.Header == nil { 160 req.Header = make(http.Header) 161 } 162 req.Header.Set(HeaderContentType, ContentTypeApplicationJSON) 163 164 return req, nil 165 }