go.temporal.io/server@v1.23.0/common/persistence/sql/queue.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package sql
    26  
    27  import (
    28  	"context"
    29  	"database/sql"
    30  	"fmt"
    31  
    32  	commonpb "go.temporal.io/api/common/v1"
    33  	"go.temporal.io/api/serviceerror"
    34  
    35  	"go.temporal.io/server/common/log"
    36  	"go.temporal.io/server/common/persistence"
    37  	"go.temporal.io/server/common/persistence/sql/sqlplugin"
    38  )
    39  
    40  type (
    41  	sqlQueue struct {
    42  		queueType persistence.QueueType
    43  		logger    log.Logger
    44  		SqlStore
    45  	}
    46  )
    47  
    48  func newQueue(
    49  	db sqlplugin.DB,
    50  	logger log.Logger,
    51  	queueType persistence.QueueType,
    52  ) (persistence.Queue, error) {
    53  	queue := &sqlQueue{
    54  		SqlStore:  NewSqlStore(db, logger),
    55  		queueType: queueType,
    56  		logger:    logger,
    57  	}
    58  	return queue, nil
    59  }
    60  
    61  func (q *sqlQueue) Init(
    62  	ctx context.Context,
    63  	blob *commonpb.DataBlob,
    64  ) error {
    65  	if err := q.initializeQueueMetadata(ctx, blob); err != nil {
    66  		return err
    67  	}
    68  	return q.initializeDLQMetadata(ctx, blob)
    69  }
    70  
    71  func (q *sqlQueue) EnqueueMessage(
    72  	ctx context.Context,
    73  	blob *commonpb.DataBlob,
    74  ) error {
    75  	err := q.txExecute(ctx, "EnqueueMessage", func(tx sqlplugin.Tx) error {
    76  		lastMessageID, err := tx.GetLastEnqueuedMessageIDForUpdate(ctx, q.queueType)
    77  		switch err {
    78  		case nil:
    79  			_, err = tx.InsertIntoMessages(ctx, []sqlplugin.QueueMessageRow{
    80  				newQueueRow(q.queueType, lastMessageID+1, blob),
    81  			})
    82  			return err
    83  		case sql.ErrNoRows:
    84  			_, err = tx.InsertIntoMessages(ctx, []sqlplugin.QueueMessageRow{
    85  				newQueueRow(q.queueType, persistence.EmptyQueueMessageID+1, blob),
    86  			})
    87  			return err
    88  		default:
    89  			return fmt.Errorf("failed to get last enqueued message id: %v", err)
    90  		}
    91  	})
    92  	if err != nil {
    93  		return serviceerror.NewUnavailable(err.Error())
    94  	}
    95  	return nil
    96  }
    97  
    98  func (q *sqlQueue) ReadMessages(
    99  	ctx context.Context,
   100  	lastMessageID int64,
   101  	pageSize int,
   102  ) ([]*persistence.QueueMessage, error) {
   103  	rows, err := q.Db.RangeSelectFromMessages(ctx, sqlplugin.QueueMessagesRangeFilter{
   104  		QueueType:    q.queueType,
   105  		MinMessageID: lastMessageID,
   106  		MaxMessageID: persistence.MaxQueueMessageID,
   107  		PageSize:     pageSize,
   108  	})
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	var messages []*persistence.QueueMessage
   114  	for _, row := range rows {
   115  		messages = append(messages, &persistence.QueueMessage{
   116  			QueueType: q.queueType,
   117  			ID:        row.MessageID,
   118  			Data:      row.MessagePayload,
   119  			Encoding:  row.MessageEncoding,
   120  		})
   121  	}
   122  	return messages, nil
   123  }
   124  
   125  func (q *sqlQueue) DeleteMessagesBefore(
   126  	ctx context.Context,
   127  	messageID int64,
   128  ) error {
   129  	_, err := q.Db.RangeDeleteFromMessages(ctx, sqlplugin.QueueMessagesRangeFilter{
   130  		QueueType:    q.queueType,
   131  		MinMessageID: persistence.EmptyQueueMessageID,
   132  		MaxMessageID: messageID - 1,
   133  	})
   134  	if err != nil {
   135  		return serviceerror.NewUnavailable(fmt.Sprintf("DeleteMessagesBefore operation failed. Error %v", err))
   136  	}
   137  	return nil
   138  }
   139  
   140  func (q *sqlQueue) UpdateAckLevel(
   141  	ctx context.Context,
   142  	metadata *persistence.InternalQueueMetadata,
   143  ) error {
   144  	err := q.txExecute(ctx, "UpdateAckLevel", func(tx sqlplugin.Tx) error {
   145  		result, err := tx.UpdateQueueMetadata(ctx, &sqlplugin.QueueMetadataRow{
   146  			QueueType:    q.queueType,
   147  			Data:         metadata.Blob.Data,
   148  			DataEncoding: metadata.Blob.EncodingType.String(),
   149  			Version:      metadata.Version,
   150  		})
   151  		if err != nil {
   152  			return serviceerror.NewUnavailable(fmt.Sprintf("UpdateAckLevel operation failed. Error %v", err))
   153  		}
   154  		rowsAffected, err := result.RowsAffected()
   155  		if err != nil {
   156  			return fmt.Errorf("rowsAffected returned error for queue metadata %v: %v", q.queueType, err)
   157  		}
   158  		if rowsAffected != 1 {
   159  			return &persistence.ConditionFailedError{Msg: "UpdateAckLevel operation encountered concurrent write."}
   160  		}
   161  		return nil
   162  	})
   163  
   164  	if err != nil {
   165  		return serviceerror.NewUnavailable(err.Error())
   166  	}
   167  	return nil
   168  }
   169  
   170  func (q *sqlQueue) GetAckLevels(
   171  	ctx context.Context,
   172  ) (*persistence.InternalQueueMetadata, error) {
   173  	row, err := q.Db.SelectFromQueueMetadata(ctx, sqlplugin.QueueMetadataFilter{
   174  		QueueType: q.queueType,
   175  	})
   176  	if err != nil {
   177  		return nil, serviceerror.NewUnavailable(fmt.Sprintf("GetAckLevels operation failed. Error %v", err))
   178  	}
   179  
   180  	return &persistence.InternalQueueMetadata{
   181  		Blob:    persistence.NewDataBlob(row.Data, row.DataEncoding),
   182  		Version: row.Version,
   183  	}, nil
   184  }
   185  
   186  func (q *sqlQueue) EnqueueMessageToDLQ(
   187  	ctx context.Context,
   188  	blob *commonpb.DataBlob,
   189  ) (int64, error) {
   190  	var lastMessageID int64
   191  	err := q.txExecute(ctx, "EnqueueMessageToDLQ", func(tx sqlplugin.Tx) error {
   192  		var err error
   193  		lastMessageID, err = tx.GetLastEnqueuedMessageIDForUpdate(ctx, q.getDLQTypeFromQueueType())
   194  		switch err {
   195  		case nil:
   196  			_, err = tx.InsertIntoMessages(ctx, []sqlplugin.QueueMessageRow{
   197  				newQueueRow(q.getDLQTypeFromQueueType(), lastMessageID+1, blob),
   198  			})
   199  			return err
   200  		case sql.ErrNoRows:
   201  			_, err = tx.InsertIntoMessages(ctx, []sqlplugin.QueueMessageRow{
   202  				newQueueRow(q.getDLQTypeFromQueueType(), persistence.EmptyQueueMessageID+1, blob),
   203  			})
   204  			return err
   205  		default:
   206  			return fmt.Errorf("failed to get last enqueued message id from DLQ: %v", err)
   207  		}
   208  	})
   209  	if err != nil {
   210  		return persistence.EmptyQueueMessageID, serviceerror.NewUnavailable(err.Error())
   211  	}
   212  	return lastMessageID + 1, nil
   213  }
   214  
   215  func (q *sqlQueue) ReadMessagesFromDLQ(
   216  	ctx context.Context,
   217  	firstMessageID int64,
   218  	lastMessageID int64,
   219  	pageSize int,
   220  	pageToken []byte,
   221  ) ([]*persistence.QueueMessage, []byte, error) {
   222  	if len(pageToken) != 0 {
   223  		lastReadMessageID, err := deserializePageToken(pageToken)
   224  		if err != nil {
   225  			return nil, nil, serviceerror.NewInternal(fmt.Sprintf("invalid next page token %v", pageToken))
   226  		}
   227  		firstMessageID = lastReadMessageID
   228  	}
   229  
   230  	rows, err := q.Db.RangeSelectFromMessages(ctx, sqlplugin.QueueMessagesRangeFilter{
   231  		QueueType:    q.getDLQTypeFromQueueType(),
   232  		MinMessageID: firstMessageID,
   233  		MaxMessageID: lastMessageID,
   234  		PageSize:     pageSize,
   235  	})
   236  	if err != nil {
   237  		return nil, nil, serviceerror.NewUnavailable(fmt.Sprintf("ReadMessagesFromDLQ operation failed. Error %v", err))
   238  	}
   239  
   240  	var messages []*persistence.QueueMessage
   241  	for _, row := range rows {
   242  		messages = append(messages, &persistence.QueueMessage{
   243  			QueueType: q.getDLQTypeFromQueueType(),
   244  			ID:        row.MessageID,
   245  			Data:      row.MessagePayload,
   246  			Encoding:  row.MessageEncoding,
   247  		})
   248  	}
   249  
   250  	var newPagingToken []byte
   251  	if messages != nil && len(messages) >= pageSize {
   252  		lastReadMessageID := messages[len(messages)-1].ID
   253  		newPagingToken = serializePageToken(lastReadMessageID)
   254  	}
   255  	return messages, newPagingToken, nil
   256  }
   257  
   258  func (q *sqlQueue) DeleteMessageFromDLQ(
   259  	ctx context.Context,
   260  	messageID int64,
   261  ) error {
   262  	_, err := q.Db.DeleteFromMessages(ctx, sqlplugin.QueueMessagesFilter{
   263  		QueueType: q.getDLQTypeFromQueueType(),
   264  		MessageID: messageID,
   265  	})
   266  	if err != nil {
   267  		return serviceerror.NewUnavailable(fmt.Sprintf("DeleteMessageFromDLQ operation failed. Error %v", err))
   268  	}
   269  	return nil
   270  }
   271  
   272  func (q *sqlQueue) RangeDeleteMessagesFromDLQ(
   273  	ctx context.Context,
   274  	firstMessageID int64,
   275  	lastMessageID int64,
   276  ) error {
   277  	_, err := q.Db.RangeDeleteFromMessages(ctx, sqlplugin.QueueMessagesRangeFilter{
   278  		QueueType:    q.getDLQTypeFromQueueType(),
   279  		MinMessageID: firstMessageID,
   280  		MaxMessageID: lastMessageID,
   281  	})
   282  	if err != nil {
   283  		return serviceerror.NewUnavailable(fmt.Sprintf("RangeDeleteMessagesFromDLQ operation failed. Error %v", err))
   284  	}
   285  	return nil
   286  }
   287  
   288  func (q *sqlQueue) UpdateDLQAckLevel(
   289  	ctx context.Context,
   290  	metadata *persistence.InternalQueueMetadata,
   291  ) error {
   292  	err := q.txExecute(ctx, "UpdateDLQAckLevel", func(tx sqlplugin.Tx) error {
   293  
   294  		result, err := tx.UpdateQueueMetadata(ctx, &sqlplugin.QueueMetadataRow{
   295  			QueueType:    q.getDLQTypeFromQueueType(),
   296  			Data:         metadata.Blob.Data,
   297  			DataEncoding: metadata.Blob.EncodingType.String(),
   298  		})
   299  		if err != nil {
   300  			return serviceerror.NewUnavailable(fmt.Sprintf("UpdateDLQAckLevel operation failed. Error %v", err))
   301  		}
   302  		rowsAffected, err := result.RowsAffected()
   303  		if err != nil {
   304  			return fmt.Errorf("rowsAffected returned error for DLQ metadata %v: %v", q.queueType, err)
   305  		}
   306  		if rowsAffected != 1 {
   307  			return fmt.Errorf("rowsAffected returned %v DLQ metadata instead of one", rowsAffected)
   308  		}
   309  		return nil
   310  	})
   311  
   312  	if err != nil {
   313  		return serviceerror.NewUnavailable(err.Error())
   314  	}
   315  	return nil
   316  }
   317  
   318  func (q *sqlQueue) GetDLQAckLevels(
   319  	ctx context.Context,
   320  ) (*persistence.InternalQueueMetadata, error) {
   321  	row, err := q.Db.SelectFromQueueMetadata(ctx, sqlplugin.QueueMetadataFilter{
   322  		QueueType: q.getDLQTypeFromQueueType(),
   323  	})
   324  	if err != nil {
   325  		return nil, serviceerror.NewUnavailable(fmt.Sprintf("GetDLQAckLevels operation failed. Error %v", err))
   326  	}
   327  
   328  	return &persistence.InternalQueueMetadata{
   329  		Blob:    persistence.NewDataBlob(row.Data, row.DataEncoding),
   330  		Version: row.Version,
   331  	}, nil
   332  }
   333  
   334  func (q *sqlQueue) getDLQTypeFromQueueType() persistence.QueueType {
   335  	return -q.queueType
   336  }
   337  
   338  func (q *sqlQueue) initializeQueueMetadata(
   339  	ctx context.Context,
   340  	blob *commonpb.DataBlob,
   341  ) error {
   342  	_, err := q.Db.SelectFromQueueMetadata(ctx, sqlplugin.QueueMetadataFilter{
   343  		QueueType: q.queueType,
   344  	})
   345  	switch err {
   346  	case nil:
   347  		return nil
   348  	case sql.ErrNoRows:
   349  		result, err := q.Db.InsertIntoQueueMetadata(ctx, &sqlplugin.QueueMetadataRow{
   350  			QueueType:    q.queueType,
   351  			Data:         blob.Data,
   352  			DataEncoding: blob.EncodingType.String(),
   353  		})
   354  		if err != nil {
   355  			return serviceerror.NewUnavailable(fmt.Sprintf("initializeQueueMetadata operation failed. Error %v", err))
   356  		}
   357  		rowsAffected, err := result.RowsAffected()
   358  		if err != nil {
   359  			return fmt.Errorf("rowsAffected returned error when initializing queue metadata  %v: %v", q.queueType, err)
   360  		}
   361  		if rowsAffected != 1 {
   362  			return fmt.Errorf("rowsAffected returned %v queue metadata instead of one", rowsAffected)
   363  		}
   364  		return nil
   365  	default:
   366  		return err
   367  	}
   368  }
   369  
   370  func (q *sqlQueue) initializeDLQMetadata(
   371  	ctx context.Context,
   372  	blob *commonpb.DataBlob,
   373  ) error {
   374  	_, err := q.Db.SelectFromQueueMetadata(ctx, sqlplugin.QueueMetadataFilter{
   375  		QueueType: q.getDLQTypeFromQueueType(),
   376  	})
   377  	switch err {
   378  	case nil:
   379  		return nil
   380  	case sql.ErrNoRows:
   381  		result, err := q.Db.InsertIntoQueueMetadata(ctx, &sqlplugin.QueueMetadataRow{
   382  			QueueType:    q.getDLQTypeFromQueueType(),
   383  			Data:         blob.Data,
   384  			DataEncoding: blob.EncodingType.String(),
   385  		})
   386  		if err != nil {
   387  			return serviceerror.NewUnavailable(fmt.Sprintf("initializeDLQMetadata operation failed. Error %v", err))
   388  		}
   389  		rowsAffected, err := result.RowsAffected()
   390  		if err != nil {
   391  			return fmt.Errorf("rowsAffected returned error when initializing DLQ metadata  %v: %v", q.queueType, err)
   392  		}
   393  		if rowsAffected != 1 {
   394  			return fmt.Errorf("rowsAffected returned %v DLQ metadata instead of one", rowsAffected)
   395  		}
   396  		return nil
   397  	default:
   398  		return err
   399  	}
   400  }
   401  
   402  func newQueueRow(
   403  	queueType persistence.QueueType,
   404  	messageID int64,
   405  	blob *commonpb.DataBlob,
   406  ) sqlplugin.QueueMessageRow {
   407  
   408  	return sqlplugin.QueueMessageRow{
   409  		QueueType:       queueType,
   410  		MessageID:       messageID,
   411  		MessagePayload:  blob.Data,
   412  		MessageEncoding: blob.EncodingType.String(),
   413  	}
   414  }