github.com/aliyun/credentials-go@v1.4.7/credentials/providers/ram_role_arn.go (about)

     1  package providers
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"net/url"
     9  	"os"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	httputil "github.com/aliyun/credentials-go/credentials/internal/http"
    15  	"github.com/aliyun/credentials-go/credentials/internal/utils"
    16  )
    17  
    18  type assumedRoleUser struct {
    19  }
    20  
    21  type credentials struct {
    22  	SecurityToken   *string `json:"SecurityToken"`
    23  	Expiration      *string `json:"Expiration"`
    24  	AccessKeySecret *string `json:"AccessKeySecret"`
    25  	AccessKeyId     *string `json:"AccessKeyId"`
    26  }
    27  
    28  type assumeRoleResponse struct {
    29  	RequestID       *string          `json:"RequestId"`
    30  	AssumedRoleUser *assumedRoleUser `json:"AssumedRoleUser"`
    31  	Credentials     *credentials     `json:"Credentials"`
    32  }
    33  
    34  type sessionCredentials struct {
    35  	AccessKeyId     string
    36  	AccessKeySecret string
    37  	SecurityToken   string
    38  	Expiration      string
    39  }
    40  
    41  type HttpOptions struct {
    42  	Proxy string
    43  	// Connection timeout, in milliseconds.
    44  	ConnectTimeout int
    45  	// Read timeout, in milliseconds.
    46  	ReadTimeout int
    47  }
    48  
    49  type RAMRoleARNCredentialsProvider struct {
    50  	// for previous credentials
    51  	accessKeyId         string
    52  	accessKeySecret     string
    53  	securityToken       string
    54  	credentialsProvider CredentialsProvider
    55  
    56  	roleArn         string
    57  	roleSessionName string
    58  	durationSeconds int
    59  	policy          string
    60  	externalId      string
    61  	// for sts endpoint
    62  	stsRegionId string
    63  	enableVpc   bool
    64  	stsEndpoint string
    65  	// for http options
    66  	httpOptions *HttpOptions
    67  	// inner
    68  	expirationTimestamp  int64
    69  	lastUpdateTimestamp  int64
    70  	previousProviderName string
    71  	sessionCredentials   *sessionCredentials
    72  }
    73  
    74  type RAMRoleARNCredentialsProviderBuilder struct {
    75  	provider *RAMRoleARNCredentialsProvider
    76  }
    77  
    78  func NewRAMRoleARNCredentialsProviderBuilder() *RAMRoleARNCredentialsProviderBuilder {
    79  	return &RAMRoleARNCredentialsProviderBuilder{
    80  		provider: &RAMRoleARNCredentialsProvider{},
    81  	}
    82  }
    83  
    84  func (builder *RAMRoleARNCredentialsProviderBuilder) WithAccessKeyId(accessKeyId string) *RAMRoleARNCredentialsProviderBuilder {
    85  	builder.provider.accessKeyId = accessKeyId
    86  	return builder
    87  }
    88  
    89  func (builder *RAMRoleARNCredentialsProviderBuilder) WithAccessKeySecret(accessKeySecret string) *RAMRoleARNCredentialsProviderBuilder {
    90  	builder.provider.accessKeySecret = accessKeySecret
    91  	return builder
    92  }
    93  
    94  func (builder *RAMRoleARNCredentialsProviderBuilder) WithSecurityToken(securityToken string) *RAMRoleARNCredentialsProviderBuilder {
    95  	builder.provider.securityToken = securityToken
    96  	return builder
    97  }
    98  
    99  func (builder *RAMRoleARNCredentialsProviderBuilder) WithCredentialsProvider(credentialsProvider CredentialsProvider) *RAMRoleARNCredentialsProviderBuilder {
   100  	builder.provider.credentialsProvider = credentialsProvider
   101  	return builder
   102  }
   103  
   104  func (builder *RAMRoleARNCredentialsProviderBuilder) WithRoleArn(roleArn string) *RAMRoleARNCredentialsProviderBuilder {
   105  	builder.provider.roleArn = roleArn
   106  	return builder
   107  }
   108  
   109  func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsRegionId(regionId string) *RAMRoleARNCredentialsProviderBuilder {
   110  	builder.provider.stsRegionId = regionId
   111  	return builder
   112  }
   113  
   114  func (builder *RAMRoleARNCredentialsProviderBuilder) WithEnableVpc(enableVpc bool) *RAMRoleARNCredentialsProviderBuilder {
   115  	builder.provider.enableVpc = enableVpc
   116  	return builder
   117  }
   118  
   119  func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsEndpoint(endpoint string) *RAMRoleARNCredentialsProviderBuilder {
   120  	builder.provider.stsEndpoint = endpoint
   121  	return builder
   122  }
   123  
   124  func (builder *RAMRoleARNCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *RAMRoleARNCredentialsProviderBuilder {
   125  	builder.provider.roleSessionName = roleSessionName
   126  	return builder
   127  }
   128  
   129  func (builder *RAMRoleARNCredentialsProviderBuilder) WithPolicy(policy string) *RAMRoleARNCredentialsProviderBuilder {
   130  	builder.provider.policy = policy
   131  	return builder
   132  }
   133  
   134  func (builder *RAMRoleARNCredentialsProviderBuilder) WithExternalId(externalId string) *RAMRoleARNCredentialsProviderBuilder {
   135  	builder.provider.externalId = externalId
   136  	return builder
   137  }
   138  
   139  func (builder *RAMRoleARNCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *RAMRoleARNCredentialsProviderBuilder {
   140  	builder.provider.durationSeconds = durationSeconds
   141  	return builder
   142  }
   143  
   144  func (builder *RAMRoleARNCredentialsProviderBuilder) WithHttpOptions(httpOptions *HttpOptions) *RAMRoleARNCredentialsProviderBuilder {
   145  	builder.provider.httpOptions = httpOptions
   146  	return builder
   147  }
   148  
   149  func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleARNCredentialsProvider, err error) {
   150  	if builder.provider.credentialsProvider == nil {
   151  		if builder.provider.accessKeyId != "" && builder.provider.accessKeySecret != "" && builder.provider.securityToken != "" {
   152  			builder.provider.credentialsProvider, err = NewStaticSTSCredentialsProviderBuilder().
   153  				WithAccessKeyId(builder.provider.accessKeyId).
   154  				WithAccessKeySecret(builder.provider.accessKeySecret).
   155  				WithSecurityToken(builder.provider.securityToken).
   156  				Build()
   157  			if err != nil {
   158  				return
   159  			}
   160  		} else if builder.provider.accessKeyId != "" && builder.provider.accessKeySecret != "" {
   161  			builder.provider.credentialsProvider, err = NewStaticAKCredentialsProviderBuilder().
   162  				WithAccessKeyId(builder.provider.accessKeyId).
   163  				WithAccessKeySecret(builder.provider.accessKeySecret).
   164  				Build()
   165  			if err != nil {
   166  				return
   167  			}
   168  		} else {
   169  			err = errors.New("must specify a previous credentials provider to assume role")
   170  			return
   171  		}
   172  	}
   173  
   174  	if builder.provider.roleArn == "" {
   175  		if roleArn := os.Getenv("ALIBABA_CLOUD_ROLE_ARN"); roleArn != "" {
   176  			builder.provider.roleArn = roleArn
   177  		} else {
   178  			err = errors.New("the RoleArn is empty")
   179  			return
   180  		}
   181  	}
   182  
   183  	if builder.provider.roleSessionName == "" {
   184  		if roleSessionName := os.Getenv("ALIBABA_CLOUD_ROLE_SESSION_NAME"); roleSessionName != "" {
   185  			builder.provider.roleSessionName = roleSessionName
   186  		} else {
   187  			builder.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10)
   188  		}
   189  	}
   190  
   191  	// duration seconds
   192  	if builder.provider.durationSeconds == 0 {
   193  		// default to 3600
   194  		builder.provider.durationSeconds = 3600
   195  	}
   196  
   197  	if builder.provider.durationSeconds < 900 {
   198  		err = errors.New("session duration should be in the range of 900s - max session duration")
   199  		return
   200  	}
   201  
   202  	// sts endpoint
   203  	if builder.provider.stsEndpoint == "" {
   204  		if !builder.provider.enableVpc {
   205  			builder.provider.enableVpc = strings.ToLower(os.Getenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED")) == "true"
   206  		}
   207  		prefix := "sts"
   208  		if builder.provider.enableVpc {
   209  			prefix = "sts-vpc"
   210  		}
   211  		if builder.provider.stsRegionId != "" {
   212  			builder.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, builder.provider.stsRegionId)
   213  		} else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" {
   214  			builder.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, region)
   215  		} else {
   216  			builder.provider.stsEndpoint = "sts.aliyuncs.com"
   217  		}
   218  	}
   219  
   220  	provider = builder.provider
   221  	return
   222  }
   223  
   224  func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) (session *sessionCredentials, err error) {
   225  	method := "POST"
   226  	req := &httputil.Request{
   227  		Method:   method,
   228  		Protocol: "https",
   229  		Host:     provider.stsEndpoint,
   230  		Headers:  map[string]string{},
   231  	}
   232  
   233  	queries := make(map[string]string)
   234  	queries["Version"] = "2015-04-01"
   235  	queries["Action"] = "AssumeRole"
   236  	queries["Format"] = "JSON"
   237  	queries["Timestamp"] = utils.GetTimeInFormatISO8601()
   238  	queries["SignatureMethod"] = "HMAC-SHA1"
   239  	queries["SignatureVersion"] = "1.0"
   240  	queries["SignatureNonce"] = utils.GetNonce()
   241  	queries["AccessKeyId"] = cc.AccessKeyId
   242  
   243  	if cc.SecurityToken != "" {
   244  		queries["SecurityToken"] = cc.SecurityToken
   245  	}
   246  
   247  	bodyForm := make(map[string]string)
   248  	bodyForm["RoleArn"] = provider.roleArn
   249  	if provider.policy != "" {
   250  		bodyForm["Policy"] = provider.policy
   251  	}
   252  	if provider.externalId != "" {
   253  		bodyForm["ExternalId"] = provider.externalId
   254  	}
   255  	bodyForm["RoleSessionName"] = provider.roleSessionName
   256  	bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds)
   257  	req.Form = bodyForm
   258  
   259  	// caculate signature
   260  	signParams := make(map[string]string)
   261  	for key, value := range queries {
   262  		signParams[key] = value
   263  	}
   264  	for key, value := range bodyForm {
   265  		signParams[key] = value
   266  	}
   267  
   268  	stringToSign := utils.GetURLFormedMap(signParams)
   269  	stringToSign = strings.Replace(stringToSign, "+", "%20", -1)
   270  	stringToSign = strings.Replace(stringToSign, "*", "%2A", -1)
   271  	stringToSign = strings.Replace(stringToSign, "%7E", "~", -1)
   272  	stringToSign = url.QueryEscape(stringToSign)
   273  	stringToSign = method + "&%2F&" + stringToSign
   274  	secret := cc.AccessKeySecret + "&"
   275  	queries["Signature"] = utils.ShaHmac1(stringToSign, secret)
   276  
   277  	req.Queries = queries
   278  
   279  	// set headers
   280  	req.Headers["Accept-Encoding"] = "identity"
   281  	req.Headers["Content-Type"] = "application/x-www-form-urlencoded"
   282  	req.Headers["x-acs-credentials-provider"] = cc.ProviderName
   283  
   284  	connectTimeout := 5 * time.Second
   285  	readTimeout := 10 * time.Second
   286  
   287  	if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
   288  		connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
   289  	}
   290  	if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
   291  		readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
   292  	}
   293  	if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
   294  		req.Proxy = provider.httpOptions.Proxy
   295  	}
   296  	req.ConnectTimeout = connectTimeout
   297  	req.ReadTimeout = readTimeout
   298  
   299  	res, err := httpDo(req)
   300  	if err != nil {
   301  		return
   302  	}
   303  
   304  	if res.StatusCode != http.StatusOK {
   305  		err = errors.New("refresh session token failed: " + string(res.Body))
   306  		return
   307  	}
   308  	var data assumeRoleResponse
   309  	err = json.Unmarshal(res.Body, &data)
   310  	if err != nil {
   311  		err = fmt.Errorf("refresh RoleArn sts token err, json.Unmarshal fail: %s", err.Error())
   312  		return
   313  	}
   314  	if data.Credentials == nil {
   315  		err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
   316  		return
   317  	}
   318  
   319  	if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil {
   320  		err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
   321  		return
   322  	}
   323  
   324  	session = &sessionCredentials{
   325  		AccessKeyId:     *data.Credentials.AccessKeyId,
   326  		AccessKeySecret: *data.Credentials.AccessKeySecret,
   327  		SecurityToken:   *data.Credentials.SecurityToken,
   328  		Expiration:      *data.Credentials.Expiration,
   329  	}
   330  	return
   331  }
   332  
   333  func (provider *RAMRoleARNCredentialsProvider) needUpdateCredential() (result bool) {
   334  	if provider.expirationTimestamp == 0 {
   335  		return true
   336  	}
   337  
   338  	return provider.expirationTimestamp-time.Now().Unix() <= 180
   339  }
   340  
   341  func (provider *RAMRoleARNCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
   342  	if provider.sessionCredentials == nil || provider.needUpdateCredential() {
   343  		// 获取前置凭证
   344  		previousCredentials, err1 := provider.credentialsProvider.GetCredentials()
   345  		if err1 != nil {
   346  			return nil, err1
   347  		}
   348  		sessionCredentials, err2 := provider.getCredentials(previousCredentials)
   349  		if err2 != nil {
   350  			return nil, err2
   351  		}
   352  
   353  		expirationTime, err := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration)
   354  		if err != nil {
   355  			return nil, err
   356  		}
   357  
   358  		provider.expirationTimestamp = expirationTime.Unix()
   359  		provider.lastUpdateTimestamp = time.Now().Unix()
   360  		provider.previousProviderName = previousCredentials.ProviderName
   361  		provider.sessionCredentials = sessionCredentials
   362  	}
   363  
   364  	cc = &Credentials{
   365  		AccessKeyId:     provider.sessionCredentials.AccessKeyId,
   366  		AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
   367  		SecurityToken:   provider.sessionCredentials.SecurityToken,
   368  		ProviderName:    fmt.Sprintf("%s/%s", provider.GetProviderName(), provider.previousProviderName),
   369  	}
   370  	return
   371  }
   372  
   373  func (provider *RAMRoleARNCredentialsProvider) GetProviderName() string {
   374  	return "ram_role_arn"
   375  }