github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/ssocreds/provider_test.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package ssocreds
     5  
     6  import (
     7  	"fmt"
     8  	"reflect"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/aavshr/aws-sdk-go/aws"
    13  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    14  	"github.com/aavshr/aws-sdk-go/aws/request"
    15  	"github.com/aavshr/aws-sdk-go/service/sso"
    16  	"github.com/aavshr/aws-sdk-go/service/sso/ssoiface"
    17  )
    18  
    19  type mockClient struct {
    20  	ssoiface.SSOAPI
    21  
    22  	t *testing.T
    23  
    24  	Output *sso.GetRoleCredentialsOutput
    25  	Err    error
    26  
    27  	ExpectedAccountID    string
    28  	ExpectedAccessToken  string
    29  	ExpectedRoleName     string
    30  	ExpectedClientRegion string
    31  
    32  	Response func(mockClient) (*sso.GetRoleCredentialsOutput, error)
    33  }
    34  
    35  func (m mockClient) GetRoleCredentialsWithContext(ctx aws.Context, params *sso.GetRoleCredentialsInput, _ ...request.Option) (*sso.GetRoleCredentialsOutput, error) {
    36  	m.t.Helper()
    37  
    38  	if len(m.ExpectedAccountID) > 0 {
    39  		if e, a := m.ExpectedAccountID, aws.StringValue(params.AccountId); e != a {
    40  			m.t.Errorf("expect %v, got %v", e, a)
    41  		}
    42  	}
    43  
    44  	if len(m.ExpectedAccessToken) > 0 {
    45  		if e, a := m.ExpectedAccessToken, aws.StringValue(params.AccessToken); e != a {
    46  			m.t.Errorf("expect %v, got %v", e, a)
    47  		}
    48  	}
    49  
    50  	if len(m.ExpectedRoleName) > 0 {
    51  		if e, a := m.ExpectedRoleName, aws.StringValue(params.RoleName); e != a {
    52  			m.t.Errorf("expect %v, got %v", e, a)
    53  		}
    54  	}
    55  
    56  	if m.Response == nil {
    57  		return &sso.GetRoleCredentialsOutput{}, nil
    58  	}
    59  
    60  	return m.Response(m)
    61  }
    62  
    63  func swapCacheLocation(dir string) func() {
    64  	original := defaultCacheLocation
    65  	defaultCacheLocation = func() string {
    66  		return dir
    67  	}
    68  	return func() {
    69  		defaultCacheLocation = original
    70  	}
    71  }
    72  
    73  func swapNowTime(referenceTime time.Time) func() {
    74  	original := nowTime
    75  	nowTime = func() time.Time {
    76  		return referenceTime
    77  	}
    78  	return func() {
    79  		nowTime = original
    80  	}
    81  }
    82  
    83  func TestProvider(t *testing.T) {
    84  	restoreCache := swapCacheLocation("testdata")
    85  	defer restoreCache()
    86  
    87  	restoreTime := swapNowTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC))
    88  	defer restoreTime()
    89  
    90  	cases := map[string]struct {
    91  		Client    mockClient
    92  		AccountID string
    93  		Region    string
    94  		RoleName  string
    95  		StartURL  string
    96  
    97  		ExpectedErr         bool
    98  		ExpectedCredentials credentials.Value
    99  		ExpectedExpire      time.Time
   100  	}{
   101  		"missing required parameter values": {
   102  			StartURL:    "https://invalid-required",
   103  			ExpectedErr: true,
   104  		},
   105  		"valid required parameter values": {
   106  			Client: mockClient{
   107  				ExpectedAccountID:    "012345678901",
   108  				ExpectedRoleName:     "TestRole",
   109  				ExpectedClientRegion: "us-west-2",
   110  				ExpectedAccessToken:  "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
   111  				Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
   112  					return &sso.GetRoleCredentialsOutput{
   113  						RoleCredentials: &sso.RoleCredentials{
   114  							AccessKeyId:     aws.String("AccessKey"),
   115  							SecretAccessKey: aws.String("SecretKey"),
   116  							SessionToken:    aws.String("SessionToken"),
   117  							Expiration:      aws.Int64(1611177743123),
   118  						},
   119  					}, nil
   120  				},
   121  			},
   122  			AccountID: "012345678901",
   123  			Region:    "us-west-2",
   124  			RoleName:  "TestRole",
   125  			StartURL:  "https://valid-required-only",
   126  			ExpectedCredentials: credentials.Value{
   127  				AccessKeyID:     "AccessKey",
   128  				SecretAccessKey: "SecretKey",
   129  				SessionToken:    "SessionToken",
   130  				ProviderName:    ProviderName,
   131  			},
   132  			ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
   133  		},
   134  		"expired access token": {
   135  			StartURL:    "https://expired",
   136  			ExpectedErr: true,
   137  		},
   138  		"api error": {
   139  			Client: mockClient{
   140  				ExpectedAccountID:    "012345678901",
   141  				ExpectedRoleName:     "TestRole",
   142  				ExpectedClientRegion: "us-west-2",
   143  				ExpectedAccessToken:  "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
   144  				Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
   145  					return nil, fmt.Errorf("api error")
   146  				},
   147  			},
   148  			AccountID:   "012345678901",
   149  			Region:      "us-west-2",
   150  			RoleName:    "TestRole",
   151  			StartURL:    "https://valid-required-only",
   152  			ExpectedErr: true,
   153  		},
   154  	}
   155  
   156  	for name, tt := range cases {
   157  		t.Run(name, func(t *testing.T) {
   158  			tt.Client.t = t
   159  
   160  			provider := &Provider{
   161  				Client:    tt.Client,
   162  				AccountID: tt.AccountID,
   163  				RoleName:  tt.RoleName,
   164  				StartURL:  tt.StartURL,
   165  			}
   166  
   167  			provider.Expiry.CurrentTime = nowTime
   168  
   169  			credentials, err := provider.Retrieve()
   170  			if (err != nil) != tt.ExpectedErr {
   171  				t.Errorf("expect error: %v", tt.ExpectedErr)
   172  			}
   173  
   174  			if e, a := tt.ExpectedCredentials, credentials; !reflect.DeepEqual(e, a) {
   175  				t.Errorf("expect %v, got %v", e, a)
   176  			}
   177  
   178  			if !tt.ExpectedExpire.IsZero() {
   179  				if e, a := tt.ExpectedExpire, provider.ExpiresAt(); !e.Equal(a) {
   180  					t.Errorf("expect %v, got %v", e, a)
   181  				}
   182  			}
   183  		})
   184  	}
   185  }