github.com/Cloud-Foundations/Dominator@v0.3.4/lib/awsutil/iterators.go (about)

     1  package awsutil
     2  
     3  import (
     4  	"errors"
     5  
     6  	"github.com/Cloud-Foundations/Dominator/lib/log"
     7  	"github.com/Cloud-Foundations/Dominator/lib/log/prefixlogger"
     8  	"github.com/aws/aws-sdk-go/aws/session"
     9  	"github.com/aws/aws-sdk-go/service/ec2"
    10  )
    11  
    12  type resultsType struct {
    13  	numTargets int
    14  	err        error
    15  }
    16  
    17  func forEachTarget(targets TargetList, skipList TargetList,
    18  	targetFunc func(awsService *ec2.EC2, accountName, regionName string,
    19  		logger log.Logger),
    20  	logger log.Logger) (int, error) {
    21  	cs, err := LoadCredentials()
    22  	if err != nil {
    23  		return 0, err
    24  	}
    25  	return cs.ForEachEC2Target(targets, skipList, targetFunc, false, logger)
    26  }
    27  
    28  func (cs *CredentialsStore) forEachEC2Target(targets TargetList,
    29  	skipList TargetList,
    30  	targetFunc func(awsService *ec2.EC2, accountName, regionName string,
    31  		logger log.Logger),
    32  	wait bool, logger log.Logger) (int, error) {
    33  	return cs.ForEachTarget(targets, skipList,
    34  		func(awsSession *session.Session, accountName, regionName string,
    35  			logger log.Logger) {
    36  			targetFunc(cs.GetEC2Service(accountName, regionName), accountName,
    37  				regionName, logger)
    38  		},
    39  		wait, logger)
    40  }
    41  
    42  func (cs *CredentialsStore) forEachTarget(targets TargetList,
    43  	skipList TargetList,
    44  	targetFunc func(awsSession *session.Session, accountName, regionName string,
    45  		logger log.Logger),
    46  	wait bool, logger log.Logger) (int, error) {
    47  	if len(targets) < 1 { // Full wildcard.
    48  		targets = make(TargetList, 1)
    49  	}
    50  	accountMap := make(map[string][]string) // Key: accountName, value: regions.
    51  	skipTargets := make(map[Target]struct{})
    52  	for _, target := range skipList {
    53  		if target.AccountName != "" || target.Region != "" {
    54  			skipTargets[Target{target.AccountName, target.Region}] = struct{}{}
    55  		}
    56  	}
    57  	// Expand any wildcard account names.
    58  	for _, target := range targets {
    59  		if target.AccountName == "" {
    60  			for _, accountName := range cs.ListAccountsWithCredentials() {
    61  				regions := accountMap[accountName]
    62  				regions = append(regions, target.Region)
    63  				accountMap[accountName] = regions
    64  			}
    65  		} else {
    66  			regions := accountMap[target.AccountName]
    67  			regions = append(regions, target.Region)
    68  			accountMap[target.AccountName] = regions
    69  		}
    70  	}
    71  	// Verify we have credentials for all accounts.
    72  	for accountName := range accountMap {
    73  		if cs.GetSessionForAccount(accountName) == nil {
    74  			return 0, errors.New("no session for account: " + accountName)
    75  		}
    76  	}
    77  	accountResultsChannel := make(chan resultsType, 1)
    78  	var numTargets int
    79  	var waitChannel chan struct{}
    80  	if wait {
    81  		waitChannel = make(chan struct{}, 1)
    82  	}
    83  	for accountName, regions := range accountMap {
    84  		// Remove duplicate/redundant regions.
    85  		regionMap := make(map[string]struct{})
    86  		for _, region := range regions {
    87  			if region == "" {
    88  				regionMap = nil
    89  				break
    90  			}
    91  			regionMap[region] = struct{}{}
    92  		}
    93  		regionList := make([]string, 0, len(regionMap))
    94  		for region := range regionMap {
    95  			regionList = append(regionList, region)
    96  		}
    97  		awsSession := cs.GetSessionForAccount(accountName)
    98  		go cs.forEachRegionInAccount(awsSession, accountName, regionList,
    99  			accountResultsChannel, skipTargets, targetFunc, waitChannel, logger)
   100  	}
   101  	var firstError error
   102  	// Collect account results.
   103  	for range accountMap {
   104  		result := <-accountResultsChannel
   105  		if result.err != nil && firstError == nil {
   106  			firstError = result.err
   107  		}
   108  		numTargets += result.numTargets
   109  	}
   110  	if waitChannel != nil {
   111  		for count := 0; count < numTargets; count++ {
   112  			<-waitChannel
   113  		}
   114  	}
   115  	return numTargets, firstError
   116  }
   117  
   118  func (cs *CredentialsStore) forEachRegionInAccount(awsSession *session.Session,
   119  	accountName string, regions []string,
   120  	resultsChannel chan<- resultsType, skipTargets map[Target]struct{},
   121  	targetFunc func(*session.Session, string, string, log.Logger),
   122  	waitChannel chan<- struct{}, logger log.Logger) {
   123  	if len(regions) < 1 {
   124  		regions = cs.ListRegionsForAccount(accountName)
   125  	}
   126  	if _, ok := skipTargets[Target{accountName, ""}]; ok {
   127  		logger.Println(accountName + ": skipping account")
   128  		resultsChannel <- resultsType{0, nil}
   129  		return
   130  	}
   131  	// Start goroutine for each target ((account,region) tuple).
   132  	numRegions := 0
   133  	for _, region := range regions {
   134  		logger := prefixlogger.New(accountName+": "+region+": ", logger)
   135  		if _, ok := skipTargets[Target{accountName, region}]; ok {
   136  			logger.Println("skipping target")
   137  			continue
   138  		}
   139  		if _, ok := skipTargets[Target{"", region}]; ok {
   140  			logger.Println("skipping region")
   141  			continue
   142  		}
   143  		go func(awsSession *session.Session, accountName, regionName string,
   144  			logger log.Logger) {
   145  			targetFunc(awsSession, accountName, regionName, logger)
   146  			if waitChannel != nil {
   147  				waitChannel <- struct{}{}
   148  			}
   149  		}(awsSession, accountName, region, logger)
   150  		numRegions++
   151  	}
   152  	resultsChannel <- resultsType{numRegions, nil}
   153  }