github.com/aavshr/aws-sdk-go@v1.41.3/aws/session/credentials_test.go (about)

     1  //go:build go1.7
     2  // +build go1.7
     3  
     4  package session
     5  
     6  import (
     7  	"fmt"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"os"
    12  	"path/filepath"
    13  	"reflect"
    14  	"runtime"
    15  	"strconv"
    16  	"strings"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/aavshr/aws-sdk-go/aws"
    21  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    22  	"github.com/aavshr/aws-sdk-go/aws/defaults"
    23  	"github.com/aavshr/aws-sdk-go/aws/endpoints"
    24  	"github.com/aavshr/aws-sdk-go/aws/request"
    25  	"github.com/aavshr/aws-sdk-go/internal/sdktesting"
    26  	"github.com/aavshr/aws-sdk-go/internal/shareddefaults"
    27  	"github.com/aavshr/aws-sdk-go/private/protocol"
    28  	"github.com/aavshr/aws-sdk-go/service/sts"
    29  )
    30  
    31  func newEc2MetadataServer(key, secret string, closeAfterGetCreds bool) *httptest.Server {
    32  	var server *httptest.Server
    33  	server = httptest.NewServer(http.HandlerFunc(
    34  		func(w http.ResponseWriter, r *http.Request) {
    35  			if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
    36  				w.Write([]byte(fmt.Sprintf(ec2MetadataResponse, key, secret)))
    37  
    38  				if closeAfterGetCreds {
    39  					go server.Close()
    40  				}
    41  			} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
    42  				w.Write([]byte("RoleName"))
    43  			} else {
    44  				w.Write([]byte(""))
    45  			}
    46  		}))
    47  
    48  	return server
    49  }
    50  
    51  func setupCredentialsEndpoints(t *testing.T) (endpoints.Resolver, func()) {
    52  	origECSEndpoint := shareddefaults.ECSContainerCredentialsURI
    53  
    54  	ecsMetadataServer := httptest.NewServer(http.HandlerFunc(
    55  		func(w http.ResponseWriter, r *http.Request) {
    56  			if r.URL.Path == "/ECS" {
    57  				w.Write([]byte(ecsResponse))
    58  			} else {
    59  				w.Write([]byte(""))
    60  			}
    61  		}))
    62  	shareddefaults.ECSContainerCredentialsURI = ecsMetadataServer.URL
    63  
    64  	ec2MetadataServer := newEc2MetadataServer("ec2_key", "ec2_secret", false)
    65  
    66  	stsServer := httptest.NewServer(http.HandlerFunc(
    67  		func(w http.ResponseWriter, r *http.Request) {
    68  			if err := r.ParseForm(); err != nil {
    69  				w.WriteHeader(500)
    70  				return
    71  			}
    72  
    73  			form := r.Form
    74  
    75  			switch form.Get("Action") {
    76  			case "AssumeRole":
    77  				w.Write([]byte(fmt.Sprintf(
    78  					assumeRoleRespMsg,
    79  					time.Now().
    80  						Add(15*time.Minute).
    81  						Format(protocol.ISO8601TimeFormat))))
    82  				return
    83  			case "AssumeRoleWithWebIdentity":
    84  				w.Write([]byte(fmt.Sprintf(assumeRoleWithWebIdentityResponse,
    85  					time.Now().
    86  						Add(15*time.Minute).
    87  						Format(protocol.ISO8601TimeFormat))))
    88  				return
    89  			default:
    90  				w.WriteHeader(404)
    91  				return
    92  			}
    93  		}))
    94  
    95  	ssoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    96  		w.Write([]byte(fmt.Sprintf(
    97  			getRoleCredentialsResponse,
    98  			time.Now().
    99  				Add(15*time.Minute).
   100  				UnixNano()/int64(time.Millisecond))))
   101  	}))
   102  
   103  	resolver := endpoints.ResolverFunc(
   104  		func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
   105  			switch service {
   106  			case "ec2metadata":
   107  				return endpoints.ResolvedEndpoint{
   108  					URL: ec2MetadataServer.URL,
   109  				}, nil
   110  			case "sts":
   111  				return endpoints.ResolvedEndpoint{
   112  					URL: stsServer.URL,
   113  				}, nil
   114  			case "portal.sso":
   115  				return endpoints.ResolvedEndpoint{
   116  					URL: ssoServer.URL,
   117  				}, nil
   118  			default:
   119  				return endpoints.ResolvedEndpoint{},
   120  					fmt.Errorf("unknown service endpoint, %s", service)
   121  			}
   122  		})
   123  
   124  	return resolver, func() {
   125  		shareddefaults.ECSContainerCredentialsURI = origECSEndpoint
   126  		ecsMetadataServer.Close()
   127  		ec2MetadataServer.Close()
   128  		ssoServer.Close()
   129  		stsServer.Close()
   130  	}
   131  }
   132  
   133  func TestSharedConfigCredentialSource(t *testing.T) {
   134  	const configFileForWindows = "testdata/credential_source_config_for_windows"
   135  	const configFile = "testdata/credential_source_config"
   136  
   137  	cases := []struct {
   138  		name                   string
   139  		profile                string
   140  		sessOptProfile         string
   141  		sessOptEC2IMDSEndpoint string
   142  		expectedError          error
   143  		expectedAccessKey      string
   144  		expectedSecretKey      string
   145  		expectedSessionToken   string
   146  		expectedChain          []string
   147  		init                   func() (func(), error)
   148  		dependentOnOS          bool
   149  	}{
   150  		{
   151  			name:          "credential source and source profile",
   152  			profile:       "invalid_source_and_credential_source",
   153  			expectedError: ErrSharedConfigSourceCollision,
   154  			init: func() (func(), error) {
   155  				os.Setenv("AWS_ACCESS_KEY", "access_key")
   156  				os.Setenv("AWS_SECRET_KEY", "secret_key")
   157  				return func() {}, nil
   158  			},
   159  		},
   160  		{
   161  			name:                 "env var credential source",
   162  			sessOptProfile:       "env_var_credential_source",
   163  			expectedAccessKey:    "AKID",
   164  			expectedSecretKey:    "SECRET",
   165  			expectedSessionToken: "SESSION_TOKEN",
   166  			expectedChain: []string{
   167  				"assume_role_w_creds_role_arn_env",
   168  			},
   169  			init: func() (func(), error) {
   170  				os.Setenv("AWS_ACCESS_KEY", "access_key")
   171  				os.Setenv("AWS_SECRET_KEY", "secret_key")
   172  				return func() {}, nil
   173  			},
   174  		},
   175  		{
   176  			name:    "ec2metadata credential source",
   177  			profile: "ec2metadata",
   178  			expectedChain: []string{
   179  				"assume_role_w_creds_role_arn_ec2",
   180  			},
   181  			expectedAccessKey:    "AKID",
   182  			expectedSecretKey:    "SECRET",
   183  			expectedSessionToken: "SESSION_TOKEN",
   184  		},
   185  		{
   186  			name:                 "ec2metadata custom EC2 IMDS endpoint, env var",
   187  			profile:              "not-exists-profile",
   188  			expectedAccessKey:    "ec2_custom_key",
   189  			expectedSecretKey:    "ec2_custom_secret",
   190  			expectedSessionToken: "token",
   191  			init: func() (func(), error) {
   192  				altServer := newEc2MetadataServer("ec2_custom_key", "ec2_custom_secret", true)
   193  				os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", altServer.URL)
   194  				return func() {}, nil
   195  			},
   196  		},
   197  		{
   198  			name:                 "ecs container credential source",
   199  			profile:              "ecscontainer",
   200  			expectedAccessKey:    "AKID",
   201  			expectedSecretKey:    "SECRET",
   202  			expectedSessionToken: "SESSION_TOKEN",
   203  			expectedChain: []string{
   204  				"assume_role_w_creds_role_arn_ecs",
   205  			},
   206  			init: func() (func(), error) {
   207  				os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS")
   208  				return func() {}, nil
   209  			},
   210  		},
   211  		{
   212  			name:                 "chained assume role with env creds",
   213  			profile:              "chained_assume_role",
   214  			expectedAccessKey:    "AKID",
   215  			expectedSecretKey:    "SECRET",
   216  			expectedSessionToken: "SESSION_TOKEN",
   217  			expectedChain: []string{
   218  				"assume_role_w_creds_role_arn_chain",
   219  				"assume_role_w_creds_role_arn_ec2",
   220  			},
   221  		},
   222  		{
   223  			name:              "credential process with no ARN set",
   224  			profile:           "cred_proc_no_arn_set",
   225  			dependentOnOS:     true,
   226  			expectedAccessKey: "cred_proc_akid",
   227  			expectedSecretKey: "cred_proc_secret",
   228  		},
   229  		{
   230  			name:                 "credential process with ARN set",
   231  			profile:              "cred_proc_arn_set",
   232  			dependentOnOS:        true,
   233  			expectedAccessKey:    "AKID",
   234  			expectedSecretKey:    "SECRET",
   235  			expectedSessionToken: "SESSION_TOKEN",
   236  			expectedChain: []string{
   237  				"assume_role_w_creds_proc_role_arn",
   238  			},
   239  		},
   240  		{
   241  			name:                 "chained assume role with credential process",
   242  			profile:              "chained_cred_proc",
   243  			dependentOnOS:        true,
   244  			expectedAccessKey:    "AKID",
   245  			expectedSecretKey:    "SECRET",
   246  			expectedSessionToken: "SESSION_TOKEN",
   247  			expectedChain: []string{
   248  				"assume_role_w_creds_proc_source_prof",
   249  			},
   250  		},
   251  		{
   252  			name:                 "sso credentials",
   253  			profile:              "sso_creds",
   254  			expectedAccessKey:    "SSO_AKID",
   255  			expectedSecretKey:    "SSO_SECRET_KEY",
   256  			expectedSessionToken: "SSO_SESSION_TOKEN",
   257  			init: func() (func(), error) {
   258  				return ssoTestSetup()
   259  			},
   260  		},
   261  		{
   262  			name:                 "chained assume role with sso credentials",
   263  			profile:              "source_sso_creds",
   264  			expectedAccessKey:    "AKID",
   265  			expectedSecretKey:    "SECRET",
   266  			expectedSessionToken: "SESSION_TOKEN",
   267  			expectedChain: []string{
   268  				"source_sso_creds_arn",
   269  			},
   270  			init: func() (func(), error) {
   271  				return ssoTestSetup()
   272  			},
   273  		},
   274  		{
   275  			name:                 "chained assume role with sso and static credentials",
   276  			profile:              "assume_sso_and_static",
   277  			expectedAccessKey:    "AKID",
   278  			expectedSecretKey:    "SECRET",
   279  			expectedSessionToken: "SESSION_TOKEN",
   280  			expectedChain: []string{
   281  				"assume_sso_and_static_arn",
   282  			},
   283  		},
   284  		{
   285  			name:          "invalid sso configuration",
   286  			profile:       "sso_invalid",
   287  			expectedError: fmt.Errorf("profile \"sso_invalid\" is configured to use SSO but is missing required configuration: sso_region, sso_start_url"),
   288  		},
   289  		{
   290  			name:              "environment credentials with invalid sso",
   291  			profile:           "sso_invalid",
   292  			expectedAccessKey: "access_key",
   293  			expectedSecretKey: "secret_key",
   294  			init: func() (func(), error) {
   295  				os.Setenv("AWS_ACCESS_KEY", "access_key")
   296  				os.Setenv("AWS_SECRET_KEY", "secret_key")
   297  				return func() {}, nil
   298  			},
   299  		},
   300  		{
   301  			name:                 "sso mixed with credential process provider",
   302  			profile:              "sso_mixed_credproc",
   303  			expectedAccessKey:    "SSO_AKID",
   304  			expectedSecretKey:    "SSO_SECRET_KEY",
   305  			expectedSessionToken: "SSO_SESSION_TOKEN",
   306  			init: func() (func(), error) {
   307  				return ssoTestSetup()
   308  			},
   309  		},
   310  		{
   311  			name:                 "sso mixed with web identity token provider",
   312  			profile:              "sso_mixed_webident",
   313  			expectedAccessKey:    "WEB_IDENTITY_AKID",
   314  			expectedSecretKey:    "WEB_IDENTITY_SECRET",
   315  			expectedSessionToken: "WEB_IDENTITY_SESSION_TOKEN",
   316  		},
   317  	}
   318  
   319  	for i, c := range cases {
   320  		t.Run(strconv.Itoa(i)+"_"+c.name, func(t *testing.T) {
   321  			restoreEnvFn := sdktesting.StashEnv()
   322  			defer restoreEnvFn()
   323  
   324  			if c.dependentOnOS && runtime.GOOS == "windows" {
   325  				os.Setenv("AWS_CONFIG_FILE", configFileForWindows)
   326  			} else {
   327  				os.Setenv("AWS_CONFIG_FILE", configFile)
   328  			}
   329  
   330  			os.Setenv("AWS_REGION", "us-east-1")
   331  			os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
   332  			if len(c.profile) != 0 {
   333  				os.Setenv("AWS_PROFILE", c.profile)
   334  			}
   335  
   336  			endpointResolver, cleanupFn := setupCredentialsEndpoints(t)
   337  			defer cleanupFn()
   338  
   339  			if c.init != nil {
   340  				cleanup, err := c.init()
   341  				if err != nil {
   342  					t.Fatalf("expect no error, got %v", err)
   343  				}
   344  				defer cleanup()
   345  			}
   346  
   347  			var credChain []string
   348  			handlers := defaults.Handlers()
   349  			handlers.Sign.PushBack(func(r *request.Request) {
   350  				if r.Config.Credentials == credentials.AnonymousCredentials {
   351  					return
   352  				}
   353  				params := r.Params.(*sts.AssumeRoleInput)
   354  				credChain = append(credChain, *params.RoleArn)
   355  			})
   356  
   357  			sess, err := NewSessionWithOptions(Options{
   358  				Profile: c.sessOptProfile,
   359  				Config: aws.Config{
   360  					Logger:           t,
   361  					EndpointResolver: endpointResolver,
   362  				},
   363  				Handlers:        handlers,
   364  				EC2IMDSEndpoint: c.sessOptEC2IMDSEndpoint,
   365  			})
   366  
   367  			if c.expectedError != nil {
   368  				var errStr string
   369  				if err != nil {
   370  					errStr = err.Error()
   371  				}
   372  				if e, a := c.expectedError.Error(), errStr; !strings.Contains(a, e) {
   373  					t.Fatalf("expected %v, but received %v", e, a)
   374  				}
   375  			}
   376  
   377  			if c.expectedError != nil {
   378  				return
   379  			}
   380  
   381  			creds, err := sess.Config.Credentials.Get()
   382  			if err != nil {
   383  				t.Fatalf("expected no error, but received %v", err)
   384  			}
   385  
   386  			if e, a := c.expectedChain, credChain; !reflect.DeepEqual(e, a) {
   387  				t.Errorf("expected %v, but received %v", e, a)
   388  			}
   389  
   390  			if e, a := c.expectedAccessKey, creds.AccessKeyID; e != a {
   391  				t.Errorf("expected %v, but received %v", e, a)
   392  			}
   393  
   394  			if e, a := c.expectedSecretKey, creds.SecretAccessKey; e != a {
   395  				t.Errorf("expected %v, but received %v", e, a)
   396  			}
   397  
   398  			if e, a := c.expectedSessionToken, creds.SessionToken; e != a {
   399  				t.Errorf("expected %v, but received %v", e, a)
   400  			}
   401  		})
   402  	}
   403  }
   404  
   405  const ecsResponse = `{
   406  	  "Code": "Success",
   407  	  "Type": "AWS-HMAC",
   408  	  "AccessKeyId" : "ecs-access-key",
   409  	  "SecretAccessKey" : "ecs-secret-key",
   410  	  "Token" : "token",
   411  	  "Expiration" : "2100-01-01T00:00:00Z",
   412  	  "LastUpdated" : "2009-11-23T0:00:00Z"
   413  	}`
   414  
   415  const ec2MetadataResponse = `{
   416  	  "Code": "Success",
   417  	  "Type": "AWS-HMAC",
   418  	  "AccessKeyId" : "%s",
   419  	  "SecretAccessKey" : "%s",
   420  	  "Token" : "token",
   421  	  "Expiration" : "2100-01-01T00:00:00Z",
   422  	  "LastUpdated" : "2009-11-23T0:00:00Z"
   423  	}`
   424  
   425  const assumeRoleRespMsg = `
   426  <AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
   427    <AssumeRoleResult>
   428      <AssumedRoleUser>
   429        <Arn>arn:aws:sts::account_id:assumed-role/role/session_name</Arn>
   430        <AssumedRoleId>AKID:session_name</AssumedRoleId>
   431      </AssumedRoleUser>
   432      <Credentials>
   433        <AccessKeyId>AKID</AccessKeyId>
   434        <SecretAccessKey>SECRET</SecretAccessKey>
   435        <SessionToken>SESSION_TOKEN</SessionToken>
   436        <Expiration>%s</Expiration>
   437      </Credentials>
   438    </AssumeRoleResult>
   439    <ResponseMetadata>
   440      <RequestId>request-id</RequestId>
   441    </ResponseMetadata>
   442  </AssumeRoleResponse>
   443  `
   444  
   445  var assumeRoleWithWebIdentityResponse = `<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
   446    <AssumeRoleWithWebIdentityResult>
   447      <SubjectFromWebIdentityToken>amzn1.account.AF6RHO7KZU5XRVQJGXK6HB56KR2A</SubjectFromWebIdentityToken>
   448      <Audience>client.5498841531868486423.1548@apps.example.com</Audience>
   449      <AssumedRoleUser>
   450        <Arn>arn:aws:sts::123456789012:assumed-role/FederatedWebIdentityRole/app1</Arn>
   451        <AssumedRoleId>AROACLKWSDQRAOEXAMPLE:app1</AssumedRoleId>
   452      </AssumedRoleUser>
   453      <Credentials>
   454        <AccessKeyId>WEB_IDENTITY_AKID</AccessKeyId>
   455        <SecretAccessKey>WEB_IDENTITY_SECRET</SecretAccessKey>
   456        <SessionToken>WEB_IDENTITY_SESSION_TOKEN</SessionToken>
   457        <Expiration>%s</Expiration>
   458      </Credentials>
   459      <Provider>www.amazon.com</Provider>
   460    </AssumeRoleWithWebIdentityResult>
   461    <ResponseMetadata>
   462      <RequestId>request-id</RequestId>
   463    </ResponseMetadata>
   464  </AssumeRoleWithWebIdentityResponse>
   465  `
   466  
   467  const getRoleCredentialsResponse = `{
   468    "roleCredentials": {
   469      "accessKeyId": "SSO_AKID",
   470      "secretAccessKey": "SSO_SECRET_KEY",
   471      "sessionToken": "SSO_SESSION_TOKEN",
   472      "expiration": %d
   473    }
   474  }`
   475  
   476  const ssoTokenCacheFile = `{
   477    "accessToken": "ssoAccessToken",
   478    "expiresAt": "%s"
   479  }`
   480  
   481  func TestSessionAssumeRole(t *testing.T) {
   482  	restoreEnvFn := initSessionTestEnv()
   483  	defer restoreEnvFn()
   484  
   485  	os.Setenv("AWS_REGION", "us-east-1")
   486  	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
   487  	os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
   488  	os.Setenv("AWS_PROFILE", "assume_role_w_creds")
   489  
   490  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   491  		w.Write([]byte(fmt.Sprintf(
   492  			assumeRoleRespMsg,
   493  			time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))
   494  	}))
   495  	defer server.Close()
   496  
   497  	s, err := NewSession(&aws.Config{
   498  		Endpoint:   aws.String(server.URL),
   499  		DisableSSL: aws.Bool(true),
   500  	})
   501  	if err != nil {
   502  		t.Fatalf("expect no error, got %v", err)
   503  	}
   504  
   505  	creds, err := s.Config.Credentials.Get()
   506  	if err != nil {
   507  		t.Fatalf("expect no error, got %v", err)
   508  	}
   509  	if e, a := "AKID", creds.AccessKeyID; e != a {
   510  		t.Errorf("expect %v, got %v", e, a)
   511  	}
   512  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
   513  		t.Errorf("expect %v, got %v", e, a)
   514  	}
   515  	if e, a := "SESSION_TOKEN", creds.SessionToken; e != a {
   516  		t.Errorf("expect %v, got %v", e, a)
   517  	}
   518  	if e, a := "AssumeRoleProvider", creds.ProviderName; !strings.Contains(a, e) {
   519  		t.Errorf("expect %v, to be in %v", e, a)
   520  	}
   521  }
   522  
   523  func TestSessionAssumeRole_WithMFA(t *testing.T) {
   524  	restoreEnvFn := initSessionTestEnv()
   525  	defer restoreEnvFn()
   526  
   527  	os.Setenv("AWS_REGION", "us-east-1")
   528  	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
   529  	os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
   530  	os.Setenv("AWS_PROFILE", "assume_role_w_creds")
   531  
   532  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   533  		if e, a := r.FormValue("SerialNumber"), "0123456789"; e != a {
   534  			t.Errorf("expect %v, got %v", e, a)
   535  		}
   536  		if e, a := r.FormValue("TokenCode"), "tokencode"; e != a {
   537  			t.Errorf("expect %v, got %v", e, a)
   538  		}
   539  		if e, a := "900", r.FormValue("DurationSeconds"); e != a {
   540  			t.Errorf("expect %v, got %v", e, a)
   541  		}
   542  
   543  		w.Write([]byte(fmt.Sprintf(
   544  			assumeRoleRespMsg,
   545  			time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))
   546  	}))
   547  	defer server.Close()
   548  
   549  	customProviderCalled := false
   550  	sess, err := NewSessionWithOptions(Options{
   551  		Profile: "assume_role_w_mfa",
   552  		Config: aws.Config{
   553  			Region:     aws.String("us-east-1"),
   554  			Endpoint:   aws.String(server.URL),
   555  			DisableSSL: aws.Bool(true),
   556  		},
   557  		SharedConfigState: SharedConfigEnable,
   558  		AssumeRoleTokenProvider: func() (string, error) {
   559  			customProviderCalled = true
   560  
   561  			return "tokencode", nil
   562  		},
   563  	})
   564  	if err != nil {
   565  		t.Fatalf("expect no error, got %v", err)
   566  	}
   567  
   568  	creds, err := sess.Config.Credentials.Get()
   569  	if err != nil {
   570  		t.Fatalf("expect no error, got %v", err)
   571  	}
   572  	if !customProviderCalled {
   573  		t.Errorf("expect true")
   574  	}
   575  
   576  	if e, a := "AKID", creds.AccessKeyID; e != a {
   577  		t.Errorf("expect %v, got %v", e, a)
   578  	}
   579  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
   580  		t.Errorf("expect %v, got %v", e, a)
   581  	}
   582  	if e, a := "SESSION_TOKEN", creds.SessionToken; e != a {
   583  		t.Errorf("expect %v, got %v", e, a)
   584  	}
   585  	if e, a := "AssumeRoleProvider", creds.ProviderName; !strings.Contains(a, e) {
   586  		t.Errorf("expect %v, to be in %v", e, a)
   587  	}
   588  }
   589  
   590  func TestSessionAssumeRole_WithMFA_NoTokenProvider(t *testing.T) {
   591  	restoreEnvFn := initSessionTestEnv()
   592  	defer restoreEnvFn()
   593  
   594  	os.Setenv("AWS_REGION", "us-east-1")
   595  	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
   596  	os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
   597  	os.Setenv("AWS_PROFILE", "assume_role_w_creds")
   598  
   599  	_, err := NewSessionWithOptions(Options{
   600  		Profile:           "assume_role_w_mfa",
   601  		SharedConfigState: SharedConfigEnable,
   602  	})
   603  	if e, a := (AssumeRoleTokenProviderNotSetError{}), err; e != a {
   604  		t.Errorf("expect %v, got %v", e, a)
   605  	}
   606  }
   607  
   608  func TestSessionAssumeRole_DisableSharedConfig(t *testing.T) {
   609  	// Backwards compatibility with Shared config disabled
   610  	// assume role should not be built into the config.
   611  	restoreEnvFn := initSessionTestEnv()
   612  	defer restoreEnvFn()
   613  
   614  	os.Setenv("AWS_SDK_LOAD_CONFIG", "0")
   615  	os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
   616  	os.Setenv("AWS_PROFILE", "assume_role_w_creds")
   617  
   618  	s, err := NewSession(&aws.Config{
   619  		CredentialsChainVerboseErrors: aws.Bool(true),
   620  	})
   621  	if err != nil {
   622  		t.Fatalf("expect no error, got %v", err)
   623  	}
   624  
   625  	creds, err := s.Config.Credentials.Get()
   626  	if err != nil {
   627  		t.Fatalf("expect no error, got %v", err)
   628  	}
   629  	if e, a := "assume_role_w_creds_akid", creds.AccessKeyID; e != a {
   630  		t.Errorf("expect %v, got %v", e, a)
   631  	}
   632  	if e, a := "assume_role_w_creds_secret", creds.SecretAccessKey; e != a {
   633  		t.Errorf("expect %v, got %v", e, a)
   634  	}
   635  	if e, a := "SharedConfigCredentials", creds.ProviderName; !strings.Contains(a, e) {
   636  		t.Errorf("expect %v, to be in %v", e, a)
   637  	}
   638  }
   639  
   640  func TestSessionAssumeRole_InvalidSourceProfile(t *testing.T) {
   641  	// Backwards compatibility with Shared config disabled
   642  	// assume role should not be built into the config.
   643  	restoreEnvFn := initSessionTestEnv()
   644  	defer restoreEnvFn()
   645  
   646  	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
   647  	os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
   648  	os.Setenv("AWS_PROFILE", "assume_role_invalid_source_profile")
   649  
   650  	s, err := NewSession()
   651  	if err == nil {
   652  		t.Fatalf("expect error, got none")
   653  	}
   654  
   655  	expectMsg := "SharedConfigAssumeRoleError: failed to load assume role"
   656  	if e, a := expectMsg, err.Error(); !strings.Contains(a, e) {
   657  		t.Errorf("expect %v, to be in %v", e, a)
   658  	}
   659  	if s != nil {
   660  		t.Errorf("expect nil, %v", err)
   661  	}
   662  }
   663  
   664  func TestSessionAssumeRole_ExtendedDuration(t *testing.T) {
   665  	restoreEnvFn := initSessionTestEnv()
   666  	defer restoreEnvFn()
   667  
   668  	cases := []struct {
   669  		profile          string
   670  		optionDuration   time.Duration
   671  		expectedDuration string
   672  	}{
   673  		{
   674  			profile:          "assume_role_w_creds",
   675  			expectedDuration: "900",
   676  		},
   677  		{
   678  			profile:          "assume_role_w_creds",
   679  			optionDuration:   30 * time.Minute,
   680  			expectedDuration: "1800",
   681  		},
   682  		{
   683  			profile:          "assume_role_w_creds_w_duration",
   684  			expectedDuration: "1800",
   685  		},
   686  		{
   687  			profile:          "assume_role_w_creds_w_invalid_duration",
   688  			expectedDuration: "900",
   689  		},
   690  	}
   691  
   692  	for _, tt := range cases {
   693  		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   694  			if e, a := tt.expectedDuration, r.FormValue("DurationSeconds"); e != a {
   695  				t.Errorf("expect %v, got %v", e, a)
   696  			}
   697  
   698  			w.Write([]byte(fmt.Sprintf(
   699  				assumeRoleRespMsg,
   700  				time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))
   701  		}))
   702  
   703  		os.Setenv("AWS_REGION", "us-east-1")
   704  		os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
   705  		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
   706  		os.Setenv("AWS_PROFILE", "assume_role_w_creds")
   707  
   708  		opts := Options{
   709  			Profile: tt.profile,
   710  			Config: aws.Config{
   711  				Endpoint:   aws.String(server.URL),
   712  				DisableSSL: aws.Bool(true),
   713  			},
   714  			SharedConfigState: SharedConfigEnable,
   715  		}
   716  		if tt.optionDuration != 0 {
   717  			opts.AssumeRoleDuration = tt.optionDuration
   718  		}
   719  
   720  		s, err := NewSessionWithOptions(opts)
   721  		if err != nil {
   722  			server.Close()
   723  			t.Fatalf("expect no error, got %v", err)
   724  		}
   725  
   726  		creds, err := s.Config.Credentials.Get()
   727  		if err != nil {
   728  			server.Close()
   729  			t.Fatalf("expect no error, got %v", err)
   730  		}
   731  
   732  		if e, a := "AKID", creds.AccessKeyID; e != a {
   733  			t.Errorf("expect %v, got %v", e, a)
   734  		}
   735  		if e, a := "SECRET", creds.SecretAccessKey; e != a {
   736  			t.Errorf("expect %v, got %v", e, a)
   737  		}
   738  		if e, a := "SESSION_TOKEN", creds.SessionToken; e != a {
   739  			t.Errorf("expect %v, got %v", e, a)
   740  		}
   741  		if e, a := "AssumeRoleProvider", creds.ProviderName; !strings.Contains(a, e) {
   742  			t.Errorf("expect %v, to be in %v", e, a)
   743  		}
   744  
   745  		server.Close()
   746  	}
   747  }
   748  
   749  func TestSessionAssumeRole_WithMFA_ExtendedDuration(t *testing.T) {
   750  	restoreEnvFn := initSessionTestEnv()
   751  	defer restoreEnvFn()
   752  
   753  	os.Setenv("AWS_REGION", "us-east-1")
   754  	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
   755  	os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename)
   756  	os.Setenv("AWS_PROFILE", "assume_role_w_creds")
   757  
   758  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   759  		if e, a := "0123456789", r.FormValue("SerialNumber"); e != a {
   760  			t.Errorf("expect %v, got %v", e, a)
   761  		}
   762  		if e, a := "tokencode", r.FormValue("TokenCode"); e != a {
   763  			t.Errorf("expect %v, got %v", e, a)
   764  		}
   765  		if e, a := "1800", r.FormValue("DurationSeconds"); e != a {
   766  			t.Errorf("expect %v, got %v", e, a)
   767  		}
   768  
   769  		w.Write([]byte(fmt.Sprintf(
   770  			assumeRoleRespMsg,
   771  			time.Now().Add(30*time.Minute).Format("2006-01-02T15:04:05Z"))))
   772  	}))
   773  	defer server.Close()
   774  
   775  	customProviderCalled := false
   776  	sess, err := NewSessionWithOptions(Options{
   777  		Profile: "assume_role_w_mfa",
   778  		Config: aws.Config{
   779  			Region:     aws.String("us-east-1"),
   780  			Endpoint:   aws.String(server.URL),
   781  			DisableSSL: aws.Bool(true),
   782  		},
   783  		SharedConfigState:  SharedConfigEnable,
   784  		AssumeRoleDuration: 30 * time.Minute,
   785  		AssumeRoleTokenProvider: func() (string, error) {
   786  			customProviderCalled = true
   787  
   788  			return "tokencode", nil
   789  		},
   790  	})
   791  	if err != nil {
   792  		t.Fatalf("expect no error, got %v", err)
   793  	}
   794  
   795  	creds, err := sess.Config.Credentials.Get()
   796  	if err != nil {
   797  		t.Fatalf("expect no error, got %v", err)
   798  	}
   799  	if !customProviderCalled {
   800  		t.Errorf("expect true")
   801  	}
   802  
   803  	if e, a := "AKID", creds.AccessKeyID; e != a {
   804  		t.Errorf("expect %v, got %v", e, a)
   805  	}
   806  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
   807  		t.Errorf("expect %v, got %v", e, a)
   808  	}
   809  	if e, a := "SESSION_TOKEN", creds.SessionToken; e != a {
   810  		t.Errorf("expect %v, got %v", e, a)
   811  	}
   812  	if e, a := "AssumeRoleProvider", creds.ProviderName; !strings.Contains(a, e) {
   813  		t.Errorf("expect %v, to be in %v", e, a)
   814  	}
   815  }
   816  
   817  func ssoTestSetup() (func(), error) {
   818  	dir, err := ioutil.TempDir("", "sso-test")
   819  	if err != nil {
   820  		return nil, err
   821  	}
   822  
   823  	cacheDir := filepath.Join(dir, ".aws", "sso", "cache")
   824  	err = os.MkdirAll(cacheDir, 0750)
   825  	if err != nil {
   826  		os.RemoveAll(dir)
   827  		return nil, err
   828  	}
   829  
   830  	tokenFile, err := os.Create(filepath.Join(cacheDir, "eb5e43e71ce87dd92ec58903d76debd8ee42aefd.json"))
   831  	if err != nil {
   832  		os.RemoveAll(dir)
   833  		return nil, err
   834  	}
   835  	defer tokenFile.Close()
   836  
   837  	_, err = tokenFile.WriteString(fmt.Sprintf(ssoTokenCacheFile, time.Now().
   838  		Add(15*time.Minute).
   839  		Format(time.RFC3339)))
   840  	if err != nil {
   841  		os.RemoveAll(dir)
   842  		return nil, err
   843  	}
   844  
   845  	if runtime.GOOS == "windows" {
   846  		os.Setenv("USERPROFILE", dir)
   847  	} else {
   848  		os.Setenv("HOME", dir)
   849  	}
   850  
   851  	return func() {
   852  		os.RemoveAll(dir)
   853  	}, nil
   854  }