github.com/aliyun/credentials-go@v1.4.7/credentials/providers/oidc_test.go (about)

     1  package providers
     2  
     3  import (
     4  	"errors"
     5  	"os"
     6  	"path"
     7  	"strings"
     8  	"testing"
     9  	"time"
    10  
    11  	httputil "github.com/aliyun/credentials-go/credentials/internal/http"
    12  	"github.com/aliyun/credentials-go/credentials/internal/utils"
    13  	"github.com/stretchr/testify/assert"
    14  )
    15  
    16  func TestOIDCCredentialsProviderGetCredentialsWithError(t *testing.T) {
    17  	wd, _ := os.Getwd()
    18  	p, err := NewOIDCCredentialsProviderBuilder().
    19  		// read a normal token
    20  		WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")).
    21  		WithOIDCProviderARN("provider-arn").
    22  		WithRoleArn("roleArn").
    23  		WithRoleSessionName("rsn").
    24  		WithPolicy("policy").
    25  		WithDurationSeconds(1000).
    26  		WithHttpOptions(&HttpOptions{
    27  			ConnectTimeout: 10000,
    28  		}).
    29  		Build()
    30  
    31  	assert.Nil(t, err)
    32  	assert.Equal(t, 10000, p.httpOptions.ConnectTimeout)
    33  	_, err = p.GetCredentials()
    34  	assert.NotNil(t, err)
    35  	assert.Contains(t, err.Error(), "AuthenticationFail.NoPermission")
    36  }
    37  
    38  func TestNewOIDCCredentialsProvider(t *testing.T) {
    39  	rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "ALIBABA_CLOUD_ROLE_ARN", "ALIBABA_CLOUD_STS_REGION", "ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED")
    40  	defer func() {
    41  		rollback()
    42  	}()
    43  
    44  	_, err := NewOIDCCredentialsProviderBuilder().Build()
    45  	assert.NotNil(t, err)
    46  	assert.Equal(t, "the OIDCTokenFilePath is empty", err.Error())
    47  
    48  	_, err = NewOIDCCredentialsProviderBuilder().WithOIDCTokenFilePath("/path/to/invalid/oidc.token").Build()
    49  	assert.NotNil(t, err)
    50  	assert.Equal(t, "the OIDCProviderARN is empty", err.Error())
    51  
    52  	_, err = NewOIDCCredentialsProviderBuilder().
    53  		WithOIDCTokenFilePath("/path/to/invalid/oidc.token").
    54  		WithOIDCProviderARN("provider-arn").
    55  		Build()
    56  	assert.NotNil(t, err)
    57  	assert.Equal(t, "the RoleArn is empty", err.Error())
    58  
    59  	p, err := NewOIDCCredentialsProviderBuilder().
    60  		WithOIDCTokenFilePath("/path/to/invalid/oidc.token").
    61  		WithOIDCProviderARN("provider-arn").
    62  		WithRoleArn("roleArn").
    63  		Build()
    64  	assert.Nil(t, err)
    65  
    66  	assert.Equal(t, "/path/to/invalid/oidc.token", p.oidcTokenFilePath)
    67  	assert.True(t, strings.HasPrefix(p.roleSessionName, "credentials-go-"))
    68  	assert.Equal(t, 3600, p.durationSeconds)
    69  
    70  	_, err = NewOIDCCredentialsProviderBuilder().
    71  		WithOIDCTokenFilePath("/path/to/invalid/oidc.token").
    72  		WithOIDCProviderARN("provider-arn").
    73  		WithRoleArn("roleArn").
    74  		WithDurationSeconds(100).
    75  		Build()
    76  	assert.NotNil(t, err)
    77  	assert.Equal(t, "the Assume Role session duration should be in the range of 15min - max duration seconds", err.Error())
    78  
    79  	os.Setenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "/path/from/env")
    80  	os.Setenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "provider_arn_from_env")
    81  	os.Setenv("ALIBABA_CLOUD_ROLE_ARN", "role_arn_from_env")
    82  
    83  	p, err = NewOIDCCredentialsProviderBuilder().
    84  		Build()
    85  	assert.Nil(t, err)
    86  
    87  	assert.Equal(t, "/path/from/env", p.oidcTokenFilePath)
    88  	assert.Equal(t, "provider_arn_from_env", p.oidcProviderARN)
    89  	assert.Equal(t, "role_arn_from_env", p.roleArn)
    90  	// sts endpoint: default
    91  	assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint)
    92  
    93  	// sts endpoint: with sts endpoint env
    94  	os.Setenv("ALIBABA_CLOUD_STS_REGION", "cn-hangzhou")
    95  	os.Setenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED", "true")
    96  	p, err = NewOIDCCredentialsProviderBuilder().
    97  		Build()
    98  	assert.Nil(t, err)
    99  	assert.Equal(t, "sts-vpc.cn-hangzhou.aliyuncs.com", p.stsEndpoint)
   100  
   101  	// sts endpoint: with sts endpoint
   102  	p, err = NewOIDCCredentialsProviderBuilder().
   103  		WithSTSEndpoint("sts.cn-shanghai.aliyuncs.com").
   104  		Build()
   105  	assert.Nil(t, err)
   106  	assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint)
   107  
   108  	// sts endpoint: with sts regionId
   109  	p, err = NewOIDCCredentialsProviderBuilder().
   110  		WithStsRegionId("cn-beijing").
   111  		WithEnableVpc(true).
   112  		Build()
   113  	assert.Nil(t, err)
   114  	assert.Equal(t, "sts-vpc.cn-beijing.aliyuncs.com", p.stsEndpoint)
   115  
   116  	os.Setenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED", "false")
   117  	p, err = NewOIDCCredentialsProviderBuilder().
   118  		WithOIDCTokenFilePath("/path/to/invalid/oidc.token").
   119  		WithOIDCProviderARN("provider-arn").
   120  		WithRoleArn("roleArn").
   121  		WithRoleSessionName("rsn").
   122  		WithStsRegionId("cn-hangzhou").
   123  		WithPolicy("policy").
   124  		Build()
   125  	assert.Nil(t, err)
   126  
   127  	assert.Equal(t, "/path/to/invalid/oidc.token", p.oidcTokenFilePath)
   128  	assert.Equal(t, "provider-arn", p.oidcProviderARN)
   129  	assert.Equal(t, "roleArn", p.roleArn)
   130  	assert.Equal(t, "rsn", p.roleSessionName)
   131  	assert.Equal(t, "cn-hangzhou", p.stsRegionId)
   132  	assert.Equal(t, "policy", p.policy)
   133  	assert.Equal(t, 3600, p.durationSeconds)
   134  	assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint)
   135  }
   136  
   137  func TestOIDCCredentialsProvider_getCredentials(t *testing.T) {
   138  	originHttpDo := httpDo
   139  	defer func() { httpDo = originHttpDo }()
   140  
   141  	// case 0: invalid oidc token file path
   142  	p, err := NewOIDCCredentialsProviderBuilder().
   143  		WithOIDCTokenFilePath("/path/to/invalid/oidc.token").
   144  		WithOIDCProviderARN("provider-arn").
   145  		WithRoleArn("roleArn").
   146  		WithRoleSessionName("rsn").
   147  		WithStsRegionId("cn-hangzhou").
   148  		WithPolicy("policy").
   149  		Build()
   150  	assert.Nil(t, err)
   151  
   152  	_, err = p.getCredentials()
   153  	assert.NotNil(t, err)
   154  	assert.Equal(t, "open /path/to/invalid/oidc.token: no such file or directory", err.Error())
   155  
   156  	// case 1: mock new http request failed
   157  	wd, _ := os.Getwd()
   158  	p, err = NewOIDCCredentialsProviderBuilder().
   159  		// read a normal token
   160  		WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")).
   161  		WithOIDCProviderARN("provider-arn").
   162  		WithRoleArn("roleArn").
   163  		WithRoleSessionName("rsn").
   164  		WithStsRegionId("cn-hangzhou").
   165  		WithPolicy("policy").
   166  		Build()
   167  	assert.Nil(t, err)
   168  
   169  	// case 2: server error
   170  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   171  		err = errors.New("mock server error")
   172  		return
   173  	}
   174  	_, err = p.getCredentials()
   175  	assert.NotNil(t, err)
   176  	assert.Equal(t, "mock server error", err.Error())
   177  
   178  	// case 3: 4xx error
   179  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   180  		res = &httputil.Response{
   181  			StatusCode: 400,
   182  			Body:       []byte("4xx error"),
   183  		}
   184  		return
   185  	}
   186  	_, err = p.getCredentials()
   187  	assert.NotNil(t, err)
   188  	assert.Equal(t, "get session token failed: 4xx error", err.Error())
   189  
   190  	// case 4: invalid json
   191  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   192  		res = &httputil.Response{
   193  			StatusCode: 200,
   194  			Body:       []byte("invalid json"),
   195  		}
   196  		return
   197  	}
   198  	_, err = p.getCredentials()
   199  	assert.NotNil(t, err)
   200  	assert.Equal(t, "get oidc sts token err, json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error())
   201  
   202  	// case 5: empty response json
   203  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   204  		res = &httputil.Response{
   205  			StatusCode: 200,
   206  			Body:       []byte("null"),
   207  		}
   208  		return
   209  	}
   210  	_, err = p.getCredentials()
   211  	assert.NotNil(t, err)
   212  	assert.Equal(t, "get oidc sts token err, fail to get credentials", err.Error())
   213  
   214  	// case 6: empty session ak response json
   215  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   216  		res = &httputil.Response{
   217  			StatusCode: 200,
   218  			Body:       []byte(`{"Credentials": {}}`),
   219  		}
   220  		return
   221  	}
   222  	_, err = p.getCredentials()
   223  	assert.NotNil(t, err)
   224  	assert.Equal(t, "refresh RoleArn sts token err, fail to get credentials", err.Error())
   225  
   226  	// case 7: mock ok value
   227  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   228  		res = &httputil.Response{
   229  			StatusCode: 200,
   230  			Body:       []byte(`{"Credentials": {"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token"}}`),
   231  		}
   232  		return
   233  	}
   234  	creds, err := p.getCredentials()
   235  	assert.Nil(t, err)
   236  	assert.Equal(t, "saki", creds.AccessKeyId)
   237  	assert.Equal(t, "saks", creds.AccessKeySecret)
   238  	assert.Equal(t, "token", creds.SecurityToken)
   239  	assert.Equal(t, "2021-10-20T04:27:09Z", creds.Expiration)
   240  
   241  	// needUpdateCredential
   242  	assert.True(t, p.needUpdateCredential())
   243  	p.expirationTimestamp = time.Now().Unix()
   244  	assert.True(t, p.needUpdateCredential())
   245  
   246  	p.expirationTimestamp = time.Now().Unix() + 300
   247  	assert.False(t, p.needUpdateCredential())
   248  }
   249  
   250  func TestOIDCCredentialsProvider_getCredentialsWithRequestCheck(t *testing.T) {
   251  	originHttpDo := httpDo
   252  	defer func() { httpDo = originHttpDo }()
   253  
   254  	// case 1: mock new http request failed
   255  	wd, _ := os.Getwd()
   256  	p, err := NewOIDCCredentialsProviderBuilder().
   257  		// read a normal token
   258  		WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")).
   259  		WithOIDCProviderARN("provider-arn").
   260  		WithRoleArn("roleArn").
   261  		WithRoleSessionName("rsn").
   262  		WithPolicy("policy").
   263  		WithDurationSeconds(1000).
   264  		Build()
   265  
   266  	assert.Nil(t, err)
   267  
   268  	// case 1: server error
   269  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   270  		assert.Equal(t, "sts.aliyuncs.com", req.Host)
   271  		assert.Equal(t, "AssumeRoleWithOIDC", req.Queries["Action"])
   272  		assert.Equal(t, "policy", req.Form["Policy"])
   273  		assert.Equal(t, "roleArn", req.Form["RoleArn"])
   274  		assert.Equal(t, "rsn", req.Form["RoleSessionName"])
   275  		assert.Equal(t, "1000", req.Form["DurationSeconds"])
   276  
   277  		err = errors.New("mock server error")
   278  		return
   279  	}
   280  	_, err = p.getCredentials()
   281  	assert.NotNil(t, err)
   282  	assert.Equal(t, "mock server error", err.Error())
   283  }
   284  
   285  func TestOIDCCredentialsProviderGetCredentials(t *testing.T) {
   286  	originHttpDo := httpDo
   287  	defer func() { httpDo = originHttpDo }()
   288  
   289  	// case 1: mock new http request failed
   290  	wd, _ := os.Getwd()
   291  	p, err := NewOIDCCredentialsProviderBuilder().
   292  		// read a normal token
   293  		WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")).
   294  		WithOIDCProviderARN("provider-arn").
   295  		WithRoleArn("roleArn").
   296  		WithRoleSessionName("rsn").
   297  		WithPolicy("policy").
   298  		WithDurationSeconds(1000).
   299  		Build()
   300  
   301  	assert.Nil(t, err)
   302  
   303  	// case 2: get credentials failed
   304  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   305  		err = errors.New("mock server error")
   306  		return
   307  	}
   308  	_, err = p.GetCredentials()
   309  	assert.NotNil(t, err)
   310  	assert.Equal(t, "mock server error", err.Error())
   311  
   312  	// case 2: get invalid expiration
   313  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   314  		res = &httputil.Response{
   315  			StatusCode: 200,
   316  			Body:       []byte(`{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"invalidexpiration","SecurityToken":"ststoken"}}`),
   317  		}
   318  		return
   319  	}
   320  	_, err = p.GetCredentials()
   321  	assert.NotNil(t, err)
   322  	assert.Equal(t, "parsing time \"invalidexpiration\" as \"2006-01-02T15:04:05Z\": cannot parse \"invalidexpiration\" as \"2006\"", err.Error())
   323  
   324  	// case 3: happy result
   325  	httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
   326  		res = &httputil.Response{
   327  			StatusCode: 200,
   328  			Body:       []byte(`{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"ststoken"}}`),
   329  		}
   330  		return
   331  	}
   332  	cc, err := p.GetCredentials()
   333  	assert.Nil(t, err)
   334  	assert.Equal(t, "akid", cc.AccessKeyId)
   335  	assert.Equal(t, "aksecret", cc.AccessKeySecret)
   336  	assert.Equal(t, "ststoken", cc.SecurityToken)
   337  	assert.Equal(t, "oidc_role_arn", cc.ProviderName)
   338  	assert.True(t, p.needUpdateCredential())
   339  }
   340  
   341  func TestOIDCCredentialsProviderGetCredentialsWithHttpOptions(t *testing.T) {
   342  	wd, _ := os.Getwd()
   343  	p, err := NewOIDCCredentialsProviderBuilder().
   344  		// read a normal token
   345  		WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")).
   346  		WithOIDCProviderARN("provider-arn").
   347  		WithRoleArn("roleArn").
   348  		WithRoleSessionName("rsn").
   349  		WithPolicy("policy").
   350  		WithDurationSeconds(1000).
   351  		WithHttpOptions(&HttpOptions{
   352  			ConnectTimeout: 1000,
   353  			ReadTimeout:    1000,
   354  			Proxy:          "localhost:3999",
   355  		}).
   356  		Build()
   357  
   358  	assert.Nil(t, err)
   359  	_, err = p.GetCredentials()
   360  	assert.NotNil(t, err)
   361  	assert.Contains(t, err.Error(), "proxyconnect tcp:")
   362  }