github.com/darmach/terratest@v0.34.8-0.20210517103231-80931f95e3ff/modules/aws/ssm.go (about)

     1  package aws
     2  
     3  import (
     4  	"fmt"
     5  	"time"
     6  
     7  	"github.com/aws/aws-sdk-go/aws"
     8  	"github.com/aws/aws-sdk-go/service/ssm"
     9  	"github.com/gruntwork-io/terratest/modules/logger"
    10  	"github.com/gruntwork-io/terratest/modules/retry"
    11  	"github.com/gruntwork-io/terratest/modules/testing"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  // GetParameter retrieves the latest version of SSM Parameter at keyName with decryption.
    16  func GetParameter(t testing.TestingT, awsRegion string, keyName string) string {
    17  	keyValue, err := GetParameterE(t, awsRegion, keyName)
    18  	require.NoError(t, err)
    19  	return keyValue
    20  }
    21  
    22  // GetParameterE retrieves the latest version of SSM Parameter at keyName with decryption.
    23  func GetParameterE(t testing.TestingT, awsRegion string, keyName string) (string, error) {
    24  	ssmClient, err := NewSsmClientE(t, awsRegion)
    25  	if err != nil {
    26  		return "", err
    27  	}
    28  
    29  	resp, err := ssmClient.GetParameter(&ssm.GetParameterInput{Name: aws.String(keyName), WithDecryption: aws.Bool(true)})
    30  	if err != nil {
    31  		return "", err
    32  	}
    33  
    34  	parameter := *resp.Parameter
    35  	return *parameter.Value, nil
    36  }
    37  
    38  // PutParameter creates new version of SSM Parameter at keyName with keyValue as SecureString.
    39  func PutParameter(t testing.TestingT, awsRegion string, keyName string, keyDescription string, keyValue string) int64 {
    40  	version, err := PutParameterE(t, awsRegion, keyName, keyDescription, keyValue)
    41  	require.NoError(t, err)
    42  	return version
    43  }
    44  
    45  // PutParameterE creates new version of SSM Parameter at keyName with keyValue as SecureString.
    46  func PutParameterE(t testing.TestingT, awsRegion string, keyName string, keyDescription string, keyValue string) (int64, error) {
    47  	ssmClient, err := NewSsmClientE(t, awsRegion)
    48  	if err != nil {
    49  		return 0, err
    50  	}
    51  
    52  	resp, err := ssmClient.PutParameter(&ssm.PutParameterInput{Name: aws.String(keyName), Description: aws.String(keyDescription), Value: aws.String(keyValue), Type: aws.String("SecureString")})
    53  	if err != nil {
    54  		return 0, err
    55  	}
    56  
    57  	return *resp.Version, nil
    58  }
    59  
    60  // DeleteParameter deletes all versions of SSM Parameter at keyName.
    61  func DeleteParameter(t testing.TestingT, awsRegion string, keyName string) {
    62  	err := DeleteParameterE(t, awsRegion, keyName)
    63  	require.NoError(t, err)
    64  }
    65  
    66  // DeleteParameterE deletes all versions of SSM Parameter at keyName.
    67  func DeleteParameterE(t testing.TestingT, awsRegion string, keyName string) error {
    68  	ssmClient, err := NewSsmClientE(t, awsRegion)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	_, err = ssmClient.DeleteParameter(&ssm.DeleteParameterInput{Name: aws.String(keyName)})
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	return nil
    79  }
    80  
    81  // NewSsmClient creates a SSM client.
    82  func NewSsmClient(t testing.TestingT, region string) *ssm.SSM {
    83  	client, err := NewSsmClientE(t, region)
    84  	require.NoError(t, err)
    85  	return client
    86  }
    87  
    88  // NewSsmClientE creates an SSM client.
    89  func NewSsmClientE(t testing.TestingT, region string) (*ssm.SSM, error) {
    90  	sess, err := NewAuthenticatedSession(region)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	return ssm.New(sess), nil
    96  }
    97  
    98  // WaitForSsmInstanceE waits until the instance get registered to the SSM inventory.
    99  func WaitForSsmInstanceE(t testing.TestingT, awsRegion, instanceID string, timeout time.Duration) error {
   100  	timeBetweenRetries := 2 * time.Second
   101  	maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds())
   102  	description := fmt.Sprintf("Waiting for %s to appear in the SSM inventory", instanceID)
   103  
   104  	input := &ssm.GetInventoryInput{
   105  		Filters: []*ssm.InventoryFilter{
   106  			{
   107  				Key:    aws.String("AWS:InstanceInformation.InstanceId"),
   108  				Type:   aws.String("Equal"),
   109  				Values: aws.StringSlice([]string{instanceID}),
   110  			},
   111  		},
   112  	}
   113  	_, err := retry.DoWithRetryE(t, description, maxRetries, timeBetweenRetries, func() (string, error) {
   114  		client := NewSsmClient(t, awsRegion)
   115  		resp, err := client.GetInventory(input)
   116  
   117  		if err != nil {
   118  			return "", err
   119  		}
   120  
   121  		if len(resp.Entities) != 1 {
   122  			return "", fmt.Errorf("%s is not in the SSM inventory", instanceID)
   123  		}
   124  
   125  		return "", nil
   126  	})
   127  
   128  	return err
   129  }
   130  
   131  // WaitForSsmInstance waits until the instance get registered to the SSM inventory.
   132  func WaitForSsmInstance(t testing.TestingT, awsRegion, instanceID string, timeout time.Duration) {
   133  	err := WaitForSsmInstanceE(t, awsRegion, instanceID, timeout)
   134  	require.NoError(t, err)
   135  }
   136  
   137  // CheckSsmCommand checks that you can run the given command on the given instance through AWS SSM.
   138  func CheckSsmCommand(t testing.TestingT, awsRegion, instanceID, command string, timeout time.Duration) *CommandOutput {
   139  	result, err := CheckSsmCommandE(t, awsRegion, instanceID, command, timeout)
   140  	require.NoErrorf(t, err, "failed to execute '%s' on %s (%v):]\n  stdout: %#v\n  stderr: %#v", command, instanceID, err, result.Stdout, result.Stderr)
   141  	return result
   142  }
   143  
   144  // CommandOutput contains the result of the SSM command.
   145  type CommandOutput struct {
   146  	Stdout   string
   147  	Stderr   string
   148  	ExitCode int64
   149  }
   150  
   151  // CheckSsmCommandE checks that you can run the given command on the given instance through AWS SSM. Returns the result and an error if one occurs.
   152  func CheckSsmCommandE(t testing.TestingT, awsRegion, instanceID, command string, timeout time.Duration) (*CommandOutput, error) {
   153  	logger.Logf(t, "Running command '%s' on EC2 instance with ID '%s'", command, instanceID)
   154  
   155  	timeBetweenRetries := 2 * time.Second
   156  	maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds())
   157  
   158  	// Now that we know the instance in the SSM inventory, we can send the command
   159  	client, err := NewSsmClientE(t, awsRegion)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	resp, err := client.SendCommand(&ssm.SendCommandInput{
   164  		Comment:      aws.String("Terratest SSM"),
   165  		DocumentName: aws.String("AWS-RunShellScript"),
   166  		InstanceIds:  aws.StringSlice([]string{instanceID}),
   167  		Parameters: map[string][]*string{
   168  			"commands": aws.StringSlice([]string{command}),
   169  		},
   170  	})
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	// Wait for the result
   176  	description := "Waiting for the result of the command"
   177  	retryableErrors := map[string]string{
   178  		"InvocationDoesNotExist": "InvocationDoesNotExist",
   179  		"bad status: Pending":    "bad status: Pending",
   180  		"bad status: InProgress": "bad status: InProgress",
   181  		"bad status: Delayed":    "bad status: Delayed",
   182  	}
   183  
   184  	result := &CommandOutput{}
   185  	_, err = retry.DoWithRetryableErrorsE(t, description, retryableErrors, maxRetries, timeBetweenRetries, func() (string, error) {
   186  		resp, err := client.GetCommandInvocation(&ssm.GetCommandInvocationInput{
   187  			CommandId:  resp.Command.CommandId,
   188  			InstanceId: &instanceID,
   189  		})
   190  
   191  		if err != nil {
   192  			return "", err
   193  		}
   194  
   195  		result.Stderr = aws.StringValue(resp.StandardErrorContent)
   196  		result.Stdout = aws.StringValue(resp.StandardOutputContent)
   197  		result.ExitCode = aws.Int64Value(resp.ResponseCode)
   198  
   199  		status := aws.StringValue(resp.Status)
   200  
   201  		if status == ssm.CommandInvocationStatusSuccess {
   202  			return "", nil
   203  		}
   204  
   205  		if status == ssm.CommandInvocationStatusFailed {
   206  			return "", fmt.Errorf(aws.StringValue(resp.StatusDetails))
   207  		}
   208  
   209  		return "", fmt.Errorf("bad status: %s", status)
   210  	})
   211  
   212  	if err != nil {
   213  		if actualErr, ok := err.(retry.FatalError); ok {
   214  			return result, actualErr.Underlying
   215  		}
   216  		return result, fmt.Errorf("Unexpected error: %v", err)
   217  	}
   218  
   219  	return result, nil
   220  }