github.com/openshift/installer@v1.4.17/pkg/asset/installconfig/aws/session.go (about)

     1  package aws
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"strings"
     9  	"sync"
    10  
    11  	survey "github.com/AlecAivazis/survey/v2"
    12  	"github.com/aws/aws-sdk-go/aws"
    13  	"github.com/aws/aws-sdk-go/aws/awserr"
    14  	"github.com/aws/aws-sdk-go/aws/credentials"
    15  	"github.com/aws/aws-sdk-go/aws/defaults"
    16  	"github.com/aws/aws-sdk-go/aws/endpoints"
    17  	"github.com/aws/aws-sdk-go/aws/request"
    18  	"github.com/aws/aws-sdk-go/aws/session"
    19  	"github.com/sirupsen/logrus"
    20  	ini "gopkg.in/ini.v1"
    21  
    22  	typesaws "github.com/openshift/installer/pkg/types/aws"
    23  	"github.com/openshift/installer/pkg/version"
    24  )
    25  
    26  var (
    27  	onceLoggers = map[string]*sync.Once{
    28  		credentials.SharedCredsProviderName: new(sync.Once),
    29  		credentials.EnvProviderName:         new(sync.Once),
    30  		"credentialsFromSession":            new(sync.Once),
    31  	}
    32  )
    33  
    34  // SessionOptions is a function that modifies the provided session.Option.
    35  type SessionOptions func(sess *session.Options)
    36  
    37  // WithRegion configures the session.Option to set the AWS region.
    38  func WithRegion(region string) SessionOptions {
    39  	return func(sess *session.Options) {
    40  		cfg := aws.NewConfig().WithRegion(region)
    41  		sess.Config.MergeIn(cfg)
    42  	}
    43  }
    44  
    45  // WithServiceEndpoints configures the session.Option to use provides services for AWS endpoints.
    46  func WithServiceEndpoints(region string, services []typesaws.ServiceEndpoint) SessionOptions {
    47  	return func(sess *session.Options) {
    48  		resolver := newAWSResolver(region, services)
    49  		cfg := aws.NewConfig().WithEndpointResolver(resolver)
    50  		sess.Config.MergeIn(cfg)
    51  	}
    52  }
    53  
    54  // GetSession returns an AWS session by checking credentials
    55  // and, if no creds are found, asks for them and stores them on disk in a config file
    56  func GetSession() (*session.Session, error) { return GetSessionWithOptions() }
    57  
    58  // GetSessionWithOptions returns an AWS session by checking credentials
    59  // and, if no creds are found, asks for them and stores them on disk in a config file
    60  func GetSessionWithOptions(optFuncs ...SessionOptions) (*session.Session, error) {
    61  	options := session.Options{
    62  		Config:            aws.Config{MaxRetries: aws.Int(0)},
    63  		SharedConfigState: session.SharedConfigEnable,
    64  	}
    65  	for _, optFunc := range optFuncs {
    66  		optFunc(&options)
    67  	}
    68  
    69  	_, err := getCredentials(options)
    70  	if err != nil && errCodeEquals(err, "NoCredentialProviders") {
    71  		if err = getUserCredentials(); err != nil {
    72  			return nil, err
    73  		}
    74  	}
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	ssn := session.Must(session.NewSessionWithOptions(options))
    80  	ssn = ssn.Copy(&aws.Config{MaxRetries: aws.Int(25)})
    81  	ssn.Handlers.Build.PushBackNamed(request.NamedHandler{
    82  		Name: "openshiftInstaller.OpenshiftInstallerUserAgentHandler",
    83  		Fn:   request.MakeAddToUserAgentHandler("OpenShift/4.x Installer", version.Raw),
    84  	})
    85  	return ssn, nil
    86  }
    87  
    88  func getCredentials(options session.Options) (*credentials.Credentials, error) {
    89  	sharedCredentialsProvider := &credentials.SharedCredentialsProvider{}
    90  	providers := []credentials.Provider{
    91  		&credentials.EnvProvider{},
    92  		sharedCredentialsProvider,
    93  	}
    94  
    95  	creds := credentials.NewChainCredentials(providers)
    96  	credsValue, err := creds.Get()
    97  	if err != nil && errCodeEquals(err, "NoCredentialProviders") {
    98  		// getCredentialsFromSession returns credentials derived from a session. A
    99  		// session uses the AWS SDK Go chain of providers so may use a provider (e.g.,
   100  		// STS) which provides temporary credentials.
   101  		return getCredentialsFromSession(options)
   102  	}
   103  	if err != nil {
   104  		return nil, fmt.Errorf("error loading credentials for AWS Provider: %w", err)
   105  	}
   106  
   107  	// log the source of credential provider.
   108  	switch credsValue.ProviderName {
   109  	case credentials.SharedCredsProviderName:
   110  		onceLoggers[credentials.SharedCredsProviderName].Do(func() {
   111  			logrus.Infof("Credentials loaded from the %q profile in file %q", sharedCredentialsProvider.Profile, sharedCredentialsProvider.Filename)
   112  		})
   113  	case credentials.EnvProviderName:
   114  		onceLoggers[credentials.EnvProviderName].Do(func() {
   115  			logrus.Info("Credentials loaded from default AWS environment variables")
   116  		})
   117  	}
   118  	return creds, nil
   119  }
   120  
   121  func getCredentialsFromSession(options session.Options) (*credentials.Credentials, error) {
   122  	sess, err := session.NewSessionWithOptions(options)
   123  	if err != nil {
   124  		if errCodeEquals(err, "NoCredentialProviders") {
   125  			return nil, fmt.Errorf("failed to get credentials from session: %w", err)
   126  		}
   127  		return nil, fmt.Errorf("error creating AWS session: %w", err)
   128  	}
   129  	creds := sess.Config.Credentials
   130  
   131  	credsValue, err := sess.Config.Credentials.Get()
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	onceLoggers["credentialsFromSession"].Do(func() {
   136  		logrus.Infof("Credentials loaded from the AWS config using %q provider", credsValue.ProviderName)
   137  	})
   138  
   139  	return creds, nil
   140  }
   141  
   142  // IsStaticCredentials returns whether the credentials value provider are
   143  // static credentials safe for installer to transfer to cluster for use as-is.
   144  func IsStaticCredentials(credsValue credentials.Value) bool {
   145  	switch credsValue.ProviderName {
   146  	case credentials.EnvProviderName, credentials.StaticProviderName, credentials.SharedCredsProviderName, session.EnvProviderName:
   147  		return credsValue.SessionToken == ""
   148  	}
   149  	if strings.HasPrefix(credsValue.ProviderName, "SharedConfigCredentials") {
   150  		return credsValue.SessionToken == ""
   151  	}
   152  	return false
   153  }
   154  
   155  // errCodeEquals returns true if the error matches all these conditions:
   156  //   - err is of type awserr.Error
   157  //   - Error.Code() equals code
   158  func errCodeEquals(err error, code string) bool {
   159  	var awsErr awserr.Error
   160  	if errors.As(err, &awsErr) {
   161  		return awsErr.Code() == code
   162  	}
   163  	return false
   164  }
   165  
   166  func getUserCredentials() error {
   167  	var keyID string
   168  	err := survey.Ask([]*survey.Question{
   169  		{
   170  			Prompt: &survey.Input{
   171  				Message: "AWS Access Key ID",
   172  				Help:    "The AWS access key ID to use for installation (this is not your username).\nhttps://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html",
   173  			},
   174  		},
   175  	}, &keyID)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	var secretKey string
   181  	err = survey.Ask([]*survey.Question{
   182  		{
   183  			Prompt: &survey.Password{
   184  				Message: "AWS Secret Access Key",
   185  				Help:    "The AWS secret access key corresponding to your access key ID (this is not your password).",
   186  			},
   187  		},
   188  	}, &secretKey)
   189  	if err != nil {
   190  		return err
   191  	}
   192  
   193  	path := defaults.SharedCredentialsFilename()
   194  	if env := os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); env != "" {
   195  		path = env
   196  	}
   197  	logrus.Infof("Writing AWS credentials to %q (https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html)", path)
   198  	err = os.MkdirAll(filepath.Dir(path), 0700)
   199  	if err != nil {
   200  		return err
   201  	}
   202  
   203  	creds, err := ini.Load(path)
   204  	if err != nil {
   205  		if !os.IsNotExist(err) {
   206  			return fmt.Errorf("failed to load credentials file %s: %w", path, err)
   207  		}
   208  		creds = ini.Empty()
   209  		creds.Section("").Comment = "https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html"
   210  	}
   211  
   212  	profile := os.Getenv("AWS_PROFILE")
   213  	if profile == "" {
   214  		profile = "default"
   215  	}
   216  
   217  	creds.Section(profile).Key("aws_access_key_id").SetValue(keyID)
   218  	creds.Section(profile).Key("aws_secret_access_key").SetValue(secretKey)
   219  
   220  	tempPath := path + ".tmp"
   221  	file, err := os.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
   222  	if err != nil {
   223  		return err
   224  	}
   225  	defer file.Close()
   226  
   227  	_, err = creds.WriteTo(file)
   228  	if err != nil {
   229  		err2 := os.Remove(tempPath)
   230  		if err2 != nil {
   231  			logrus.Error(fmt.Errorf("failed to remove partially-written credentials file: %w", err2))
   232  		}
   233  		return err
   234  	}
   235  
   236  	return os.Rename(tempPath, path)
   237  }
   238  
   239  type awsResolver struct {
   240  	region   string
   241  	services map[string]typesaws.ServiceEndpoint
   242  
   243  	// this is a list of known default endpoints for specific regions that would
   244  	// otherwise require user to set the service overrides.
   245  	// it's a map of region => service => resolved endpoint
   246  	// this is only used when the user hasn't specified a override for the service in that region.
   247  	defaultEndpoints map[string]map[string]endpoints.ResolvedEndpoint
   248  }
   249  
   250  func newAWSResolver(region string, services []typesaws.ServiceEndpoint) *awsResolver {
   251  	resolver := &awsResolver{
   252  		region:           region,
   253  		services:         make(map[string]typesaws.ServiceEndpoint),
   254  		defaultEndpoints: defaultEndpoints(),
   255  	}
   256  	for _, service := range services {
   257  		service := service
   258  		resolver.services[resolverKey(service.Name)] = service
   259  	}
   260  	return resolver
   261  }
   262  
   263  func (ar *awsResolver) EndpointFor(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
   264  	if s, ok := ar.services[resolverKey(service)]; ok {
   265  		logrus.Debugf("resolved AWS service %s (%s) to %q", service, region, s.URL)
   266  		signingRegion := ar.region
   267  		def, _ := endpoints.DefaultResolver().EndpointFor(service, region)
   268  		if len(def.SigningRegion) > 0 {
   269  			signingRegion = def.SigningRegion
   270  		}
   271  		return endpoints.ResolvedEndpoint{
   272  			URL:           s.URL,
   273  			SigningRegion: signingRegion,
   274  		}, nil
   275  	}
   276  	if rv, ok := ar.defaultEndpoints[region]; ok {
   277  		if v, ok := rv[service]; ok {
   278  			return v, nil
   279  		}
   280  	}
   281  	return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
   282  }
   283  
   284  func resolverKey(service string) string {
   285  	return service
   286  }
   287  
   288  // this is a list of known default endpoints for specific regions that would
   289  // otherwise require user to set the service overrides.
   290  // it's a map of region => service => resolved endpoint
   291  // this is only used when the user hasn't specified a override for the service in that region.
   292  func defaultEndpoints() map[string]map[string]endpoints.ResolvedEndpoint {
   293  	return map[string]map[string]endpoints.ResolvedEndpoint{
   294  		endpoints.CnNorth1RegionID: {
   295  			"route53": {
   296  				URL:           "https://route53.amazonaws.com.cn",
   297  				SigningRegion: endpoints.CnNorthwest1RegionID,
   298  			},
   299  		},
   300  		endpoints.CnNorthwest1RegionID: {
   301  			"route53": {
   302  				URL:           "https://route53.amazonaws.com.cn",
   303  				SigningRegion: endpoints.CnNorthwest1RegionID,
   304  			},
   305  		},
   306  	}
   307  }