github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/stscreds/web_identity_provider_test.go (about)

     1  //go:build go1.7
     2  // +build go1.7
     3  
     4  package stscreds_test
     5  
     6  import (
     7  	"net/http"
     8  	"reflect"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/aavshr/aws-sdk-go/aws"
    14  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    15  	"github.com/aavshr/aws-sdk-go/aws/corehandlers"
    16  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    17  	"github.com/aavshr/aws-sdk-go/aws/credentials/stscreds"
    18  	"github.com/aavshr/aws-sdk-go/aws/request"
    19  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    20  	"github.com/aavshr/aws-sdk-go/service/sts"
    21  )
    22  
    23  func TestWebIdentityProviderRetrieve(t *testing.T) {
    24  	var reqCount int
    25  	cases := map[string]struct {
    26  		onSendReq         func(*testing.T, *request.Request)
    27  		roleARN           string
    28  		tokenFilepath     string
    29  		sessionName       string
    30  		duration          time.Duration
    31  		expectedError     string
    32  		expectedCredValue credentials.Value
    33  	}{
    34  		"session name case": {
    35  			roleARN:       "arn01234567890123456789",
    36  			tokenFilepath: "testdata/token.jwt",
    37  			sessionName:   "foo",
    38  			onSendReq: func(t *testing.T, r *request.Request) {
    39  				input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
    40  				if e, a := "foo", *input.RoleSessionName; e != a {
    41  					t.Errorf("expected %v, but received %v", e, a)
    42  				}
    43  				if input.DurationSeconds != nil {
    44  					t.Errorf("expect no duration, got %v", *input.DurationSeconds)
    45  				}
    46  
    47  				data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
    48  				*data = sts.AssumeRoleWithWebIdentityOutput{
    49  					Credentials: &sts.Credentials{
    50  						Expiration:      aws.Time(time.Now()),
    51  						AccessKeyId:     aws.String("access-key-id"),
    52  						SecretAccessKey: aws.String("secret-access-key"),
    53  						SessionToken:    aws.String("session-token"),
    54  					},
    55  				}
    56  			},
    57  			expectedCredValue: credentials.Value{
    58  				AccessKeyID:     "access-key-id",
    59  				SecretAccessKey: "secret-access-key",
    60  				SessionToken:    "session-token",
    61  				ProviderName:    stscreds.WebIdentityProviderName,
    62  			},
    63  		},
    64  		"with duration": {
    65  			roleARN:       "arn01234567890123456789",
    66  			tokenFilepath: "testdata/token.jwt",
    67  			sessionName:   "foo",
    68  			duration:      15 * time.Minute,
    69  			onSendReq: func(t *testing.T, r *request.Request) {
    70  				input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
    71  				if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a {
    72  					t.Errorf("expect %v duration, got %v", e, a)
    73  				}
    74  
    75  				data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
    76  				*data = sts.AssumeRoleWithWebIdentityOutput{
    77  					Credentials: &sts.Credentials{
    78  						Expiration:      aws.Time(time.Now()),
    79  						AccessKeyId:     aws.String("access-key-id"),
    80  						SecretAccessKey: aws.String("secret-access-key"),
    81  						SessionToken:    aws.String("session-token"),
    82  					},
    83  				}
    84  			},
    85  			expectedCredValue: credentials.Value{
    86  				AccessKeyID:     "access-key-id",
    87  				SecretAccessKey: "secret-access-key",
    88  				SessionToken:    "session-token",
    89  				ProviderName:    stscreds.WebIdentityProviderName,
    90  			},
    91  		},
    92  		"invalid token retry": {
    93  			roleARN:       "arn01234567890123456789",
    94  			tokenFilepath: "testdata/token.jwt",
    95  			sessionName:   "foo",
    96  			onSendReq: func(t *testing.T, r *request.Request) {
    97  				input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
    98  				if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
    99  					t.Errorf("expected %v, but received %v", e, a)
   100  				}
   101  
   102  				if reqCount == 0 {
   103  					r.HTTPResponse.StatusCode = 400
   104  					r.Error = awserr.New(sts.ErrCodeInvalidIdentityTokenException,
   105  						"some error message", nil)
   106  					return
   107  				}
   108  
   109  				data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
   110  				*data = sts.AssumeRoleWithWebIdentityOutput{
   111  					Credentials: &sts.Credentials{
   112  						Expiration:      aws.Time(time.Now()),
   113  						AccessKeyId:     aws.String("access-key-id"),
   114  						SecretAccessKey: aws.String("secret-access-key"),
   115  						SessionToken:    aws.String("session-token"),
   116  					},
   117  				}
   118  			},
   119  			expectedCredValue: credentials.Value{
   120  				AccessKeyID:     "access-key-id",
   121  				SecretAccessKey: "secret-access-key",
   122  				SessionToken:    "session-token",
   123  				ProviderName:    stscreds.WebIdentityProviderName,
   124  			},
   125  		},
   126  	}
   127  
   128  	for name, c := range cases {
   129  		t.Run(name, func(t *testing.T) {
   130  			reqCount = 0
   131  
   132  			svc := sts.New(unit.Session, &aws.Config{
   133  				Logger: t,
   134  			})
   135  			svc.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{
   136  				Name: "custom send stub handler",
   137  				Fn: func(r *request.Request) {
   138  					r.HTTPResponse = &http.Response{
   139  						StatusCode: 200, Header: http.Header{},
   140  					}
   141  					c.onSendReq(t, r)
   142  					reqCount++
   143  				},
   144  			})
   145  			svc.Handlers.UnmarshalMeta.Clear()
   146  			svc.Handlers.Unmarshal.Clear()
   147  			svc.Handlers.UnmarshalError.Clear()
   148  
   149  			p := stscreds.NewWebIdentityRoleProvider(svc, c.roleARN, c.sessionName, c.tokenFilepath)
   150  			p.Duration = c.duration
   151  
   152  			credValue, err := p.Retrieve()
   153  			if len(c.expectedError) != 0 {
   154  				if err == nil {
   155  					t.Fatalf("expect error, got none")
   156  				}
   157  				if e, a := c.expectedError, err.Error(); !strings.Contains(a, e) {
   158  					t.Fatalf("expect error to contain %v, got %v", e, a)
   159  				}
   160  				return
   161  			}
   162  			if err != nil {
   163  				t.Fatalf("expect no error, got %v", err)
   164  			}
   165  
   166  			if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
   167  				t.Errorf("expected %v, but received %v", e, a)
   168  			}
   169  		})
   170  	}
   171  }