github.com/versent/saml2aws@v2.17.0+incompatible/cmd/saml2aws/commands/login.go (about)

     1  package commands
     2  
     3  import (
     4  	"encoding/base64"
     5  	"fmt"
     6  	"os"
     7  
     8  	"github.com/aws/aws-sdk-go/aws"
     9  	"github.com/aws/aws-sdk-go/aws/session"
    10  	"github.com/aws/aws-sdk-go/service/sts"
    11  	"github.com/pkg/errors"
    12  	"github.com/sirupsen/logrus"
    13  	"github.com/versent/saml2aws"
    14  	"github.com/versent/saml2aws/helper/credentials"
    15  	"github.com/versent/saml2aws/pkg/awsconfig"
    16  	"github.com/versent/saml2aws/pkg/cfg"
    17  	"github.com/versent/saml2aws/pkg/creds"
    18  	"github.com/versent/saml2aws/pkg/flags"
    19  )
    20  
    21  // Login login to ADFS
    22  func Login(loginFlags *flags.LoginExecFlags) error {
    23  
    24  	logger := logrus.WithField("command", "login")
    25  
    26  	account, err := buildIdpAccount(loginFlags)
    27  	if err != nil {
    28  		return errors.Wrap(err, "error building login details")
    29  	}
    30  
    31  	sharedCreds := awsconfig.NewSharedCredentials(account.Profile)
    32  
    33  	logger.Debug("check if Creds Exist")
    34  
    35  	// this checks if the credentials file has been created yet
    36  	exist, err := sharedCreds.CredsExists()
    37  	if err != nil {
    38  		return errors.Wrap(err, "error loading credentials")
    39  	}
    40  	if !exist {
    41  		fmt.Println("unable to load credentials, login required to create them")
    42  		return nil
    43  	}
    44  
    45  	if !sharedCreds.Expired() && !loginFlags.Force {
    46  		fmt.Println("credentials are not expired skipping")
    47  		return nil
    48  	}
    49  
    50  	loginDetails, err := resolveLoginDetails(account, loginFlags)
    51  	if err != nil {
    52  		fmt.Printf("%+v\n", err)
    53  		os.Exit(1)
    54  	}
    55  
    56  	err = loginDetails.Validate()
    57  	if err != nil {
    58  		return errors.Wrap(err, "error validating login details")
    59  	}
    60  
    61  	logger.WithField("idpAccount", account).Debug("building provider")
    62  
    63  	provider, err := saml2aws.NewSAMLClient(account)
    64  	if err != nil {
    65  		return errors.Wrap(err, "error building IdP client")
    66  	}
    67  
    68  	fmt.Printf("Authenticating as %s ...\n", loginDetails.Username)
    69  
    70  	samlAssertion, err := provider.Authenticate(loginDetails)
    71  	if err != nil {
    72  		return errors.Wrap(err, "error authenticating to IdP")
    73  
    74  	}
    75  
    76  	if samlAssertion == "" {
    77  		fmt.Println("Response did not contain a valid SAML assertion")
    78  		fmt.Println("Please check your username and password is correct")
    79  		os.Exit(1)
    80  	}
    81  
    82  	err = credentials.SaveCredentials(loginDetails.URL, loginDetails.Username, loginDetails.Password)
    83  	if err != nil {
    84  		return errors.Wrap(err, "error storing password in keychain")
    85  	}
    86  
    87  	role, err := selectAwsRole(samlAssertion, account)
    88  	if err != nil {
    89  		return errors.Wrap(err, "Failed to assume role, please check whether you are permitted to assume the given role for the AWS service")
    90  	}
    91  
    92  	fmt.Println("Selected role:", role.RoleARN)
    93  
    94  	awsCreds, err := loginToStsUsingRole(account, role, samlAssertion)
    95  	if err != nil {
    96  		return errors.Wrap(err, "error logging into aws role using saml assertion")
    97  	}
    98  
    99  	return saveCredentials(awsCreds, sharedCreds)
   100  }
   101  
   102  func buildIdpAccount(loginFlags *flags.LoginExecFlags) (*cfg.IDPAccount, error) {
   103  	cfgm, err := cfg.NewConfigManager(cfg.DefaultConfigPath)
   104  	if err != nil {
   105  		return nil, errors.Wrap(err, "failed to load configuration")
   106  	}
   107  
   108  	account, err := cfgm.LoadIDPAccount(loginFlags.CommonFlags.IdpAccount)
   109  	if err != nil {
   110  		return nil, errors.Wrap(err, "failed to load idp account")
   111  	}
   112  
   113  	// update username and hostname if supplied
   114  	flags.ApplyFlagOverrides(loginFlags.CommonFlags, account)
   115  
   116  	err = account.Validate()
   117  	if err != nil {
   118  		return nil, errors.Wrap(err, "failed to validate account")
   119  	}
   120  
   121  	return account, nil
   122  }
   123  
   124  func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFlags) (*creds.LoginDetails, error) {
   125  
   126  	// fmt.Printf("loginFlags %+v\n", loginFlags)
   127  
   128  	loginDetails := &creds.LoginDetails{URL: account.URL, Username: account.Username, MFAToken: loginFlags.CommonFlags.MFAToken, DuoMFAOption: loginFlags.DuoMFAOption}
   129  
   130  	fmt.Printf("Using IDP Account %s to access %s %s\n", loginFlags.CommonFlags.IdpAccount, account.Provider, account.URL)
   131  
   132  	err := credentials.LookupCredentials(loginDetails, account.Provider)
   133  	if err != nil {
   134  		if !credentials.IsErrCredentialsNotFound(err) {
   135  			return nil, errors.Wrap(err, "error loading saved password")
   136  		}
   137  	}
   138  
   139  	// fmt.Printf("%s %s\n", savedUsername, savedPassword)
   140  
   141  	// if you supply a username in a flag it takes precedence
   142  	if loginFlags.CommonFlags.Username != "" {
   143  		loginDetails.Username = loginFlags.CommonFlags.Username
   144  	}
   145  
   146  	// if you supply a password in a flag it takes precedence
   147  	if loginFlags.CommonFlags.Password != "" {
   148  		loginDetails.Password = loginFlags.CommonFlags.Password
   149  	}
   150  
   151  	// fmt.Printf("loginDetails %+v\n", loginDetails)
   152  
   153  	// if skip prompt was passed just pass back the flag values
   154  	if loginFlags.CommonFlags.SkipPrompt {
   155  		return loginDetails, nil
   156  	}
   157  
   158  	err = saml2aws.PromptForLoginDetails(loginDetails, account.Provider)
   159  	if err != nil {
   160  		return nil, errors.Wrap(err, "Error occurred accepting input")
   161  	}
   162  
   163  	return loginDetails, nil
   164  }
   165  
   166  func selectAwsRole(samlAssertion string, account *cfg.IDPAccount) (*saml2aws.AWSRole, error) {
   167  	data, err := base64.StdEncoding.DecodeString(samlAssertion)
   168  	if err != nil {
   169  		return nil, errors.Wrap(err, "error decoding saml assertion")
   170  	}
   171  
   172  	roles, err := saml2aws.ExtractAwsRoles(data)
   173  	if err != nil {
   174  		return nil, errors.Wrap(err, "error parsing aws roles")
   175  	}
   176  
   177  	if len(roles) == 0 {
   178  		fmt.Println("No roles to assume")
   179  		fmt.Println("Please check you are permitted to assume roles for the AWS service")
   180  		os.Exit(1)
   181  	}
   182  
   183  	awsRoles, err := saml2aws.ParseAWSRoles(roles)
   184  	if err != nil {
   185  		return nil, errors.Wrap(err, "error parsing aws roles")
   186  	}
   187  
   188  	return resolveRole(awsRoles, samlAssertion, account)
   189  }
   190  
   191  func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string, account *cfg.IDPAccount) (*saml2aws.AWSRole, error) {
   192  	var role = new(saml2aws.AWSRole)
   193  
   194  	if len(awsRoles) == 1 {
   195  		if account.RoleARN != "" {
   196  			return saml2aws.LocateRole(awsRoles, account.RoleARN)
   197  		}
   198  		return awsRoles[0], nil
   199  	} else if len(awsRoles) == 0 {
   200  		return nil, errors.New("no roles available")
   201  	}
   202  
   203  	awsAccounts, err := saml2aws.ParseAWSAccounts(samlAssertion)
   204  	if err != nil {
   205  		return nil, errors.Wrap(err, "error parsing aws role accounts")
   206  	}
   207  	if len(awsAccounts) == 0 {
   208  		return nil, errors.New("no accounts available")
   209  	}
   210  
   211  	saml2aws.AssignPrincipals(awsRoles, awsAccounts)
   212  
   213  	if account.RoleARN != "" {
   214  		return saml2aws.LocateRole(awsRoles, account.RoleARN)
   215  	}
   216  
   217  	for {
   218  		role, err = saml2aws.PromptForAWSRoleSelection(awsAccounts)
   219  		if err == nil {
   220  			break
   221  		}
   222  		fmt.Println("error selecting role, try again")
   223  	}
   224  
   225  	return role, nil
   226  }
   227  
   228  func loginToStsUsingRole(account *cfg.IDPAccount, role *saml2aws.AWSRole, samlAssertion string) (*awsconfig.AWSCredentials, error) {
   229  
   230  	sess, err := session.NewSession()
   231  	if err != nil {
   232  		return nil, errors.Wrap(err, "failed to create session")
   233  	}
   234  
   235  	svc := sts.New(sess)
   236  
   237  	params := &sts.AssumeRoleWithSAMLInput{
   238  		PrincipalArn:    aws.String(role.PrincipalARN), // Required
   239  		RoleArn:         aws.String(role.RoleARN),      // Required
   240  		SAMLAssertion:   aws.String(samlAssertion),     // Required
   241  		DurationSeconds: aws.Int64(int64(account.SessionDuration)),
   242  	}
   243  
   244  	fmt.Println("Requesting AWS credentials using SAML assertion")
   245  
   246  	resp, err := svc.AssumeRoleWithSAML(params)
   247  	if err != nil {
   248  		return nil, errors.Wrap(err, "error retrieving STS credentials using SAML")
   249  	}
   250  
   251  	return &awsconfig.AWSCredentials{
   252  		AWSAccessKey:     aws.StringValue(resp.Credentials.AccessKeyId),
   253  		AWSSecretKey:     aws.StringValue(resp.Credentials.SecretAccessKey),
   254  		AWSSessionToken:  aws.StringValue(resp.Credentials.SessionToken),
   255  		AWSSecurityToken: aws.StringValue(resp.Credentials.SessionToken),
   256  		PrincipalARN:     aws.StringValue(resp.AssumedRoleUser.Arn),
   257  		Expires:          resp.Credentials.Expiration.Local(),
   258  	}, nil
   259  }
   260  
   261  func saveCredentials(awsCreds *awsconfig.AWSCredentials, sharedCreds *awsconfig.CredentialsProvider) error {
   262  	err := sharedCreds.Save(awsCreds)
   263  	if err != nil {
   264  		return errors.Wrap(err, "error saving credentials")
   265  	}
   266  
   267  	fmt.Println("Logged in as:", awsCreds.PrincipalARN)
   268  	fmt.Println("")
   269  	fmt.Println("Your new access key pair has been stored in the AWS configuration")
   270  	fmt.Printf("Note that it will expire at %v\n", awsCreds.Expires)
   271  	fmt.Println("To use this credential, call the AWS CLI with the --profile option (e.g. aws --profile", sharedCreds.Profile, "ec2 describe-instances).")
   272  
   273  	return nil
   274  }