github.com/hairyhenderson/templater@v3.5.0+incompatible/aws/sts.go (about)

     1  package aws
     2  
     3  import (
     4  	"github.com/aws/aws-sdk-go/aws"
     5  	"github.com/aws/aws-sdk-go/service/sts"
     6  )
     7  
     8  // STS -
     9  type STS struct {
    10  	identifier func() CallerIdentitifier
    11  	cache      map[string]interface{}
    12  }
    13  
    14  var identifierClient CallerIdentitifier
    15  
    16  // CallerIdentitifier - an interface to wrap GetCallerIdentity
    17  type CallerIdentitifier interface {
    18  	GetCallerIdentity(*sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error)
    19  }
    20  
    21  // NewSTS -
    22  func NewSTS(options ClientOptions) *STS {
    23  	return &STS{
    24  		identifier: func() CallerIdentitifier {
    25  			if identifierClient == nil {
    26  				session := SDKSession()
    27  				identifierClient = sts.New(session)
    28  			}
    29  			return identifierClient
    30  		},
    31  		cache: make(map[string]interface{}),
    32  	}
    33  }
    34  
    35  func (s *STS) getCallerID() (*sts.GetCallerIdentityOutput, error) {
    36  	i := s.identifier()
    37  	if val, ok := s.cache["GetCallerIdentity"]; ok {
    38  		if c, ok := val.(*sts.GetCallerIdentityOutput); ok {
    39  			return c, nil
    40  		}
    41  	}
    42  	in := &sts.GetCallerIdentityInput{}
    43  	out, err := i.GetCallerIdentity(in)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  	s.cache["GetCallerIdentity"] = out
    48  	return out, nil
    49  }
    50  
    51  // UserID -
    52  func (s *STS) UserID() (string, error) {
    53  	cid, err := s.getCallerID()
    54  	if err != nil {
    55  		return "", err
    56  	}
    57  	return aws.StringValue(cid.UserId), nil
    58  }
    59  
    60  // Account -
    61  func (s *STS) Account() (string, error) {
    62  	cid, err := s.getCallerID()
    63  	if err != nil {
    64  		return "", err
    65  	}
    66  	return aws.StringValue(cid.Account), nil
    67  }
    68  
    69  // Arn -
    70  func (s *STS) Arn() (string, error) {
    71  	cid, err := s.getCallerID()
    72  	if err != nil {
    73  		return "", err
    74  	}
    75  	return aws.StringValue(cid.Arn), nil
    76  }