github.com/mweagle/Sparta@v1.15.0/validator/drift.go (about)

     1  package validator
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  	"time"
     7  
     8  	"github.com/aws/aws-sdk-go/aws"
     9  	"github.com/aws/aws-sdk-go/aws/session"
    10  	"github.com/aws/aws-sdk-go/service/cloudformation"
    11  	sparta "github.com/mweagle/Sparta"
    12  	gocf "github.com/mweagle/go-cloudformation"
    13  	"github.com/pkg/errors"
    14  	"github.com/sirupsen/logrus"
    15  )
    16  
    17  // DriftDetector is a detector that ensures that the service hasn't
    18  // experienced configuration drift prior to being overwritten by a new provisioning
    19  // step.
    20  func DriftDetector(errorOnDrift bool) sparta.ServiceValidationHookHandler {
    21  
    22  	driftDetector := func(context map[string]interface{},
    23  		serviceName string,
    24  		template *gocf.Template,
    25  		S3Bucket string,
    26  		S3Key string,
    27  		buildID string,
    28  		awsSession *session.Session,
    29  		noop bool,
    30  		logger *logrus.Logger) error {
    31  		// Create a cloudformation service.
    32  		cfSvc := cloudformation.New(awsSession)
    33  		detectStackDrift, detectStackDriftErr := cfSvc.DetectStackDrift(&cloudformation.DetectStackDriftInput{
    34  			StackName: aws.String(serviceName),
    35  		})
    36  		if detectStackDriftErr != nil {
    37  			// If it doesn't exist, then no worries...
    38  			if strings.Contains(detectStackDriftErr.Error(), "does not exist") {
    39  				return nil
    40  			}
    41  			return errors.Wrapf(detectStackDriftErr, "attempting to determine stack drift")
    42  		}
    43  
    44  		// Poll until it's done...
    45  		describeDriftDetectionStatus := &cloudformation.DescribeStackDriftDetectionStatusInput{
    46  			StackDriftDetectionId: detectStackDrift.StackDriftDetectionId,
    47  		}
    48  		detectionComplete := false
    49  
    50  		// Put a limit on the detection
    51  		for i := 0; i <= 30 && !detectionComplete; i++ {
    52  			driftStatus, driftStatusErr := cfSvc.DescribeStackDriftDetectionStatus(describeDriftDetectionStatus)
    53  			if driftStatusErr != nil {
    54  				logger.WithField("error", driftStatusErr).Warn("Failed to check Stack Drift")
    55  			}
    56  			if driftStatus != nil {
    57  				switch *driftStatus.DetectionStatus {
    58  				case "DETECTION_COMPLETE":
    59  					detectionComplete = true
    60  				default:
    61  					logger.WithField("Status", *driftStatus.DetectionStatus).
    62  						Info("Waiting for drift detection to complete")
    63  					time.Sleep(11 * time.Second)
    64  				}
    65  			}
    66  		}
    67  		if !detectionComplete {
    68  			return errors.Errorf("Stack drift detection did not complete in time")
    69  		}
    70  
    71  		golangFuncName := func(logicalResourceID string) string {
    72  			templateRes, templateResExists := template.Resources[logicalResourceID]
    73  			if !templateResExists {
    74  				return ""
    75  			}
    76  			metadata := templateRes.Metadata
    77  			if len(metadata) <= 0 {
    78  				metadata = make(map[string]interface{})
    79  			}
    80  			golangFunc, golangFuncExists := metadata["golangFunc"]
    81  			if !golangFuncExists {
    82  				return ""
    83  			}
    84  			switch typedFunc := golangFunc.(type) {
    85  			case string:
    86  				return typedFunc
    87  			default:
    88  				return fmt.Sprintf("%#v", typedFunc)
    89  			}
    90  		}
    91  
    92  		// Log the drifts
    93  		logDrifts := func(stackResourceDrifts []*cloudformation.StackResourceDrift) {
    94  			for _, eachDrift := range stackResourceDrifts {
    95  				if len(eachDrift.PropertyDifferences) != 0 {
    96  					for _, eachDiff := range eachDrift.PropertyDifferences {
    97  						entry := logger.WithFields(logrus.Fields{
    98  							"Resource":       *eachDrift.LogicalResourceId,
    99  							"Actual":         *eachDiff.ActualValue,
   100  							"Expected":       *eachDiff.ExpectedValue,
   101  							"Relation":       *eachDiff.DifferenceType,
   102  							"PropertyPath":   *eachDiff.PropertyPath,
   103  							"LambdaFuncName": golangFuncName(*eachDrift.LogicalResourceId),
   104  						})
   105  						if errorOnDrift {
   106  							entry.Error("Stack drift detected")
   107  						} else {
   108  							entry.Warn("Stack drift detected")
   109  						}
   110  					}
   111  				}
   112  			}
   113  		}
   114  
   115  		// Utility function to fetch all the drifts
   116  		stackResourceDrifts := make([]*cloudformation.StackResourceDrift, 0)
   117  		input := &cloudformation.DescribeStackResourceDriftsInput{
   118  			MaxResults: aws.Int64(100),
   119  			StackName:  aws.String(serviceName),
   120  		}
   121  		// There can't be more than 200 resources in the template
   122  		// https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/cloudformation-limits.html
   123  		loopCounter := 0
   124  		for {
   125  			driftResults, driftResultsErr := cfSvc.DescribeStackResourceDrifts(input)
   126  			if driftResultsErr != nil {
   127  				return errors.Wrapf(driftResultsErr, "attempting to describe stack drift")
   128  			}
   129  			stackResourceDrifts = append(stackResourceDrifts, driftResults.StackResourceDrifts...)
   130  			if driftResults.NextToken == nil {
   131  				break
   132  			}
   133  			loopCounter++
   134  			// If there is more than 10 (1k total) something is seriously wrong...
   135  			if loopCounter >= 10 {
   136  				logDrifts(stackResourceDrifts)
   137  				return errors.Errorf("Exceeded maximum number of Stack resource drifts: %d", len(stackResourceDrifts))
   138  			}
   139  
   140  			input = &cloudformation.DescribeStackResourceDriftsInput{
   141  				MaxResults: aws.Int64(100),
   142  				StackName:  aws.String(serviceName),
   143  				NextToken:  driftResults.NextToken,
   144  			}
   145  		}
   146  
   147  		// Log them
   148  		logDrifts(stackResourceDrifts)
   149  		if len(stackResourceDrifts) == 0 || !errorOnDrift {
   150  			return nil
   151  		}
   152  		return errors.Errorf("stack %s operation prevented due to stack drift", serviceName)
   153  	}
   154  	return sparta.ServiceValidationHookFunc(driftDetector)
   155  }