github.com/mweagle/Sparta@v1.15.0/aws/accessor/dynamo.go (about) 1 package accessor 2 3 // Simple dynamo accessor to get put range over items... 4 5 import ( 6 "context" 7 "errors" 8 9 "github.com/aws/aws-sdk-go/aws" 10 "github.com/aws/aws-sdk-go/service/dynamodb" 11 "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" 12 sparta "github.com/mweagle/Sparta" 13 spartaAWS "github.com/mweagle/Sparta/aws" 14 "github.com/sirupsen/logrus" 15 ) 16 17 const ( 18 attrID = "id" 19 ) 20 21 // DynamoAccessor to make it a bit easier to work with Dynamo 22 // as the backing store 23 type DynamoAccessor struct { 24 testingTableName string 25 DynamoTableResourceName string 26 } 27 28 func (svc *DynamoAccessor) dynamoSvc(ctx context.Context) *dynamodb.DynamoDB { 29 logger, _ := ctx.Value(sparta.ContextKeyLogger).(*logrus.Logger) 30 sess := spartaAWS.NewSession(logger) 31 dynamoClient := dynamodb.New(sess) 32 xrayInit(dynamoClient.Client) 33 return dynamoClient 34 } 35 36 func (svc *DynamoAccessor) dynamoTableName() string { 37 if svc.testingTableName != "" { 38 return svc.testingTableName 39 } 40 discover, discoveryInfoErr := sparta.Discover() 41 if discoveryInfoErr != nil { 42 return "" 43 } 44 dynamoTableRes, dynamoTableResExists := discover.Resources[svc.DynamoTableResourceName] 45 if !dynamoTableResExists { 46 return "" 47 } 48 return dynamoTableRes.ResourceRef 49 } 50 51 func dynamoKeyValueAttrMap(keyPath string) map[string]*dynamodb.AttributeValue { 52 return map[string]*dynamodb.AttributeValue{ 53 attrID: { 54 S: aws.String(keyPath), 55 }} 56 } 57 58 // Delete handles deleting the resource 59 func (svc *DynamoAccessor) Delete(ctx context.Context, keyPath string) error { 60 deleteItemInput := &dynamodb.DeleteItemInput{ 61 TableName: aws.String(svc.dynamoTableName()), 62 Key: dynamoKeyValueAttrMap(keyPath), 63 } 64 _, deleteResultErr := svc. 65 dynamoSvc(ctx). 66 DeleteItemWithContext(ctx, deleteItemInput) 67 return deleteResultErr 68 } 69 70 // DeleteAll handles deleting all the items 71 func (svc *DynamoAccessor) DeleteAll(ctx context.Context) error { 72 var deleteErr error 73 input := &dynamodb.ScanInput{ 74 TableName: aws.String(svc.dynamoTableName()), 75 } 76 77 scanHandler := func(output *dynamodb.ScanOutput, lastPage bool) bool { 78 writeDeleteRequests := make([]*dynamodb.WriteRequest, len(output.Items)) 79 for index, eachItem := range output.Items { 80 keyID := "" 81 stringVal, stringValOk := eachItem[attrID] 82 if stringValOk && stringVal.S != nil { 83 keyID = *(stringVal.S) 84 } 85 writeDeleteRequests[index] = &dynamodb.WriteRequest{ 86 DeleteRequest: &dynamodb.DeleteRequest{ 87 Key: dynamoKeyValueAttrMap(keyID), 88 }, 89 } 90 } 91 input := &dynamodb.BatchWriteItemInput{ 92 RequestItems: map[string][]*dynamodb.WriteRequest{ 93 svc.dynamoTableName(): writeDeleteRequests, 94 }, 95 } 96 _, deleteErr = svc.dynamoSvc(ctx).BatchWriteItem(input) 97 return deleteErr == nil 98 } 99 100 scanErr := svc.dynamoSvc(ctx).ScanPagesWithContext(ctx, input, scanHandler) 101 if scanErr != nil { 102 return scanErr 103 } 104 return deleteErr 105 } 106 107 // Put handles saving the item 108 func (svc *DynamoAccessor) Put(ctx context.Context, keyPath string, object interface{}) error { 109 110 // What's the type of the object? 111 if object == nil { 112 return errors.New("DynamoAccessor Put object must not be nil") 113 } 114 // Map it... 115 marshal, marshalErr := dynamodbattribute.MarshalMap(object) 116 if marshalErr != nil { 117 return marshalErr 118 } 119 // TODO - consider using tags for this... 120 _, idExists := marshal[attrID] 121 if !idExists { 122 marshal[attrID] = &dynamodb.AttributeValue{ 123 S: aws.String(keyPath), 124 } 125 } 126 putItemInput := &dynamodb.PutItemInput{ 127 TableName: aws.String(svc.dynamoTableName()), 128 Item: marshal, 129 } 130 _, putItemErr := svc.dynamoSvc(ctx).PutItemWithContext(ctx, putItemInput) 131 return putItemErr 132 } 133 134 // Get handles getting the item 135 func (svc *DynamoAccessor) Get(ctx context.Context, 136 keyPath string, 137 destObject interface{}) error { 138 getItemInput := &dynamodb.GetItemInput{ 139 TableName: aws.String(svc.dynamoTableName()), 140 Key: dynamoKeyValueAttrMap(keyPath), 141 } 142 143 getItemResult, getItemResultErr := svc.dynamoSvc(ctx).GetItemWithContext(ctx, getItemInput) 144 if getItemResultErr != nil { 145 return getItemResultErr 146 } 147 return dynamodbattribute.UnmarshalMap(getItemResult.Item, destObject) 148 } 149 150 // GetAll handles returning all of the items 151 func (svc *DynamoAccessor) GetAll(ctx context.Context, 152 ctor NewObjectConstructor) ([]interface{}, error) { 153 var getAllErr error 154 155 results := make([]interface{}, 0) 156 input := &dynamodb.ScanInput{ 157 TableName: aws.String(svc.dynamoTableName()), 158 } 159 scanHandler := func(output *dynamodb.ScanOutput, lastPage bool) bool { 160 for _, eachItem := range output.Items { 161 unmarshalTarget := ctor() 162 unmarshalErr := dynamodbattribute.UnmarshalMap(eachItem, unmarshalTarget) 163 if unmarshalErr != nil { 164 getAllErr = unmarshalErr 165 return false 166 } 167 results = append(results, unmarshalTarget) 168 } 169 return true 170 } 171 scanErr := svc.dynamoSvc(ctx).ScanPagesWithContext(ctx, input, scanHandler) 172 if scanErr != nil { 173 return nil, scanErr 174 } 175 if getAllErr != nil { 176 return nil, getAllErr 177 } 178 return results, nil 179 }