go.temporal.io/server@v1.23.0/common/persistence/sql/queue_v2.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package sql
    26  
    27  import (
    28  	"context"
    29  	"database/sql"
    30  	"errors"
    31  	"fmt"
    32  
    33  	commonpb "go.temporal.io/api/common/v1"
    34  	"go.temporal.io/api/enums/v1"
    35  	"go.temporal.io/api/serviceerror"
    36  
    37  	persistencespb "go.temporal.io/server/api/persistence/v1"
    38  	"go.temporal.io/server/common/log"
    39  	"go.temporal.io/server/common/log/tag"
    40  	"go.temporal.io/server/common/persistence"
    41  	"go.temporal.io/server/common/persistence/serialization"
    42  	"go.temporal.io/server/common/persistence/sql/sqlplugin"
    43  )
    44  
    45  const (
    46  	defaultPartition = 0
    47  )
    48  
    49  type (
    50  	queueV2 struct {
    51  		SqlStore
    52  	}
    53  
    54  	QueueV2Metadata struct {
    55  		Metadata *persistencespb.Queue
    56  		Version  int64
    57  	}
    58  )
    59  
    60  // NewQueueV2 returns an implementation of persistence.QueueV2.
    61  func NewQueueV2(db sqlplugin.DB,
    62  	logger log.Logger,
    63  ) persistence.QueueV2 {
    64  	return &queueV2{
    65  		SqlStore: NewSqlStore(db, logger),
    66  	}
    67  }
    68  
    69  func (q *queueV2) EnqueueMessage(
    70  	ctx context.Context,
    71  	request *persistence.InternalEnqueueMessageRequest,
    72  ) (*persistence.InternalEnqueueMessageResponse, error) {
    73  
    74  	_, err := q.getQueueMetadata(ctx, q.Db, request.QueueType, request.QueueName)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	tx, err := q.Db.BeginTx(ctx)
    79  	if err != nil {
    80  		return nil, serviceerror.NewUnavailable(fmt.Sprintf(
    81  			"EnqueueMessage failed for queue with type: %v and name: %v. BeginTx operation failed. Error: %v",
    82  			request.QueueType,
    83  			request.QueueName,
    84  			err),
    85  		)
    86  	}
    87  	nextMessageID, err := q.getNextMessageID(ctx, request.QueueType, request.QueueName, tx)
    88  	if err != nil {
    89  		rollBackErr := tx.Rollback()
    90  		if rollBackErr != nil {
    91  			q.SqlStore.logger.Error("transaction rollback error", tag.Error(rollBackErr))
    92  		}
    93  		return nil, serviceerror.NewUnavailable(fmt.Sprintf(
    94  			"EnqueueMessage failed for queue with type: %v and name: %v. failed to get next messageId. Error: %v",
    95  			request.QueueType,
    96  			request.QueueName,
    97  			err),
    98  		)
    99  	}
   100  	_, err = tx.InsertIntoQueueV2Messages(ctx, []sqlplugin.QueueV2MessageRow{
   101  		newQueueV2Row(request.QueueType, request.QueueName, nextMessageID, request.Blob),
   102  	})
   103  	if err != nil {
   104  		rollBackErr := tx.Rollback()
   105  		if rollBackErr != nil {
   106  			q.SqlStore.logger.Error("transaction rollback error", tag.Error(rollBackErr))
   107  		}
   108  		return nil, serviceerror.NewUnavailable(fmt.Sprintf(
   109  			"EnqueueMessage failed for queue with type: %v and name: %v. InsertIntoQueueV2Messages operation failed. Error: %v",
   110  			request.QueueType,
   111  			request.QueueName,
   112  			err),
   113  		)
   114  	}
   115  
   116  	if err := tx.Commit(); err != nil {
   117  		return nil, serviceerror.NewUnavailable(fmt.Sprintf(
   118  			"EnqueueMessage failed for queue with type: %v and name: %v. commit operation failed. Error: %v",
   119  			request.QueueType,
   120  			request.QueueName,
   121  			err),
   122  		)
   123  	}
   124  	return &persistence.InternalEnqueueMessageResponse{Metadata: persistence.MessageMetadata{ID: nextMessageID}}, err
   125  }
   126  
   127  func (q *queueV2) ReadMessages(
   128  	ctx context.Context,
   129  	request *persistence.InternalReadMessagesRequest,
   130  ) (*persistence.InternalReadMessagesResponse, error) {
   131  
   132  	if request.PageSize <= 0 {
   133  		return nil, persistence.ErrNonPositiveReadQueueMessagesPageSize
   134  	}
   135  	qm, err := q.getQueueMetadata(ctx, q.Db, request.QueueType, request.QueueName)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	minMessageID, err := persistence.GetMinMessageIDToReadForQueueV2(
   140  		request.QueueType,
   141  		request.QueueName,
   142  		request.NextPageToken,
   143  		qm,
   144  	)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	rows, err := q.Db.RangeSelectFromQueueV2Messages(ctx, sqlplugin.QueueV2MessagesFilter{
   149  		QueueType:    request.QueueType,
   150  		QueueName:    request.QueueName,
   151  		Partition:    defaultPartition,
   152  		MinMessageID: minMessageID,
   153  		PageSize:     request.PageSize,
   154  	})
   155  	if err != nil {
   156  		return nil, serviceerror.NewUnavailable(fmt.Sprintf(
   157  			"ReadMessages failed for queue with type: %v and name: %v. RangeSelectFromQueueV2Messages operation failed. Error: %v",
   158  			request.QueueType,
   159  			request.QueueName,
   160  			err),
   161  		)
   162  	}
   163  	var messages []persistence.QueueV2Message
   164  	for _, row := range rows {
   165  		encoding, err := enums.EncodingTypeFromString(row.MessageEncoding)
   166  		if err != nil {
   167  			return nil, serialization.NewUnknownEncodingTypeError(row.MessageEncoding)
   168  		}
   169  		encodingType := enums.EncodingType(encoding)
   170  		message := persistence.QueueV2Message{
   171  			MetaData: persistence.MessageMetadata{ID: row.MessageID},
   172  			Data: &commonpb.DataBlob{
   173  				EncodingType: encodingType,
   174  				Data:         row.MessagePayload,
   175  			},
   176  		}
   177  		messages = append(messages, message)
   178  	}
   179  	nextPageToken := persistence.GetNextPageTokenForReadMessages(messages)
   180  	response := &persistence.InternalReadMessagesResponse{
   181  		Messages:      messages,
   182  		NextPageToken: nextPageToken,
   183  	}
   184  	return response, nil
   185  }
   186  
   187  func newQueueV2Row(
   188  	queueType persistence.QueueV2Type,
   189  	queueName string,
   190  	messageID int64,
   191  	blob *commonpb.DataBlob,
   192  ) sqlplugin.QueueV2MessageRow {
   193  
   194  	return sqlplugin.QueueV2MessageRow{
   195  		QueueType:       queueType,
   196  		QueueName:       queueName,
   197  		QueuePartition:  defaultPartition,
   198  		MessageID:       messageID,
   199  		MessagePayload:  blob.Data,
   200  		MessageEncoding: blob.EncodingType.String(),
   201  	}
   202  }
   203  
   204  func (q *queueV2) CreateQueue(
   205  	ctx context.Context,
   206  	request *persistence.InternalCreateQueueRequest,
   207  ) (*persistence.InternalCreateQueueResponse, error) {
   208  	payload := persistencespb.Queue{
   209  		Partitions: map[int32]*persistencespb.QueuePartition{
   210  			defaultPartition: {
   211  				MinMessageId: persistence.FirstQueueMessageID,
   212  			},
   213  		},
   214  	}
   215  	bytes, _ := payload.Marshal()
   216  	row := sqlplugin.QueueV2MetadataRow{
   217  		QueueType:        request.QueueType,
   218  		QueueName:        request.QueueName,
   219  		MetadataPayload:  bytes,
   220  		MetadataEncoding: enums.ENCODING_TYPE_PROTO3.String(),
   221  	}
   222  	_, err := q.Db.InsertIntoQueueV2Metadata(ctx, &row)
   223  	if q.Db.IsDupEntryError(err) {
   224  		return nil, fmt.Errorf(
   225  			"%w: queue type %v and name %v",
   226  			persistence.ErrQueueAlreadyExists,
   227  			request.QueueType,
   228  			request.QueueName,
   229  		)
   230  	}
   231  	if err != nil {
   232  		return nil, serviceerror.NewUnavailable(fmt.Sprintf(
   233  			"CreateQueue failed for queue with type: %v and name: %v. InsertIntoQueueV2Metadata operation failed. Error: %v",
   234  			request.QueueType,
   235  			request.QueueName,
   236  			err),
   237  		)
   238  	}
   239  	return &persistence.InternalCreateQueueResponse{}, nil
   240  }
   241  
   242  func (q *queueV2) RangeDeleteMessages(
   243  	ctx context.Context,
   244  	request *persistence.InternalRangeDeleteMessagesRequest,
   245  ) (*persistence.InternalRangeDeleteMessagesResponse, error) {
   246  	if request.InclusiveMaxMessageMetadata.ID < persistence.FirstQueueMessageID {
   247  		return nil, fmt.Errorf(
   248  			"%w: id is %d but must be >= %d",
   249  			persistence.ErrInvalidQueueRangeDeleteMaxMessageID,
   250  			request.InclusiveMaxMessageMetadata.ID,
   251  			persistence.FirstQueueMessageID,
   252  		)
   253  	}
   254  	var resp *persistence.InternalRangeDeleteMessagesResponse
   255  	err := q.txExecute(ctx, "RangeDeleteMessages", func(tx sqlplugin.Tx) error {
   256  		qm, err := q.getQueueMetadata(ctx, tx, request.QueueType, request.QueueName)
   257  		if err != nil {
   258  			return err
   259  		}
   260  		partition, err := persistence.GetPartitionForQueueV2(request.QueueType, request.QueueName, qm)
   261  		if err != nil {
   262  			return serviceerror.NewUnavailable(fmt.Sprintf(
   263  				"RangeDeleteMessages failed for queue with type: %v and name: %v. GetPartitionForQueueV2 operation failed. Error: %v",
   264  				request.QueueType,
   265  				request.QueueName,
   266  				err),
   267  			)
   268  		}
   269  		maxMessageID, ok, err := q.getMaxMessageID(ctx, request.QueueType, request.QueueName, tx)
   270  		if err != nil {
   271  			return serviceerror.NewUnavailable(fmt.Sprintf(
   272  				"RangeDeleteMessages failed for queue with type: %v and name: %v. failed to get MaxMessageID. Error: %v",
   273  				request.QueueType,
   274  				request.QueueName,
   275  				err),
   276  			)
   277  		}
   278  		if !ok {
   279  			return nil
   280  		}
   281  		deleteRange, ok := persistence.GetDeleteRange(persistence.DeleteRequest{
   282  			LastIDToDeleteInclusive: request.InclusiveMaxMessageMetadata.ID,
   283  			ExistingMessageRange: persistence.InclusiveMessageRange{
   284  				MinMessageID: partition.MinMessageId,
   285  				MaxMessageID: maxMessageID,
   286  			},
   287  		})
   288  		if !ok {
   289  			resp = &persistence.InternalRangeDeleteMessagesResponse{
   290  				MessagesDeleted: 0,
   291  			}
   292  			return nil
   293  		}
   294  		msgFilter := sqlplugin.QueueV2MessagesFilter{
   295  			QueueType:    request.QueueType,
   296  			QueueName:    request.QueueName,
   297  			Partition:    defaultPartition,
   298  			MinMessageID: deleteRange.MinMessageID,
   299  			MaxMessageID: deleteRange.MaxMessageID,
   300  		}
   301  		_, err = tx.RangeDeleteFromQueueV2Messages(ctx, msgFilter)
   302  		if err != nil {
   303  			return serviceerror.NewUnavailable(fmt.Sprintf(
   304  				"RangeDeleteMessages failed for queue with type: %v and name: %v. RangeDeleteFromQueueV2Messages operation failed. Error: %v",
   305  				request.QueueType,
   306  				request.QueueName,
   307  				err),
   308  			)
   309  		}
   310  		partition.MinMessageId = deleteRange.NewMinMessageID
   311  		bytes, _ := qm.Marshal()
   312  		row := sqlplugin.QueueV2MetadataRow{
   313  			QueueType:        request.QueueType,
   314  			QueueName:        request.QueueName,
   315  			MetadataPayload:  bytes,
   316  			MetadataEncoding: enums.ENCODING_TYPE_PROTO3.String(),
   317  		}
   318  		_, err = tx.UpdateQueueV2Metadata(ctx, &row)
   319  		if err != nil {
   320  			return serviceerror.NewUnavailable(fmt.Sprintf(
   321  				"RangeDeleteMessages failed for queue with type: %v and name: %v. UpdateQueueV2Metadata operation failed. Error: %v",
   322  				request.QueueType,
   323  				request.QueueName,
   324  				err),
   325  			)
   326  		}
   327  		resp = &persistence.InternalRangeDeleteMessagesResponse{
   328  			MessagesDeleted: deleteRange.MessagesToDelete,
   329  		}
   330  		return nil
   331  	})
   332  	if err != nil {
   333  		return nil, err
   334  	}
   335  	return resp, nil
   336  }
   337  
   338  func (q *queueV2) getQueueMetadata(
   339  	ctx context.Context,
   340  	tc sqlplugin.TableCRUD,
   341  	queueType persistence.QueueV2Type,
   342  	queueName string,
   343  ) (*persistencespb.Queue, error) {
   344  
   345  	filter := sqlplugin.QueueV2MetadataFilter{
   346  		QueueType: queueType,
   347  		QueueName: queueName,
   348  	}
   349  	var (
   350  		metadata *sqlplugin.QueueV2MetadataRow
   351  		err      error
   352  	)
   353  	switch tc.(type) {
   354  	case sqlplugin.Tx:
   355  		metadata, err = tc.SelectFromQueueV2MetadataForUpdate(ctx, filter)
   356  	default:
   357  		metadata, err = tc.SelectFromQueueV2Metadata(ctx, filter)
   358  	}
   359  	if err != nil {
   360  		if errors.Is(err, sql.ErrNoRows) {
   361  			return nil, persistence.NewQueueNotFoundError(queueType, queueName)
   362  		}
   363  		return nil, serviceerror.NewUnavailable(
   364  			fmt.Sprintf("failed to get metadata for queue with type: %v and name: %v. Error: %v", queueType, queueName, err),
   365  		)
   366  	}
   367  	return q.extractQueueMetadata(metadata)
   368  }
   369  
   370  func (q queueV2) extractQueueMetadata(metadataRow *sqlplugin.QueueV2MetadataRow) (*persistencespb.Queue, error) {
   371  	if metadataRow.MetadataEncoding != enums.ENCODING_TYPE_PROTO3.String() {
   372  		return nil, fmt.Errorf(
   373  			"queue with type %v and name %v has invalid encoding: %w",
   374  			metadataRow.QueueType,
   375  			metadataRow.QueueName,
   376  			serialization.NewUnknownEncodingTypeError(metadataRow.MetadataEncoding, enums.ENCODING_TYPE_PROTO3),
   377  		)
   378  	}
   379  	qm := &persistencespb.Queue{}
   380  	err := qm.Unmarshal(metadataRow.MetadataPayload)
   381  	if err != nil {
   382  		return nil, serialization.NewDeserializationError(
   383  			enums.ENCODING_TYPE_PROTO3,
   384  			fmt.Errorf("unmarshal payload for queue with type %v and name %v failed: %w",
   385  				metadataRow.QueueType,
   386  				metadataRow.QueueName,
   387  				err),
   388  		)
   389  	}
   390  	return qm, nil
   391  }
   392  
   393  func (q *queueV2) getMaxMessageID(ctx context.Context, queueType persistence.QueueV2Type, queueName string, tc sqlplugin.TableCRUD) (int64, bool, error) {
   394  	lastMessageID, err := tc.GetLastEnqueuedMessageIDForUpdateV2(ctx, sqlplugin.QueueV2Filter{
   395  		QueueType: queueType,
   396  		QueueName: queueName,
   397  		Partition: defaultPartition,
   398  	})
   399  	switch {
   400  	case err == nil:
   401  		return lastMessageID, true, nil
   402  	case errors.Is(err, sql.ErrNoRows):
   403  		return 0, false, nil
   404  	default:
   405  		return 0, false, err
   406  	}
   407  }
   408  
   409  func (q *queueV2) getNextMessageID(ctx context.Context, queueType persistence.QueueV2Type, queueName string, tc sqlplugin.TableCRUD) (int64, error) {
   410  	maxMessageID, ok, err := q.getMaxMessageID(ctx, queueType, queueName, tc)
   411  	if err != nil {
   412  		return 0, err
   413  	}
   414  	if !ok {
   415  		return persistence.FirstQueueMessageID, nil
   416  	}
   417  	return maxMessageID + 1, nil
   418  }
   419  
   420  func (q *queueV2) ListQueues(
   421  	ctx context.Context,
   422  	request *persistence.InternalListQueuesRequest,
   423  ) (*persistence.InternalListQueuesResponse, error) {
   424  	if request.PageSize <= 0 {
   425  		return nil, persistence.ErrNonPositiveListQueuesPageSize
   426  	}
   427  	offset, err := persistence.GetOffsetForListQueues(request.NextPageToken)
   428  	if err != nil {
   429  		return nil, err
   430  	}
   431  	if offset < 0 {
   432  		return nil, persistence.ErrNegativeListQueuesOffset
   433  	}
   434  	rows, err := q.Db.SelectNameFromQueueV2Metadata(ctx, sqlplugin.QueueV2MetadataTypeFilter{
   435  		QueueType:  request.QueueType,
   436  		PageSize:   request.PageSize,
   437  		PageOffset: offset,
   438  	})
   439  	if err != nil && !errors.Is(err, sql.ErrNoRows) {
   440  		return nil, serviceerror.NewUnavailable(fmt.Sprintf(
   441  			"ListQueues failed for type: %v. SelectNameFromQueueV2Metadata operation failed. Error: %v",
   442  			request.QueueType,
   443  			err),
   444  		)
   445  	}
   446  	var queues []persistence.QueueInfo
   447  	for _, row := range rows {
   448  		messageCount, err := q.getMessageCount(ctx, &row)
   449  		if err != nil {
   450  			return nil, err
   451  		}
   452  		queues = append(queues, persistence.QueueInfo{
   453  			QueueName:    row.QueueName,
   454  			MessageCount: messageCount,
   455  		})
   456  	}
   457  	lastReadQueueNumber := offset + int64(len(queues))
   458  	var nextPageToken []byte
   459  	if len(queues) > 0 {
   460  		nextPageToken = persistence.GetNextPageTokenForListQueues(lastReadQueueNumber)
   461  	}
   462  	response := &persistence.InternalListQueuesResponse{
   463  		Queues:        queues,
   464  		NextPageToken: nextPageToken,
   465  	}
   466  	return response, nil
   467  }
   468  
   469  func (q *queueV2) getMessageCount(
   470  	ctx context.Context,
   471  	row *sqlplugin.QueueV2MetadataRow,
   472  ) (int64, error) {
   473  	nextMessageID, err := q.getNextMessageID(ctx, row.QueueType, row.QueueName, q.Db)
   474  	if err != nil {
   475  		return 0, serviceerror.NewUnavailable(fmt.Sprintf(
   476  			"getNextMessageID operation failed for queue with type %v and name %v. Error: %v",
   477  			row.QueueType,
   478  			row.QueueName,
   479  			err),
   480  		)
   481  	}
   482  	qm, err := q.extractQueueMetadata(row)
   483  	if err != nil {
   484  		return 0, err
   485  	}
   486  	partition, err := persistence.GetPartitionForQueueV2(row.QueueType, row.QueueName, qm)
   487  	if err != nil {
   488  		return 0, err
   489  	}
   490  	return nextMessageID - partition.MinMessageId, nil
   491  }