sigs.k8s.io/cluster-api-provider-aws@v1.5.5/pkg/cloud/identity/identity_test.go (about)

     1  /*
     2  Copyright 2021 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8  	http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package identity
    18  
    19  import (
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/aws/aws-sdk-go/aws"
    24  	"github.com/aws/aws-sdk-go/aws/credentials"
    25  	"github.com/aws/aws-sdk-go/service/sts"
    26  	"github.com/golang/mock/gomock"
    27  	"github.com/google/go-cmp/cmp"
    28  	. "github.com/onsi/gomega"
    29  	"github.com/pkg/errors"
    30  	corev1 "k8s.io/api/core/v1"
    31  	"k8s.io/utils/pointer"
    32  
    33  	infrav1 "sigs.k8s.io/cluster-api-provider-aws/api/v1beta1"
    34  	"sigs.k8s.io/cluster-api-provider-aws/pkg/cloud/services/sts/mock_stsiface"
    35  )
    36  
    37  func TestAWSStaticPrincipalTypeProvider(t *testing.T) {
    38  	mockCtrl := gomock.NewController(t)
    39  	defer mockCtrl.Finish()
    40  
    41  	secret := &corev1.Secret{
    42  		Data: map[string][]byte{
    43  			"AccessKeyID":     []byte("static-AccessKeyID"),
    44  			"SecretAccessKey": []byte("static-SecretAccessKey"),
    45  		},
    46  	}
    47  
    48  	var staticProvider AWSPrincipalTypeProvider = NewAWSStaticPrincipalTypeProvider(&infrav1.AWSClusterStaticIdentity{}, secret)
    49  
    50  	stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
    51  	roleIdentity := &infrav1.AWSClusterRoleIdentity{
    52  		Spec: infrav1.AWSClusterRoleIdentitySpec{
    53  			AWSRoleSpec: infrav1.AWSRoleSpec{
    54  				RoleArn:         "arn:*:iam::*:role/aws-role/firstroleprovider",
    55  				SessionName:     "first-role-provider-session",
    56  				DurationSeconds: 900,
    57  			},
    58  		},
    59  	}
    60  
    61  	var roleProvider AWSPrincipalTypeProvider = &AWSRolePrincipalTypeProvider{
    62  		credentials:    nil,
    63  		Principal:      roleIdentity,
    64  		sourceProvider: &staticProvider,
    65  		stsClient:      stsMock,
    66  	}
    67  
    68  	roleIdentity2 := &infrav1.AWSClusterRoleIdentity{
    69  		Spec: infrav1.AWSClusterRoleIdentitySpec{
    70  			AWSRoleSpec: infrav1.AWSRoleSpec{
    71  				RoleArn:         "arn:*:iam::*:role/aws-role/secondroleprovider",
    72  				SessionName:     "second-role-provider-session",
    73  				DurationSeconds: 900,
    74  			},
    75  		},
    76  	}
    77  
    78  	var roleProvider2 AWSPrincipalTypeProvider = &AWSRolePrincipalTypeProvider{
    79  		credentials:    nil,
    80  		Principal:      roleIdentity2,
    81  		sourceProvider: &roleProvider,
    82  		stsClient:      stsMock,
    83  	}
    84  
    85  	testCases := []struct {
    86  		name      string
    87  		provider  AWSPrincipalTypeProvider
    88  		expect    func(m *mock_stsiface.MockSTSAPIMockRecorder)
    89  		expectErr bool
    90  		value     credentials.Value
    91  	}{
    92  		{
    93  			name:      "Static provider successfully retrieves",
    94  			provider:  staticProvider,
    95  			expect:    func(m *mock_stsiface.MockSTSAPIMockRecorder) {},
    96  			expectErr: false,
    97  			value: credentials.Value{
    98  				AccessKeyID:     "static-AccessKeyID",
    99  				SecretAccessKey: "static-SecretAccessKey",
   100  				ProviderName:    "StaticProvider",
   101  			},
   102  		},
   103  		{
   104  			name:     "Role provider with static provider source successfully retrieves",
   105  			provider: roleProvider,
   106  			expect: func(m *mock_stsiface.MockSTSAPIMockRecorder) {
   107  				m.AssumeRoleWithContext(gomock.Any(), &sts.AssumeRoleInput{
   108  					RoleArn:         aws.String(roleIdentity.Spec.RoleArn),
   109  					RoleSessionName: aws.String(roleIdentity.Spec.SessionName),
   110  					DurationSeconds: pointer.Int64Ptr(int64(roleIdentity.Spec.DurationSeconds)),
   111  				}).Return(&sts.AssumeRoleOutput{
   112  					Credentials: &sts.Credentials{
   113  						AccessKeyId:     aws.String("assumedAccessKeyId"),
   114  						SecretAccessKey: aws.String("assumedSecretAccessKey"),
   115  						SessionToken:    aws.String("assumedSessionToken"),
   116  						Expiration:      aws.Time(time.Now()),
   117  					},
   118  				}, nil)
   119  			},
   120  			expectErr: false,
   121  			value: credentials.Value{
   122  				AccessKeyID:     "assumedAccessKeyId",
   123  				SecretAccessKey: "assumedSecretAccessKey",
   124  				SessionToken:    "assumedSessionToken",
   125  				ProviderName:    "AssumeRoleProvider",
   126  			},
   127  		},
   128  		{
   129  			name:     "Role provider with role provider source successfully retrieves",
   130  			provider: roleProvider2,
   131  			expect: func(m *mock_stsiface.MockSTSAPIMockRecorder) {
   132  				m.AssumeRoleWithContext(gomock.Any(), &sts.AssumeRoleInput{
   133  					RoleArn:         aws.String(roleIdentity.Spec.RoleArn),
   134  					RoleSessionName: aws.String(roleIdentity.Spec.SessionName),
   135  					DurationSeconds: pointer.Int64Ptr(int64(roleIdentity.Spec.DurationSeconds)),
   136  				}).Return(&sts.AssumeRoleOutput{
   137  					Credentials: &sts.Credentials{
   138  						AccessKeyId:     aws.String("assumedAccessKeyId"),
   139  						SecretAccessKey: aws.String("assumedSecretAccessKey"),
   140  						SessionToken:    aws.String("assumedSessionToken"),
   141  						Expiration:      aws.Time(time.Now().AddDate(+1, 0, 0)),
   142  					},
   143  				}, nil)
   144  
   145  				m.AssumeRoleWithContext(gomock.Any(), &sts.AssumeRoleInput{
   146  					RoleArn:         aws.String(roleIdentity2.Spec.RoleArn),
   147  					RoleSessionName: aws.String(roleIdentity2.Spec.SessionName),
   148  					DurationSeconds: pointer.Int64Ptr(int64(roleIdentity2.Spec.DurationSeconds)),
   149  				}).Return(&sts.AssumeRoleOutput{
   150  					Credentials: &sts.Credentials{
   151  						AccessKeyId:     aws.String("assumedAccessKeyId2"),
   152  						SecretAccessKey: aws.String("assumedSecretAccessKey2"),
   153  						SessionToken:    aws.String("assumedSessionToken2"),
   154  						Expiration:      aws.Time(time.Now()),
   155  					},
   156  				}, nil)
   157  			},
   158  			expectErr: false,
   159  			value: credentials.Value{
   160  				AccessKeyID:     "assumedAccessKeyId2",
   161  				SecretAccessKey: "assumedSecretAccessKey2",
   162  				SessionToken:    "assumedSessionToken2",
   163  				ProviderName:    "AssumeRoleProvider",
   164  			},
   165  		},
   166  		{
   167  			name:     "Role provider with role provider source fails to retrieve when the source's source cannot assume source",
   168  			provider: roleProvider2,
   169  			expect: func(m *mock_stsiface.MockSTSAPIMockRecorder) {
   170  				roleProvider.(*AWSRolePrincipalTypeProvider).credentials.Expire()
   171  				roleProvider2.(*AWSRolePrincipalTypeProvider).credentials.Expire()
   172  				// AssumeRoleWithContext() call is not needed for roleIdentity as it has unexpired credentials
   173  				m.AssumeRoleWithContext(gomock.Any(), &sts.AssumeRoleInput{
   174  					RoleArn:         aws.String(roleIdentity.Spec.RoleArn),
   175  					RoleSessionName: aws.String(roleIdentity.Spec.SessionName),
   176  					DurationSeconds: pointer.Int64Ptr(int64(roleIdentity.Spec.DurationSeconds)),
   177  				}).Return(&sts.AssumeRoleOutput{}, errors.New("Not authorized to assume role"))
   178  			},
   179  			expectErr: true,
   180  		},
   181  	}
   182  
   183  	for _, tc := range testCases {
   184  		t.Run(tc.name, func(t *testing.T) {
   185  			g := NewWithT(t)
   186  
   187  			tc.expect(stsMock.EXPECT())
   188  			value, err := tc.provider.Retrieve()
   189  			if tc.expectErr {
   190  				g.Expect(err).ToNot(BeNil())
   191  				return
   192  			}
   193  
   194  			g.Expect(err).To(BeNil())
   195  
   196  			if !cmp.Equal(tc.value, value) {
   197  				t.Fatal("Did not get expected result")
   198  			}
   199  		})
   200  	}
   201  }