github.com/aliyun/credentials-go@v1.4.7/credentials/providers/oidc.go (about)

     1  package providers
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"os"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	httputil "github.com/aliyun/credentials-go/credentials/internal/http"
    15  	"github.com/aliyun/credentials-go/credentials/internal/utils"
    16  )
    17  
    18  type OIDCCredentialsProvider struct {
    19  	oidcProviderARN   string
    20  	oidcTokenFilePath string
    21  	roleArn           string
    22  	roleSessionName   string
    23  	durationSeconds   int
    24  	policy            string
    25  	// for sts endpoint
    26  	stsRegionId string
    27  	enableVpc   bool
    28  	stsEndpoint string
    29  
    30  	lastUpdateTimestamp int64
    31  	expirationTimestamp int64
    32  	sessionCredentials  *sessionCredentials
    33  	// for http options
    34  	httpOptions *HttpOptions
    35  }
    36  
    37  type OIDCCredentialsProviderBuilder struct {
    38  	provider *OIDCCredentialsProvider
    39  }
    40  
    41  func NewOIDCCredentialsProviderBuilder() *OIDCCredentialsProviderBuilder {
    42  	return &OIDCCredentialsProviderBuilder{
    43  		provider: &OIDCCredentialsProvider{},
    44  	}
    45  }
    46  
    47  func (b *OIDCCredentialsProviderBuilder) WithOIDCProviderARN(oidcProviderArn string) *OIDCCredentialsProviderBuilder {
    48  	b.provider.oidcProviderARN = oidcProviderArn
    49  	return b
    50  }
    51  
    52  func (b *OIDCCredentialsProviderBuilder) WithOIDCTokenFilePath(oidcTokenFilePath string) *OIDCCredentialsProviderBuilder {
    53  	b.provider.oidcTokenFilePath = oidcTokenFilePath
    54  	return b
    55  }
    56  
    57  func (b *OIDCCredentialsProviderBuilder) WithRoleArn(roleArn string) *OIDCCredentialsProviderBuilder {
    58  	b.provider.roleArn = roleArn
    59  	return b
    60  }
    61  
    62  func (b *OIDCCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *OIDCCredentialsProviderBuilder {
    63  	b.provider.roleSessionName = roleSessionName
    64  	return b
    65  }
    66  
    67  func (b *OIDCCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *OIDCCredentialsProviderBuilder {
    68  	b.provider.durationSeconds = durationSeconds
    69  	return b
    70  }
    71  
    72  func (b *OIDCCredentialsProviderBuilder) WithStsRegionId(regionId string) *OIDCCredentialsProviderBuilder {
    73  	b.provider.stsRegionId = regionId
    74  	return b
    75  }
    76  
    77  func (b *OIDCCredentialsProviderBuilder) WithEnableVpc(enableVpc bool) *OIDCCredentialsProviderBuilder {
    78  	b.provider.enableVpc = enableVpc
    79  	return b
    80  }
    81  
    82  func (b *OIDCCredentialsProviderBuilder) WithPolicy(policy string) *OIDCCredentialsProviderBuilder {
    83  	b.provider.policy = policy
    84  	return b
    85  }
    86  
    87  func (b *OIDCCredentialsProviderBuilder) WithSTSEndpoint(stsEndpoint string) *OIDCCredentialsProviderBuilder {
    88  	b.provider.stsEndpoint = stsEndpoint
    89  	return b
    90  }
    91  
    92  func (b *OIDCCredentialsProviderBuilder) WithHttpOptions(httpOptions *HttpOptions) *OIDCCredentialsProviderBuilder {
    93  	b.provider.httpOptions = httpOptions
    94  	return b
    95  }
    96  
    97  func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvider, err error) {
    98  	if b.provider.roleSessionName == "" {
    99  		b.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10)
   100  	}
   101  
   102  	if b.provider.oidcTokenFilePath == "" {
   103  		b.provider.oidcTokenFilePath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE")
   104  	}
   105  
   106  	if b.provider.oidcTokenFilePath == "" {
   107  		err = errors.New("the OIDCTokenFilePath is empty")
   108  		return
   109  	}
   110  
   111  	if b.provider.oidcProviderARN == "" {
   112  		b.provider.oidcProviderARN = os.Getenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN")
   113  	}
   114  
   115  	if b.provider.oidcProviderARN == "" {
   116  		err = errors.New("the OIDCProviderARN is empty")
   117  		return
   118  	}
   119  
   120  	if b.provider.roleArn == "" {
   121  		b.provider.roleArn = os.Getenv("ALIBABA_CLOUD_ROLE_ARN")
   122  	}
   123  
   124  	if b.provider.roleArn == "" {
   125  		err = errors.New("the RoleArn is empty")
   126  		return
   127  	}
   128  
   129  	if b.provider.durationSeconds == 0 {
   130  		b.provider.durationSeconds = 3600
   131  	}
   132  
   133  	if b.provider.durationSeconds < 900 {
   134  		err = errors.New("the Assume Role session duration should be in the range of 15min - max duration seconds")
   135  	}
   136  
   137  	if b.provider.stsEndpoint == "" {
   138  		if !b.provider.enableVpc {
   139  			b.provider.enableVpc = strings.ToLower(os.Getenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED")) == "true"
   140  		}
   141  		prefix := "sts"
   142  		if b.provider.enableVpc {
   143  			prefix = "sts-vpc"
   144  		}
   145  		if b.provider.stsRegionId != "" {
   146  			b.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, b.provider.stsRegionId)
   147  		} else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" {
   148  			b.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, region)
   149  		} else {
   150  			b.provider.stsEndpoint = "sts.aliyuncs.com"
   151  		}
   152  	}
   153  
   154  	provider = b.provider
   155  	return
   156  }
   157  
   158  func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCredentials, err error) {
   159  	req := &httputil.Request{
   160  		Method:   "POST",
   161  		Protocol: "https",
   162  		Host:     provider.stsEndpoint,
   163  		Headers:  map[string]string{},
   164  	}
   165  
   166  	connectTimeout := 5 * time.Second
   167  	readTimeout := 10 * time.Second
   168  
   169  	if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
   170  		connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
   171  	}
   172  	if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
   173  		readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
   174  	}
   175  	if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
   176  		req.Proxy = provider.httpOptions.Proxy
   177  	}
   178  	req.ConnectTimeout = connectTimeout
   179  	req.ReadTimeout = readTimeout
   180  
   181  	queries := make(map[string]string)
   182  	queries["Version"] = "2015-04-01"
   183  	queries["Action"] = "AssumeRoleWithOIDC"
   184  	queries["Format"] = "JSON"
   185  	queries["Timestamp"] = utils.GetTimeInFormatISO8601()
   186  	req.Queries = queries
   187  
   188  	bodyForm := make(map[string]string)
   189  	bodyForm["RoleArn"] = provider.roleArn
   190  	bodyForm["OIDCProviderArn"] = provider.oidcProviderARN
   191  	token, err := ioutil.ReadFile(provider.oidcTokenFilePath)
   192  	if err != nil {
   193  		return
   194  	}
   195  
   196  	bodyForm["OIDCToken"] = string(token)
   197  	if provider.policy != "" {
   198  		bodyForm["Policy"] = provider.policy
   199  	}
   200  
   201  	bodyForm["RoleSessionName"] = provider.roleSessionName
   202  	bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds)
   203  	req.Form = bodyForm
   204  
   205  	// set headers
   206  	req.Headers["Accept-Encoding"] = "identity"
   207  	res, err := httpDo(req)
   208  	if err != nil {
   209  		return
   210  	}
   211  
   212  	if res.StatusCode != http.StatusOK {
   213  		message := "get session token failed: "
   214  		err = errors.New(message + string(res.Body))
   215  		return
   216  	}
   217  	var data assumeRoleResponse
   218  	err = json.Unmarshal(res.Body, &data)
   219  	if err != nil {
   220  		err = fmt.Errorf("get oidc sts token err, json.Unmarshal fail: %s", err.Error())
   221  		return
   222  	}
   223  	if data.Credentials == nil {
   224  		err = fmt.Errorf("get oidc sts token err, fail to get credentials")
   225  		return
   226  	}
   227  
   228  	if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil {
   229  		err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
   230  		return
   231  	}
   232  
   233  	session = &sessionCredentials{
   234  		AccessKeyId:     *data.Credentials.AccessKeyId,
   235  		AccessKeySecret: *data.Credentials.AccessKeySecret,
   236  		SecurityToken:   *data.Credentials.SecurityToken,
   237  		Expiration:      *data.Credentials.Expiration,
   238  	}
   239  	return
   240  }
   241  
   242  func (provider *OIDCCredentialsProvider) needUpdateCredential() (result bool) {
   243  	if provider.expirationTimestamp == 0 {
   244  		return true
   245  	}
   246  
   247  	return provider.expirationTimestamp-time.Now().Unix() <= 180
   248  }
   249  
   250  func (provider *OIDCCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
   251  	if provider.sessionCredentials == nil || provider.needUpdateCredential() {
   252  		sessionCredentials, err1 := provider.getCredentials()
   253  		if err1 != nil {
   254  			return nil, err1
   255  		}
   256  
   257  		provider.sessionCredentials = sessionCredentials
   258  		expirationTime, err2 := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration)
   259  		if err2 != nil {
   260  			return nil, err2
   261  		}
   262  
   263  		provider.lastUpdateTimestamp = time.Now().Unix()
   264  		provider.expirationTimestamp = expirationTime.Unix()
   265  	}
   266  
   267  	cc = &Credentials{
   268  		AccessKeyId:     provider.sessionCredentials.AccessKeyId,
   269  		AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
   270  		SecurityToken:   provider.sessionCredentials.SecurityToken,
   271  		ProviderName:    provider.GetProviderName(),
   272  	}
   273  	return
   274  }
   275  
   276  func (provider *OIDCCredentialsProvider) GetProviderName() string {
   277  	return "oidc_role_arn"
   278  }