go.temporal.io/server@v1.23.0/common/persistence/cassandra/queue_v2_store.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package cassandra
    26  
    27  import (
    28  	"context"
    29  	"fmt"
    30  
    31  	commonpb "go.temporal.io/api/common/v1"
    32  	"go.temporal.io/api/enums/v1"
    33  	persistencespb "go.temporal.io/server/api/persistence/v1"
    34  	"go.temporal.io/server/common/log"
    35  	"go.temporal.io/server/common/persistence"
    36  	"go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql"
    37  	"go.temporal.io/server/common/persistence/serialization"
    38  )
    39  
    40  type (
    41  	// queueV2Store contains the SQL queries and serialization/deserialization functions to interact with the queues and
    42  	// queue_messages tables that implement the QueueV2 interface. The schema is located at:
    43  	//	schema/cassandra/temporal/versioned/v1.9/queues.cql
    44  	queueV2Store struct {
    45  		session gocql.Session
    46  		logger  log.Logger
    47  	}
    48  
    49  	Queue struct {
    50  		Metadata *persistencespb.Queue
    51  		Version  int64
    52  	}
    53  )
    54  
    55  const (
    56  	TemplateEnqueueMessageQuery      = `INSERT INTO queue_messages (queue_type, queue_name, queue_partition, message_id, message_payload, message_encoding) VALUES (?, ?, ?, ?, ?, ?) IF NOT EXISTS`
    57  	TemplateGetMessagesQuery         = `SELECT message_id, message_payload, message_encoding FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? AND message_id >= ? ORDER BY message_id ASC LIMIT ?`
    58  	TemplateGetMaxMessageIDQuery     = `SELECT message_id FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? ORDER BY message_id DESC LIMIT 1`
    59  	TemplateCreateQueueQuery         = `INSERT INTO queues (queue_type, queue_name, metadata_payload, metadata_encoding, version) VALUES (?, ?, ?, ?, ?) IF NOT EXISTS`
    60  	TemplateGetQueueQuery            = `SELECT metadata_payload, metadata_encoding, version FROM queues WHERE queue_type = ? AND queue_name = ?`
    61  	TemplateRangeDeleteMessagesQuery = `DELETE FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? AND message_id >= ? AND message_id <= ?`
    62  	TemplateUpdateQueueMetadataQuery = `UPDATE queues SET metadata_payload = ?, metadata_encoding = ?, version = ? WHERE queue_type = ? AND queue_name = ? IF version = ?`
    63  	// We will have to ALLOW FILTERING for this query since partition key consists of both queue_type and queue_name.
    64  	templateGetQueueNamesQuery = `SELECT queue_name, metadata_payload, metadata_encoding, version FROM queues WHERE queue_type = ? ALLOW FILTERING`
    65  )
    66  
    67  var (
    68  	// ErrEnqueueMessageConflict is returned when a message with the same ID already exists in the queue. This is
    69  	// possible when there are concurrent writes to the queue because we enqueue a message using two queries:
    70  	//
    71  	// 	1. SELECT MAX(ID) to get the next message ID (for a given queue partition)
    72  	// 	2. INSERT (ID, message) with IF NOT EXISTS
    73  	//
    74  	// See the following example:
    75  	//
    76  	//  Client A           Client B                          Cassandra DB
    77  	//  |                  |                                            |
    78  	//  |--1. SELECT MAX(ID) FROM queue_messages----------------------->|
    79  	//  |                  |                                            |
    80  	//  |<-2. Return X--------------------------------------------------|
    81  	//  |                  |                                            |
    82  	//  |                  |--3. SELECT MAX(ID) FROM queue_messages---->|
    83  	//  |                  |                                            |
    84  	//  |                  |<-4. Return X-------------------------------|
    85  	//  |                  |                                            |
    86  	//  |--5. INSERT INTO queue_messages (ID = X)---------------------->|
    87  	//  |                  |                                            |
    88  	//  |<-6. Acknowledge-----------------------------------------------|
    89  	//  |                  |                                            |
    90  	//  |                  |--7. INSERT INTO queue_messages (ID = X)--->|
    91  	//  |                  |                                            |
    92  	//  |                  |<-8. Conflict/Error-------------------------|
    93  	//  |                  |                                            |
    94  	ErrEnqueueMessageConflict = &persistence.ConditionFailedError{
    95  		Msg: "conflict inserting queue message, likely due to concurrent writes",
    96  	}
    97  	// ErrUpdateQueueConflict is returned when a queue is updated with the wrong version. This happens when there are
    98  	// concurrent writes to the queue because we update a queue using two queries, similar to the enqueue message query.
    99  	//
   100  	// 	1. SELECT (queue, version) FROM queues
   101  	// 	2. UPDATE queue, version IF version = version from step 1
   102  	//
   103  	// See the following example:
   104  	//
   105  	//  Client A           Client B                           Cassandra DB
   106  	//  |                  |                                            |
   107  	//  |--1. SELECT (queue, version) FROM queues---------------------->|
   108  	//  |                  |                                            |
   109  	//  |<-2. Return (queue, v1)----------------------------------------|
   110  	//  |                  |                                            |
   111  	//  |                  |--3. SELECT (queue, version) FROM queues--->|
   112  	//  |                  |                                            |
   113  	//  |                  |<-4. Return (queue, v1)---------------------|
   114  	//  |                  |                                            |
   115  	//  |--5. UPDATE queue, version IF version = v1-------------------->|
   116  	//  |                  |                                            |
   117  	//  |<-6. Acknowledge-----------------------------------------------|
   118  	//  |                  |                                            |
   119  	//  |                  |--7. UPDATE queue, version IF version = v1->|
   120  	//  |                  |                                            |
   121  	//  |                  |<-8. Conflict/Error-------------------------|
   122  	//  |                  |                                            |
   123  	ErrUpdateQueueConflict = &persistence.ConditionFailedError{
   124  		Msg: "conflict updating queue, likely due to concurrent writes",
   125  	}
   126  )
   127  
   128  func NewQueueV2Store(session gocql.Session, logger log.Logger) persistence.QueueV2 {
   129  	return &queueV2Store{
   130  		session: session,
   131  		logger:  logger,
   132  	}
   133  }
   134  
   135  func (s *queueV2Store) EnqueueMessage(
   136  	ctx context.Context,
   137  	request *persistence.InternalEnqueueMessageRequest,
   138  ) (*persistence.InternalEnqueueMessageResponse, error) {
   139  	// TODO: add concurrency control around this method to avoid things like QueueMessageIDConflict.
   140  	// TODO: cache the queue in memory to avoid querying the database every time.
   141  	_, err := s.getQueue(ctx, request.QueueType, request.QueueName)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  	messageID, err := s.getNextMessageID(ctx, request.QueueType, request.QueueName)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  	err = s.tryInsert(ctx, request.QueueType, request.QueueName, request.Blob, messageID)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	return &persistence.InternalEnqueueMessageResponse{
   154  		Metadata: persistence.MessageMetadata{ID: messageID},
   155  	}, nil
   156  }
   157  
   158  func (s *queueV2Store) ReadMessages(
   159  	ctx context.Context,
   160  	request *persistence.InternalReadMessagesRequest,
   161  ) (*persistence.InternalReadMessagesResponse, error) {
   162  	q, err := s.getQueue(ctx, request.QueueType, request.QueueName)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	if request.PageSize <= 0 {
   167  		return nil, persistence.ErrNonPositiveReadQueueMessagesPageSize
   168  	}
   169  	minMessageID, err := persistence.GetMinMessageIDToReadForQueueV2(request.QueueType, request.QueueName, request.NextPageToken, q.Metadata)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	iter := s.session.Query(
   175  		TemplateGetMessagesQuery,
   176  		request.QueueType,
   177  		request.QueueName,
   178  		0,
   179  		int(minMessageID),
   180  		request.PageSize,
   181  	).WithContext(ctx).Iter()
   182  
   183  	var (
   184  		messages []persistence.QueueV2Message
   185  		// messageID is the ID of the last message returned by the query.
   186  		messageID int64
   187  	)
   188  
   189  	for {
   190  		var (
   191  			messagePayload  []byte
   192  			messageEncoding string
   193  		)
   194  		if !iter.Scan(&messageID, &messagePayload, &messageEncoding) {
   195  			break
   196  		}
   197  		encoding, err := enums.EncodingTypeFromString(messageEncoding)
   198  		if err != nil {
   199  			return nil, serialization.NewUnknownEncodingTypeError(messageEncoding)
   200  		}
   201  
   202  		encodingType := enums.EncodingType(encoding)
   203  
   204  		message := persistence.QueueV2Message{
   205  			MetaData: persistence.MessageMetadata{ID: messageID},
   206  			Data: &commonpb.DataBlob{
   207  				EncodingType: encodingType,
   208  				Data:         messagePayload,
   209  			},
   210  		}
   211  		messages = append(messages, message)
   212  	}
   213  
   214  	if err := iter.Close(); err != nil {
   215  		return nil, gocql.ConvertError("QueueV2ReadMessages", err)
   216  	}
   217  
   218  	nextPageToken := persistence.GetNextPageTokenForReadMessages(messages)
   219  	return &persistence.InternalReadMessagesResponse{
   220  		Messages:      messages,
   221  		NextPageToken: nextPageToken,
   222  	}, nil
   223  }
   224  
   225  func (s *queueV2Store) CreateQueue(
   226  	ctx context.Context,
   227  	request *persistence.InternalCreateQueueRequest,
   228  ) (*persistence.InternalCreateQueueResponse, error) {
   229  	queueType := request.QueueType
   230  	queueName := request.QueueName
   231  	q := persistencespb.Queue{
   232  		Partitions: map[int32]*persistencespb.QueuePartition{
   233  			0: {
   234  				MinMessageId: persistence.FirstQueueMessageID,
   235  			},
   236  		},
   237  	}
   238  	bytes, _ := q.Marshal()
   239  	applied, err := s.session.Query(
   240  		TemplateCreateQueueQuery,
   241  		queueType,
   242  		queueName,
   243  		bytes,
   244  		enums.ENCODING_TYPE_PROTO3.String(),
   245  		0,
   246  	).WithContext(ctx).MapScanCAS(make(map[string]interface{}))
   247  	if err != nil {
   248  		return nil, gocql.ConvertError("QueueV2CreateQueue", err)
   249  	}
   250  
   251  	if !applied {
   252  		return nil, fmt.Errorf(
   253  			"%w: queue type %v and name %v",
   254  			persistence.ErrQueueAlreadyExists,
   255  			queueType,
   256  			queueName,
   257  		)
   258  	}
   259  	return &persistence.InternalCreateQueueResponse{}, nil
   260  }
   261  
   262  func (s *queueV2Store) RangeDeleteMessages(
   263  	ctx context.Context,
   264  	request *persistence.InternalRangeDeleteMessagesRequest,
   265  ) (*persistence.InternalRangeDeleteMessagesResponse, error) {
   266  	if request.InclusiveMaxMessageMetadata.ID < persistence.FirstQueueMessageID {
   267  		return nil, fmt.Errorf(
   268  			"%w: id is %d but must be >= %d",
   269  			persistence.ErrInvalidQueueRangeDeleteMaxMessageID,
   270  			request.InclusiveMaxMessageMetadata.ID,
   271  			persistence.FirstQueueMessageID,
   272  		)
   273  	}
   274  	queueType := request.QueueType
   275  	queueName := request.QueueName
   276  	q, err := s.getQueue(ctx, queueType, queueName)
   277  	if err != nil {
   278  		return nil, err
   279  	}
   280  	partition, err := persistence.GetPartitionForQueueV2(queueType, queueName, q.Metadata)
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	maxMessageID, ok, err := s.getMaxMessageID(ctx, queueType, queueName)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  	if !ok {
   289  		// Nothing in the queue to delete.
   290  		return &persistence.InternalRangeDeleteMessagesResponse{}, nil
   291  	}
   292  	deleteRange, ok := persistence.GetDeleteRange(persistence.DeleteRequest{
   293  		LastIDToDeleteInclusive: request.InclusiveMaxMessageMetadata.ID,
   294  		ExistingMessageRange: persistence.InclusiveMessageRange{
   295  			MinMessageID: partition.MinMessageId,
   296  			MaxMessageID: maxMessageID,
   297  		},
   298  	})
   299  	if !ok {
   300  		return &persistence.InternalRangeDeleteMessagesResponse{}, nil
   301  	}
   302  	err = s.session.Query(
   303  		TemplateRangeDeleteMessagesQuery,
   304  		queueType,
   305  		queueName,
   306  		0, // partition
   307  		deleteRange.MinMessageID,
   308  		deleteRange.MaxMessageID,
   309  	).WithContext(ctx).Exec()
   310  	if err != nil {
   311  		return nil, gocql.ConvertError("QueueV2RangeDeleteMessages", err)
   312  	}
   313  	partition.MinMessageId = deleteRange.NewMinMessageID
   314  	err = s.updateQueue(ctx, q, queueType, queueName)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  	return &persistence.InternalRangeDeleteMessagesResponse{
   319  		MessagesDeleted: deleteRange.MessagesToDelete,
   320  	}, nil
   321  }
   322  
   323  func (s *queueV2Store) updateQueue(
   324  	ctx context.Context,
   325  	q *Queue,
   326  	queueType persistence.QueueV2Type,
   327  	queueName string,
   328  ) error {
   329  	bytes, _ := q.Metadata.Marshal()
   330  	version := q.Version
   331  	nextVersion := version + 1
   332  	q.Version = nextVersion
   333  	applied, err := s.session.Query(
   334  		TemplateUpdateQueueMetadataQuery,
   335  		bytes,
   336  		enums.ENCODING_TYPE_PROTO3.String(),
   337  		nextVersion,
   338  		queueType,
   339  		queueName,
   340  		version,
   341  	).WithContext(ctx).MapScanCAS(make(map[string]interface{}))
   342  	if err != nil {
   343  		return gocql.ConvertError("QueueV2UpdateQueueMetadata", err)
   344  	}
   345  	if !applied {
   346  		return fmt.Errorf(
   347  			"%w: queue type %v and name %v",
   348  			ErrUpdateQueueConflict,
   349  			queueType,
   350  			queueName,
   351  		)
   352  	}
   353  	return nil
   354  }
   355  
   356  func (s *queueV2Store) tryInsert(
   357  	ctx context.Context,
   358  	queueType persistence.QueueV2Type,
   359  	queueName string,
   360  	blob *commonpb.DataBlob,
   361  	messageID int64,
   362  ) error {
   363  	applied, err := s.session.Query(
   364  		TemplateEnqueueMessageQuery,
   365  		queueType,
   366  		queueName,
   367  		0,
   368  		messageID,
   369  		blob.Data,
   370  		blob.EncodingType.String(),
   371  	).WithContext(ctx).MapScanCAS(make(map[string]interface{}))
   372  	if err != nil {
   373  		return gocql.ConvertError("QueueV2EnqueueMessage", err)
   374  	}
   375  	if !applied {
   376  		return fmt.Errorf(
   377  			"%w: queue type %v and name %v already has a message with ID %v",
   378  			ErrEnqueueMessageConflict,
   379  			queueType,
   380  			queueName,
   381  			messageID,
   382  		)
   383  	}
   384  
   385  	return nil
   386  }
   387  
   388  func (s *queueV2Store) getQueue(
   389  	ctx context.Context,
   390  	queueType persistence.QueueV2Type,
   391  	name string,
   392  ) (*Queue, error) {
   393  	return GetQueue(ctx, s.session, name, queueType)
   394  }
   395  
   396  func GetQueue(
   397  	ctx context.Context,
   398  	session gocql.Session,
   399  	queueName string,
   400  	queueType persistence.QueueV2Type,
   401  ) (*Queue, error) {
   402  	var (
   403  		queueBytes       []byte
   404  		queueEncodingStr string
   405  		version          int64
   406  	)
   407  
   408  	err := session.Query(TemplateGetQueueQuery, queueType, queueName).WithContext(ctx).Scan(
   409  		&queueBytes,
   410  		&queueEncodingStr,
   411  		&version,
   412  	)
   413  	if err != nil {
   414  		if gocql.IsNotFoundError(err) {
   415  			return nil, persistence.NewQueueNotFoundError(queueType, queueName)
   416  		}
   417  		return nil, gocql.ConvertError("QueueV2GetQueue", err)
   418  	}
   419  	return getQueueFromMetadata(queueType, queueName, queueBytes, queueEncodingStr, version)
   420  }
   421  
   422  func getQueueFromMetadata(
   423  	queueType persistence.QueueV2Type,
   424  	queueName string,
   425  	queueBytes []byte,
   426  	queueEncodingStr string,
   427  	version int64,
   428  ) (*Queue, error) {
   429  	if queueEncodingStr != enums.ENCODING_TYPE_PROTO3.String() {
   430  		return nil, fmt.Errorf(
   431  			"%w: invalid queue encoding type: queue with type %v and name %v has invalid encoding",
   432  			serialization.NewUnknownEncodingTypeError(queueEncodingStr, enums.ENCODING_TYPE_PROTO3),
   433  			queueType,
   434  			queueName,
   435  		)
   436  	}
   437  
   438  	q := &persistencespb.Queue{}
   439  	err := q.Unmarshal(queueBytes)
   440  	if err != nil {
   441  		return nil, serialization.NewDeserializationError(
   442  			enums.ENCODING_TYPE_PROTO3,
   443  			fmt.Errorf("%w: unmarshal queue payload: failed for queue with type %v and name %v",
   444  				err, queueType, queueName),
   445  		)
   446  	}
   447  
   448  	return &Queue{
   449  		Metadata: q,
   450  		Version:  version,
   451  	}, nil
   452  }
   453  
   454  func (s *queueV2Store) getNextMessageID(ctx context.Context, queueType persistence.QueueV2Type, queueName string) (int64, error) {
   455  	maxMessageID, ok, err := s.getMaxMessageID(ctx, queueType, queueName)
   456  	if err != nil {
   457  		return 0, err
   458  	}
   459  	if !ok {
   460  		return persistence.FirstQueueMessageID, nil
   461  	}
   462  
   463  	// The next message ID is the max message ID + 1.
   464  	return maxMessageID + 1, nil
   465  }
   466  
   467  func (s *queueV2Store) getMaxMessageID(ctx context.Context, queueType persistence.QueueV2Type, queueName string) (int64, bool, error) {
   468  	var maxMessageID int64
   469  
   470  	err := s.session.Query(TemplateGetMaxMessageIDQuery, queueType, queueName, 0).WithContext(ctx).Scan(&maxMessageID)
   471  	if err != nil {
   472  		if gocql.IsNotFoundError(err) {
   473  			return 0, false, nil
   474  		}
   475  		return 0, false, gocql.ConvertError("QueueV2GetMaxMessageID", err)
   476  	}
   477  	return maxMessageID, true, nil
   478  }
   479  
   480  func (s *queueV2Store) ListQueues(
   481  	ctx context.Context,
   482  	request *persistence.InternalListQueuesRequest,
   483  ) (*persistence.InternalListQueuesResponse, error) {
   484  	if request.PageSize <= 0 {
   485  		return nil, persistence.ErrNonPositiveListQueuesPageSize
   486  	}
   487  	iter := s.session.Query(
   488  		templateGetQueueNamesQuery,
   489  		request.QueueType,
   490  	).PageSize(request.PageSize).PageState(request.NextPageToken).WithContext(ctx).Iter()
   491  
   492  	var queues []persistence.QueueInfo
   493  	for {
   494  		var (
   495  			queueName        string
   496  			metadataBytes    []byte
   497  			metadataEncoding string
   498  			version          int64
   499  		)
   500  		if !iter.Scan(&queueName, &metadataBytes, &metadataEncoding, &version) {
   501  			break
   502  		}
   503  		q, err := getQueueFromMetadata(request.QueueType, queueName, metadataBytes, metadataEncoding, version)
   504  		if err != nil {
   505  			return nil, err
   506  		}
   507  		partition, err := persistence.GetPartitionForQueueV2(request.QueueType, queueName, q.Metadata)
   508  		if err != nil {
   509  			return nil, err
   510  		}
   511  		nextMessageID, err := s.getNextMessageID(ctx, request.QueueType, queueName)
   512  		if err != nil {
   513  			return nil, err
   514  		}
   515  		messageCount := nextMessageID - partition.MinMessageId
   516  		queues = append(queues, persistence.QueueInfo{
   517  			QueueName:    queueName,
   518  			MessageCount: messageCount,
   519  		})
   520  	}
   521  	if err := iter.Close(); err != nil {
   522  		return nil, gocql.ConvertError("QueueV2ListQueues", err)
   523  	}
   524  	return &persistence.InternalListQueuesResponse{
   525  		Queues:        queues,
   526  		NextPageToken: iter.PageState(),
   527  	}, nil
   528  }