github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/processcreds/provider_test.go (about)

     1  package processcreds_test
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"os"
    10  	"os/exec"
    11  	"runtime"
    12  	"strings"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/aavshr/aws-sdk-go/aws"
    17  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    18  	"github.com/aavshr/aws-sdk-go/aws/credentials/processcreds"
    19  	"github.com/aavshr/aws-sdk-go/aws/session"
    20  	"github.com/aavshr/aws-sdk-go/internal/sdktesting"
    21  )
    22  
    23  func TestProcessProviderFromSessionCfg(t *testing.T) {
    24  	restoreEnvFn := sdktesting.StashEnv()
    25  	defer restoreEnvFn()
    26  
    27  	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
    28  	if runtime.GOOS == "windows" {
    29  		os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
    30  	} else {
    31  		os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
    32  	}
    33  
    34  	sess, err := session.NewSession(&aws.Config{
    35  		Region: aws.String("region")},
    36  	)
    37  
    38  	if err != nil {
    39  		t.Errorf("error getting session: %v", err)
    40  	}
    41  
    42  	creds, err := sess.Config.Credentials.Get()
    43  	if err != nil {
    44  		t.Errorf("error getting credentials: %v", err)
    45  	}
    46  
    47  	if e, a := "accessKey", creds.AccessKeyID; e != a {
    48  		t.Errorf("expected %v, got %v", e, a)
    49  	}
    50  
    51  	if e, a := "secret", creds.SecretAccessKey; e != a {
    52  		t.Errorf("expected %v, got %v", e, a)
    53  	}
    54  
    55  	if e, a := "tokenDefault", creds.SessionToken; e != a {
    56  		t.Errorf("expected %v, got %v", e, a)
    57  	}
    58  
    59  }
    60  
    61  func TestProcessProviderFromSessionWithProfileCfg(t *testing.T) {
    62  	restoreEnvFn := sdktesting.StashEnv()
    63  	defer restoreEnvFn()
    64  
    65  	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
    66  	os.Setenv("AWS_PROFILE", "non_expire")
    67  	if runtime.GOOS == "windows" {
    68  		os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
    69  	} else {
    70  		os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
    71  	}
    72  
    73  	sess, err := session.NewSession(&aws.Config{
    74  		Region: aws.String("region")},
    75  	)
    76  
    77  	if err != nil {
    78  		t.Errorf("error getting session: %v", err)
    79  	}
    80  
    81  	creds, err := sess.Config.Credentials.Get()
    82  	if err != nil {
    83  		t.Errorf("error getting credentials: %v", err)
    84  	}
    85  
    86  	if e, a := "nonDefaultToken", creds.SessionToken; e != a {
    87  		t.Errorf("expected %v, got %v", e, a)
    88  	}
    89  
    90  }
    91  
    92  func TestProcessProviderNotFromCredProcCfg(t *testing.T) {
    93  	restoreEnvFn := sdktesting.StashEnv()
    94  	defer restoreEnvFn()
    95  
    96  	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
    97  	os.Setenv("AWS_PROFILE", "not_alone")
    98  	if runtime.GOOS == "windows" {
    99  		os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
   100  	} else {
   101  		os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
   102  	}
   103  
   104  	sess, err := session.NewSession(&aws.Config{
   105  		Region: aws.String("region")},
   106  	)
   107  
   108  	if err != nil {
   109  		t.Errorf("error getting session: %v", err)
   110  	}
   111  
   112  	creds, err := sess.Config.Credentials.Get()
   113  	if err != nil {
   114  		t.Errorf("error getting credentials: %v", err)
   115  	}
   116  
   117  	if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a {
   118  		t.Errorf("expected %v, got %v", e, a)
   119  	}
   120  
   121  	if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a {
   122  		t.Errorf("expected %v, got %v", e, a)
   123  	}
   124  
   125  }
   126  
   127  func TestProcessProviderFromSessionCrd(t *testing.T) {
   128  	restoreEnvFn := sdktesting.StashEnv()
   129  	defer restoreEnvFn()
   130  
   131  	if runtime.GOOS == "windows" {
   132  		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
   133  	} else {
   134  		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
   135  	}
   136  
   137  	sess, err := session.NewSession(&aws.Config{
   138  		Region: aws.String("region")},
   139  	)
   140  
   141  	if err != nil {
   142  		t.Errorf("error getting session: %v", err)
   143  	}
   144  
   145  	creds, err := sess.Config.Credentials.Get()
   146  	if err != nil {
   147  		t.Errorf("error getting credentials: %v", err)
   148  	}
   149  
   150  	if e, a := "accessKey", creds.AccessKeyID; e != a {
   151  		t.Errorf("expected %v, got %v", e, a)
   152  	}
   153  
   154  	if e, a := "secret", creds.SecretAccessKey; e != a {
   155  		t.Errorf("expected %v, got %v", e, a)
   156  	}
   157  
   158  	if e, a := "tokenDefault", creds.SessionToken; e != a {
   159  		t.Errorf("expected %v, got %v", e, a)
   160  	}
   161  
   162  }
   163  
   164  func TestProcessProviderFromSessionWithProfileCrd(t *testing.T) {
   165  	restoreEnvFn := sdktesting.StashEnv()
   166  	defer restoreEnvFn()
   167  
   168  	os.Setenv("AWS_PROFILE", "non_expire")
   169  	if runtime.GOOS == "windows" {
   170  		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
   171  	} else {
   172  		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
   173  	}
   174  
   175  	sess, err := session.NewSession(&aws.Config{
   176  		Region: aws.String("region")},
   177  	)
   178  
   179  	if err != nil {
   180  		t.Errorf("error getting session: %v", err)
   181  	}
   182  
   183  	creds, err := sess.Config.Credentials.Get()
   184  	if err != nil {
   185  		t.Errorf("error getting credentials: %v", err)
   186  	}
   187  
   188  	if e, a := "nonDefaultToken", creds.SessionToken; e != a {
   189  		t.Errorf("expected %v, got %v", e, a)
   190  	}
   191  
   192  }
   193  
   194  func TestProcessProviderNotFromCredProcCrd(t *testing.T) {
   195  	restoreEnvFn := sdktesting.StashEnv()
   196  	defer restoreEnvFn()
   197  
   198  	os.Setenv("AWS_PROFILE", "not_alone")
   199  	if runtime.GOOS == "windows" {
   200  		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
   201  	} else {
   202  		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
   203  	}
   204  
   205  	sess, err := session.NewSession(&aws.Config{
   206  		Region: aws.String("region")},
   207  	)
   208  
   209  	if err != nil {
   210  		t.Errorf("error getting session: %v", err)
   211  	}
   212  
   213  	creds, err := sess.Config.Credentials.Get()
   214  	if err != nil {
   215  		t.Errorf("error getting credentials: %v", err)
   216  	}
   217  
   218  	if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a {
   219  		t.Errorf("expected %v, got %v", e, a)
   220  	}
   221  
   222  	if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a {
   223  		t.Errorf("expected %v, got %v", e, a)
   224  	}
   225  
   226  }
   227  
   228  func TestProcessProviderBadCommand(t *testing.T) {
   229  	restoreEnvFn := sdktesting.StashEnv()
   230  	defer restoreEnvFn()
   231  
   232  	creds := processcreds.NewCredentials("/bad/process")
   233  	_, err := creds.Get()
   234  	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution {
   235  		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
   236  	}
   237  }
   238  
   239  func TestProcessProviderMoreEmptyCommands(t *testing.T) {
   240  	restoreEnvFn := sdktesting.StashEnv()
   241  	defer restoreEnvFn()
   242  
   243  	creds := processcreds.NewCredentials("")
   244  	_, err := creds.Get()
   245  	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution {
   246  		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
   247  	}
   248  
   249  }
   250  
   251  func TestProcessProviderExpectErrors(t *testing.T) {
   252  	restoreEnvFn := sdktesting.StashEnv()
   253  	defer restoreEnvFn()
   254  
   255  	creds := processcreds.NewCredentials(
   256  		fmt.Sprintf(
   257  			"%s %s",
   258  			getOSCat(),
   259  			strings.Join(
   260  				[]string{"testdata", "malformed.json"},
   261  				string(os.PathSeparator))))
   262  	_, err := creds.Get()
   263  	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderParse {
   264  		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderParse, err)
   265  	}
   266  
   267  	creds = processcreds.NewCredentials(
   268  		fmt.Sprintf("%s %s",
   269  			getOSCat(),
   270  			strings.Join(
   271  				[]string{"testdata", "wrongversion.json"},
   272  				string(os.PathSeparator))))
   273  	_, err = creds.Get()
   274  	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderVersion {
   275  		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderVersion, err)
   276  	}
   277  
   278  	creds = processcreds.NewCredentials(
   279  		fmt.Sprintf(
   280  			"%s %s",
   281  			getOSCat(),
   282  			strings.Join(
   283  				[]string{"testdata", "missingkey.json"},
   284  				string(os.PathSeparator))))
   285  	_, err = creds.Get()
   286  	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired {
   287  		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err)
   288  	}
   289  
   290  	creds = processcreds.NewCredentials(
   291  		fmt.Sprintf(
   292  			"%s %s",
   293  			getOSCat(),
   294  			strings.Join(
   295  				[]string{"testdata", "missingsecret.json"},
   296  				string(os.PathSeparator))))
   297  	_, err = creds.Get()
   298  	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired {
   299  		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err)
   300  	}
   301  
   302  }
   303  
   304  func TestProcessProviderTimeout(t *testing.T) {
   305  	restoreEnvFn := sdktesting.StashEnv()
   306  	defer restoreEnvFn()
   307  
   308  	command := "/bin/sleep 2"
   309  	if runtime.GOOS == "windows" {
   310  		// "timeout" command does not work due to pipe redirection
   311  		command = "ping -n 2 127.0.0.1>nul"
   312  	}
   313  
   314  	creds := processcreds.NewCredentialsTimeout(
   315  		command,
   316  		time.Duration(1)*time.Second)
   317  	if _, err := creds.Get(); err == nil || err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution || err.(awserr.Error).Message() != "credential process timed out" {
   318  		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
   319  	}
   320  
   321  }
   322  
   323  func TestProcessProviderWithLongSessionToken(t *testing.T) {
   324  	restoreEnvFn := sdktesting.StashEnv()
   325  	defer restoreEnvFn()
   326  
   327  	creds := processcreds.NewCredentials(
   328  		fmt.Sprintf(
   329  			"%s %s",
   330  			getOSCat(),
   331  			strings.Join(
   332  				[]string{"testdata", "longsessiontoken.json"},
   333  				string(os.PathSeparator))))
   334  	v, err := creds.Get()
   335  	if err != nil {
   336  		t.Errorf("expected %v, got %v", "no error", err)
   337  	}
   338  
   339  	// Text string same length as session token returned by AWS for AssumeRoleWithWebIdentity
   340  	e := "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
   341  	if a := v.SessionToken; e != a {
   342  		t.Errorf("expected %v, got %v", e, a)
   343  	}
   344  }
   345  
   346  type credentialTest struct {
   347  	Version         int
   348  	AccessKeyID     string `json:"AccessKeyId"`
   349  	SecretAccessKey string
   350  	Expiration      string
   351  }
   352  
   353  func TestProcessProviderStatic(t *testing.T) {
   354  	restoreEnvFn := sdktesting.StashEnv()
   355  	defer restoreEnvFn()
   356  
   357  	// static
   358  	creds := processcreds.NewCredentials(
   359  		fmt.Sprintf(
   360  			"%s %s",
   361  			getOSCat(),
   362  			strings.Join(
   363  				[]string{"testdata", "static.json"},
   364  				string(os.PathSeparator))))
   365  	_, err := creds.Get()
   366  	if err != nil {
   367  		t.Errorf("expected %v, got %v", "no error", err)
   368  	}
   369  	if creds.IsExpired() {
   370  		t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
   371  	}
   372  
   373  }
   374  
   375  func TestProcessProviderNotExpired(t *testing.T) {
   376  	restoreEnvFn := sdktesting.StashEnv()
   377  	defer restoreEnvFn()
   378  
   379  	// non-static, not expired
   380  	exp := &credentialTest{}
   381  	exp.Version = 1
   382  	exp.AccessKeyID = "accesskey"
   383  	exp.SecretAccessKey = "secretkey"
   384  	exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
   385  	b, err := json.Marshal(exp)
   386  	if err != nil {
   387  		t.Errorf("expected %v, got %v", "no error", err)
   388  	}
   389  
   390  	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expiring")
   391  	if err != nil {
   392  		t.Errorf("expected %v, got %v", "no error", err)
   393  	}
   394  	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
   395  		t.Errorf("expected %v, got %v", "no error", err)
   396  	}
   397  	defer func() {
   398  		if err = tmpFile.Close(); err != nil {
   399  			t.Errorf("expected %v, got %v", "no error", err)
   400  		}
   401  		if err = os.Remove(tmpFile.Name()); err != nil {
   402  			t.Errorf("expected %v, got %v", "no error", err)
   403  		}
   404  	}()
   405  	creds := processcreds.NewCredentials(
   406  		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
   407  	_, err = creds.Get()
   408  	if err != nil {
   409  		t.Errorf("expected %v, got %v", "no error", err)
   410  	}
   411  	if creds.IsExpired() {
   412  		t.Errorf("expected %v, got %v", "not expired", "expired")
   413  	}
   414  }
   415  
   416  func TestProcessProviderExpired(t *testing.T) {
   417  	restoreEnvFn := sdktesting.StashEnv()
   418  	defer restoreEnvFn()
   419  
   420  	// non-static, expired
   421  	exp := &credentialTest{}
   422  	exp.Version = 1
   423  	exp.AccessKeyID = "accesskey"
   424  	exp.SecretAccessKey = "secretkey"
   425  	exp.Expiration = time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339)
   426  	b, err := json.Marshal(exp)
   427  	if err != nil {
   428  		t.Errorf("expected %v, got %v", "no error", err)
   429  	}
   430  
   431  	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expired")
   432  	if err != nil {
   433  		t.Errorf("expected %v, got %v", "no error", err)
   434  	}
   435  	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
   436  		t.Errorf("expected %v, got %v", "no error", err)
   437  	}
   438  	defer func() {
   439  		if err = tmpFile.Close(); err != nil {
   440  			t.Errorf("expected %v, got %v", "no error", err)
   441  		}
   442  		if err = os.Remove(tmpFile.Name()); err != nil {
   443  			t.Errorf("expected %v, got %v", "no error", err)
   444  		}
   445  	}()
   446  	creds := processcreds.NewCredentials(
   447  		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
   448  	_, err = creds.Get()
   449  	if err != nil {
   450  		t.Errorf("expected %v, got %v", "no error", err)
   451  	}
   452  	if !creds.IsExpired() {
   453  		t.Errorf("expected %v, got %v", "expired", "not expired")
   454  	}
   455  }
   456  
   457  func TestProcessProviderForceExpire(t *testing.T) {
   458  	restoreEnvFn := sdktesting.StashEnv()
   459  	defer restoreEnvFn()
   460  
   461  	// non-static, not expired
   462  
   463  	// setup test credentials file
   464  	exp := &credentialTest{}
   465  	exp.Version = 1
   466  	exp.AccessKeyID = "accesskey"
   467  	exp.SecretAccessKey = "secretkey"
   468  	exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
   469  	b, err := json.Marshal(exp)
   470  	if err != nil {
   471  		t.Errorf("expected %v, got %v", "no error", err)
   472  	}
   473  	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_force_expire")
   474  	if err != nil {
   475  		t.Errorf("expected %v, got %v", "no error", err)
   476  	}
   477  	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
   478  		t.Errorf("expected %v, got %v", "no error", err)
   479  	}
   480  	defer func() {
   481  		if err = tmpFile.Close(); err != nil {
   482  			t.Errorf("expected %v, got %v", "no error", err)
   483  		}
   484  		if err = os.Remove(tmpFile.Name()); err != nil {
   485  			t.Errorf("expected %v, got %v", "no error", err)
   486  		}
   487  	}()
   488  
   489  	// get credentials from file
   490  	creds := processcreds.NewCredentials(
   491  		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
   492  	if _, err = creds.Get(); err != nil {
   493  		t.Errorf("expected %v, got %v", "no error", err)
   494  	}
   495  	if creds.IsExpired() {
   496  		t.Errorf("expected %v, got %v", "not expired", "expired")
   497  	}
   498  
   499  	// force expire creds
   500  	creds.Expire()
   501  	if !creds.IsExpired() {
   502  		t.Errorf("expected %v, got %v", "expired", "not expired")
   503  	}
   504  
   505  	// renew creds
   506  	if _, err = creds.Get(); err != nil {
   507  		t.Errorf("expected %v, got %v", "no error", err)
   508  	}
   509  	if creds.IsExpired() {
   510  		t.Errorf("expected %v, got %v", "not expired", "expired")
   511  	}
   512  
   513  }
   514  
   515  func TestProcessProviderAltConstruct(t *testing.T) {
   516  	restoreEnvFn := sdktesting.StashEnv()
   517  	defer restoreEnvFn()
   518  
   519  	// constructing with exec.Cmd instead of string
   520  	myCommand := exec.Command(
   521  		fmt.Sprintf(
   522  			"%s %s",
   523  			getOSCat(),
   524  			strings.Join(
   525  				[]string{"testdata", "static.json"},
   526  				string(os.PathSeparator))))
   527  	creds := processcreds.NewCredentialsCommand(myCommand, func(opt *processcreds.ProcessProvider) {
   528  		opt.Timeout = time.Duration(1) * time.Second
   529  	})
   530  	_, err := creds.Get()
   531  	if err != nil {
   532  		t.Errorf("expected %v, got %v", "no error", err)
   533  	}
   534  	if creds.IsExpired() {
   535  		t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
   536  	}
   537  }
   538  
   539  func BenchmarkProcessProvider(b *testing.B) {
   540  	restoreEnvFn := sdktesting.StashEnv()
   541  	defer restoreEnvFn()
   542  
   543  	creds := processcreds.NewCredentials(
   544  		fmt.Sprintf(
   545  			"%s %s",
   546  			getOSCat(),
   547  			strings.Join(
   548  				[]string{"testdata", "static.json"},
   549  				string(os.PathSeparator))))
   550  	_, err := creds.Get()
   551  	if err != nil {
   552  		b.Fatal(err)
   553  	}
   554  
   555  	b.ResetTimer()
   556  	for i := 0; i < b.N; i++ {
   557  		_, err := creds.Get()
   558  		if err != nil {
   559  			b.Fatal(err)
   560  		}
   561  	}
   562  }
   563  
   564  func getOSCat() string {
   565  	if runtime.GOOS == "windows" {
   566  		return "type"
   567  	}
   568  	return "cat"
   569  }