github.com/versent/saml2aws@v2.17.0+incompatible/pkg/awsconfig/awsconfig.go (about)

     1  package awsconfig
     2  
     3  import (
     4  	"io/ioutil"
     5  	"os"
     6  	"path"
     7  	"path/filepath"
     8  	"runtime"
     9  	"time"
    10  
    11  	homedir "github.com/mitchellh/go-homedir"
    12  	"github.com/sirupsen/logrus"
    13  
    14  	"github.com/pkg/errors"
    15  
    16  	ini "gopkg.in/ini.v1"
    17  )
    18  
    19  var (
    20  	// ErrCredentialsHomeNotFound returned when a user home directory can't be located.
    21  	ErrCredentialsHomeNotFound = errors.New("user home directory not found")
    22  
    23  	// ErrCredentialsNotFound returned when the required aws credentials don't exist.
    24  	ErrCredentialsNotFound = errors.New("aws credentials not found")
    25  
    26  	logger = logrus.WithField("pkg", "awsconfig")
    27  )
    28  
    29  // AWSCredentials represents the set of attributes used to authenticate to AWS with a short lived session
    30  type AWSCredentials struct {
    31  	AWSAccessKey     string    `ini:"aws_access_key_id"`
    32  	AWSSecretKey     string    `ini:"aws_secret_access_key"`
    33  	AWSSessionToken  string    `ini:"aws_session_token"`
    34  	AWSSecurityToken string    `ini:"aws_security_token"`
    35  	PrincipalARN     string    `ini:"x_principal_arn"`
    36  	Expires          time.Time `ini:"x_security_token_expires"`
    37  }
    38  
    39  // CredentialsProvider loads aws credentials file
    40  type CredentialsProvider struct {
    41  	Filename string
    42  	Profile  string
    43  }
    44  
    45  // NewSharedCredentials helper to create the credentials provider
    46  func NewSharedCredentials(profile string) *CredentialsProvider {
    47  	return &CredentialsProvider{
    48  		Profile: profile,
    49  	}
    50  }
    51  
    52  // CredsExists verify that the credentials exist
    53  func (p *CredentialsProvider) CredsExists() (bool, error) {
    54  	filename, err := p.resolveFilename()
    55  	if err != nil {
    56  		return false, err
    57  	}
    58  
    59  	err = p.ensureConfigExists()
    60  	if err != nil {
    61  		if os.IsNotExist(err) {
    62  			return false, nil
    63  		}
    64  		return false, errors.Wrapf(err, "unable to load file %s", filename)
    65  	}
    66  
    67  	return true, nil
    68  }
    69  
    70  // Save persist the credentials
    71  func (p *CredentialsProvider) Save(awsCreds *AWSCredentials) error {
    72  	filename, err := p.resolveFilename()
    73  	if err != nil {
    74  		return err
    75  	}
    76  
    77  	err = p.ensureConfigExists()
    78  	if err != nil {
    79  		if os.IsNotExist(err) {
    80  			return createAndSaveProfile(filename, p.Profile, awsCreds)
    81  		}
    82  		return errors.Wrap(err, "unable to load file")
    83  	}
    84  
    85  	return saveProfile(filename, p.Profile, awsCreds)
    86  }
    87  
    88  // Load load the aws credentials file
    89  func (p *CredentialsProvider) Load() (*AWSCredentials, error) {
    90  	filename, err := p.resolveFilename()
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	config, err := ini.Load(filename)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	iniProfile, err := config.GetSection(p.Profile)
   101  	if err != nil {
   102  		return nil, ErrCredentialsNotFound
   103  	}
   104  
   105  	awsCreds := new(AWSCredentials)
   106  
   107  	err = iniProfile.MapTo(awsCreds)
   108  	if err != nil {
   109  		return nil, ErrCredentialsNotFound
   110  	}
   111  
   112  	return awsCreds, nil
   113  }
   114  
   115  // Expired checks if the current credentials are expired
   116  func (p *CredentialsProvider) Expired() bool {
   117  	creds, err := p.Load()
   118  	if err != nil {
   119  		return true
   120  	}
   121  
   122  	return time.Now().After(creds.Expires)
   123  }
   124  
   125  // ensureConfigExists verify that the config file exists
   126  func (p *CredentialsProvider) ensureConfigExists() error {
   127  	filename, err := p.resolveFilename()
   128  	if err != nil {
   129  		return err
   130  	}
   131  	logger.WithField("filename", filename).Debug("ensureConfigExists")
   132  
   133  	if _, err := os.Stat(filename); err != nil {
   134  		if os.IsNotExist(err) {
   135  
   136  			dir := filepath.Dir(filename)
   137  
   138  			err = os.MkdirAll(dir, os.ModePerm)
   139  			if err != nil {
   140  				return err
   141  			}
   142  
   143  			logger.WithField("dir", dir).Debug("Dir created")
   144  
   145  			// create an base config file
   146  			err = ioutil.WriteFile(filename, []byte("["+p.Profile+"]"), 0600)
   147  			if err != nil {
   148  				return err
   149  			}
   150  
   151  			logger.WithField("filename", filename).Debug("File created")
   152  
   153  		}
   154  		return err
   155  	}
   156  
   157  	return nil
   158  }
   159  
   160  func (p *CredentialsProvider) resolveFilename() (string, error) {
   161  	if p.Filename == "" {
   162  		filename, err := locateConfigFile()
   163  		if err != nil {
   164  			return "", err
   165  		}
   166  
   167  		p.Filename = filename
   168  	}
   169  
   170  	return p.Filename, nil
   171  }
   172  
   173  func locateConfigFile() (string, error) {
   174  
   175  	filename := os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
   176  
   177  	if filename != "" {
   178  		return filename, nil
   179  	}
   180  
   181  	var name string
   182  	var err error
   183  	if runtime.GOOS == "windows" {
   184  		name = path.Join(os.Getenv("USERPROFILE"), ".aws", "credentials")
   185  	} else {
   186  		name, err = homedir.Expand("~/.aws/credentials")
   187  		if err != nil {
   188  			return "", ErrCredentialsHomeNotFound
   189  		}
   190  	}
   191  	logger.WithField("name", name).Debug("Expand")
   192  
   193  	// is the filename a symlink?
   194  	name, err = resolveSymlink(name)
   195  	if err != nil {
   196  		return "", errors.Wrap(err, "unable to resolve symlink")
   197  	}
   198  
   199  	logger.WithField("name", name).Debug("resolveSymlink")
   200  
   201  	return name, nil
   202  }
   203  
   204  func resolveSymlink(filename string) (string, error) {
   205  	sympath, err := filepath.EvalSymlinks(filename)
   206  
   207  	// return the un modified filename
   208  	if os.IsNotExist(err) {
   209  		return filename, nil
   210  	}
   211  	if err != nil {
   212  		return "", err
   213  	}
   214  
   215  	return sympath, nil
   216  }
   217  
   218  func createAndSaveProfile(filename, profile string, awsCreds *AWSCredentials) error {
   219  
   220  	dirPath := filepath.Dir(filename)
   221  
   222  	err := os.Mkdir(dirPath, 0700)
   223  	if err != nil {
   224  		return errors.Wrapf(err, "unable to create %s directory", dirPath)
   225  	}
   226  
   227  	_, err = os.Create(filename)
   228  	if err != nil {
   229  		return errors.Wrapf(err, "unable to create configuration")
   230  	}
   231  
   232  	return saveProfile(filename, profile, awsCreds)
   233  }
   234  
   235  func saveProfile(filename, profile string, awsCreds *AWSCredentials) error {
   236  	config, err := ini.Load(filename)
   237  	if err != nil {
   238  		return err
   239  	}
   240  	iniProfile, err := config.NewSection(profile)
   241  	if err != nil {
   242  		return err
   243  	}
   244  
   245  	err = iniProfile.ReflectFrom(awsCreds)
   246  	if err != nil {
   247  		return err
   248  	}
   249  
   250  	return config.SaveTo(filename)
   251  }