github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/stscreds/assume_role_provider_test.go (about) 1 package stscreds 2 3 import ( 4 "fmt" 5 "testing" 6 "time" 7 8 "github.com/aavshr/aws-sdk-go/aws" 9 "github.com/aavshr/aws-sdk-go/aws/credentials" 10 "github.com/aavshr/aws-sdk-go/aws/request" 11 "github.com/aavshr/aws-sdk-go/service/sts" 12 ) 13 14 type stubSTS struct { 15 TestInput func(*sts.AssumeRoleInput) 16 } 17 18 func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { 19 if s.TestInput != nil { 20 s.TestInput(input) 21 } 22 expiry := time.Now().Add(60 * time.Minute) 23 return &sts.AssumeRoleOutput{ 24 Credentials: &sts.Credentials{ 25 // Just reflect the role arn to the provider. 26 AccessKeyId: input.RoleArn, 27 SecretAccessKey: aws.String("assumedSecretAccessKey"), 28 SessionToken: aws.String("assumedSessionToken"), 29 Expiration: &expiry, 30 }, 31 }, nil 32 } 33 34 type stubSTSWithContext struct { 35 stubSTS 36 called chan struct{} 37 } 38 39 func (s *stubSTSWithContext) AssumeRoleWithContext(context credentials.Context, input *sts.AssumeRoleInput, option ...request.Option) (*sts.AssumeRoleOutput, error) { 40 <-s.called 41 return s.stubSTS.AssumeRole(input) 42 } 43 44 func TestAssumeRoleProvider(t *testing.T) { 45 stub := &stubSTS{} 46 p := &AssumeRoleProvider{ 47 Client: stub, 48 RoleARN: "roleARN", 49 } 50 51 creds, err := p.Retrieve() 52 if err != nil { 53 t.Errorf("expect nil, got %v", err) 54 } 55 56 if e, a := "roleARN", creds.AccessKeyID; e != a { 57 t.Errorf("expect %v, got %v", e, a) 58 } 59 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a { 60 t.Errorf("expect %v, got %v", e, a) 61 } 62 if e, a := "assumedSessionToken", creds.SessionToken; e != a { 63 t.Errorf("expect %v, got %v", e, a) 64 } 65 } 66 67 func TestAssumeRoleProvider_WithTokenCode(t *testing.T) { 68 stub := &stubSTS{ 69 TestInput: func(in *sts.AssumeRoleInput) { 70 if e, a := "0123456789", *in.SerialNumber; e != a { 71 t.Errorf("expect %v, got %v", e, a) 72 } 73 if e, a := "code", *in.TokenCode; e != a { 74 t.Errorf("expect %v, got %v", e, a) 75 } 76 }, 77 } 78 p := &AssumeRoleProvider{ 79 Client: stub, 80 RoleARN: "roleARN", 81 SerialNumber: aws.String("0123456789"), 82 TokenCode: aws.String("code"), 83 } 84 85 creds, err := p.Retrieve() 86 if err != nil { 87 t.Errorf("expect nil, got %v", err) 88 } 89 90 if e, a := "roleARN", creds.AccessKeyID; e != a { 91 t.Errorf("expect %v, got %v", e, a) 92 } 93 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a { 94 t.Errorf("expect %v, got %v", e, a) 95 } 96 if e, a := "assumedSessionToken", creds.SessionToken; e != a { 97 t.Errorf("expect %v, got %v", e, a) 98 } 99 } 100 101 func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) { 102 stub := &stubSTS{ 103 TestInput: func(in *sts.AssumeRoleInput) { 104 if e, a := "0123456789", *in.SerialNumber; e != a { 105 t.Errorf("expect %v, got %v", e, a) 106 } 107 if e, a := "code", *in.TokenCode; e != a { 108 t.Errorf("expect %v, got %v", e, a) 109 } 110 }, 111 } 112 p := &AssumeRoleProvider{ 113 Client: stub, 114 RoleARN: "roleARN", 115 SerialNumber: aws.String("0123456789"), 116 TokenProvider: func() (string, error) { 117 return "code", nil 118 }, 119 } 120 121 creds, err := p.Retrieve() 122 if err != nil { 123 t.Errorf("expect nil, got %v", err) 124 } 125 126 if e, a := "roleARN", creds.AccessKeyID; e != a { 127 t.Errorf("expect %v, got %v", e, a) 128 } 129 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a { 130 t.Errorf("expect %v, got %v", e, a) 131 } 132 if e, a := "assumedSessionToken", creds.SessionToken; e != a { 133 t.Errorf("expect %v, got %v", e, a) 134 } 135 } 136 137 func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) { 138 stub := &stubSTS{ 139 TestInput: func(in *sts.AssumeRoleInput) { 140 t.Errorf("API request should not of been called") 141 }, 142 } 143 p := &AssumeRoleProvider{ 144 Client: stub, 145 RoleARN: "roleARN", 146 SerialNumber: aws.String("0123456789"), 147 TokenProvider: func() (string, error) { 148 return "", fmt.Errorf("error occurred") 149 }, 150 } 151 152 creds, err := p.Retrieve() 153 if err == nil { 154 t.Errorf("expect error") 155 } 156 157 if v := creds.AccessKeyID; len(v) != 0 { 158 t.Errorf("expect empty, got %v", v) 159 } 160 if v := creds.SecretAccessKey; len(v) != 0 { 161 t.Errorf("expect empty, got %v", v) 162 } 163 if v := creds.SessionToken; len(v) != 0 { 164 t.Errorf("expect empty, got %v", v) 165 } 166 } 167 168 func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) { 169 stub := &stubSTS{ 170 TestInput: func(in *sts.AssumeRoleInput) { 171 t.Errorf("API request should not of been called") 172 }, 173 } 174 p := &AssumeRoleProvider{ 175 Client: stub, 176 RoleARN: "roleARN", 177 SerialNumber: aws.String("0123456789"), 178 } 179 180 creds, err := p.Retrieve() 181 if err == nil { 182 t.Errorf("expect error") 183 } 184 185 if v := creds.AccessKeyID; len(v) != 0 { 186 t.Errorf("expect empty, got %v", v) 187 } 188 if v := creds.SecretAccessKey; len(v) != 0 { 189 t.Errorf("expect empty, got %v", v) 190 } 191 if v := creds.SessionToken; len(v) != 0 { 192 t.Errorf("expect empty, got %v", v) 193 } 194 } 195 196 func BenchmarkAssumeRoleProvider(b *testing.B) { 197 stub := &stubSTS{} 198 p := &AssumeRoleProvider{ 199 Client: stub, 200 RoleARN: "roleARN", 201 } 202 203 b.ResetTimer() 204 for i := 0; i < b.N; i++ { 205 if _, err := p.Retrieve(); err != nil { 206 b.Fatal(err) 207 } 208 } 209 } 210 211 func TestAssumeRoleProvider_WithTags(t *testing.T) { 212 stub := &stubSTS{ 213 TestInput: func(in *sts.AssumeRoleInput) { 214 if *in.TransitiveTagKeys[0] != "TagName" { 215 t.Errorf("TransitiveTagKeys not passed along") 216 } 217 if *in.Tags[0].Key != "TagName" || *in.Tags[0].Value != "TagValue" { 218 t.Errorf("Tags not passed along") 219 } 220 }, 221 } 222 p := &AssumeRoleProvider{ 223 Client: stub, 224 RoleARN: "roleARN", 225 Tags: []*sts.Tag{ 226 { 227 Key: aws.String("TagName"), 228 Value: aws.String("TagValue"), 229 }, 230 }, 231 TransitiveTagKeys: []*string{aws.String("TagName")}, 232 } 233 _, err := p.Retrieve() 234 if err != nil { 235 t.Errorf("expect error") 236 } 237 } 238 239 func TestAssumeRoleProvider_RetrieveWithContext(t *testing.T) { 240 stub := &stubSTSWithContext{ 241 called: make(chan struct{}), 242 } 243 p := &AssumeRoleProvider{ 244 Client: stub, 245 RoleARN: "roleARN", 246 } 247 248 go func() { 249 stub.called <- struct{}{} 250 }() 251 252 creds, err := p.RetrieveWithContext(aws.BackgroundContext()) 253 if err != nil { 254 t.Errorf("expect nil, got %v", err) 255 } 256 257 if e, a := "roleARN", creds.AccessKeyID; e != a { 258 t.Errorf("expect %v, got %v", e, a) 259 } 260 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a { 261 t.Errorf("expect %v, got %v", e, a) 262 } 263 if e, a := "assumedSessionToken", creds.SessionToken; e != a { 264 t.Errorf("expect %v, got %v", e, a) 265 } 266 }