github.com/mweagle/Sparta@v1.15.0/aws/accessor/dynamo.go (about)

     1  package accessor
     2  
     3  // Simple dynamo accessor to get put range over items...
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  
     9  	"github.com/aws/aws-sdk-go/aws"
    10  	"github.com/aws/aws-sdk-go/service/dynamodb"
    11  	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
    12  	sparta "github.com/mweagle/Sparta"
    13  	spartaAWS "github.com/mweagle/Sparta/aws"
    14  	"github.com/sirupsen/logrus"
    15  )
    16  
    17  const (
    18  	attrID = "id"
    19  )
    20  
    21  // DynamoAccessor to make it a bit easier to work with Dynamo
    22  // as the backing store
    23  type DynamoAccessor struct {
    24  	testingTableName        string
    25  	DynamoTableResourceName string
    26  }
    27  
    28  func (svc *DynamoAccessor) dynamoSvc(ctx context.Context) *dynamodb.DynamoDB {
    29  	logger, _ := ctx.Value(sparta.ContextKeyLogger).(*logrus.Logger)
    30  	sess := spartaAWS.NewSession(logger)
    31  	dynamoClient := dynamodb.New(sess)
    32  	xrayInit(dynamoClient.Client)
    33  	return dynamoClient
    34  }
    35  
    36  func (svc *DynamoAccessor) dynamoTableName() string {
    37  	if svc.testingTableName != "" {
    38  		return svc.testingTableName
    39  	}
    40  	discover, discoveryInfoErr := sparta.Discover()
    41  	if discoveryInfoErr != nil {
    42  		return ""
    43  	}
    44  	dynamoTableRes, dynamoTableResExists := discover.Resources[svc.DynamoTableResourceName]
    45  	if !dynamoTableResExists {
    46  		return ""
    47  	}
    48  	return dynamoTableRes.ResourceRef
    49  }
    50  
    51  func dynamoKeyValueAttrMap(keyPath string) map[string]*dynamodb.AttributeValue {
    52  	return map[string]*dynamodb.AttributeValue{
    53  		attrID: {
    54  			S: aws.String(keyPath),
    55  		}}
    56  }
    57  
    58  // Delete handles deleting the resource
    59  func (svc *DynamoAccessor) Delete(ctx context.Context, keyPath string) error {
    60  	deleteItemInput := &dynamodb.DeleteItemInput{
    61  		TableName: aws.String(svc.dynamoTableName()),
    62  		Key:       dynamoKeyValueAttrMap(keyPath),
    63  	}
    64  	_, deleteResultErr := svc.
    65  		dynamoSvc(ctx).
    66  		DeleteItemWithContext(ctx, deleteItemInput)
    67  	return deleteResultErr
    68  }
    69  
    70  // DeleteAll handles deleting all the items
    71  func (svc *DynamoAccessor) DeleteAll(ctx context.Context) error {
    72  	var deleteErr error
    73  	input := &dynamodb.ScanInput{
    74  		TableName: aws.String(svc.dynamoTableName()),
    75  	}
    76  
    77  	scanHandler := func(output *dynamodb.ScanOutput, lastPage bool) bool {
    78  		writeDeleteRequests := make([]*dynamodb.WriteRequest, len(output.Items))
    79  		for index, eachItem := range output.Items {
    80  			keyID := ""
    81  			stringVal, stringValOk := eachItem[attrID]
    82  			if stringValOk && stringVal.S != nil {
    83  				keyID = *(stringVal.S)
    84  			}
    85  			writeDeleteRequests[index] = &dynamodb.WriteRequest{
    86  				DeleteRequest: &dynamodb.DeleteRequest{
    87  					Key: dynamoKeyValueAttrMap(keyID),
    88  				},
    89  			}
    90  		}
    91  		input := &dynamodb.BatchWriteItemInput{
    92  			RequestItems: map[string][]*dynamodb.WriteRequest{
    93  				svc.dynamoTableName(): writeDeleteRequests,
    94  			},
    95  		}
    96  		_, deleteErr = svc.dynamoSvc(ctx).BatchWriteItem(input)
    97  		return deleteErr == nil
    98  	}
    99  
   100  	scanErr := svc.dynamoSvc(ctx).ScanPagesWithContext(ctx, input, scanHandler)
   101  	if scanErr != nil {
   102  		return scanErr
   103  	}
   104  	return deleteErr
   105  }
   106  
   107  // Put handles saving the item
   108  func (svc *DynamoAccessor) Put(ctx context.Context, keyPath string, object interface{}) error {
   109  
   110  	// What's the type of the object?
   111  	if object == nil {
   112  		return errors.New("DynamoAccessor Put object must not be nil")
   113  	}
   114  	// Map it...
   115  	marshal, marshalErr := dynamodbattribute.MarshalMap(object)
   116  	if marshalErr != nil {
   117  		return marshalErr
   118  	}
   119  	// TODO - consider using tags for this...
   120  	_, idExists := marshal[attrID]
   121  	if !idExists {
   122  		marshal[attrID] = &dynamodb.AttributeValue{
   123  			S: aws.String(keyPath),
   124  		}
   125  	}
   126  	putItemInput := &dynamodb.PutItemInput{
   127  		TableName: aws.String(svc.dynamoTableName()),
   128  		Item:      marshal,
   129  	}
   130  	_, putItemErr := svc.dynamoSvc(ctx).PutItemWithContext(ctx, putItemInput)
   131  	return putItemErr
   132  }
   133  
   134  // Get handles getting the item
   135  func (svc *DynamoAccessor) Get(ctx context.Context,
   136  	keyPath string,
   137  	destObject interface{}) error {
   138  	getItemInput := &dynamodb.GetItemInput{
   139  		TableName: aws.String(svc.dynamoTableName()),
   140  		Key:       dynamoKeyValueAttrMap(keyPath),
   141  	}
   142  
   143  	getItemResult, getItemResultErr := svc.dynamoSvc(ctx).GetItemWithContext(ctx, getItemInput)
   144  	if getItemResultErr != nil {
   145  		return getItemResultErr
   146  	}
   147  	return dynamodbattribute.UnmarshalMap(getItemResult.Item, destObject)
   148  }
   149  
   150  // GetAll handles returning all of the items
   151  func (svc *DynamoAccessor) GetAll(ctx context.Context,
   152  	ctor NewObjectConstructor) ([]interface{}, error) {
   153  	var getAllErr error
   154  
   155  	results := make([]interface{}, 0)
   156  	input := &dynamodb.ScanInput{
   157  		TableName: aws.String(svc.dynamoTableName()),
   158  	}
   159  	scanHandler := func(output *dynamodb.ScanOutput, lastPage bool) bool {
   160  		for _, eachItem := range output.Items {
   161  			unmarshalTarget := ctor()
   162  			unmarshalErr := dynamodbattribute.UnmarshalMap(eachItem, unmarshalTarget)
   163  			if unmarshalErr != nil {
   164  				getAllErr = unmarshalErr
   165  				return false
   166  			}
   167  			results = append(results, unmarshalTarget)
   168  		}
   169  		return true
   170  	}
   171  	scanErr := svc.dynamoSvc(ctx).ScanPagesWithContext(ctx, input, scanHandler)
   172  	if scanErr != nil {
   173  		return nil, scanErr
   174  	}
   175  	if getAllErr != nil {
   176  		return nil, getAllErr
   177  	}
   178  	return results, nil
   179  }