github.com/opentofu/opentofu@v1.7.1/internal/encryption/keyprovider/aws_kms/config.go (about)

     1  package aws_kms
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  
     8  	"github.com/aws/aws-sdk-go-v2/aws"
     9  	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
    10  	"github.com/aws/aws-sdk-go-v2/service/kms"
    11  	"github.com/aws/aws-sdk-go-v2/service/kms/types"
    12  	awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
    13  	baselogging "github.com/hashicorp/aws-sdk-go-base/v2/logging"
    14  	"github.com/opentofu/opentofu/internal/encryption/keyprovider"
    15  	"github.com/opentofu/opentofu/internal/httpclient"
    16  	"github.com/opentofu/opentofu/internal/logging"
    17  	"github.com/opentofu/opentofu/version"
    18  )
    19  
    20  // Can be overridden for test mocking
    21  var newKMSFromConfig func(aws.Config) kmsClient = func(cfg aws.Config) kmsClient {
    22  	return kms.NewFromConfig(cfg)
    23  }
    24  
    25  type Config struct {
    26  	// KeyProvider Config
    27  	KMSKeyID string `hcl:"kms_key_id"`
    28  	KeySpec  string `hcl:"key_spec"`
    29  
    30  	// Mirrored S3 Backend Config, mirror any changes
    31  	AccessKey                      string                     `hcl:"access_key,optional"`
    32  	Endpoints                      []ConfigEndpoints          `hcl:"endpoints,block"`
    33  	MaxRetries                     int                        `hcl:"max_retries,optional"`
    34  	Profile                        string                     `hcl:"profile,optional"`
    35  	Region                         string                     `hcl:"region,optional"`
    36  	SecretKey                      string                     `hcl:"secret_key,optional"`
    37  	SkipCredsValidation            bool                       `hcl:"skip_credentials_validation,optional"`
    38  	SkipRequestingAccountId        bool                       `hcl:"skip_requesting_account_id,optional"`
    39  	STSRegion                      string                     `hcl:"sts_region,optional"`
    40  	Token                          string                     `hcl:"token,optional"`
    41  	HTTPProxy                      *string                    `hcl:"http_proxy,optional"`
    42  	HTTPSProxy                     *string                    `hcl:"https_proxy,optional"`
    43  	NoProxy                        string                     `hcl:"no_proxy,optional"`
    44  	Insecure                       bool                       `hcl:"insecure,optional"`
    45  	UseDualStackEndpoint           bool                       `hcl:"use_dualstack_endpoint,optional"`
    46  	UseFIPSEndpoint                bool                       `hcl:"use_fips_endpoint,optional"`
    47  	CustomCABundle                 string                     `hcl:"custom_ca_bundle,optional"`
    48  	EC2MetadataServiceEndpoint     string                     `hcl:"ec2_metadata_service_endpoint,optional"`
    49  	EC2MetadataServiceEndpointMode string                     `hcl:"ec2_metadata_service_endpoint_mode,optional"`
    50  	SkipMetadataAPICheck           *bool                      `hcl:"skip_metadata_api_check,optional"`
    51  	SharedCredentialsFiles         []string                   `hcl:"shared_credentials_files,optional"`
    52  	SharedConfigFiles              []string                   `hcl:"shared_config_files,optional"`
    53  	AssumeRole                     *AssumeRole                `hcl:"assume_role,optional"`
    54  	AssumeRoleWithWebIdentity      *AssumeRoleWithWebIdentity `hcl:"assume_role_with_web_identity,optional"`
    55  	AllowedAccountIds              []string                   `hcl:"allowed_account_ids,optional"`
    56  	ForbiddenAccountIds            []string                   `hcl:"forbidden_account_ids,optional"`
    57  	RetryMode                      string                     `hcl:"retry_mode,optional"`
    58  }
    59  
    60  func stringAttrEnvFallback(val string, env string) string {
    61  	if val != "" {
    62  		return val
    63  	}
    64  	return os.Getenv(env)
    65  }
    66  
    67  func stringArrayAttrEnvFallback(val []string, env string) []string {
    68  	if len(val) != 0 {
    69  		return val
    70  	}
    71  	envVal := os.Getenv(env)
    72  	if envVal != "" {
    73  		return []string{envVal}
    74  	}
    75  	return nil
    76  }
    77  
    78  func (c Config) asAWSBase() (*awsbase.Config, error) {
    79  	// Get endpoints to use
    80  	endpoints, err := c.getEndpoints()
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	// Get assume role
    86  	assumeRole, err := c.AssumeRole.asAWSBase()
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	// Get assume role with web identity
    92  	assumeRoleWithWebIdentity, err := c.AssumeRoleWithWebIdentity.asAWSBase()
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	// validate region
    98  	if c.Region == "" && os.Getenv("AWS_REGION") == "" && os.Getenv("AWS_DEFAULT_REGION") == "" {
    99  		return nil, fmt.Errorf(`the "region" attribute or the "AWS_REGION" or "AWS_DEFAULT_REGION" environment variables must be set.`)
   100  	}
   101  
   102  	// Retry Mode
   103  	if c.MaxRetries == 0 {
   104  		c.MaxRetries = 5
   105  	}
   106  	var retryMode aws.RetryMode
   107  	if len(c.RetryMode) != 0 {
   108  		retryMode, err = aws.ParseRetryMode(c.RetryMode)
   109  		if err != nil {
   110  			return nil, fmt.Errorf("%w: expected %q or %q", err, aws.RetryModeStandard, aws.RetryModeAdaptive)
   111  		}
   112  	}
   113  
   114  	// IDMS handling
   115  	imdsEnabled := imds.ClientDefaultEnableState
   116  	if c.SkipMetadataAPICheck != nil {
   117  		if *c.SkipMetadataAPICheck {
   118  			imdsEnabled = imds.ClientEnabled
   119  		} else {
   120  			imdsEnabled = imds.ClientDisabled
   121  		}
   122  	}
   123  
   124  	// validate account_ids
   125  	if len(c.AllowedAccountIds) != 0 && len(c.ForbiddenAccountIds) != 0 {
   126  		return nil, fmt.Errorf("conflicting config attributes: only allowed_account_ids or forbidden_account_ids can be specified, not both")
   127  	}
   128  
   129  	return &awsbase.Config{
   130  		AccessKey:               c.AccessKey,
   131  		CallerDocumentationURL:  "https://opentofu.org/docs/language/settings/backends/s3", // TODO
   132  		CallerName:              "KMS Key Provider",
   133  		IamEndpoint:             stringAttrEnvFallback(endpoints.IAM, "AWS_ENDPOINT_URL_IAM"),
   134  		MaxRetries:              c.MaxRetries,
   135  		RetryMode:               retryMode,
   136  		Profile:                 c.Profile,
   137  		Region:                  c.Region,
   138  		SecretKey:               c.SecretKey,
   139  		SkipCredsValidation:     c.SkipCredsValidation,
   140  		SkipRequestingAccountId: c.SkipRequestingAccountId,
   141  		StsEndpoint:             stringAttrEnvFallback(endpoints.STS, "AWS_ENDPOINT_URL_STS"),
   142  		StsRegion:               c.STSRegion,
   143  		Token:                   c.Token,
   144  
   145  		// Note: we don't need to read env variables explicitly because they are read implicitly by aws-sdk-base-go:
   146  		// see: https://github.com/hashicorp/aws-sdk-go-base/blob/v2.0.0-beta.41/internal/config/config.go#L133
   147  		// which relies on: https://cs.opensource.google/go/x/net/+/refs/tags/v0.18.0:http/httpproxy/proxy.go;l=89-96
   148  		HTTPProxy:            c.HTTPProxy,
   149  		HTTPSProxy:           c.HTTPSProxy,
   150  		NoProxy:              c.NoProxy,
   151  		Insecure:             c.Insecure,
   152  		UseDualStackEndpoint: c.UseDualStackEndpoint,
   153  		UseFIPSEndpoint:      c.UseFIPSEndpoint,
   154  		UserAgent: awsbase.UserAgentProducts{
   155  			{Name: "APN", Version: "1.0"},
   156  			{Name: httpclient.DefaultApplicationName, Version: version.String()},
   157  		},
   158  		CustomCABundle: stringAttrEnvFallback(c.CustomCABundle, "AWS_CA_BUNDLE"),
   159  
   160  		EC2MetadataServiceEnableState:  imdsEnabled,
   161  		EC2MetadataServiceEndpoint:     stringAttrEnvFallback(c.EC2MetadataServiceEndpoint, "AWS_EC2_METADATA_SERVICE_ENDPOINT"),
   162  		EC2MetadataServiceEndpointMode: stringAttrEnvFallback(c.EC2MetadataServiceEndpointMode, "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE"),
   163  
   164  		SharedCredentialsFiles:    stringArrayAttrEnvFallback(c.SharedCredentialsFiles, "AWS_SHARED_CREDENTIALS_FILE"),
   165  		SharedConfigFiles:         stringArrayAttrEnvFallback(c.SharedConfigFiles, "AWS_SHARED_CONFIG_FILE"),
   166  		AssumeRole:                assumeRole,
   167  		AssumeRoleWithWebIdentity: assumeRoleWithWebIdentity,
   168  		AllowedAccountIds:         c.AllowedAccountIds,
   169  		ForbiddenAccountIds:       c.ForbiddenAccountIds,
   170  	}, nil
   171  }
   172  
   173  func (c Config) Build() (keyprovider.KeyProvider, keyprovider.KeyMeta, error) {
   174  	err := c.validate()
   175  	if err != nil {
   176  		return nil, nil, err
   177  	}
   178  
   179  	cfg, err := c.asAWSBase()
   180  	if err != nil {
   181  		return nil, nil, err
   182  	}
   183  
   184  	ctx := context.Background()
   185  	ctx, baselog := attachLoggerToContext(ctx)
   186  	cfg.Logger = baselog
   187  
   188  	_, awsConfig, awsDiags := awsbase.GetAwsConfig(ctx, cfg)
   189  
   190  	if awsDiags.HasError() {
   191  		out := "errors were encountered in aws kms configuration"
   192  		for _, diag := range awsDiags.Errors() {
   193  			out += "\n" + diag.Summary() + " : " + diag.Detail()
   194  		}
   195  
   196  		return nil, nil, fmt.Errorf(out)
   197  	}
   198  
   199  	return &keyProvider{
   200  		Config: c,
   201  		svc:    newKMSFromConfig(awsConfig),
   202  		ctx:    ctx,
   203  	}, new(keyMeta), nil
   204  }
   205  
   206  // validate checks the configuration for the key provider
   207  func (c Config) validate() (err error) {
   208  	if c.KMSKeyID == "" {
   209  		return &keyprovider.ErrInvalidConfiguration{
   210  			Message: "no kms_key_id provided",
   211  		}
   212  	}
   213  
   214  	if c.KeySpec == "" {
   215  		return &keyprovider.ErrInvalidConfiguration{
   216  			Message: "no key_spec provided",
   217  		}
   218  	}
   219  
   220  	spec := c.getKeySpecAsAWSType()
   221  	if spec == nil {
   222  		// This is to fetch a list of the values from the enum, because `spec` here can be nil, so we have to grab
   223  		// at least one of the enum possibilities here just to call .Values()
   224  		values := types.DataKeySpecAes256.Values()
   225  		return &keyprovider.ErrInvalidConfiguration{
   226  			Message: fmt.Sprintf("invalid key_spec %s, expected one of %v", c.KeySpec, values),
   227  		}
   228  	}
   229  
   230  	return nil
   231  }
   232  
   233  // getSpecAsAWSType handles conversion between the string from the config and the aws expected enum type
   234  // it will return nil if it cannot find a match
   235  func (c Config) getKeySpecAsAWSType() *types.DataKeySpec {
   236  	var spec types.DataKeySpec
   237  	for _, opt := range spec.Values() {
   238  		if string(opt) == c.KeySpec {
   239  			return &opt
   240  		}
   241  	}
   242  	return nil
   243  }
   244  
   245  // Mirrored from s3 backend config
   246  func attachLoggerToContext(ctx context.Context) (context.Context, baselogging.HcLogger) {
   247  	ctx, baseLog := baselogging.NewHcLogger(ctx, logging.HCLogger().Named("backend-s3"))
   248  	ctx = baselogging.RegisterLogger(ctx, baseLog)
   249  	return ctx, baseLog
   250  }