github.com/aliyun/credentials-go@v1.4.7/credentials/providers/ram_role_arn_test.go (about)

     1  package providers
     2  
     3  import (
     4  	"errors"
     5  	"os"
     6  	"strings"
     7  	"testing"
     8  	"time"
     9  
    10  	httputil "github.com/aliyun/credentials-go/credentials/internal/http"
    11  	"github.com/aliyun/credentials-go/credentials/internal/utils"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
    16  	rollback := utils.Memory("ALIBABA_CLOUD_STS_REGION", "ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED")
    17  	defer func() {
    18  		rollback()
    19  	}()
    20  	// case 1: no credentials provider
    21  	_, err := NewRAMRoleARNCredentialsProviderBuilder().
    22  		Build()
    23  	assert.EqualError(t, err, "must specify a previous credentials provider to assume role")
    24  
    25  	// case 2: no role arn
    26  	akProvider, err := NewStaticAKCredentialsProviderBuilder().
    27  		WithAccessKeyId("akid").
    28  		WithAccessKeySecret("aksecret").
    29  		Build()
    30  	assert.Nil(t, err)
    31  	_, err = NewRAMRoleARNCredentialsProviderBuilder().
    32  		WithCredentialsProvider(akProvider).
    33  		Build()
    34  	assert.EqualError(t, err, "the RoleArn is empty")
    35  
    36  	// case 3: check default role session name
    37  	p, err := NewRAMRoleARNCredentialsProviderBuilder().
    38  		WithCredentialsProvider(akProvider).
    39  		WithRoleArn("roleArn").
    40  		Build()
    41  	assert.Nil(t, err)
    42  	assert.True(t, strings.HasPrefix(p.roleSessionName, "credentials-go-"))
    43  
    44  	// case 4: check default duration seconds
    45  	p, err = NewRAMRoleARNCredentialsProviderBuilder().
    46  		WithCredentialsProvider(akProvider).
    47  		WithRoleArn("roleArn").Build()
    48  	assert.Nil(t, err)
    49  	assert.Equal(t, 3600, p.durationSeconds)
    50  
    51  	// case 5: check invalid duration seconds
    52  	_, err = NewRAMRoleARNCredentialsProviderBuilder().
    53  		WithCredentialsProvider(akProvider).
    54  		WithRoleArn("roleArn").
    55  		WithDurationSeconds(100).
    56  		Build()
    57  	assert.EqualError(t, err, "session duration should be in the range of 900s - max session duration")
    58  
    59  	// case 6: check all duration seconds
    60  	p, err = NewRAMRoleARNCredentialsProviderBuilder().
    61  		WithCredentialsProvider(akProvider).
    62  		WithRoleArn("roleArn").
    63  		WithStsRegionId("cn-hangzhou").
    64  		WithEnableVpc(true).
    65  		WithPolicy("policy").
    66  		WithExternalId("externalId").
    67  		WithRoleSessionName("rsn").
    68  		WithDurationSeconds(1000).
    69  		Build()
    70  	assert.Nil(t, err)
    71  	assert.Equal(t, "rsn", p.roleSessionName)
    72  	assert.Equal(t, "roleArn", p.roleArn)
    73  	assert.Equal(t, "policy", p.policy)
    74  	assert.Equal(t, "externalId", p.externalId)
    75  	assert.Equal(t, "cn-hangzhou", p.stsRegionId)
    76  	assert.Equal(t, 1000, p.durationSeconds)
    77  	// sts endpoint with sts region
    78  	assert.Equal(t, "sts-vpc.cn-hangzhou.aliyuncs.com", p.stsEndpoint)
    79  
    80  	// case 7: check default sts endpoint
    81  	os.Setenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED", "1")
    82  	p, err = NewRAMRoleARNCredentialsProviderBuilder().
    83  		WithCredentialsProvider(akProvider).
    84  		WithRoleArn("roleArn").
    85  		WithPolicy("policy").
    86  		WithExternalId("externalId").
    87  		WithRoleSessionName("rsn").
    88  		WithDurationSeconds(1000).
    89  		Build()
    90  	assert.Nil(t, err)
    91  	assert.Equal(t, "rsn", p.roleSessionName)
    92  	assert.Equal(t, "roleArn", p.roleArn)
    93  	assert.Equal(t, "policy", p.policy)
    94  	assert.Equal(t, "externalId", p.externalId)
    95  	assert.Equal(t, "", p.stsRegionId)
    96  	assert.Equal(t, 1000, p.durationSeconds)
    97  	assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint)
    98  
    99  	// case 8: check sts endpoint with env
   100  	os.Setenv("ALIBABA_CLOUD_STS_REGION", "cn-hangzhou")
   101  	os.Setenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED", "True")
   102  	p, err = NewRAMRoleARNCredentialsProviderBuilder().
   103  		WithCredentialsProvider(akProvider).
   104  		WithRoleArn("roleArn").
   105  		WithPolicy("policy").
   106  		WithExternalId("externalId").
   107  		WithRoleSessionName("rsn").
   108  		WithDurationSeconds(1000).
   109  		Build()
   110  	assert.Nil(t, err)
   111  	assert.Equal(t, "sts-vpc.cn-hangzhou.aliyuncs.com", p.stsEndpoint)
   112  
   113  	// case 9: check sts endpoint with sts endpoint
   114  	p, err = NewRAMRoleARNCredentialsProviderBuilder().
   115  		WithCredentialsProvider(akProvider).
   116  		WithRoleArn("roleArn").
   117  		WithStsEndpoint("sts.cn-shanghai.aliyuncs.com").
   118  		WithPolicy("policy").
   119  		WithExternalId("externalId").
   120  		WithRoleSessionName("rsn").
   121  		WithDurationSeconds(1000).
   122  		Build()
   123  	assert.Nil(t, err)
   124  	assert.Equal(t, "rsn", p.roleSessionName)
   125  	assert.Equal(t, "roleArn", p.roleArn)
   126  	assert.Equal(t, "policy", p.policy)
   127  	assert.Equal(t, "externalId", p.externalId)
   128  	assert.Equal(t, "", p.stsRegionId)
   129  	assert.Equal(t, 1000, p.durationSeconds)
   130  	assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint)
   131  
   132  	// case 10: check ak&sk
   133  	p, err = NewRAMRoleARNCredentialsProviderBuilder().
   134  		WithAccessKeyId("ak").
   135  		WithAccessKeySecret("sk").
   136  		WithRoleArn("roleArn").
   137  		WithStsEndpoint("sts.cn-shanghai.aliyuncs.com").
   138  		WithPolicy("policy").
   139  		WithExternalId("externalId").
   140  		WithRoleSessionName("rsn").
   141  		WithDurationSeconds(1000).
   142  		Build()
   143  	assert.Nil(t, err)
   144  	cre, err := p.credentialsProvider.GetCredentials()
   145  	assert.Nil(t, err)
   146  	assert.Equal(t, "ak", cre.AccessKeyId)
   147  	assert.Equal(t, "sk", cre.AccessKeySecret)
   148  	assert.Equal(t, "static_ak", cre.ProviderName)
   149  
   150  	// case 11: check ak&sk&token
   151  	p, err = NewRAMRoleARNCredentialsProviderBuilder().
   152  		WithAccessKeyId("ak").
   153  		WithAccessKeySecret("sk").
   154  		WithSecurityToken("token").
   155  		WithRoleArn("roleArn").
   156  		WithStsEndpoint("sts.cn-shanghai.aliyuncs.com").
   157  		WithPolicy("policy").
   158  		WithExternalId("externalId").
   159  		WithRoleSessionName("rsn").
   160  		WithDurationSeconds(1000).
   161  		Build()
   162  	assert.Nil(t, err)
   163  	cre, err = p.credentialsProvider.GetCredentials()
   164  	assert.Nil(t, err)
   165  	assert.Equal(t, "ak", cre.AccessKeyId)
   166  	assert.Equal(t, "sk", cre.AccessKeySecret)
   167  	assert.Equal(t, "token", cre.SecurityToken)
   168  	assert.Equal(t, "static_sts", cre.ProviderName)
   169  }
   170  
   171  func TestRAMRoleARNCredentialsProvider_getCredentials(t *testing.T) {
   172  	originHttpDo := httpDo
   173  	defer func() { httpDo = originHttpDo }()
   174  
   175  	akProvider, err := NewStaticAKCredentialsProviderBuilder().
   176  		WithAccessKeyId("akid").
   177  		WithAccessKeySecret("aksecret").
   178  		Build()
   179  	assert.Nil(t, err)
   180  	p, err := NewRAMRoleARNCredentialsProviderBuilder().
   181  		WithCredentialsProvider(akProvider).
   182  		WithRoleArn("roleArn").
   183  		WithRoleSessionName("rsn").
   184  		WithDurationSeconds(1000).
   185  		Build()
   186  	assert.Nil(t, err)
   187  
   188  	cc, err := akProvider.GetCredentials()
   189  	assert.Nil(t, err)
   190  
   191  	// case 1: server error
   192  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   193  		err = errors.New("mock server error")
   194  		return
   195  	}
   196  	_, err = p.getCredentials(cc)
   197  	assert.NotNil(t, err)
   198  	assert.Equal(t, "mock server error", err.Error())
   199  
   200  	// case 2: 4xx error
   201  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   202  		res = &httputil.Response{
   203  			StatusCode: 400,
   204  			Body:       []byte("4xx error"),
   205  		}
   206  		return
   207  	}
   208  
   209  	_, err = p.getCredentials(cc)
   210  	assert.NotNil(t, err)
   211  	assert.Equal(t, "refresh session token failed: 4xx error", err.Error())
   212  
   213  	// case 3: invalid json
   214  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   215  		res = &httputil.Response{
   216  			StatusCode: 200,
   217  			Body:       []byte("invalid json"),
   218  		}
   219  		return
   220  	}
   221  	_, err = p.getCredentials(cc)
   222  	assert.NotNil(t, err)
   223  	assert.Equal(t, "refresh RoleArn sts token err, json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error())
   224  
   225  	// case 4: empty response json
   226  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   227  		res = &httputil.Response{
   228  			StatusCode: 200,
   229  			Body:       []byte("null"),
   230  		}
   231  		return
   232  	}
   233  	_, err = p.getCredentials(cc)
   234  	assert.NotNil(t, err)
   235  	assert.Equal(t, "refresh RoleArn sts token err, fail to get credentials", err.Error())
   236  
   237  	// case 5: empty session ak response json
   238  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   239  		res = &httputil.Response{
   240  			StatusCode: 200,
   241  			Body:       []byte(`{"Credentials": {}}`),
   242  		}
   243  		return
   244  	}
   245  	_, err = p.getCredentials(cc)
   246  	assert.NotNil(t, err)
   247  	assert.Equal(t, "refresh RoleArn sts token err, fail to get credentials", err.Error())
   248  
   249  	// case 6: mock ok value
   250  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   251  		res = &httputil.Response{
   252  			StatusCode: 200,
   253  			Body:       []byte(`{"Credentials": {"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token"}}`),
   254  		}
   255  		return
   256  	}
   257  	creds, err := p.getCredentials(cc)
   258  	assert.Nil(t, err)
   259  	assert.Equal(t, "saki", creds.AccessKeyId)
   260  	assert.Equal(t, "saks", creds.AccessKeySecret)
   261  	assert.Equal(t, "token", creds.SecurityToken)
   262  	assert.Equal(t, "2021-10-20T04:27:09Z", creds.Expiration)
   263  
   264  	// needUpdateCredential
   265  	assert.True(t, p.needUpdateCredential())
   266  	p.expirationTimestamp = time.Now().Unix()
   267  	assert.True(t, p.needUpdateCredential())
   268  
   269  	p.expirationTimestamp = time.Now().Unix() + 300
   270  	assert.False(t, p.needUpdateCredential())
   271  }
   272  
   273  func TestRAMRoleARNCredentialsProvider_getCredentialsWithRequestCheck(t *testing.T) {
   274  	originHttpDo := httpDo
   275  	defer func() { httpDo = originHttpDo }()
   276  
   277  	stsProvider, err := NewStaticSTSCredentialsProviderBuilder().
   278  		WithAccessKeyId("akid").
   279  		WithAccessKeySecret("aksecret").
   280  		WithSecurityToken("ststoken").
   281  		Build()
   282  	assert.Nil(t, err)
   283  	p, err := NewRAMRoleARNCredentialsProviderBuilder().
   284  		WithCredentialsProvider(stsProvider).
   285  		WithRoleArn("roleArn").
   286  		WithRoleSessionName("rsn").
   287  		WithDurationSeconds(1000).
   288  		WithPolicy("policy").
   289  		WithStsRegionId("cn-beijing").
   290  		WithExternalId("externalId").
   291  		Build()
   292  	assert.Nil(t, err)
   293  
   294  	// case 1: server error
   295  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   296  		assert.Equal(t, "sts.cn-beijing.aliyuncs.com", req.Host)
   297  		assert.Equal(t, "ststoken", req.Queries["SecurityToken"])
   298  		assert.Equal(t, "policy", req.Form["Policy"])
   299  		assert.Equal(t, "roleArn", req.Form["RoleArn"])
   300  		assert.Equal(t, "rsn", req.Form["RoleSessionName"])
   301  		assert.Equal(t, "1000", req.Form["DurationSeconds"])
   302  
   303  		err = errors.New("mock server error")
   304  		return
   305  	}
   306  
   307  	cc, err := stsProvider.GetCredentials()
   308  	assert.Nil(t, err)
   309  	_, err = p.getCredentials(cc)
   310  	assert.NotNil(t, err)
   311  	assert.Equal(t, "mock server error", err.Error())
   312  }
   313  
   314  type errorCredentialsProvider struct {
   315  }
   316  
   317  func (p *errorCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
   318  	err = errors.New("get credentials failed")
   319  	return
   320  }
   321  
   322  func (p *errorCredentialsProvider) GetProviderName() string {
   323  	return "error_credentials_provider"
   324  }
   325  
   326  func TestRAMRoleARNCredentialsProviderGetCredentials(t *testing.T) {
   327  	originHttpDo := httpDo
   328  	defer func() { httpDo = originHttpDo }()
   329  
   330  	// case 0: get previous credentials failed
   331  	p, err := NewRAMRoleARNCredentialsProviderBuilder().
   332  		WithCredentialsProvider(&errorCredentialsProvider{}).
   333  		WithRoleArn("roleArn").
   334  		WithRoleSessionName("rsn").
   335  		WithDurationSeconds(1000).
   336  		Build()
   337  	assert.Nil(t, err)
   338  	_, err = p.GetCredentials()
   339  	assert.Equal(t, "get credentials failed", err.Error())
   340  
   341  	akProvider, err := NewStaticAKCredentialsProviderBuilder().
   342  		WithAccessKeyId("akid").
   343  		WithAccessKeySecret("aksecret").
   344  		Build()
   345  	assert.Nil(t, err)
   346  
   347  	p, err = NewRAMRoleARNCredentialsProviderBuilder().
   348  		WithCredentialsProvider(akProvider).
   349  		WithRoleArn("roleArn").
   350  		WithRoleSessionName("rsn").
   351  		WithDurationSeconds(1000).
   352  		Build()
   353  	assert.Nil(t, err)
   354  
   355  	// case 1: get credentials failed
   356  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   357  		err = errors.New("mock server error")
   358  		return
   359  	}
   360  	_, err = p.GetCredentials()
   361  	assert.NotNil(t, err)
   362  	assert.Equal(t, "mock server error", err.Error())
   363  
   364  	// case 2: get invalid expiration
   365  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   366  		res = &httputil.Response{
   367  			StatusCode: 200,
   368  			Body:       []byte(`{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"invalidexpiration","SecurityToken":"ststoken"}}`),
   369  		}
   370  		return
   371  	}
   372  	_, err = p.GetCredentials()
   373  	assert.NotNil(t, err)
   374  	assert.Equal(t, "parsing time \"invalidexpiration\" as \"2006-01-02T15:04:05Z\": cannot parse \"invalidexpiration\" as \"2006\"", err.Error())
   375  
   376  	// case 3: happy result
   377  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   378  		res = &httputil.Response{
   379  			StatusCode: 200,
   380  			Body:       []byte(`{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"ststoken"}}`),
   381  		}
   382  		return
   383  	}
   384  	cc, err := p.GetCredentials()
   385  	assert.Nil(t, err)
   386  	assert.Equal(t, "akid", cc.AccessKeyId)
   387  	assert.Equal(t, "aksecret", cc.AccessKeySecret)
   388  	assert.Equal(t, "ststoken", cc.SecurityToken)
   389  	assert.Equal(t, "ram_role_arn/static_ak", cc.ProviderName)
   390  	assert.True(t, p.needUpdateCredential())
   391  	// get credentials again
   392  	cc, err = p.GetCredentials()
   393  	assert.Nil(t, err)
   394  	assert.Equal(t, "akid", cc.AccessKeyId)
   395  	assert.Equal(t, "aksecret", cc.AccessKeySecret)
   396  	assert.Equal(t, "ststoken", cc.SecurityToken)
   397  	assert.Equal(t, "ram_role_arn/static_ak", cc.ProviderName)
   398  	assert.True(t, p.needUpdateCredential())
   399  
   400  	pp, err := NewRAMRoleARNCredentialsProviderBuilder().
   401  		WithCredentialsProvider(p).
   402  		WithRoleArn("roleArn").
   403  		WithRoleSessionName("rsn").
   404  		WithDurationSeconds(1000).
   405  		Build()
   406  	assert.Nil(t, err)
   407  	cc, err = pp.GetCredentials()
   408  	assert.Nil(t, err)
   409  	assert.Equal(t, "akid", cc.AccessKeyId)
   410  	assert.Equal(t, "aksecret", cc.AccessKeySecret)
   411  	assert.Equal(t, "ststoken", cc.SecurityToken)
   412  	assert.True(t, pp.needUpdateCredential())
   413  	assert.Equal(t, "ram_role_arn/ram_role_arn/static_ak", cc.ProviderName)
   414  }
   415  
   416  func TestRAMRoleARNCredentialsProviderGetCredentialsWithError(t *testing.T) {
   417  	akProvider, err := NewStaticAKCredentialsProviderBuilder().
   418  		WithAccessKeyId("akid").
   419  		WithAccessKeySecret("aksecret").
   420  		Build()
   421  	assert.Nil(t, err)
   422  	p, err := NewRAMRoleARNCredentialsProviderBuilder().
   423  		WithCredentialsProvider(akProvider).
   424  		WithRoleArn("roleArn").
   425  		WithRoleSessionName("rsn").
   426  		WithDurationSeconds(1000).
   427  		Build()
   428  	assert.Nil(t, err)
   429  	_, err = p.GetCredentials()
   430  	assert.NotNil(t, err)
   431  	assert.Contains(t, err.Error(), "InvalidAccessKeyId.NotFound")
   432  }
   433  
   434  func TestRAMRoleARNCredentialsProviderWithHttpOptions(t *testing.T) {
   435  	akProvider, err := NewStaticAKCredentialsProviderBuilder().
   436  		WithAccessKeyId("akid").
   437  		WithAccessKeySecret("aksecret").
   438  		Build()
   439  	assert.Nil(t, err)
   440  	p, err := NewRAMRoleARNCredentialsProviderBuilder().
   441  		WithCredentialsProvider(akProvider).
   442  		WithRoleArn("roleArn").
   443  		WithRoleSessionName("rsn").
   444  		WithDurationSeconds(1000).
   445  		WithHttpOptions(&HttpOptions{
   446  			ConnectTimeout: 1000,
   447  			ReadTimeout:    1000,
   448  			Proxy:          "localhost:3999",
   449  		}).
   450  		Build()
   451  	assert.Nil(t, err)
   452  	_, err = p.GetCredentials()
   453  	assert.NotNil(t, err)
   454  	assert.Contains(t, err.Error(), "proxyconnect tcp:")
   455  }