github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/ssocreds/provider_test.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package ssocreds 5 6 import ( 7 "fmt" 8 "reflect" 9 "testing" 10 "time" 11 12 "github.com/aavshr/aws-sdk-go/aws" 13 "github.com/aavshr/aws-sdk-go/aws/credentials" 14 "github.com/aavshr/aws-sdk-go/aws/request" 15 "github.com/aavshr/aws-sdk-go/service/sso" 16 "github.com/aavshr/aws-sdk-go/service/sso/ssoiface" 17 ) 18 19 type mockClient struct { 20 ssoiface.SSOAPI 21 22 t *testing.T 23 24 Output *sso.GetRoleCredentialsOutput 25 Err error 26 27 ExpectedAccountID string 28 ExpectedAccessToken string 29 ExpectedRoleName string 30 ExpectedClientRegion string 31 32 Response func(mockClient) (*sso.GetRoleCredentialsOutput, error) 33 } 34 35 func (m mockClient) GetRoleCredentialsWithContext(ctx aws.Context, params *sso.GetRoleCredentialsInput, _ ...request.Option) (*sso.GetRoleCredentialsOutput, error) { 36 m.t.Helper() 37 38 if len(m.ExpectedAccountID) > 0 { 39 if e, a := m.ExpectedAccountID, aws.StringValue(params.AccountId); e != a { 40 m.t.Errorf("expect %v, got %v", e, a) 41 } 42 } 43 44 if len(m.ExpectedAccessToken) > 0 { 45 if e, a := m.ExpectedAccessToken, aws.StringValue(params.AccessToken); e != a { 46 m.t.Errorf("expect %v, got %v", e, a) 47 } 48 } 49 50 if len(m.ExpectedRoleName) > 0 { 51 if e, a := m.ExpectedRoleName, aws.StringValue(params.RoleName); e != a { 52 m.t.Errorf("expect %v, got %v", e, a) 53 } 54 } 55 56 if m.Response == nil { 57 return &sso.GetRoleCredentialsOutput{}, nil 58 } 59 60 return m.Response(m) 61 } 62 63 func swapCacheLocation(dir string) func() { 64 original := defaultCacheLocation 65 defaultCacheLocation = func() string { 66 return dir 67 } 68 return func() { 69 defaultCacheLocation = original 70 } 71 } 72 73 func swapNowTime(referenceTime time.Time) func() { 74 original := nowTime 75 nowTime = func() time.Time { 76 return referenceTime 77 } 78 return func() { 79 nowTime = original 80 } 81 } 82 83 func TestProvider(t *testing.T) { 84 restoreCache := swapCacheLocation("testdata") 85 defer restoreCache() 86 87 restoreTime := swapNowTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC)) 88 defer restoreTime() 89 90 cases := map[string]struct { 91 Client mockClient 92 AccountID string 93 Region string 94 RoleName string 95 StartURL string 96 97 ExpectedErr bool 98 ExpectedCredentials credentials.Value 99 ExpectedExpire time.Time 100 }{ 101 "missing required parameter values": { 102 StartURL: "https://invalid-required", 103 ExpectedErr: true, 104 }, 105 "valid required parameter values": { 106 Client: mockClient{ 107 ExpectedAccountID: "012345678901", 108 ExpectedRoleName: "TestRole", 109 ExpectedClientRegion: "us-west-2", 110 ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", 111 Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) { 112 return &sso.GetRoleCredentialsOutput{ 113 RoleCredentials: &sso.RoleCredentials{ 114 AccessKeyId: aws.String("AccessKey"), 115 SecretAccessKey: aws.String("SecretKey"), 116 SessionToken: aws.String("SessionToken"), 117 Expiration: aws.Int64(1611177743123), 118 }, 119 }, nil 120 }, 121 }, 122 AccountID: "012345678901", 123 Region: "us-west-2", 124 RoleName: "TestRole", 125 StartURL: "https://valid-required-only", 126 ExpectedCredentials: credentials.Value{ 127 AccessKeyID: "AccessKey", 128 SecretAccessKey: "SecretKey", 129 SessionToken: "SessionToken", 130 ProviderName: ProviderName, 131 }, 132 ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC), 133 }, 134 "expired access token": { 135 StartURL: "https://expired", 136 ExpectedErr: true, 137 }, 138 "api error": { 139 Client: mockClient{ 140 ExpectedAccountID: "012345678901", 141 ExpectedRoleName: "TestRole", 142 ExpectedClientRegion: "us-west-2", 143 ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", 144 Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) { 145 return nil, fmt.Errorf("api error") 146 }, 147 }, 148 AccountID: "012345678901", 149 Region: "us-west-2", 150 RoleName: "TestRole", 151 StartURL: "https://valid-required-only", 152 ExpectedErr: true, 153 }, 154 } 155 156 for name, tt := range cases { 157 t.Run(name, func(t *testing.T) { 158 tt.Client.t = t 159 160 provider := &Provider{ 161 Client: tt.Client, 162 AccountID: tt.AccountID, 163 RoleName: tt.RoleName, 164 StartURL: tt.StartURL, 165 } 166 167 provider.Expiry.CurrentTime = nowTime 168 169 credentials, err := provider.Retrieve() 170 if (err != nil) != tt.ExpectedErr { 171 t.Errorf("expect error: %v", tt.ExpectedErr) 172 } 173 174 if e, a := tt.ExpectedCredentials, credentials; !reflect.DeepEqual(e, a) { 175 t.Errorf("expect %v, got %v", e, a) 176 } 177 178 if !tt.ExpectedExpire.IsZero() { 179 if e, a := tt.ExpectedExpire, provider.ExpiresAt(); !e.Equal(a) { 180 t.Errorf("expect %v, got %v", e, a) 181 } 182 } 183 }) 184 } 185 }