go.temporal.io/server@v1.23.0/common/persistence/cassandra/queue_store.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package cassandra
    26  
    27  import (
    28  	"context"
    29  	"fmt"
    30  
    31  	commonpb "go.temporal.io/api/common/v1"
    32  	enumspb "go.temporal.io/api/enums/v1"
    33  	"go.temporal.io/api/serviceerror"
    34  
    35  	persistencespb "go.temporal.io/server/api/persistence/v1"
    36  	"go.temporal.io/server/common/log"
    37  	"go.temporal.io/server/common/persistence"
    38  	"go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql"
    39  	"go.temporal.io/server/common/persistence/serialization"
    40  )
    41  
    42  const (
    43  	templateEnqueueMessageQuery       = `INSERT INTO queue (queue_type, message_id, message_payload, message_encoding) VALUES(?, ?, ?, ?) IF NOT EXISTS`
    44  	templateGetLastMessageIDQuery     = `SELECT message_id FROM queue WHERE queue_type=? ORDER BY message_id DESC LIMIT 1`
    45  	templateGetMessagesQuery          = `SELECT message_id, message_payload, message_encoding FROM queue WHERE queue_type = ? and message_id > ? LIMIT ?`
    46  	templateGetMessagesFromDLQQuery   = `SELECT message_id, message_payload, message_encoding FROM queue WHERE queue_type = ? and message_id > ? and message_id <= ?`
    47  	templateDeleteMessagesBeforeQuery = `DELETE FROM queue WHERE queue_type = ? and message_id < ?`
    48  	templateDeleteMessagesQuery       = `DELETE FROM queue WHERE queue_type = ? and message_id > ? and message_id <= ?`
    49  	templateDeleteMessageQuery        = `DELETE FROM queue WHERE queue_type = ? and message_id = ?`
    50  
    51  	templateGetQueueMetadataQuery    = `SELECT cluster_ack_level, data, data_encoding, version FROM queue_metadata WHERE queue_type = ?`
    52  	templateInsertQueueMetadataQuery = `INSERT INTO queue_metadata (queue_type, cluster_ack_level, data, data_encoding, version) VALUES(?, ?, ?, ?, ?) IF NOT EXISTS`
    53  	templateUpdateQueueMetadataQuery = `UPDATE queue_metadata SET cluster_ack_level = ?, data = ?, data_encoding = ?, version = ? WHERE queue_type = ? IF version = ?`
    54  )
    55  
    56  type (
    57  	QueueStore struct {
    58  		queueType persistence.QueueType
    59  		session   gocql.Session
    60  		logger    log.Logger
    61  	}
    62  )
    63  
    64  func NewQueueStore(
    65  	queueType persistence.QueueType,
    66  	session gocql.Session,
    67  	logger log.Logger,
    68  ) (persistence.Queue, error) {
    69  	return &QueueStore{
    70  		queueType: queueType,
    71  		session:   session,
    72  		logger:    logger,
    73  	}, nil
    74  }
    75  
    76  func (q *QueueStore) Init(
    77  	ctx context.Context,
    78  	blob *commonpb.DataBlob,
    79  ) error {
    80  	if err := q.initializeQueueMetadata(ctx, blob); err != nil {
    81  		return err
    82  	}
    83  	return q.initializeDLQMetadata(ctx, blob)
    84  }
    85  
    86  func (q *QueueStore) EnqueueMessage(
    87  	ctx context.Context,
    88  	blob *commonpb.DataBlob,
    89  ) error {
    90  	lastMessageID, err := q.getLastMessageID(ctx, q.queueType)
    91  	if err != nil {
    92  		return err
    93  	}
    94  
    95  	_, err = q.tryEnqueue(ctx, q.queueType, lastMessageID+1, blob)
    96  	return err
    97  }
    98  
    99  func (q *QueueStore) EnqueueMessageToDLQ(
   100  	ctx context.Context,
   101  	blob *commonpb.DataBlob,
   102  ) (int64, error) {
   103  	// Use negative queue type as the dlq type
   104  	lastMessageID, err := q.getLastMessageID(ctx, q.getDLQTypeFromQueueType())
   105  	if err != nil {
   106  		return persistence.EmptyQueueMessageID, err
   107  	}
   108  
   109  	// Use negative queue type as the dlq type
   110  	return q.tryEnqueue(ctx, q.getDLQTypeFromQueueType(), lastMessageID+1, blob)
   111  }
   112  
   113  func (q *QueueStore) tryEnqueue(
   114  	ctx context.Context,
   115  	queueType persistence.QueueType,
   116  	messageID int64,
   117  	blob *commonpb.DataBlob,
   118  ) (int64, error) {
   119  	query := q.session.Query(templateEnqueueMessageQuery, queueType, messageID, blob.Data, blob.EncodingType.String()).WithContext(ctx)
   120  	previous := make(map[string]interface{})
   121  	applied, err := query.MapScanCAS(previous)
   122  	if err != nil {
   123  		return persistence.EmptyQueueMessageID, gocql.ConvertError("tryEnqueue", err)
   124  	}
   125  
   126  	if !applied {
   127  		return persistence.EmptyQueueMessageID, &persistence.ConditionFailedError{Msg: fmt.Sprintf("message ID %v exists in queue", previous["message_id"])}
   128  	}
   129  
   130  	return messageID, nil
   131  }
   132  
   133  func (q *QueueStore) getLastMessageID(
   134  	ctx context.Context,
   135  	queueType persistence.QueueType,
   136  ) (int64, error) {
   137  
   138  	query := q.session.Query(templateGetLastMessageIDQuery, queueType).WithContext(ctx)
   139  	result := make(map[string]interface{})
   140  	err := query.MapScan(result)
   141  	if err != nil {
   142  		if gocql.IsNotFoundError(err) {
   143  			return persistence.EmptyQueueMessageID, nil
   144  		}
   145  		return persistence.EmptyQueueMessageID, gocql.ConvertError("getLastMessageID", err)
   146  	}
   147  	return result["message_id"].(int64), nil
   148  }
   149  
   150  func (q *QueueStore) ReadMessages(
   151  	ctx context.Context,
   152  	lastMessageID int64,
   153  	maxCount int,
   154  ) ([]*persistence.QueueMessage, error) {
   155  	// Reading replication tasks need to be quorum level consistent, otherwise we could lose tasks
   156  	query := q.session.Query(templateGetMessagesQuery,
   157  		q.queueType,
   158  		lastMessageID,
   159  		maxCount,
   160  	).WithContext(ctx)
   161  
   162  	iter := query.Iter()
   163  
   164  	var result []*persistence.QueueMessage
   165  	message := make(map[string]interface{})
   166  	for iter.MapScan(message) {
   167  		queueMessage := convertQueueMessage(message)
   168  		result = append(result, queueMessage)
   169  		message = make(map[string]interface{})
   170  	}
   171  
   172  	if err := iter.Close(); err != nil {
   173  		return nil, serviceerror.NewUnavailable(fmt.Sprintf("ReadMessages operation failed. Error: %v", err))
   174  	}
   175  
   176  	return result, nil
   177  }
   178  
   179  func (q *QueueStore) ReadMessagesFromDLQ(
   180  	ctx context.Context,
   181  	firstMessageID int64,
   182  	lastMessageID int64,
   183  	pageSize int,
   184  	pageToken []byte,
   185  ) ([]*persistence.QueueMessage, []byte, error) {
   186  	// Reading replication tasks need to be quorum level consistent, otherwise we could lose tasks
   187  	// Use negative queue type as the dlq type
   188  	query := q.session.Query(templateGetMessagesFromDLQQuery,
   189  		q.getDLQTypeFromQueueType(),
   190  		firstMessageID,
   191  		lastMessageID,
   192  	).WithContext(ctx)
   193  	iter := query.PageSize(pageSize).PageState(pageToken).Iter()
   194  
   195  	var result []*persistence.QueueMessage
   196  	message := make(map[string]interface{})
   197  	for iter.MapScan(message) {
   198  		queueMessage := convertQueueMessage(message)
   199  		result = append(result, queueMessage)
   200  		message = make(map[string]interface{})
   201  	}
   202  
   203  	var nextPageToken []byte
   204  	if len(iter.PageState()) > 0 {
   205  		nextPageToken = iter.PageState()
   206  	}
   207  	if err := iter.Close(); err != nil {
   208  		return nil, nil, serviceerror.NewUnavailable(fmt.Sprintf("ReadMessagesFromDLQ operation failed. Error: %v", err))
   209  	}
   210  
   211  	return result, nextPageToken, nil
   212  }
   213  
   214  func (q *QueueStore) DeleteMessagesBefore(
   215  	ctx context.Context,
   216  	messageID int64,
   217  ) error {
   218  
   219  	query := q.session.Query(templateDeleteMessagesBeforeQuery, q.queueType, messageID).WithContext(ctx)
   220  	if err := query.Exec(); err != nil {
   221  		return serviceerror.NewUnavailable(fmt.Sprintf("DeleteMessagesBefore operation failed. Error %v", err))
   222  	}
   223  	return nil
   224  }
   225  
   226  func (q *QueueStore) DeleteMessageFromDLQ(
   227  	ctx context.Context,
   228  	messageID int64,
   229  ) error {
   230  
   231  	// Use negative queue type as the dlq type
   232  	query := q.session.Query(templateDeleteMessageQuery, q.getDLQTypeFromQueueType(), messageID).WithContext(ctx)
   233  	if err := query.Exec(); err != nil {
   234  		return serviceerror.NewUnavailable(fmt.Sprintf("DeleteMessageFromDLQ operation failed. Error %v", err))
   235  	}
   236  
   237  	return nil
   238  }
   239  
   240  func (q *QueueStore) RangeDeleteMessagesFromDLQ(
   241  	ctx context.Context,
   242  	firstMessageID int64,
   243  	lastMessageID int64,
   244  ) error {
   245  
   246  	// Use negative queue type as the dlq type
   247  	query := q.session.Query(templateDeleteMessagesQuery, q.getDLQTypeFromQueueType(), firstMessageID, lastMessageID).WithContext(ctx)
   248  	if err := query.Exec(); err != nil {
   249  		return serviceerror.NewUnavailable(fmt.Sprintf("RangeDeleteMessagesFromDLQ operation failed. Error %v", err))
   250  	}
   251  
   252  	return nil
   253  }
   254  
   255  func (q *QueueStore) UpdateAckLevel(
   256  	ctx context.Context,
   257  	metadata *persistence.InternalQueueMetadata,
   258  ) error {
   259  	return q.updateAckLevel(ctx, metadata, q.queueType)
   260  }
   261  
   262  func (q *QueueStore) GetAckLevels(
   263  	ctx context.Context,
   264  ) (*persistence.InternalQueueMetadata, error) {
   265  	queueMetadata, err := q.getQueueMetadata(ctx, q.queueType)
   266  	if err != nil {
   267  		return nil, gocql.ConvertError("GetAckLevels", err)
   268  	}
   269  
   270  	return queueMetadata, nil
   271  }
   272  
   273  func (q *QueueStore) UpdateDLQAckLevel(
   274  	ctx context.Context,
   275  	metadata *persistence.InternalQueueMetadata,
   276  ) error {
   277  	return q.updateAckLevel(ctx, metadata, q.getDLQTypeFromQueueType())
   278  }
   279  
   280  func (q *QueueStore) GetDLQAckLevels(
   281  	ctx context.Context,
   282  ) (*persistence.InternalQueueMetadata, error) {
   283  	// Use negative queue type as the dlq type
   284  	queueMetadata, err := q.getQueueMetadata(ctx, q.getDLQTypeFromQueueType())
   285  	if err != nil {
   286  		return nil, gocql.ConvertError("GetDLQAckLevels", err)
   287  	}
   288  
   289  	return queueMetadata, nil
   290  }
   291  
   292  func (q *QueueStore) insertInitialQueueMetadataRecord(
   293  	ctx context.Context,
   294  	queueType persistence.QueueType,
   295  	blob *commonpb.DataBlob,
   296  ) error {
   297  
   298  	version := 0
   299  	// TODO: remove once cluster_ack_level is removed from DB
   300  	clusterAckLevels := map[string]int64{}
   301  	query := q.session.Query(templateInsertQueueMetadataQuery,
   302  		queueType,
   303  		clusterAckLevels,
   304  		blob.Data,
   305  		blob.EncodingType.String(),
   306  		version,
   307  	).WithContext(ctx)
   308  	_, err := query.MapScanCAS(make(map[string]interface{}))
   309  	if err != nil {
   310  		return fmt.Errorf("failed to insert initial queue metadata record: %v, Type: %v", err, queueType)
   311  	}
   312  	// it's ok if the query is not applied, which means that the record exists already.
   313  	return nil
   314  }
   315  
   316  func (q *QueueStore) getQueueMetadata(
   317  	ctx context.Context,
   318  	queueType persistence.QueueType,
   319  ) (*persistence.InternalQueueMetadata, error) {
   320  
   321  	query := q.session.Query(templateGetQueueMetadataQuery, queueType).WithContext(ctx)
   322  	message := make(map[string]interface{})
   323  	err := query.MapScan(message)
   324  	if err != nil {
   325  		return nil, err
   326  	}
   327  
   328  	return convertQueueMetadata(message)
   329  }
   330  
   331  func (q *QueueStore) updateAckLevel(
   332  	ctx context.Context,
   333  	metadata *persistence.InternalQueueMetadata,
   334  	queueType persistence.QueueType,
   335  ) error {
   336  
   337  	// TODO: remove this once cluster_ack_level is removed from DB
   338  	metadataStruct, err := serialization.QueueMetadataFromBlob(metadata.Blob.Data, metadata.Blob.EncodingType.String())
   339  	if err != nil {
   340  		return gocql.ConvertError("updateAckLevel", err)
   341  	}
   342  
   343  	query := q.session.Query(templateUpdateQueueMetadataQuery,
   344  		metadataStruct.ClusterAckLevels,
   345  		metadata.Blob.Data,
   346  		metadata.Blob.EncodingType.String(),
   347  		metadata.Version+1, // always increase version number on update
   348  		queueType,
   349  		metadata.Version, // condition update
   350  	).WithContext(ctx)
   351  	applied, err := query.MapScanCAS(make(map[string]interface{}))
   352  	if err != nil {
   353  		return gocql.ConvertError("updateAckLevel", err)
   354  	}
   355  	if !applied {
   356  		return &persistence.ConditionFailedError{Msg: "UpdateAckLevel operation encountered concurrent write."}
   357  	}
   358  
   359  	return nil
   360  }
   361  
   362  func (q *QueueStore) Close() {
   363  	if q.session != nil {
   364  		q.session.Close()
   365  	}
   366  }
   367  
   368  func (q *QueueStore) getDLQTypeFromQueueType() persistence.QueueType {
   369  	return -q.queueType
   370  }
   371  
   372  func (q *QueueStore) initializeQueueMetadata(
   373  	ctx context.Context,
   374  	blob *commonpb.DataBlob,
   375  ) error {
   376  	_, err := q.getQueueMetadata(ctx, q.queueType)
   377  	if gocql.IsNotFoundError(err) {
   378  		return q.insertInitialQueueMetadataRecord(ctx, q.queueType, blob)
   379  	}
   380  	return err
   381  }
   382  
   383  func (q *QueueStore) initializeDLQMetadata(
   384  	ctx context.Context,
   385  	blob *commonpb.DataBlob,
   386  ) error {
   387  	_, err := q.getQueueMetadata(ctx, q.getDLQTypeFromQueueType())
   388  	if gocql.IsNotFoundError(err) {
   389  		return q.insertInitialQueueMetadataRecord(ctx, q.getDLQTypeFromQueueType(), blob)
   390  	}
   391  	return err
   392  }
   393  
   394  func convertQueueMessage(
   395  	message map[string]interface{},
   396  ) *persistence.QueueMessage {
   397  
   398  	id := message["message_id"].(int64)
   399  	data := message["message_payload"].([]byte)
   400  	encoding := message["message_encoding"].(string)
   401  	if encoding == "" {
   402  		encoding = enumspb.ENCODING_TYPE_PROTO3.String()
   403  	}
   404  	return &persistence.QueueMessage{
   405  		ID:       id,
   406  		Data:     data,
   407  		Encoding: encoding,
   408  	}
   409  }
   410  
   411  func convertQueueMetadata(
   412  	message map[string]interface{},
   413  ) (*persistence.InternalQueueMetadata, error) {
   414  
   415  	metadata := &persistence.InternalQueueMetadata{
   416  		Version: message["version"].(int64),
   417  	}
   418  	_, ok := message["cluster_ack_level"]
   419  	if ok {
   420  		clusterAckLevel := message["cluster_ack_level"].(map[string]int64)
   421  		// TODO: remove this once we remove cluster_ack_level from DB.
   422  		blob, err := serialization.QueueMetadataToBlob(&persistencespb.QueueMetadata{ClusterAckLevels: clusterAckLevel})
   423  		if err != nil {
   424  			return nil, err
   425  		}
   426  		metadata.Blob = blob
   427  	} else {
   428  		data := message["data"].([]byte)
   429  		encoding := message["data_encoding"].(string)
   430  
   431  		metadata.Blob = persistence.NewDataBlob(data, encoding)
   432  	}
   433  
   434  	return metadata, nil
   435  }