github.com/Cloud-Foundations/Dominator@v0.3.4/lib/awsutil/credentials.go (about)

     1  package awsutil
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"sort"
     7  
     8  	"github.com/aws/aws-sdk-go/aws"
     9  	"github.com/aws/aws-sdk-go/aws/arn"
    10  	"github.com/aws/aws-sdk-go/aws/session"
    11  	"github.com/aws/aws-sdk-go/service/sts"
    12  	"path"
    13  )
    14  
    15  type sessionResult struct {
    16  	accountName string
    17  	accountId   string
    18  	awsSession  *session.Session
    19  	regions     []string
    20  	err         error
    21  }
    22  
    23  func getCredentialsPath() string {
    24  	return getAwsPath("AWS_CREDENTIAL_FILE", "credentials")
    25  }
    26  
    27  func getConfigPath() string {
    28  	return getAwsPath("AWS_CONFIG_FILE", "config")
    29  }
    30  
    31  func getAwsPath(environ, fileName string) string {
    32  	value := os.Getenv(environ)
    33  	if value != "" {
    34  		return value
    35  	}
    36  	home := os.Getenv("HOME")
    37  	return path.Join(home, ".aws", fileName)
    38  }
    39  
    40  func (c *CredentialsOptions) setDefaults() *CredentialsOptions {
    41  	if c.CredentialsPath == "" {
    42  		c.CredentialsPath = *awsCredentialsFile
    43  	}
    44  	if c.ConfigPath == "" {
    45  		c.ConfigPath = *awsConfigFile
    46  	}
    47  	return c
    48  }
    49  
    50  func tryLoadCredentialsWithOptions(
    51  	options *CredentialsOptions) (*CredentialsStore, map[string]error, error) {
    52  	accountNames, err := listAccountNames(options)
    53  	if err != nil {
    54  		return nil, nil, err
    55  	}
    56  	cs, unloadableAccounts := createCredentials(accountNames, options)
    57  	return cs, unloadableAccounts, nil
    58  }
    59  
    60  func loadCredentials() (*CredentialsStore, error) {
    61  	var options CredentialsOptions
    62  	cs, unloadableAccounts, err := tryLoadCredentialsWithOptions(
    63  		options.setDefaults())
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	for _, err := range unloadableAccounts {
    68  		return nil, err
    69  	}
    70  	return cs, nil
    71  }
    72  
    73  func createCredentials(
    74  	accountNames []string, options *CredentialsOptions) (
    75  	*CredentialsStore, map[string]error) {
    76  	cs := &CredentialsStore{
    77  		sessionMap:      make(map[string]*session.Session),
    78  		accountIdToName: make(map[string]string),
    79  		accountNameToId: make(map[string]string),
    80  		accountRegions:  make(map[string][]string),
    81  	}
    82  	resultsChannel := make(chan sessionResult, len(accountNames))
    83  	for _, accountName := range accountNames {
    84  		go func(accountName string) {
    85  			resultsChannel <- createSession(accountName, options)
    86  		}(accountName)
    87  	}
    88  	unloadableAccounts := make(map[string]error)
    89  	for range accountNames {
    90  		result := <-resultsChannel
    91  		if result.err != nil {
    92  			unloadableAccounts[result.accountName] = result.err
    93  		} else {
    94  			cs.accountNames = append(cs.accountNames, result.accountName)
    95  			cs.sessionMap[result.accountName] = result.awsSession
    96  			cs.accountIdToName[result.accountId] = result.accountName
    97  			cs.accountNameToId[result.accountName] = result.accountId
    98  			cs.accountRegions[result.accountName] = result.regions
    99  		}
   100  	}
   101  	close(resultsChannel)
   102  	sort.Strings(cs.accountNames)
   103  	return cs, unloadableAccounts
   104  }
   105  
   106  func createSession(
   107  	accountName string, options *CredentialsOptions) sessionResult {
   108  	awsSession, err := session.NewSessionWithOptions(session.Options{
   109  		Profile:           accountName,
   110  		SharedConfigState: session.SharedConfigEnable,
   111  		SharedConfigFiles: []string{
   112  			options.CredentialsPath,
   113  			options.ConfigPath,
   114  		},
   115  	})
   116  	if err != nil {
   117  		return sessionResult{
   118  			err:         fmt.Errorf("session.NewSessionWithOptions: %s", err),
   119  			accountName: accountName,
   120  		}
   121  	}
   122  	stsService := sts.New(awsSession)
   123  	inp := &sts.GetCallerIdentityInput{}
   124  	var accountId string
   125  	if out, err := stsService.GetCallerIdentity(inp); err != nil {
   126  		return sessionResult{
   127  			err:         fmt.Errorf("sts.GetCallerIdentity: %s", err),
   128  			accountName: accountName,
   129  		}
   130  	} else {
   131  		if arnV, err := arn.Parse(aws.StringValue(out.Arn)); err != nil {
   132  			return sessionResult{err: err, accountName: accountName}
   133  		} else {
   134  			accountId = arnV.AccountID
   135  		}
   136  	}
   137  	regions, err := listRegions(CreateService(awsSession, "us-east-1"))
   138  	if err != nil {
   139  		// Try the ec2::DescribeRegions call in other regions before giving
   140  		// up and reporting the error. We may need to add to this list.
   141  		otherRegions := []string{"cn-north-1"}
   142  		for _, otherRegion := range otherRegions {
   143  			regions, err := listRegions(CreateService(awsSession, otherRegion))
   144  			if err == nil {
   145  				return sessionResult{
   146  					accountName: accountName,
   147  					accountId:   accountId,
   148  					awsSession:  awsSession,
   149  					regions:     regions,
   150  				}
   151  			}
   152  		}
   153  
   154  		// If no success with other regions return the original error
   155  		return sessionResult{err: err, accountName: accountName}
   156  	}
   157  	return sessionResult{
   158  		accountName: accountName,
   159  		accountId:   accountId,
   160  		awsSession:  awsSession,
   161  		regions:     regions,
   162  	}
   163  }