github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/stscreds/assume_role_provider_test.go (about)

     1  package stscreds
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/aavshr/aws-sdk-go/aws"
     9  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    10  	"github.com/aavshr/aws-sdk-go/aws/request"
    11  	"github.com/aavshr/aws-sdk-go/service/sts"
    12  )
    13  
    14  type stubSTS struct {
    15  	TestInput func(*sts.AssumeRoleInput)
    16  }
    17  
    18  func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
    19  	if s.TestInput != nil {
    20  		s.TestInput(input)
    21  	}
    22  	expiry := time.Now().Add(60 * time.Minute)
    23  	return &sts.AssumeRoleOutput{
    24  		Credentials: &sts.Credentials{
    25  			// Just reflect the role arn to the provider.
    26  			AccessKeyId:     input.RoleArn,
    27  			SecretAccessKey: aws.String("assumedSecretAccessKey"),
    28  			SessionToken:    aws.String("assumedSessionToken"),
    29  			Expiration:      &expiry,
    30  		},
    31  	}, nil
    32  }
    33  
    34  type stubSTSWithContext struct {
    35  	stubSTS
    36  	called chan struct{}
    37  }
    38  
    39  func (s *stubSTSWithContext) AssumeRoleWithContext(context credentials.Context, input *sts.AssumeRoleInput, option ...request.Option) (*sts.AssumeRoleOutput, error) {
    40  	<-s.called
    41  	return s.stubSTS.AssumeRole(input)
    42  }
    43  
    44  func TestAssumeRoleProvider(t *testing.T) {
    45  	stub := &stubSTS{}
    46  	p := &AssumeRoleProvider{
    47  		Client:  stub,
    48  		RoleARN: "roleARN",
    49  	}
    50  
    51  	creds, err := p.Retrieve()
    52  	if err != nil {
    53  		t.Errorf("expect nil, got %v", err)
    54  	}
    55  
    56  	if e, a := "roleARN", creds.AccessKeyID; e != a {
    57  		t.Errorf("expect %v, got %v", e, a)
    58  	}
    59  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
    60  		t.Errorf("expect %v, got %v", e, a)
    61  	}
    62  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
    63  		t.Errorf("expect %v, got %v", e, a)
    64  	}
    65  }
    66  
    67  func TestAssumeRoleProvider_WithTokenCode(t *testing.T) {
    68  	stub := &stubSTS{
    69  		TestInput: func(in *sts.AssumeRoleInput) {
    70  			if e, a := "0123456789", *in.SerialNumber; e != a {
    71  				t.Errorf("expect %v, got %v", e, a)
    72  			}
    73  			if e, a := "code", *in.TokenCode; e != a {
    74  				t.Errorf("expect %v, got %v", e, a)
    75  			}
    76  		},
    77  	}
    78  	p := &AssumeRoleProvider{
    79  		Client:       stub,
    80  		RoleARN:      "roleARN",
    81  		SerialNumber: aws.String("0123456789"),
    82  		TokenCode:    aws.String("code"),
    83  	}
    84  
    85  	creds, err := p.Retrieve()
    86  	if err != nil {
    87  		t.Errorf("expect nil, got %v", err)
    88  	}
    89  
    90  	if e, a := "roleARN", creds.AccessKeyID; e != a {
    91  		t.Errorf("expect %v, got %v", e, a)
    92  	}
    93  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
    94  		t.Errorf("expect %v, got %v", e, a)
    95  	}
    96  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
    97  		t.Errorf("expect %v, got %v", e, a)
    98  	}
    99  }
   100  
   101  func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
   102  	stub := &stubSTS{
   103  		TestInput: func(in *sts.AssumeRoleInput) {
   104  			if e, a := "0123456789", *in.SerialNumber; e != a {
   105  				t.Errorf("expect %v, got %v", e, a)
   106  			}
   107  			if e, a := "code", *in.TokenCode; e != a {
   108  				t.Errorf("expect %v, got %v", e, a)
   109  			}
   110  		},
   111  	}
   112  	p := &AssumeRoleProvider{
   113  		Client:       stub,
   114  		RoleARN:      "roleARN",
   115  		SerialNumber: aws.String("0123456789"),
   116  		TokenProvider: func() (string, error) {
   117  			return "code", nil
   118  		},
   119  	}
   120  
   121  	creds, err := p.Retrieve()
   122  	if err != nil {
   123  		t.Errorf("expect nil, got %v", err)
   124  	}
   125  
   126  	if e, a := "roleARN", creds.AccessKeyID; e != a {
   127  		t.Errorf("expect %v, got %v", e, a)
   128  	}
   129  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
   130  		t.Errorf("expect %v, got %v", e, a)
   131  	}
   132  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
   133  		t.Errorf("expect %v, got %v", e, a)
   134  	}
   135  }
   136  
   137  func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
   138  	stub := &stubSTS{
   139  		TestInput: func(in *sts.AssumeRoleInput) {
   140  			t.Errorf("API request should not of been called")
   141  		},
   142  	}
   143  	p := &AssumeRoleProvider{
   144  		Client:       stub,
   145  		RoleARN:      "roleARN",
   146  		SerialNumber: aws.String("0123456789"),
   147  		TokenProvider: func() (string, error) {
   148  			return "", fmt.Errorf("error occurred")
   149  		},
   150  	}
   151  
   152  	creds, err := p.Retrieve()
   153  	if err == nil {
   154  		t.Errorf("expect error")
   155  	}
   156  
   157  	if v := creds.AccessKeyID; len(v) != 0 {
   158  		t.Errorf("expect empty, got %v", v)
   159  	}
   160  	if v := creds.SecretAccessKey; len(v) != 0 {
   161  		t.Errorf("expect empty, got %v", v)
   162  	}
   163  	if v := creds.SessionToken; len(v) != 0 {
   164  		t.Errorf("expect empty, got %v", v)
   165  	}
   166  }
   167  
   168  func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
   169  	stub := &stubSTS{
   170  		TestInput: func(in *sts.AssumeRoleInput) {
   171  			t.Errorf("API request should not of been called")
   172  		},
   173  	}
   174  	p := &AssumeRoleProvider{
   175  		Client:       stub,
   176  		RoleARN:      "roleARN",
   177  		SerialNumber: aws.String("0123456789"),
   178  	}
   179  
   180  	creds, err := p.Retrieve()
   181  	if err == nil {
   182  		t.Errorf("expect error")
   183  	}
   184  
   185  	if v := creds.AccessKeyID; len(v) != 0 {
   186  		t.Errorf("expect empty, got %v", v)
   187  	}
   188  	if v := creds.SecretAccessKey; len(v) != 0 {
   189  		t.Errorf("expect empty, got %v", v)
   190  	}
   191  	if v := creds.SessionToken; len(v) != 0 {
   192  		t.Errorf("expect empty, got %v", v)
   193  	}
   194  }
   195  
   196  func BenchmarkAssumeRoleProvider(b *testing.B) {
   197  	stub := &stubSTS{}
   198  	p := &AssumeRoleProvider{
   199  		Client:  stub,
   200  		RoleARN: "roleARN",
   201  	}
   202  
   203  	b.ResetTimer()
   204  	for i := 0; i < b.N; i++ {
   205  		if _, err := p.Retrieve(); err != nil {
   206  			b.Fatal(err)
   207  		}
   208  	}
   209  }
   210  
   211  func TestAssumeRoleProvider_WithTags(t *testing.T) {
   212  	stub := &stubSTS{
   213  		TestInput: func(in *sts.AssumeRoleInput) {
   214  			if *in.TransitiveTagKeys[0] != "TagName" {
   215  				t.Errorf("TransitiveTagKeys not passed along")
   216  			}
   217  			if *in.Tags[0].Key != "TagName" || *in.Tags[0].Value != "TagValue" {
   218  				t.Errorf("Tags not passed along")
   219  			}
   220  		},
   221  	}
   222  	p := &AssumeRoleProvider{
   223  		Client:  stub,
   224  		RoleARN: "roleARN",
   225  		Tags: []*sts.Tag{
   226  			{
   227  				Key:   aws.String("TagName"),
   228  				Value: aws.String("TagValue"),
   229  			},
   230  		},
   231  		TransitiveTagKeys: []*string{aws.String("TagName")},
   232  	}
   233  	_, err := p.Retrieve()
   234  	if err != nil {
   235  		t.Errorf("expect error")
   236  	}
   237  }
   238  
   239  func TestAssumeRoleProvider_RetrieveWithContext(t *testing.T) {
   240  	stub := &stubSTSWithContext{
   241  		called: make(chan struct{}),
   242  	}
   243  	p := &AssumeRoleProvider{
   244  		Client:  stub,
   245  		RoleARN: "roleARN",
   246  	}
   247  
   248  	go func() {
   249  		stub.called <- struct{}{}
   250  	}()
   251  
   252  	creds, err := p.RetrieveWithContext(aws.BackgroundContext())
   253  	if err != nil {
   254  		t.Errorf("expect nil, got %v", err)
   255  	}
   256  
   257  	if e, a := "roleARN", creds.AccessKeyID; e != a {
   258  		t.Errorf("expect %v, got %v", e, a)
   259  	}
   260  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
   261  		t.Errorf("expect %v, got %v", e, a)
   262  	}
   263  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
   264  		t.Errorf("expect %v, got %v", e, a)
   265  	}
   266  }