go.temporal.io/server@v1.23.0/common/persistence/tests/history_task_queue_manager_test_suite.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package tests
    26  
    27  import (
    28  	"context"
    29  	"errors"
    30  	"testing"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	commonpb "go.temporal.io/api/common/v1"
    35  	"go.temporal.io/api/enums/v1"
    36  	persistencespb "go.temporal.io/server/api/persistence/v1"
    37  	"go.temporal.io/server/client/history/historytest"
    38  	"go.temporal.io/server/common"
    39  	"go.temporal.io/server/common/definition"
    40  	"go.temporal.io/server/common/persistence"
    41  	"go.temporal.io/server/common/persistence/persistencetest"
    42  	"go.temporal.io/server/service/history/api/deletedlqtasks/deletedlqtaskstest"
    43  	"go.temporal.io/server/service/history/api/getdlqtasks/getdlqtaskstest"
    44  	"go.temporal.io/server/service/history/api/listqueues/listqueuestest"
    45  	"go.temporal.io/server/service/history/tasks"
    46  )
    47  
    48  type (
    49  	faultyQueue struct {
    50  		base                   persistence.QueueV2
    51  		enqueueErr             error
    52  		readMessagesErr        error
    53  		createQueueErr         error
    54  		rangeDeleteMessagesErr error
    55  	}
    56  )
    57  
    58  func (q faultyQueue) EnqueueMessage(
    59  	ctx context.Context,
    60  	req *persistence.InternalEnqueueMessageRequest,
    61  ) (*persistence.InternalEnqueueMessageResponse, error) {
    62  	if q.enqueueErr != nil {
    63  		return nil, q.enqueueErr
    64  	}
    65  	return q.base.EnqueueMessage(ctx, req)
    66  }
    67  
    68  func (q faultyQueue) ReadMessages(
    69  	ctx context.Context,
    70  	req *persistence.InternalReadMessagesRequest,
    71  ) (*persistence.InternalReadMessagesResponse, error) {
    72  	if q.readMessagesErr != nil {
    73  		return nil, q.readMessagesErr
    74  	}
    75  	return q.base.ReadMessages(ctx, req)
    76  }
    77  
    78  func (q faultyQueue) CreateQueue(
    79  	ctx context.Context,
    80  	req *persistence.InternalCreateQueueRequest,
    81  ) (*persistence.InternalCreateQueueResponse, error) {
    82  	if q.createQueueErr != nil {
    83  		return nil, q.createQueueErr
    84  	}
    85  	return q.base.CreateQueue(ctx, req)
    86  }
    87  
    88  func (q faultyQueue) RangeDeleteMessages(
    89  	ctx context.Context,
    90  	req *persistence.InternalRangeDeleteMessagesRequest,
    91  ) (*persistence.InternalRangeDeleteMessagesResponse, error) {
    92  	if q.rangeDeleteMessagesErr != nil {
    93  		return nil, q.rangeDeleteMessagesErr
    94  	}
    95  	return q.base.RangeDeleteMessages(ctx, req)
    96  }
    97  
    98  func (q faultyQueue) ListQueues(
    99  	ctx context.Context,
   100  	req *persistence.InternalListQueuesRequest,
   101  ) (*persistence.InternalListQueuesResponse, error) {
   102  	if q.rangeDeleteMessagesErr != nil {
   103  		return nil, q.rangeDeleteMessagesErr
   104  	}
   105  	return q.base.ListQueues(ctx, req)
   106  }
   107  
   108  // RunHistoryTaskQueueManagerTestSuite runs all tests for the history task queue manager against a given queue provided by a
   109  // particular database. This test suite should be re-used to test all queue implementations.
   110  func RunHistoryTaskQueueManagerTestSuite(t *testing.T, queue persistence.QueueV2) {
   111  	historyTaskQueueManager := persistence.NewHistoryTaskQueueManager(queue)
   112  	t.Run("ListQueues", func(t *testing.T) {
   113  		listqueuestest.TestInvoke(t, historyTaskQueueManager)
   114  	})
   115  	t.Run("TestHistoryTaskQueueManagerEnqueueTasks", func(t *testing.T) {
   116  		t.Parallel()
   117  		testHistoryTaskQueueManagerEnqueueTasks(t, historyTaskQueueManager)
   118  	})
   119  	t.Run("TestHistoryTaskQueueManagerEnqueueTasksErr", func(t *testing.T) {
   120  		t.Parallel()
   121  		testHistoryTaskQueueManagerEnqueueTasksErr(t, queue)
   122  	})
   123  	t.Run("TestHistoryTaskQueueManagerCreateQueueErr", func(t *testing.T) {
   124  		t.Parallel()
   125  		testHistoryTaskQueueManagerCreateQueueErr(t, queue)
   126  	})
   127  	t.Run("TestHistoryTQMErrDeserializeTask", func(t *testing.T) {
   128  		t.Parallel()
   129  		testHistoryTaskQueueManagerErrDeserializeHistoryTask(t, queue, historyTaskQueueManager)
   130  	})
   131  	t.Run("DeleteTasks", func(t *testing.T) {
   132  		t.Parallel()
   133  		testHistoryTaskQueueManagerDeleteTasks(t, historyTaskQueueManager)
   134  	})
   135  	t.Run("DeleteTasksErr", func(t *testing.T) {
   136  		t.Parallel()
   137  		testHistoryTaskQueueManagerDeleteTasksErr(t, queue)
   138  	})
   139  	t.Run("GetDLQTasks", func(t *testing.T) {
   140  		t.Parallel()
   141  		getdlqtaskstest.TestInvoke(t, historyTaskQueueManager)
   142  	})
   143  	t.Run("DeleteDLQTasks", func(t *testing.T) {
   144  		t.Parallel()
   145  		deletedlqtaskstest.TestInvoke(t, historyTaskQueueManager)
   146  	})
   147  	t.Run("ClientTest", func(t *testing.T) {
   148  		t.Parallel()
   149  		historytest.TestClient(t, historyTaskQueueManager)
   150  	})
   151  }
   152  
   153  func testHistoryTaskQueueManagerCreateQueueErr(t *testing.T, queue persistence.QueueV2) {
   154  	retErr := errors.New("test")
   155  	manager := persistence.NewHistoryTaskQueueManager(faultyQueue{
   156  		base:           queue,
   157  		createQueueErr: retErr,
   158  	})
   159  	_, err := manager.CreateQueue(context.Background(), &persistence.CreateQueueRequest{
   160  		QueueKey: persistencetest.GetQueueKey(t),
   161  	})
   162  	assert.ErrorIs(t, err, retErr)
   163  }
   164  
   165  func testHistoryTaskQueueManagerEnqueueTasks(t *testing.T, manager persistence.HistoryTaskQueueManager) {
   166  	numHistoryShards := 5
   167  	ctx := context.Background()
   168  
   169  	namespaceID := "test-namespace"
   170  	workflowID := "test-workflow-id"
   171  	workflowKey := definition.NewWorkflowKey(namespaceID, workflowID, "test-run-id")
   172  	shardID := 2
   173  	assert.Equal(t, int32(shardID), common.WorkflowIDToHistoryShard(namespaceID, workflowID, int32(numHistoryShards)))
   174  
   175  	queueKey := persistencetest.GetQueueKey(t)
   176  	_, err := manager.CreateQueue(ctx, &persistence.CreateQueueRequest{
   177  		QueueKey: queueKey,
   178  	})
   179  	require.NoError(t, err)
   180  
   181  	for i := 0; i < 2; i++ {
   182  		task := &tasks.WorkflowTask{
   183  			WorkflowKey: workflowKey,
   184  			TaskID:      int64(i + 1),
   185  		}
   186  		res, err := enqueueTask(ctx, manager, queueKey, task)
   187  		require.NoError(t, err)
   188  		assert.Equal(t, int64(persistence.FirstQueueMessageID+i), res.Metadata.ID)
   189  	}
   190  
   191  	var nextPageToken []byte
   192  	for i := 0; i < 3; i++ {
   193  		readRes, err := manager.ReadTasks(ctx, &persistence.ReadTasksRequest{
   194  			QueueKey:      queueKey,
   195  			PageSize:      1,
   196  			NextPageToken: nextPageToken,
   197  		})
   198  		require.NoError(t, err)
   199  
   200  		if i < 2 {
   201  			require.Len(t, readRes.Tasks, 1)
   202  			assert.Equal(t, shardID, tasks.GetShardIDForTask(readRes.Tasks[0].Task, numHistoryShards))
   203  			assert.Equal(t, int64(i+1), readRes.Tasks[0].Task.GetTaskID())
   204  			nextPageToken = readRes.NextPageToken
   205  		} else {
   206  			assert.Empty(t, readRes.Tasks)
   207  			assert.Empty(t, readRes.NextPageToken)
   208  		}
   209  	}
   210  }
   211  
   212  func testHistoryTaskQueueManagerEnqueueTasksErr(t *testing.T, queue persistence.QueueV2) {
   213  	ctx := context.Background()
   214  
   215  	retErr := errors.New("test")
   216  	manager := persistence.NewHistoryTaskQueueManager(faultyQueue{
   217  		base:       queue,
   218  		enqueueErr: retErr,
   219  	})
   220  	queueKey := persistencetest.GetQueueKey(t)
   221  	_, err := manager.CreateQueue(ctx, &persistence.CreateQueueRequest{
   222  		QueueKey: queueKey,
   223  	})
   224  	require.NoError(t, err)
   225  	_, err = enqueueTask(ctx, manager, queueKey, &tasks.WorkflowTask{
   226  		TaskID: 1,
   227  	})
   228  	assert.ErrorIs(t, err, retErr)
   229  }
   230  
   231  func testHistoryTaskQueueManagerErrDeserializeHistoryTask(
   232  	t *testing.T,
   233  	queue persistence.QueueV2,
   234  	manager persistence.HistoryTaskQueueManager,
   235  ) {
   236  	ctx := context.Background()
   237  
   238  	t.Run("nil blob", func(t *testing.T) {
   239  		t.Parallel()
   240  
   241  		err := enqueueAndDeserializeBlob(ctx, t, queue, manager, nil)
   242  		assert.ErrorContains(t, err, persistence.ErrHistoryTaskBlobIsNil.Error())
   243  	})
   244  	t.Run("empty blob", func(t *testing.T) {
   245  		t.Parallel()
   246  
   247  		err := enqueueAndDeserializeBlob(ctx, t, queue, manager, &commonpb.DataBlob{})
   248  		assert.ErrorContains(t, err, persistence.ErrMsgDeserializeHistoryTask)
   249  	})
   250  }
   251  
   252  func testHistoryTaskQueueManagerDeleteTasks(t *testing.T, manager *persistence.HistoryTaskQueueManagerImpl) {
   253  	ctx := context.Background()
   254  
   255  	queueKey := persistencetest.GetQueueKey(t)
   256  	_, err := manager.CreateQueue(ctx, &persistence.CreateQueueRequest{
   257  		QueueKey: queueKey,
   258  	})
   259  	require.NoError(t, err)
   260  	for i := 0; i < 2; i++ {
   261  		_, err := enqueueTask(ctx, manager, queueKey, &tasks.WorkflowTask{
   262  			TaskID: int64(i + 1),
   263  		})
   264  		require.NoError(t, err)
   265  	}
   266  	_, err = manager.DeleteTasks(ctx, &persistence.DeleteTasksRequest{
   267  		QueueKey: queueKey,
   268  		InclusiveMaxMessageMetadata: persistence.MessageMetadata{
   269  			ID: persistence.FirstQueueMessageID,
   270  		},
   271  	})
   272  	require.NoError(t, err)
   273  	res, err := manager.ReadTasks(ctx, &persistence.ReadTasksRequest{
   274  		QueueKey: queueKey,
   275  		PageSize: 10,
   276  	})
   277  	require.NoError(t, err)
   278  	require.Len(t, res.Tasks, 1)
   279  	assert.Equal(t, int64(2), res.Tasks[0].Task.GetTaskID())
   280  }
   281  
   282  func enqueueAndDeserializeBlob(
   283  	ctx context.Context,
   284  	t *testing.T,
   285  	queue persistence.QueueV2,
   286  	manager persistence.HistoryTaskQueueManager,
   287  	blob *commonpb.DataBlob,
   288  ) error {
   289  	t.Helper()
   290  
   291  	queueType := persistence.QueueTypeHistoryNormal
   292  	queueKey := persistencetest.GetQueueKey(t)
   293  	queueName := queueKey.GetQueueName()
   294  
   295  	_, err := queue.CreateQueue(ctx, &persistence.InternalCreateQueueRequest{
   296  		QueueType: queueType,
   297  		QueueName: queueKey.GetQueueName(),
   298  	})
   299  	require.NoError(t, err)
   300  	historyTask := persistencespb.HistoryTask{
   301  		ShardId: 1,
   302  		Blob:    blob,
   303  	}
   304  	historyTaskBytes, _ := historyTask.Marshal()
   305  	_, err = queue.EnqueueMessage(ctx, &persistence.InternalEnqueueMessageRequest{
   306  		QueueType: queueType,
   307  		QueueName: queueName,
   308  		Blob: &commonpb.DataBlob{
   309  			EncodingType: enums.ENCODING_TYPE_PROTO3,
   310  			Data:         historyTaskBytes,
   311  		},
   312  	})
   313  	require.NoError(t, err)
   314  
   315  	_, err = manager.ReadTasks(ctx, &persistence.ReadTasksRequest{
   316  		QueueKey: queueKey,
   317  		PageSize: 1,
   318  	})
   319  	return err
   320  }
   321  
   322  func testHistoryTaskQueueManagerDeleteTasksErr(t *testing.T, queue persistence.QueueV2) {
   323  	ctx := context.Background()
   324  
   325  	retErr := errors.New("test")
   326  	manager := persistence.NewHistoryTaskQueueManager(faultyQueue{
   327  		base:                   queue,
   328  		rangeDeleteMessagesErr: retErr,
   329  	})
   330  	queueKey := persistencetest.GetQueueKey(t)
   331  	_, err := manager.CreateQueue(ctx, &persistence.CreateQueueRequest{
   332  		QueueKey: queueKey,
   333  	})
   334  	require.NoError(t, err)
   335  	_, err = enqueueTask(ctx, manager, queueKey, &tasks.WorkflowTask{
   336  		TaskID: 1,
   337  	})
   338  	require.NoError(t, err)
   339  	_, err = manager.DeleteTasks(ctx, &persistence.DeleteTasksRequest{
   340  		QueueKey: queueKey,
   341  		InclusiveMaxMessageMetadata: persistence.MessageMetadata{
   342  			ID: persistence.FirstQueueMessageID,
   343  		},
   344  	})
   345  	assert.ErrorIs(t, err, retErr)
   346  }
   347  
   348  func enqueueTask(
   349  	ctx context.Context,
   350  	manager persistence.HistoryTaskQueueManager,
   351  	queueKey persistence.QueueKey,
   352  	task *tasks.WorkflowTask,
   353  ) (*persistence.EnqueueTaskResponse, error) {
   354  	return manager.EnqueueTask(ctx, &persistence.EnqueueTaskRequest{
   355  		QueueType:     queueKey.QueueType,
   356  		SourceCluster: queueKey.SourceCluster,
   357  		TargetCluster: queueKey.TargetCluster,
   358  		Task:          task,
   359  		SourceShardID: 1,
   360  	})
   361  }