github.com/Jeffail/benthos/v3@v3.65.0/lib/input/aws_kinesis_checkpointer.go (about)

     1  package input
     2  
     3  // Inspired by Patrick Robinson https://github.com/patrobinson/gokini
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/Jeffail/benthos/v3/internal/docs"
    12  	"github.com/aws/aws-sdk-go/aws"
    13  	"github.com/aws/aws-sdk-go/aws/awserr"
    14  	"github.com/aws/aws-sdk-go/aws/session"
    15  	"github.com/aws/aws-sdk-go/service/dynamodb"
    16  	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
    17  )
    18  
    19  //------------------------------------------------------------------------------
    20  
    21  var dynamoDBCheckpointFields = docs.FieldSpecs{
    22  	docs.FieldCommon("table", "The name of the table to access."),
    23  	docs.FieldCommon("create", "Whether, if the table does not exist, it should be created."),
    24  	docs.FieldAdvanced("billing_mode", "When creating the table determines the billing mode.").HasOptions("PROVISIONED", "PAY_PER_REQUEST"),
    25  	docs.FieldAdvanced("read_capacity_units", "Set the provisioned read capacity when creating the table with a `billing_mode` of `PROVISIONED`."),
    26  	docs.FieldAdvanced("write_capacity_units", "Set the provisioned write capacity when creating the table with a `billing_mode` of `PROVISIONED`."),
    27  }
    28  
    29  // DynamoDBCheckpointConfig contains configuration parameters for a DynamoDB
    30  // based checkpoint store for Kinesis.
    31  type DynamoDBCheckpointConfig struct {
    32  	Table              string `json:"table" yaml:"table"`
    33  	Create             bool   `json:"create" yaml:"create"`
    34  	ReadCapacityUnits  int64  `json:"read_capacity_units" yaml:"read_capacity_units"`
    35  	WriteCapacityUnits int64  `json:"write_capacity_units" yaml:"write_capacity_units"`
    36  	BillingMode        string `json:"billing_mode" yaml:"billing_mode"`
    37  }
    38  
    39  // NewDynamoDBCheckpointConfig returns a DynamoDBCheckpoint config struct with
    40  // default values.
    41  func NewDynamoDBCheckpointConfig() DynamoDBCheckpointConfig {
    42  	return DynamoDBCheckpointConfig{
    43  		Table:              "",
    44  		Create:             false,
    45  		ReadCapacityUnits:  0,
    46  		WriteCapacityUnits: 0,
    47  		BillingMode:        "PAY_PER_REQUEST",
    48  	}
    49  }
    50  
    51  //------------------------------------------------------------------------------
    52  
    53  // Common errors that might occur throughout checkpointing.
    54  var (
    55  	ErrLeaseNotAcquired = errors.New("the shard could not be leased due to a collision")
    56  )
    57  
    58  // awsKinesisCheckpointer manages the shard checkpointing for a given client
    59  // identifier.
    60  type awsKinesisCheckpointer struct {
    61  	conf DynamoDBCheckpointConfig
    62  
    63  	clientID      string
    64  	leaseDuration time.Duration
    65  	commitPeriod  time.Duration
    66  	svc           dynamodbiface.DynamoDBAPI
    67  }
    68  
    69  // newAWSKinesisCheckpointer creates a new DynamoDB checkpointer from an AWS
    70  // session and a configuration struct.
    71  func newAWSKinesisCheckpointer(
    72  	session *session.Session,
    73  	clientID string,
    74  	conf DynamoDBCheckpointConfig,
    75  	leaseDuration time.Duration,
    76  	commitPeriod time.Duration,
    77  ) (*awsKinesisCheckpointer, error) {
    78  	c := &awsKinesisCheckpointer{
    79  		conf:          conf,
    80  		leaseDuration: leaseDuration,
    81  		commitPeriod:  commitPeriod,
    82  		svc:           dynamodb.New(session),
    83  		clientID:      clientID,
    84  	}
    85  
    86  	if err := c.ensureTableExists(); err != nil {
    87  		return nil, err
    88  	}
    89  	return c, nil
    90  }
    91  
    92  //------------------------------------------------------------------------------
    93  
    94  func (k *awsKinesisCheckpointer) ensureTableExists() error {
    95  	_, err := k.svc.DescribeTable(&dynamodb.DescribeTableInput{
    96  		TableName: aws.String(k.conf.Table),
    97  	})
    98  	if err == nil {
    99  		return nil
   100  	}
   101  	if aerr, ok := err.(awserr.Error); !ok || aerr.Code() != dynamodb.ErrCodeResourceNotFoundException {
   102  		return err
   103  	}
   104  	if !k.conf.Create {
   105  		return fmt.Errorf("target table %v does not exist", k.conf.Table)
   106  	}
   107  
   108  	input := &dynamodb.CreateTableInput{
   109  		AttributeDefinitions: []*dynamodb.AttributeDefinition{
   110  			{AttributeName: aws.String("StreamID"), AttributeType: aws.String("S")},
   111  			{AttributeName: aws.String("ShardID"), AttributeType: aws.String("S")},
   112  		},
   113  		BillingMode: aws.String(k.conf.BillingMode),
   114  		KeySchema: []*dynamodb.KeySchemaElement{
   115  			{AttributeName: aws.String("StreamID"), KeyType: aws.String("HASH")},
   116  			{AttributeName: aws.String("ShardID"), KeyType: aws.String("RANGE")},
   117  		},
   118  		TableName: aws.String(k.conf.Table),
   119  	}
   120  	if k.conf.BillingMode == "PROVISIONED" {
   121  		input.ProvisionedThroughput = &dynamodb.ProvisionedThroughput{
   122  			ReadCapacityUnits:  aws.Int64(k.conf.ReadCapacityUnits),
   123  			WriteCapacityUnits: aws.Int64(k.conf.WriteCapacityUnits),
   124  		}
   125  	}
   126  	if _, err = k.svc.CreateTable(input); err != nil {
   127  		return fmt.Errorf("failed to create table: %w", err)
   128  	}
   129  	return nil
   130  }
   131  
   132  // awsKinesisCheckpoint contains details of a shard checkpoint.
   133  type awsKinesisCheckpoint struct {
   134  	SequenceNumber string
   135  	ClientID       *string
   136  	LeaseTimeout   *time.Time
   137  }
   138  
   139  // Both checkpoint and err can be nil when the item does not exist.
   140  func (k *awsKinesisCheckpointer) getCheckpoint(ctx context.Context, streamID, shardID string) (*awsKinesisCheckpoint, error) {
   141  	rawItem, err := k.svc.GetItemWithContext(ctx, &dynamodb.GetItemInput{
   142  		TableName: aws.String(k.conf.Table),
   143  		Key: map[string]*dynamodb.AttributeValue{
   144  			"ShardID": {
   145  				S: aws.String(shardID),
   146  			},
   147  			"StreamID": {
   148  				S: aws.String(streamID),
   149  			},
   150  		},
   151  	})
   152  	if err != nil {
   153  		if aerr, ok := err.(awserr.Error); ok {
   154  			if aerr.Code() == dynamodb.ErrCodeResourceNotFoundException {
   155  				return nil, nil
   156  			}
   157  		}
   158  		return nil, err
   159  	}
   160  
   161  	c := awsKinesisCheckpoint{}
   162  
   163  	if s, ok := rawItem.Item["SequenceNumber"]; ok && s.S != nil {
   164  		c.SequenceNumber = *s.S
   165  	} else {
   166  		return nil, errors.New("sequence ID was not found in checkpoint")
   167  	}
   168  
   169  	if s, ok := rawItem.Item["ClientID"]; ok && s.S != nil {
   170  		c.ClientID = s.S
   171  	}
   172  
   173  	if s, ok := rawItem.Item["LeaseTimeout"]; ok && s.S != nil {
   174  		timeout, err := time.Parse(time.RFC3339Nano, *s.S)
   175  		if err != nil {
   176  			return nil, err
   177  		}
   178  		c.LeaseTimeout = &timeout
   179  	}
   180  
   181  	return &c, nil
   182  }
   183  
   184  //------------------------------------------------------------------------------
   185  
   186  // awsKinesisClientClaim represents a shard claimed by a client.
   187  type awsKinesisClientClaim struct {
   188  	ShardID      string
   189  	LeaseTimeout time.Time
   190  }
   191  
   192  // AllClaims returns a map of client IDs to shards claimed by that client,
   193  // including the lease timeout of the claim.
   194  func (k *awsKinesisCheckpointer) AllClaims(ctx context.Context, streamID string) (map[string][]awsKinesisClientClaim, error) {
   195  	clientClaims := make(map[string][]awsKinesisClientClaim)
   196  	var scanErr error
   197  
   198  	if err := k.svc.ScanPagesWithContext(ctx, &dynamodb.ScanInput{
   199  		TableName:        aws.String(k.conf.Table),
   200  		FilterExpression: aws.String("StreamID = :stream_id"),
   201  		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
   202  			":stream_id": {
   203  				S: &streamID,
   204  			},
   205  		},
   206  	}, func(page *dynamodb.ScanOutput, last bool) bool {
   207  		for _, i := range page.Items {
   208  			var clientID string
   209  			if s, ok := i["ClientID"]; ok && s.S != nil {
   210  				clientID = *s.S
   211  			} else {
   212  				continue
   213  			}
   214  
   215  			var claim awsKinesisClientClaim
   216  			if s, ok := i["ShardID"]; ok && s.S != nil {
   217  				claim.ShardID = *s.S
   218  			}
   219  			if claim.ShardID == "" {
   220  				scanErr = errors.New("failed to extract shard id from claim")
   221  				return false
   222  			}
   223  
   224  			if s, ok := i["LeaseTimeout"]; ok && s.S != nil {
   225  				if claim.LeaseTimeout, scanErr = time.Parse(time.RFC3339Nano, *s.S); scanErr != nil {
   226  					scanErr = fmt.Errorf("failed to parse claim lease: %w", scanErr)
   227  					return false
   228  				}
   229  			}
   230  			if claim.LeaseTimeout.IsZero() {
   231  				scanErr = errors.New("failed to extract lease timeout from claim")
   232  				return false
   233  			}
   234  
   235  			clientClaims[clientID] = append(clientClaims[clientID], claim)
   236  		}
   237  
   238  		return true
   239  	}); err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	return clientClaims, scanErr
   244  }
   245  
   246  // Claim attempts to claim a shard for a particular stream ID. If fromClientID
   247  // is specified the shard is stolen from that particular client, and the
   248  // operation fails if a different client ID has it claimed.
   249  //
   250  // If fromClientID is specified this call will claim the new shard but block
   251  // for a period of time before reacquiring the sequence ID. This allows the
   252  // client we're claiming from to gracefully update the sequence number before
   253  // stopping.
   254  func (k *awsKinesisCheckpointer) Claim(ctx context.Context, streamID, shardID, fromClientID string) (string, error) {
   255  	newLeaseTimeoutString := time.Now().Add(k.leaseDuration).Format(time.RFC3339Nano)
   256  
   257  	var conditionalExpression string
   258  	expressionAttributeValues := map[string]*dynamodb.AttributeValue{
   259  		":new_client_id": {
   260  			S: &k.clientID,
   261  		},
   262  		":new_lease_timeout": {
   263  			S: &newLeaseTimeoutString,
   264  		},
   265  	}
   266  
   267  	if len(fromClientID) > 0 {
   268  		conditionalExpression = "ClientID = :old_client_id"
   269  		expressionAttributeValues[":old_client_id"] = &dynamodb.AttributeValue{
   270  			S: &fromClientID,
   271  		}
   272  	} else {
   273  		conditionalExpression = "attribute_not_exists(ClientID)"
   274  	}
   275  
   276  	res, err := k.svc.UpdateItemWithContext(ctx, &dynamodb.UpdateItemInput{
   277  		ReturnValues:              aws.String("ALL_OLD"),
   278  		TableName:                 aws.String(k.conf.Table),
   279  		ConditionExpression:       aws.String(conditionalExpression),
   280  		UpdateExpression:          aws.String("SET ClientID = :new_client_id, LeaseTimeout = :new_lease_timeout"),
   281  		ExpressionAttributeValues: expressionAttributeValues,
   282  		Key: map[string]*dynamodb.AttributeValue{
   283  			"StreamID": {
   284  				S: &streamID,
   285  			},
   286  			"ShardID": {
   287  				S: &shardID,
   288  			},
   289  		},
   290  	})
   291  	if err != nil {
   292  		if awsErr, ok := err.(awserr.Error); ok {
   293  			if awsErr.Code() == dynamodb.ErrCodeConditionalCheckFailedException {
   294  				return "", ErrLeaseNotAcquired
   295  			}
   296  		}
   297  		return "", err
   298  	}
   299  
   300  	var startingSequence string
   301  	if s, ok := res.Attributes["SequenceNumber"]; ok && s.S != nil {
   302  		startingSequence = *s.S
   303  	}
   304  
   305  	var currentLease time.Time
   306  	if s, ok := res.Attributes["LeaseTimeout"]; ok && s.S != nil {
   307  		currentLease, _ = time.Parse(time.RFC3339Nano, *s.S)
   308  	}
   309  
   310  	// Since we've aggressively stolen a shard then it's pretty much guaranteed
   311  	// that the client we're stealing from is still processing. What we do is we
   312  	// wait a grace period calculated by how long since the previous checkpoint
   313  	// and then reacquire the sequence.
   314  	//
   315  	// This allows the victim client to update the checkpoint with the final
   316  	// sequence as it yields the shard.
   317  	if len(fromClientID) > 0 && time.Since(currentLease) < k.leaseDuration {
   318  		// Wait for the estimated next checkpoint time plus a grace period of
   319  		// one second.
   320  		waitFor := k.leaseDuration - time.Since(currentLease) + time.Second
   321  		select {
   322  		case <-time.After(waitFor):
   323  		case <-ctx.Done():
   324  			return "", ctx.Err()
   325  		}
   326  
   327  		cp, err := k.getCheckpoint(ctx, streamID, shardID)
   328  		if err != nil {
   329  			return "", err
   330  		}
   331  		startingSequence = cp.SequenceNumber
   332  	}
   333  
   334  	return startingSequence, nil
   335  }
   336  
   337  // Checkpoint attempts to set a sequence number for a stream shard. Returns a
   338  // boolean indicating whether this shard is still owned by the client.
   339  //
   340  // If the shard has been claimed by a new client the sequence will still be set
   341  // so that the new client can begin with the latest sequence.
   342  //
   343  // If final is true the client ID is removed from the checkpoint, indicating
   344  // that this client is finished with the shard.
   345  func (k *awsKinesisCheckpointer) Checkpoint(ctx context.Context, streamID, shardID, sequenceNumber string, final bool) (bool, error) {
   346  	item := map[string]*dynamodb.AttributeValue{
   347  		"StreamID": {
   348  			S: &streamID,
   349  		},
   350  		"ShardID": {
   351  			S: &shardID,
   352  		},
   353  	}
   354  
   355  	if len(sequenceNumber) > 0 {
   356  		item["SequenceNumber"] = &dynamodb.AttributeValue{
   357  			S: &sequenceNumber,
   358  		}
   359  	}
   360  
   361  	if !final {
   362  		item["ClientID"] = &dynamodb.AttributeValue{
   363  			S: &k.clientID,
   364  		}
   365  		item["LeaseTimeout"] = &dynamodb.AttributeValue{
   366  			S: aws.String(time.Now().Add(k.leaseDuration).Format(time.RFC3339Nano)),
   367  		}
   368  	}
   369  
   370  	if _, err := k.svc.PutItem(&dynamodb.PutItemInput{
   371  		ConditionExpression: aws.String("ClientID = :client_id"),
   372  		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
   373  			":client_id": {
   374  				S: &k.clientID,
   375  			},
   376  		},
   377  		TableName: aws.String(k.conf.Table),
   378  		Item:      item,
   379  	}); err != nil {
   380  		if awsErr, ok := err.(awserr.Error); ok {
   381  			if awsErr.Code() == dynamodb.ErrCodeConditionalCheckFailedException {
   382  				return false, nil
   383  			}
   384  		}
   385  		return false, err
   386  	}
   387  	return true, nil
   388  }
   389  
   390  // Yield updates an existing checkpoint sequence number and no other fields.
   391  // This should be done after a non-final checkpoint indicates that shard has
   392  // been stolen and allows the thief client to start with the latest sequence
   393  // rather than the sequence at the point of the theft.
   394  //
   395  // This call is entirely optional, but the benefit is a reduction in duplicated
   396  // messages during a rebalance of shards.
   397  func (k *awsKinesisCheckpointer) Yield(ctx context.Context, streamID, shardID, sequenceNumber string) error {
   398  	if sequenceNumber == "" {
   399  		// Nothing to present to the thief
   400  		return nil
   401  	}
   402  
   403  	_, err := k.svc.UpdateItemWithContext(ctx, &dynamodb.UpdateItemInput{
   404  		TableName: aws.String(k.conf.Table),
   405  		Key: map[string]*dynamodb.AttributeValue{
   406  			"StreamID": {
   407  				S: &streamID,
   408  			},
   409  			"ShardID": {
   410  				S: &shardID,
   411  			},
   412  		},
   413  		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
   414  			":new_sequence_number": {
   415  				S: &sequenceNumber,
   416  			},
   417  		},
   418  		UpdateExpression: aws.String("SET SequenceNumber = :new_sequence_number"),
   419  	})
   420  	return err
   421  }
   422  
   423  // Delete attempts to delete a checkpoint, this should be called when a shard is
   424  // emptied.
   425  func (k *awsKinesisCheckpointer) Delete(ctx context.Context, streamID, shardID string) error {
   426  	_, err := k.svc.DeleteItemWithContext(ctx, &dynamodb.DeleteItemInput{
   427  		TableName: aws.String(k.conf.Table),
   428  		Key: map[string]*dynamodb.AttributeValue{
   429  			"StreamID": {
   430  				S: &streamID,
   431  			},
   432  			"ShardID": {
   433  				S: &shardID,
   434  			},
   435  		},
   436  	})
   437  	return err
   438  }
   439  
   440  //------------------------------------------------------------------------------