github.com/mweagle/Sparta@v1.15.0/aws/accessor/s3.go (about)

     1  package accessor
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"io/ioutil"
     8  
     9  	"github.com/aws/aws-sdk-go/aws"
    10  	"github.com/aws/aws-sdk-go/service/s3"
    11  	sparta "github.com/mweagle/Sparta"
    12  	spartaAWS "github.com/mweagle/Sparta/aws"
    13  	spartaCF "github.com/mweagle/Sparta/aws/cloudformation"
    14  	gocf "github.com/mweagle/go-cloudformation"
    15  	"github.com/sirupsen/logrus"
    16  )
    17  
    18  // S3Accessor to make it a bit easier to work with S3
    19  // as the backing store
    20  type S3Accessor struct {
    21  	testingBucketName    string
    22  	S3BucketResourceName string
    23  }
    24  
    25  // BucketPrivilege returns a privilege that targets the Bucket
    26  func (svc *S3Accessor) BucketPrivilege(bucketPrivs ...string) sparta.IAMRolePrivilege {
    27  	return sparta.IAMRolePrivilege{
    28  		Actions:  bucketPrivs,
    29  		Resource: spartaCF.S3ArnForBucket(gocf.Ref(svc.S3BucketResourceName)),
    30  	}
    31  }
    32  
    33  // KeysPrivilege returns a privilege that targets the Bucket objects
    34  func (svc *S3Accessor) KeysPrivilege(keyPrivileges ...string) sparta.IAMRolePrivilege {
    35  	return sparta.IAMRolePrivilege{
    36  		Actions:  keyPrivileges,
    37  		Resource: spartaCF.S3AllKeysArnForBucket(gocf.Ref(svc.S3BucketResourceName)),
    38  	}
    39  }
    40  
    41  func (svc *S3Accessor) s3Svc(ctx context.Context) *s3.S3 {
    42  	logger, _ := ctx.Value(sparta.ContextKeyLogger).(*logrus.Logger)
    43  	sess := spartaAWS.NewSession(logger)
    44  	s3Client := s3.New(sess)
    45  	xrayInit(s3Client.Client)
    46  	return s3Client
    47  }
    48  
    49  func (svc *S3Accessor) s3BucketName() string {
    50  	if svc.testingBucketName != "" {
    51  		return svc.testingBucketName
    52  	}
    53  	discover, discoveryInfoErr := sparta.Discover()
    54  	if discoveryInfoErr != nil {
    55  		return ""
    56  	}
    57  	s3BucketRes, s3BucketResExists := discover.Resources[svc.S3BucketResourceName]
    58  	if !s3BucketResExists {
    59  		return ""
    60  	}
    61  	return s3BucketRes.ResourceRef
    62  }
    63  
    64  // Delete handles deleting the resource
    65  func (svc *S3Accessor) Delete(ctx context.Context, keyPath string) error {
    66  	deleteObjectInput := &s3.DeleteObjectInput{
    67  		Bucket: aws.String(svc.s3BucketName()),
    68  		Key:    aws.String(keyPath),
    69  	}
    70  	_, deleteResultErr := svc.
    71  		s3Svc(ctx).
    72  		DeleteObjectWithContext(ctx, deleteObjectInput)
    73  
    74  	return deleteResultErr
    75  }
    76  
    77  // DeleteAll handles deleting all the items
    78  func (svc *S3Accessor) DeleteAll(ctx context.Context) error {
    79  	// List each one, delete it
    80  
    81  	listObjectInput := &s3.ListObjectsInput{
    82  		Bucket: aws.String(svc.s3BucketName()),
    83  	}
    84  
    85  	listObjectResult, listObjectResultErr := svc.
    86  		s3Svc(ctx).
    87  		ListObjectsWithContext(ctx, listObjectInput)
    88  
    89  	if listObjectResultErr != nil {
    90  		return nil
    91  	}
    92  	for _, eachObject := range listObjectResult.Contents {
    93  		deleteErr := svc.Delete(ctx, *eachObject.Key)
    94  		if deleteErr != nil {
    95  			return deleteErr
    96  		}
    97  	}
    98  	return nil
    99  }
   100  
   101  // Put handles saving the item
   102  func (svc *S3Accessor) Put(ctx context.Context, keyPath string, object interface{}) error {
   103  	jsonBytes, jsonBytesErr := json.Marshal(object)
   104  	if jsonBytesErr != nil {
   105  		return jsonBytesErr
   106  	}
   107  
   108  	logger, _ := ctx.Value(sparta.ContextKeyLogger).(*logrus.Logger)
   109  	logger.WithFields(logrus.Fields{
   110  		"Bytes":   string(jsonBytes),
   111  		"KeyPath": keyPath}).Debug("Saving S3 object")
   112  
   113  	bytesReader := bytes.NewReader(jsonBytes)
   114  	putObjectInput := &s3.PutObjectInput{
   115  		Bucket: aws.String(svc.s3BucketName()),
   116  		Key:    aws.String(keyPath),
   117  		Body:   bytesReader,
   118  	}
   119  	putObjectResponse, putObjectRespErr := svc.
   120  		s3Svc(ctx).
   121  		PutObjectWithContext(ctx, putObjectInput)
   122  
   123  	logger.WithFields(logrus.Fields{
   124  		"Error":   putObjectRespErr,
   125  		"Results": putObjectResponse}).Debug("Save object results")
   126  
   127  	return putObjectRespErr
   128  }
   129  
   130  // Get handles getting the item
   131  func (svc *S3Accessor) Get(ctx context.Context,
   132  	keyPath string,
   133  	destObject interface{}) error {
   134  
   135  	getObjectInput := &s3.GetObjectInput{
   136  		Bucket: aws.String(svc.s3BucketName()),
   137  		Key:    aws.String(keyPath),
   138  	}
   139  	getObjectResult, getObjectResultErr := svc.
   140  		s3Svc(ctx).
   141  		GetObjectWithContext(ctx, getObjectInput)
   142  	if getObjectResultErr != nil {
   143  		return getObjectResultErr
   144  	}
   145  	jsonBytes, jsonBytesErr := ioutil.ReadAll(getObjectResult.Body)
   146  	if jsonBytesErr != nil {
   147  		return jsonBytesErr
   148  	}
   149  	return json.Unmarshal(jsonBytes, destObject)
   150  }
   151  
   152  // GetAll handles returning all of the items
   153  func (svc *S3Accessor) GetAll(ctx context.Context,
   154  	ctor NewObjectConstructor) ([]interface{}, error) {
   155  
   156  	listObjectInput := &s3.ListObjectsInput{
   157  		Bucket: aws.String(svc.s3BucketName()),
   158  	}
   159  
   160  	listObjectResult, listObjectResultErr := svc.
   161  		s3Svc(ctx).
   162  		ListObjectsWithContext(ctx, listObjectInput)
   163  
   164  	if listObjectResultErr != nil {
   165  		return nil, listObjectResultErr
   166  	}
   167  	allObjects := make([]interface{}, 0)
   168  	for _, eachObject := range listObjectResult.Contents {
   169  		objectInstance := ctor()
   170  		entryEntryErr := svc.Get(ctx, *eachObject.Key, objectInstance)
   171  		if entryEntryErr != nil {
   172  			return nil, entryEntryErr
   173  		}
   174  		allObjects = append(allObjects, objectInstance)
   175  	}
   176  	return allObjects, nil
   177  }