github.com/diggerhq/digger/libs@v0.0.0-20240604170430-9d61cdf01cc5/locking/aws/dynamo_locking.go (about)

     1  package aws
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"log"
     7  	"os"
     8  	"time"
     9  
    10  	"github.com/aws/aws-sdk-go-v2/aws"
    11  	"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
    12  	"github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
    13  
    14  	"github.com/aws/aws-sdk-go-v2/service/dynamodb"
    15  	"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
    16  
    17  	"github.com/aws/smithy-go"
    18  )
    19  
    20  const (
    21  	TABLE_NAME              = "DiggerDynamoDBLockTable"
    22  	TableCreationInterval   = 1 * time.Second
    23  	TableCreationRetryCount = 10
    24  	TableLockTimeout        = 60 * 60 * 24 * 90 * time.Second
    25  )
    26  
    27  type DynamoDbLock struct {
    28  	DynamoDb DynamoDBClient
    29  }
    30  
    31  type DynamoDBClient interface {
    32  	DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error)
    33  	CreateTable(ctx context.Context, params *dynamodb.CreateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error)
    34  	UpdateItem(ctx context.Context, params *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error)
    35  	DeleteItem(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error)
    36  	GetItem(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error)
    37  }
    38  
    39  func isTableNotFoundExceptionError(err error) bool {
    40  	var apiError smithy.APIError
    41  	if errors.As(err, &apiError) {
    42  		switch apiError.(type) {
    43  		case *types.TableNotFoundException:
    44  			return true
    45  		}
    46  	}
    47  	return false
    48  }
    49  
    50  func (dynamoDbLock *DynamoDbLock) waitUntilTableCreated(ctx context.Context) error {
    51  	input := &dynamodb.DescribeTableInput{
    52  		TableName: aws.String(TABLE_NAME),
    53  	}
    54  	status, err := dynamoDbLock.DynamoDb.DescribeTable(ctx, input)
    55  	cnt := 0
    56  
    57  	if err != nil {
    58  		if !isTableNotFoundExceptionError(err) {
    59  			return err
    60  		}
    61  	}
    62  
    63  	for status.Table.TableStatus != "ACTIVE" {
    64  		time.Sleep(TableCreationInterval)
    65  		status, err = dynamoDbLock.DynamoDb.DescribeTable(ctx, input)
    66  		if err != nil {
    67  			if !isTableNotFoundExceptionError(err) {
    68  				return err
    69  			}
    70  		}
    71  		cnt++
    72  		if cnt > TableCreationRetryCount {
    73  			log.Printf("DynamoDB failed to create, timed out during creation.\n" +
    74  				"Rerunning the action may cause creation to succeed\n")
    75  			os.Exit(1)
    76  		}
    77  	}
    78  
    79  	return nil
    80  }
    81  
    82  // TODO: refactor func to return actual error and fail on callers
    83  func (dynamoDbLock *DynamoDbLock) createTableIfNotExists(ctx context.Context) error {
    84  	_, err := dynamoDbLock.DynamoDb.DescribeTable(ctx, &dynamodb.DescribeTableInput{
    85  		TableName: aws.String(TABLE_NAME),
    86  	})
    87  	if err == nil { // Table exists
    88  		return nil
    89  	}
    90  	if !isTableNotFoundExceptionError(err) {
    91  		return err
    92  	}
    93  
    94  	createtbl_input := &dynamodb.CreateTableInput{
    95  		AttributeDefinitions: []types.AttributeDefinition{
    96  			{
    97  				AttributeName: aws.String("PK"),
    98  				AttributeType: types.ScalarAttributeTypeS,
    99  			},
   100  			{
   101  				AttributeName: aws.String("SK"),
   102  				AttributeType: types.ScalarAttributeTypeS,
   103  			},
   104  		},
   105  		KeySchema: []types.KeySchemaElement{
   106  			{
   107  				AttributeName: aws.String("PK"),
   108  				KeyType:       types.KeyTypeHash,
   109  			},
   110  			{
   111  				AttributeName: aws.String("SK"),
   112  				KeyType:       types.KeyTypeRange,
   113  			},
   114  		},
   115  		BillingMode: types.BillingModePayPerRequest,
   116  		TableName:   aws.String(TABLE_NAME),
   117  	}
   118  	_, err = dynamoDbLock.DynamoDb.CreateTable(ctx, createtbl_input)
   119  	if err != nil {
   120  		if os.Getenv("DEBUG") != "" {
   121  			log.Printf("%v\n", err)
   122  		}
   123  		return err
   124  	}
   125  
   126  	err = dynamoDbLock.waitUntilTableCreated(ctx)
   127  	if err != nil {
   128  		log.Printf("%v\n", err)
   129  		return err
   130  	}
   131  	log.Printf("DynamoDB Table %v has been created\n", TABLE_NAME)
   132  	return nil
   133  }
   134  
   135  func (dynamoDbLock *DynamoDbLock) Lock(transactionId int, resource string) (bool, error) {
   136  	ctx := context.Background()
   137  	dynamoDbLock.createTableIfNotExists(ctx)
   138  	// TODO: remove timeout completely
   139  	now := time.Now().Format(time.RFC3339)
   140  	newTimeout := time.Now().Add(TableLockTimeout).Format(time.RFC3339)
   141  
   142  	expr, err := expression.NewBuilder().
   143  		WithCondition(
   144  			expression.Or(
   145  				expression.AttributeNotExists(expression.Name("SK")),
   146  				expression.LessThan(expression.Name("timeout"), expression.Value(now)),
   147  			),
   148  		).
   149  		WithUpdate(
   150  			expression.Set(
   151  				expression.Name("transaction_id"), expression.Value(transactionId),
   152  			).Set(expression.Name("timeout"), expression.Value(newTimeout)),
   153  		).
   154  		Build()
   155  	if err != nil {
   156  		return false, err
   157  	}
   158  
   159  	input := &dynamodb.UpdateItemInput{
   160  		TableName: aws.String(TABLE_NAME),
   161  		Key: map[string]types.AttributeValue{
   162  			"PK": &types.AttributeValueMemberS{Value: "LOCK"},
   163  			"SK": &types.AttributeValueMemberS{Value: "RES#" + resource},
   164  		},
   165  		ConditionExpression:       expr.Condition(),
   166  		ExpressionAttributeNames:  expr.Names(),
   167  		ExpressionAttributeValues: expr.Values(),
   168  		UpdateExpression:          expr.Update(),
   169  	}
   170  
   171  	_, err = dynamoDbLock.DynamoDb.UpdateItem(ctx, input)
   172  	if err != nil {
   173  		var apiError smithy.APIError
   174  		if errors.As(err, &apiError) {
   175  			switch apiError.(type) {
   176  			case *types.ConditionalCheckFailedException:
   177  				return false, nil
   178  			}
   179  		}
   180  		return false, err
   181  	}
   182  
   183  	return true, nil
   184  }
   185  
   186  func (dynamoDbLock *DynamoDbLock) Unlock(resource string) (bool, error) {
   187  	ctx := context.Background()
   188  	dynamoDbLock.createTableIfNotExists(ctx)
   189  	input := &dynamodb.DeleteItemInput{
   190  		TableName: aws.String(TABLE_NAME),
   191  		Key: map[string]types.AttributeValue{
   192  			"PK": &types.AttributeValueMemberS{Value: "LOCK"},
   193  			"SK": &types.AttributeValueMemberS{Value: "RES#" + resource},
   194  		},
   195  	}
   196  
   197  	_, err := dynamoDbLock.DynamoDb.DeleteItem(ctx, input)
   198  	if err != nil {
   199  		return false, err
   200  	}
   201  	return true, nil
   202  }
   203  
   204  func (dynamoDbLock *DynamoDbLock) GetLock(lockId string) (*int, error) {
   205  	ctx := context.Background()
   206  	dynamoDbLock.createTableIfNotExists(ctx)
   207  	input := &dynamodb.GetItemInput{
   208  		TableName: aws.String(TABLE_NAME),
   209  		Key: map[string]types.AttributeValue{
   210  			"PK": &types.AttributeValueMemberS{Value: "LOCK"},
   211  			"SK": &types.AttributeValueMemberS{Value: "RES#" + lockId},
   212  		},
   213  		ConsistentRead: aws.Bool(true),
   214  	}
   215  
   216  	result, err := dynamoDbLock.DynamoDb.GetItem(ctx, input)
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  
   221  	type TransactionLock struct {
   222  		TransactionID int    `dynamodbav:"transaction_id"`
   223  		Timeout       string `dynamodbav:"timeout"`
   224  	}
   225  
   226  	var t TransactionLock
   227  	err = attributevalue.UnmarshalMap(result.Item, &t)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  	if t.TransactionID != 0 {
   232  		return &t.TransactionID, nil
   233  	}
   234  	return nil, nil
   235  }