go.temporal.io/server@v1.23.0/common/persistence/task_manager.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package persistence
    26  
    27  import (
    28  	"context"
    29  	"fmt"
    30  
    31  	enumspb "go.temporal.io/api/enums/v1"
    32  	"go.temporal.io/api/serviceerror"
    33  
    34  	persistencespb "go.temporal.io/server/api/persistence/v1"
    35  	"go.temporal.io/server/common/persistence/serialization"
    36  	"go.temporal.io/server/common/primitives/timestamp"
    37  )
    38  
    39  type taskManagerImpl struct {
    40  	taskStore  TaskStore
    41  	serializer serialization.Serializer
    42  }
    43  
    44  // NewTaskManager creates a new instance of TaskManager
    45  func NewTaskManager(
    46  	store TaskStore,
    47  	serializer serialization.Serializer,
    48  ) TaskManager {
    49  	return &taskManagerImpl{
    50  		taskStore:  store,
    51  		serializer: serializer,
    52  	}
    53  }
    54  
    55  func (m *taskManagerImpl) Close() {
    56  	m.taskStore.Close()
    57  }
    58  
    59  func (m *taskManagerImpl) GetName() string {
    60  	return m.taskStore.GetName()
    61  }
    62  
    63  func (m *taskManagerImpl) CreateTaskQueue(
    64  	ctx context.Context,
    65  	request *CreateTaskQueueRequest,
    66  ) (*CreateTaskQueueResponse, error) {
    67  	taskQueueInfo := request.TaskQueueInfo
    68  	if taskQueueInfo.LastUpdateTime == nil {
    69  		panic("CreateTaskQueue encountered LastUpdateTime not set")
    70  	}
    71  	if taskQueueInfo.ExpiryTime == nil && taskQueueInfo.GetKind() == enumspb.TASK_QUEUE_KIND_STICKY {
    72  		panic("CreateTaskQueue encountered ExpiryTime not set for sticky task queue")
    73  	}
    74  	taskQueueInfoBlob, err := m.serializer.TaskQueueInfoToBlob(taskQueueInfo, enumspb.ENCODING_TYPE_PROTO3)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	internalRequest := &InternalCreateTaskQueueRequest{
    80  		NamespaceID:   request.TaskQueueInfo.GetNamespaceId(),
    81  		TaskQueue:     request.TaskQueueInfo.GetName(),
    82  		TaskType:      request.TaskQueueInfo.GetTaskType(),
    83  		TaskQueueKind: request.TaskQueueInfo.GetKind(),
    84  		RangeID:       request.RangeID,
    85  		ExpiryTime:    taskQueueInfo.ExpiryTime,
    86  		TaskQueueInfo: taskQueueInfoBlob,
    87  	}
    88  	if err := m.taskStore.CreateTaskQueue(ctx, internalRequest); err != nil {
    89  		return nil, err
    90  	}
    91  	return &CreateTaskQueueResponse{}, nil
    92  }
    93  
    94  func (m *taskManagerImpl) UpdateTaskQueue(
    95  	ctx context.Context,
    96  	request *UpdateTaskQueueRequest,
    97  ) (*UpdateTaskQueueResponse, error) {
    98  	taskQueueInfo := request.TaskQueueInfo
    99  	if taskQueueInfo.LastUpdateTime == nil {
   100  		panic("UpdateTaskQueue encountered LastUpdateTime not set")
   101  	}
   102  	if taskQueueInfo.ExpiryTime == nil && taskQueueInfo.GetKind() == enumspb.TASK_QUEUE_KIND_STICKY {
   103  		panic("UpdateTaskQueue encountered ExpiryTime not set for sticky task queue")
   104  	}
   105  	taskQueueInfoBlob, err := m.serializer.TaskQueueInfoToBlob(taskQueueInfo, enumspb.ENCODING_TYPE_PROTO3)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	internalRequest := &InternalUpdateTaskQueueRequest{
   111  		NamespaceID:   request.TaskQueueInfo.GetNamespaceId(),
   112  		TaskQueue:     request.TaskQueueInfo.GetName(),
   113  		TaskType:      request.TaskQueueInfo.GetTaskType(),
   114  		RangeID:       request.RangeID,
   115  		TaskQueueInfo: taskQueueInfoBlob,
   116  
   117  		TaskQueueKind: request.TaskQueueInfo.GetKind(),
   118  		ExpiryTime:    taskQueueInfo.ExpiryTime,
   119  
   120  		PrevRangeID: request.PrevRangeID,
   121  	}
   122  	return m.taskStore.UpdateTaskQueue(ctx, internalRequest)
   123  }
   124  
   125  func (m *taskManagerImpl) GetTaskQueue(
   126  	ctx context.Context,
   127  	request *GetTaskQueueRequest,
   128  ) (*GetTaskQueueResponse, error) {
   129  	response, err := m.taskStore.GetTaskQueue(ctx, &InternalGetTaskQueueRequest{
   130  		NamespaceID: request.NamespaceID,
   131  		TaskQueue:   request.TaskQueue,
   132  		TaskType:    request.TaskType,
   133  	})
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	taskQueueInfo, err := m.serializer.TaskQueueInfoFromBlob(response.TaskQueueInfo)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  	return &GetTaskQueueResponse{
   143  		TaskQueueInfo: taskQueueInfo,
   144  		RangeID:       response.RangeID,
   145  	}, nil
   146  }
   147  
   148  func (m *taskManagerImpl) ListTaskQueue(
   149  	ctx context.Context,
   150  	request *ListTaskQueueRequest,
   151  ) (*ListTaskQueueResponse, error) {
   152  	internalResp, err := m.taskStore.ListTaskQueue(ctx, request)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	taskQueues := make([]*PersistedTaskQueueInfo, len(internalResp.Items))
   157  	for i, item := range internalResp.Items {
   158  		tqi, err := m.serializer.TaskQueueInfoFromBlob(item.TaskQueue)
   159  		if err != nil {
   160  			return nil, err
   161  		}
   162  		taskQueues[i] = &PersistedTaskQueueInfo{
   163  			Data:    tqi,
   164  			RangeID: item.RangeID,
   165  		}
   166  
   167  	}
   168  	return &ListTaskQueueResponse{
   169  		Items:         taskQueues,
   170  		NextPageToken: internalResp.NextPageToken,
   171  	}, nil
   172  }
   173  
   174  func (m *taskManagerImpl) DeleteTaskQueue(
   175  	ctx context.Context,
   176  	request *DeleteTaskQueueRequest,
   177  ) error {
   178  	return m.taskStore.DeleteTaskQueue(ctx, request)
   179  }
   180  
   181  func (m *taskManagerImpl) CreateTasks(
   182  	ctx context.Context,
   183  	request *CreateTasksRequest,
   184  ) (*CreateTasksResponse, error) {
   185  	taskQueueInfo := request.TaskQueueInfo.Data
   186  	taskQueueInfo.LastUpdateTime = timestamp.TimeNowPtrUtc()
   187  	taskQueueInfoBlob, err := m.serializer.TaskQueueInfoToBlob(taskQueueInfo, enumspb.ENCODING_TYPE_PROTO3)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  
   192  	tasks := make([]*InternalCreateTask, len(request.Tasks))
   193  	for i, task := range request.Tasks {
   194  		taskBlob, err := m.serializer.TaskInfoToBlob(task, enumspb.ENCODING_TYPE_PROTO3)
   195  		if err != nil {
   196  			return nil, serviceerror.NewUnavailable(fmt.Sprintf("CreateTasks operation failed during serialization. Error : %v", err))
   197  		}
   198  		tasks[i] = &InternalCreateTask{
   199  			TaskId:     task.GetTaskId(),
   200  			ExpiryTime: task.Data.ExpiryTime,
   201  			Task:       taskBlob,
   202  		}
   203  	}
   204  	internalRequest := &InternalCreateTasksRequest{
   205  		NamespaceID:   request.TaskQueueInfo.Data.GetNamespaceId(),
   206  		TaskQueue:     request.TaskQueueInfo.Data.GetName(),
   207  		TaskType:      request.TaskQueueInfo.Data.GetTaskType(),
   208  		RangeID:       request.TaskQueueInfo.RangeID,
   209  		TaskQueueInfo: taskQueueInfoBlob,
   210  		Tasks:         tasks,
   211  	}
   212  	return m.taskStore.CreateTasks(ctx, internalRequest)
   213  }
   214  
   215  func (m *taskManagerImpl) GetTasks(
   216  	ctx context.Context,
   217  	request *GetTasksRequest,
   218  ) (*GetTasksResponse, error) {
   219  	if request.InclusiveMinTaskID >= request.ExclusiveMaxTaskID {
   220  		return &GetTasksResponse{}, nil
   221  	}
   222  
   223  	internalResp, err := m.taskStore.GetTasks(ctx, request)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	tasks := make([]*persistencespb.AllocatedTaskInfo, len(internalResp.Tasks))
   228  	for i, taskBlob := range internalResp.Tasks {
   229  		task, err := m.serializer.TaskInfoFromBlob(taskBlob)
   230  		if err != nil {
   231  			return nil, serviceerror.NewUnavailable(fmt.Sprintf("GetTasks failed to deserialize task: %s", err.Error()))
   232  		}
   233  		tasks[i] = task
   234  	}
   235  	return &GetTasksResponse{Tasks: tasks, NextPageToken: internalResp.NextPageToken}, nil
   236  }
   237  
   238  func (m *taskManagerImpl) CompleteTask(
   239  	ctx context.Context,
   240  	request *CompleteTaskRequest,
   241  ) error {
   242  	return m.taskStore.CompleteTask(ctx, request)
   243  }
   244  
   245  func (m *taskManagerImpl) CompleteTasksLessThan(
   246  	ctx context.Context,
   247  	request *CompleteTasksLessThanRequest,
   248  ) (int, error) {
   249  	return m.taskStore.CompleteTasksLessThan(ctx, request)
   250  }
   251  
   252  // GetTaskQueueUserData implements TaskManager
   253  func (m *taskManagerImpl) GetTaskQueueUserData(ctx context.Context, request *GetTaskQueueUserDataRequest) (*GetTaskQueueUserDataResponse, error) {
   254  	response, err := m.taskStore.GetTaskQueueUserData(ctx, request)
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  	data, err := m.serializer.TaskQueueUserDataFromBlob(response.UserData)
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  	return &GetTaskQueueUserDataResponse{UserData: &persistencespb.VersionedTaskQueueUserData{Version: response.Version, Data: data}}, nil
   263  }
   264  
   265  // UpdateTaskQueueUserData implements TaskManager
   266  func (m *taskManagerImpl) UpdateTaskQueueUserData(ctx context.Context, request *UpdateTaskQueueUserDataRequest) error {
   267  	userData, err := m.serializer.TaskQueueUserDataToBlob(request.UserData.Data, enumspb.ENCODING_TYPE_PROTO3)
   268  	if err != nil {
   269  		return err
   270  	}
   271  	internalRequest := &InternalUpdateTaskQueueUserDataRequest{
   272  		NamespaceID:     request.NamespaceID,
   273  		TaskQueue:       request.TaskQueue,
   274  		Version:         request.UserData.Version,
   275  		UserData:        userData,
   276  		BuildIdsAdded:   request.BuildIdsAdded,
   277  		BuildIdsRemoved: request.BuildIdsRemoved,
   278  	}
   279  	return m.taskStore.UpdateTaskQueueUserData(ctx, internalRequest)
   280  }
   281  
   282  func (m *taskManagerImpl) ListTaskQueueUserDataEntries(ctx context.Context, request *ListTaskQueueUserDataEntriesRequest) (*ListTaskQueueUserDataEntriesResponse, error) {
   283  	response, err := m.taskStore.ListTaskQueueUserDataEntries(ctx, request)
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  	entries := make([]*TaskQueueUserDataEntry, len(response.Entries))
   288  	for i, entry := range response.Entries {
   289  		data, err := m.serializer.TaskQueueUserDataFromBlob(entry.Data)
   290  		if err != nil {
   291  			return nil, err
   292  		}
   293  		entries[i] = &TaskQueueUserDataEntry{
   294  			TaskQueue: entry.TaskQueue,
   295  			UserData: &persistencespb.VersionedTaskQueueUserData{
   296  				Data:    data,
   297  				Version: entry.Version,
   298  			},
   299  		}
   300  	}
   301  	return &ListTaskQueueUserDataEntriesResponse{
   302  		NextPageToken: response.NextPageToken,
   303  		Entries:       entries,
   304  	}, nil
   305  }
   306  
   307  func (m *taskManagerImpl) GetTaskQueuesByBuildId(ctx context.Context, request *GetTaskQueuesByBuildIdRequest) ([]string, error) {
   308  	return m.taskStore.GetTaskQueuesByBuildId(ctx, request)
   309  }
   310  
   311  func (m *taskManagerImpl) CountTaskQueuesByBuildId(ctx context.Context, request *CountTaskQueuesByBuildIdRequest) (int, error) {
   312  	return m.taskStore.CountTaskQueuesByBuildId(ctx, request)
   313  }