github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/stscreds/web_identity_provider_test.go (about) 1 //go:build go1.7 2 // +build go1.7 3 4 package stscreds_test 5 6 import ( 7 "net/http" 8 "reflect" 9 "strings" 10 "testing" 11 "time" 12 13 "github.com/aavshr/aws-sdk-go/aws" 14 "github.com/aavshr/aws-sdk-go/aws/awserr" 15 "github.com/aavshr/aws-sdk-go/aws/corehandlers" 16 "github.com/aavshr/aws-sdk-go/aws/credentials" 17 "github.com/aavshr/aws-sdk-go/aws/credentials/stscreds" 18 "github.com/aavshr/aws-sdk-go/aws/request" 19 "github.com/aavshr/aws-sdk-go/awstesting/unit" 20 "github.com/aavshr/aws-sdk-go/service/sts" 21 ) 22 23 func TestWebIdentityProviderRetrieve(t *testing.T) { 24 var reqCount int 25 cases := map[string]struct { 26 onSendReq func(*testing.T, *request.Request) 27 roleARN string 28 tokenFilepath string 29 sessionName string 30 duration time.Duration 31 expectedError string 32 expectedCredValue credentials.Value 33 }{ 34 "session name case": { 35 roleARN: "arn01234567890123456789", 36 tokenFilepath: "testdata/token.jwt", 37 sessionName: "foo", 38 onSendReq: func(t *testing.T, r *request.Request) { 39 input := r.Params.(*sts.AssumeRoleWithWebIdentityInput) 40 if e, a := "foo", *input.RoleSessionName; e != a { 41 t.Errorf("expected %v, but received %v", e, a) 42 } 43 if input.DurationSeconds != nil { 44 t.Errorf("expect no duration, got %v", *input.DurationSeconds) 45 } 46 47 data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput) 48 *data = sts.AssumeRoleWithWebIdentityOutput{ 49 Credentials: &sts.Credentials{ 50 Expiration: aws.Time(time.Now()), 51 AccessKeyId: aws.String("access-key-id"), 52 SecretAccessKey: aws.String("secret-access-key"), 53 SessionToken: aws.String("session-token"), 54 }, 55 } 56 }, 57 expectedCredValue: credentials.Value{ 58 AccessKeyID: "access-key-id", 59 SecretAccessKey: "secret-access-key", 60 SessionToken: "session-token", 61 ProviderName: stscreds.WebIdentityProviderName, 62 }, 63 }, 64 "with duration": { 65 roleARN: "arn01234567890123456789", 66 tokenFilepath: "testdata/token.jwt", 67 sessionName: "foo", 68 duration: 15 * time.Minute, 69 onSendReq: func(t *testing.T, r *request.Request) { 70 input := r.Params.(*sts.AssumeRoleWithWebIdentityInput) 71 if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a { 72 t.Errorf("expect %v duration, got %v", e, a) 73 } 74 75 data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput) 76 *data = sts.AssumeRoleWithWebIdentityOutput{ 77 Credentials: &sts.Credentials{ 78 Expiration: aws.Time(time.Now()), 79 AccessKeyId: aws.String("access-key-id"), 80 SecretAccessKey: aws.String("secret-access-key"), 81 SessionToken: aws.String("session-token"), 82 }, 83 } 84 }, 85 expectedCredValue: credentials.Value{ 86 AccessKeyID: "access-key-id", 87 SecretAccessKey: "secret-access-key", 88 SessionToken: "session-token", 89 ProviderName: stscreds.WebIdentityProviderName, 90 }, 91 }, 92 "invalid token retry": { 93 roleARN: "arn01234567890123456789", 94 tokenFilepath: "testdata/token.jwt", 95 sessionName: "foo", 96 onSendReq: func(t *testing.T, r *request.Request) { 97 input := r.Params.(*sts.AssumeRoleWithWebIdentityInput) 98 if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) { 99 t.Errorf("expected %v, but received %v", e, a) 100 } 101 102 if reqCount == 0 { 103 r.HTTPResponse.StatusCode = 400 104 r.Error = awserr.New(sts.ErrCodeInvalidIdentityTokenException, 105 "some error message", nil) 106 return 107 } 108 109 data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput) 110 *data = sts.AssumeRoleWithWebIdentityOutput{ 111 Credentials: &sts.Credentials{ 112 Expiration: aws.Time(time.Now()), 113 AccessKeyId: aws.String("access-key-id"), 114 SecretAccessKey: aws.String("secret-access-key"), 115 SessionToken: aws.String("session-token"), 116 }, 117 } 118 }, 119 expectedCredValue: credentials.Value{ 120 AccessKeyID: "access-key-id", 121 SecretAccessKey: "secret-access-key", 122 SessionToken: "session-token", 123 ProviderName: stscreds.WebIdentityProviderName, 124 }, 125 }, 126 } 127 128 for name, c := range cases { 129 t.Run(name, func(t *testing.T) { 130 reqCount = 0 131 132 svc := sts.New(unit.Session, &aws.Config{ 133 Logger: t, 134 }) 135 svc.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{ 136 Name: "custom send stub handler", 137 Fn: func(r *request.Request) { 138 r.HTTPResponse = &http.Response{ 139 StatusCode: 200, Header: http.Header{}, 140 } 141 c.onSendReq(t, r) 142 reqCount++ 143 }, 144 }) 145 svc.Handlers.UnmarshalMeta.Clear() 146 svc.Handlers.Unmarshal.Clear() 147 svc.Handlers.UnmarshalError.Clear() 148 149 p := stscreds.NewWebIdentityRoleProvider(svc, c.roleARN, c.sessionName, c.tokenFilepath) 150 p.Duration = c.duration 151 152 credValue, err := p.Retrieve() 153 if len(c.expectedError) != 0 { 154 if err == nil { 155 t.Fatalf("expect error, got none") 156 } 157 if e, a := c.expectedError, err.Error(); !strings.Contains(a, e) { 158 t.Fatalf("expect error to contain %v, got %v", e, a) 159 } 160 return 161 } 162 if err != nil { 163 t.Fatalf("expect no error, got %v", err) 164 } 165 166 if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) { 167 t.Errorf("expected %v, but received %v", e, a) 168 } 169 }) 170 } 171 }