github.com/terraform-modules-krish/terratest@v0.29.0/modules/aws/ssm.go (about)

     1  package aws
     2  
     3  import (
     4  	"fmt"
     5  	"time"
     6  
     7  	"github.com/aws/aws-sdk-go/aws"
     8  	"github.com/aws/aws-sdk-go/service/ssm"
     9  	"github.com/terraform-modules-krish/terratest/modules/logger"
    10  	"github.com/terraform-modules-krish/terratest/modules/retry"
    11  	"github.com/terraform-modules-krish/terratest/modules/testing"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  // GetParameter retrieves the latest version of SSM Parameter at keyName with decryption.
    16  func GetParameter(t testing.TestingT, awsRegion string, keyName string) string {
    17  	keyValue, err := GetParameterE(t, awsRegion, keyName)
    18  	require.NoError(t, err)
    19  	return keyValue
    20  }
    21  
    22  // GetParameterE retrieves the latest version of SSM Parameter at keyName with decryption.
    23  func GetParameterE(t testing.TestingT, awsRegion string, keyName string) (string, error) {
    24  	ssmClient, err := NewSsmClientE(t, awsRegion)
    25  	if err != nil {
    26  		return "", err
    27  	}
    28  
    29  	resp, err := ssmClient.GetParameter(&ssm.GetParameterInput{Name: aws.String(keyName), WithDecryption: aws.Bool(true)})
    30  	if err != nil {
    31  		return "", err
    32  	}
    33  
    34  	parameter := *resp.Parameter
    35  	return *parameter.Value, nil
    36  }
    37  
    38  // PutParameter creates new version of SSM Parameter at keyName with keyValue as SecureString.
    39  func PutParameter(t testing.TestingT, awsRegion string, keyName string, keyDescription string, keyValue string) int64 {
    40  	version, err := PutParameterE(t, awsRegion, keyName, keyDescription, keyValue)
    41  	require.NoError(t, err)
    42  	return version
    43  }
    44  
    45  // PutParameterE creates new version of SSM Parameter at keyName with keyValue as SecureString.
    46  func PutParameterE(t testing.TestingT, awsRegion string, keyName string, keyDescription string, keyValue string) (int64, error) {
    47  	ssmClient, err := NewSsmClientE(t, awsRegion)
    48  	if err != nil {
    49  		return 0, err
    50  	}
    51  
    52  	resp, err := ssmClient.PutParameter(&ssm.PutParameterInput{Name: aws.String(keyName), Description: aws.String(keyDescription), Value: aws.String(keyValue), Type: aws.String("SecureString")})
    53  	if err != nil {
    54  		return 0, err
    55  	}
    56  
    57  	return *resp.Version, nil
    58  }
    59  
    60  // NewSsmClient creates a SSM client.
    61  func NewSsmClient(t testing.TestingT, region string) *ssm.SSM {
    62  	client, err := NewSsmClientE(t, region)
    63  	require.NoError(t, err)
    64  	return client
    65  }
    66  
    67  // NewSsmClientE creates an SSM client.
    68  func NewSsmClientE(t testing.TestingT, region string) (*ssm.SSM, error) {
    69  	sess, err := NewAuthenticatedSession(region)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	return ssm.New(sess), nil
    75  }
    76  
    77  // WaitForSsmInstanceE waits until the instance get registered to the SSM inventory.
    78  func WaitForSsmInstanceE(t testing.TestingT, awsRegion, instanceID string, timeout time.Duration) error {
    79  	timeBetweenRetries := 2 * time.Second
    80  	maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds())
    81  	description := fmt.Sprintf("Waiting for %s to appear in the SSM inventory", instanceID)
    82  
    83  	input := &ssm.GetInventoryInput{
    84  		Filters: []*ssm.InventoryFilter{
    85  			{
    86  				Key:    aws.String("AWS:InstanceInformation.InstanceId"),
    87  				Type:   aws.String("Equal"),
    88  				Values: aws.StringSlice([]string{instanceID}),
    89  			},
    90  		},
    91  	}
    92  	_, err := retry.DoWithRetryE(t, description, maxRetries, timeBetweenRetries, func() (string, error) {
    93  		client := NewSsmClient(t, awsRegion)
    94  		resp, err := client.GetInventory(input)
    95  
    96  		if err != nil {
    97  			return "", err
    98  		}
    99  
   100  		if len(resp.Entities) != 1 {
   101  			return "", fmt.Errorf("%s is not in the SSM inventory", instanceID)
   102  		}
   103  
   104  		return "", nil
   105  	})
   106  
   107  	return err
   108  }
   109  
   110  // WaitForSsmInstance waits until the instance get registered to the SSM inventory.
   111  func WaitForSsmInstance(t testing.TestingT, awsRegion, instanceID string, timeout time.Duration) {
   112  	err := WaitForSsmInstanceE(t, awsRegion, instanceID, timeout)
   113  	require.NoError(t, err)
   114  }
   115  
   116  // CheckSsmCommand checks that you can run the given command on the given instance through AWS SSM.
   117  func CheckSsmCommand(t testing.TestingT, awsRegion, instanceID, command string, timeout time.Duration) *CommandOutput {
   118  	result, err := CheckSsmCommandE(t, awsRegion, instanceID, command, timeout)
   119  	require.NoErrorf(t, err, "failed to execute '%s' on %s (%v):]\n  stdout: %#v\n  stderr: %#v", command, instanceID, err, result.Stdout, result.Stderr)
   120  	return result
   121  }
   122  
   123  // CommandOutput contains the result of the SSM command.
   124  type CommandOutput struct {
   125  	Stdout   string
   126  	Stderr   string
   127  	ExitCode int64
   128  }
   129  
   130  // CheckSsmCommandE checks that you can run the given command on the given instance through AWS SSM. Returns the result and an error if one occurs.
   131  func CheckSsmCommandE(t testing.TestingT, awsRegion, instanceID, command string, timeout time.Duration) (*CommandOutput, error) {
   132  	logger.Logf(t, "Running command '%s' on EC2 instance with ID '%s'", command, instanceID)
   133  
   134  	timeBetweenRetries := 2 * time.Second
   135  	maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds())
   136  
   137  	// Now that we know the instance in the SSM inventory, we can send the command
   138  	client, err := NewSsmClientE(t, awsRegion)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  	resp, err := client.SendCommand(&ssm.SendCommandInput{
   143  		Comment:      aws.String("Terratest SSM"),
   144  		DocumentName: aws.String("AWS-RunShellScript"),
   145  		InstanceIds:  aws.StringSlice([]string{instanceID}),
   146  		Parameters: map[string][]*string{
   147  			"commands": aws.StringSlice([]string{command}),
   148  		},
   149  	})
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	// Wait for the result
   155  	description := "Waiting for the result of the command"
   156  	retryableErrors := map[string]string{
   157  		"InvocationDoesNotExist": "InvocationDoesNotExist",
   158  		"bad status: Pending":    "bad status: Pending",
   159  		"bad status: InProgress": "bad status: InProgress",
   160  		"bad status: Delayed":    "bad status: Delayed",
   161  	}
   162  
   163  	result := &CommandOutput{}
   164  	_, err = retry.DoWithRetryableErrorsE(t, description, retryableErrors, maxRetries, timeBetweenRetries, func() (string, error) {
   165  		resp, err := client.GetCommandInvocation(&ssm.GetCommandInvocationInput{
   166  			CommandId:  resp.Command.CommandId,
   167  			InstanceId: &instanceID,
   168  		})
   169  
   170  		if err != nil {
   171  			return "", err
   172  		}
   173  
   174  		result.Stderr = aws.StringValue(resp.StandardErrorContent)
   175  		result.Stdout = aws.StringValue(resp.StandardOutputContent)
   176  		result.ExitCode = aws.Int64Value(resp.ResponseCode)
   177  
   178  		status := aws.StringValue(resp.Status)
   179  
   180  		if status == ssm.CommandInvocationStatusSuccess {
   181  			return "", nil
   182  		}
   183  
   184  		if status == ssm.CommandInvocationStatusFailed {
   185  			return "", fmt.Errorf(aws.StringValue(resp.StatusDetails))
   186  		}
   187  
   188  		return "", fmt.Errorf("bad status: %s", status)
   189  	})
   190  
   191  	if err != nil {
   192  		if actualErr, ok := err.(retry.FatalError); ok {
   193  			return result, actualErr.Underlying
   194  		}
   195  		return result, fmt.Errorf("Unexpected error: %v", err)
   196  	}
   197  
   198  	return result, nil
   199  }