go.temporal.io/server@v1.23.0/common/persistence/sql/sqlplugin/postgresql/task.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package postgresql
    26  
    27  import (
    28  	"context"
    29  	"database/sql"
    30  	"fmt"
    31  
    32  	"go.temporal.io/api/serviceerror"
    33  
    34  	"go.temporal.io/server/common/persistence"
    35  	"go.temporal.io/server/common/persistence/sql/sqlplugin"
    36  )
    37  
    38  const (
    39  	taskQueueCreatePart = `INTO task_queues(range_hash, task_queue_id, range_id, data, data_encoding) ` +
    40  		`VALUES (:range_hash, :task_queue_id, :range_id, :data, :data_encoding)`
    41  
    42  	// (default range ID: initialRangeID == 1)
    43  	createTaskQueueQry = `INSERT ` + taskQueueCreatePart
    44  
    45  	updateTaskQueueQry = `UPDATE task_queues SET
    46  range_id = :range_id,
    47  data = :data,
    48  data_encoding = :data_encoding
    49  WHERE
    50  range_hash = :range_hash AND
    51  task_queue_id = :task_queue_id
    52  `
    53  
    54  	listTaskQueueRowSelect = `SELECT range_hash, task_queue_id, range_id, data, data_encoding from task_queues `
    55  
    56  	listTaskQueueWithHashRangeQry = listTaskQueueRowSelect +
    57  		`WHERE range_hash >= $1 AND range_hash <= $2 AND task_queue_id > $3 ORDER BY task_queue_id ASC LIMIT $4`
    58  
    59  	listTaskQueueQry = listTaskQueueRowSelect +
    60  		`WHERE range_hash = $1 AND task_queue_id > $2 ORDER BY task_queue_id ASC LIMIT $3`
    61  
    62  	getTaskQueueQry = listTaskQueueRowSelect +
    63  		`WHERE range_hash = $1 AND task_queue_id=$2`
    64  
    65  	deleteTaskQueueQry = `DELETE FROM task_queues WHERE range_hash=$1 AND task_queue_id=$2 AND range_id=$3`
    66  
    67  	lockTaskQueueQry = `SELECT range_id FROM task_queues ` +
    68  		`WHERE range_hash=$1 AND task_queue_id=$2 FOR UPDATE`
    69  	// *** Task_Queues Table Above ***
    70  
    71  	// *** Tasks Below ***
    72  	getTaskMinMaxQry = `SELECT task_id, data, data_encoding ` +
    73  		`FROM tasks ` +
    74  		`WHERE range_hash = $1 AND task_queue_id=$2 AND task_id >= $3 AND task_id < $4 ` +
    75  		`ORDER BY task_id LIMIT $5`
    76  
    77  	getTaskMinQry = `SELECT task_id, data, data_encoding ` +
    78  		`FROM tasks ` +
    79  		`WHERE range_hash = $1 AND task_queue_id = $2 AND task_id >= $3 ORDER BY task_id LIMIT $4`
    80  
    81  	createTaskQry = `INSERT INTO ` +
    82  		`tasks(range_hash, task_queue_id, task_id, data, data_encoding) ` +
    83  		`VALUES(:range_hash, :task_queue_id, :task_id, :data, :data_encoding)`
    84  
    85  	deleteTaskQry = `DELETE FROM tasks ` +
    86  		`WHERE range_hash = $1 AND task_queue_id = $2 AND task_id = $3`
    87  
    88  	rangeDeleteTaskQry = `DELETE FROM tasks ` +
    89  		`WHERE range_hash = $1 AND task_queue_id = $2 AND task_id IN (SELECT task_id FROM
    90  		 tasks WHERE range_hash = $1 AND task_queue_id = $2 AND task_id < $3 ` +
    91  		`ORDER BY task_queue_id,task_id LIMIT $4 )`
    92  
    93  	getTaskQueueUserDataQry = `SELECT data, data_encoding, version FROM task_queue_user_data ` +
    94  		`WHERE namespace_id = $1 AND task_queue_name = $2`
    95  
    96  	updateTaskQueueUserDataQry = `UPDATE task_queue_user_data SET ` +
    97  		`data = $1, ` +
    98  		`data_encoding = $2, ` +
    99  		`version = $3 ` +
   100  		`WHERE namespace_id = $4 ` +
   101  		`AND task_queue_name = $5 ` +
   102  		`AND version = $6`
   103  
   104  	insertTaskQueueUserDataQry = `INSERT INTO task_queue_user_data` +
   105  		`(namespace_id, task_queue_name, data, data_encoding, version) ` +
   106  		`VALUES ($1, $2, $3, $4, 1)`
   107  
   108  	listTaskQueueUserDataQry = `SELECT task_queue_name, data, data_encoding, version FROM task_queue_user_data WHERE namespace_id = $1 AND task_queue_name > $2 LIMIT $3`
   109  
   110  	addBuildIdToTaskQueueMappingQry    = `INSERT INTO build_id_to_task_queue (namespace_id, build_id, task_queue_name) VALUES `
   111  	removeBuildIdToTaskQueueMappingQry = `DELETE FROM build_id_to_task_queue WHERE namespace_id = $1 AND task_queue_name = $2 AND build_id IN (`
   112  	listTaskQueuesByBuildIdQry         = `SELECT task_queue_name FROM build_id_to_task_queue WHERE namespace_id = $1 AND build_id = $2`
   113  	countTaskQueuesByBuildIdQry        = `SELECT COUNT(*) FROM build_id_to_task_queue WHERE namespace_id = $1 AND build_id = $2`
   114  )
   115  
   116  // InsertIntoTasks inserts one or more rows into tasks table
   117  func (pdb *db) InsertIntoTasks(
   118  	ctx context.Context,
   119  	rows []sqlplugin.TasksRow,
   120  ) (sql.Result, error) {
   121  	return pdb.conn.NamedExecContext(ctx,
   122  		createTaskQry,
   123  		rows,
   124  	)
   125  }
   126  
   127  // SelectFromTasks reads one or more rows from tasks table
   128  func (pdb *db) SelectFromTasks(
   129  	ctx context.Context,
   130  	filter sqlplugin.TasksFilter,
   131  ) ([]sqlplugin.TasksRow, error) {
   132  	var err error
   133  	var rows []sqlplugin.TasksRow
   134  	switch {
   135  	case filter.ExclusiveMaxTaskID != nil:
   136  		err = pdb.conn.SelectContext(ctx,
   137  			&rows,
   138  			getTaskMinMaxQry,
   139  			filter.RangeHash,
   140  			filter.TaskQueueID,
   141  			*filter.InclusiveMinTaskID,
   142  			*filter.ExclusiveMaxTaskID,
   143  			*filter.PageSize,
   144  		)
   145  	default:
   146  		err = pdb.conn.SelectContext(ctx,
   147  			&rows,
   148  			getTaskMinQry,
   149  			filter.RangeHash,
   150  			filter.TaskQueueID,
   151  			*filter.InclusiveMinTaskID,
   152  			*filter.PageSize,
   153  		)
   154  	}
   155  	return rows, err
   156  }
   157  
   158  // DeleteFromTasks deletes one or more rows from tasks table
   159  func (pdb *db) DeleteFromTasks(
   160  	ctx context.Context,
   161  	filter sqlplugin.TasksFilter,
   162  ) (sql.Result, error) {
   163  	if filter.ExclusiveMaxTaskID != nil {
   164  		if filter.Limit == nil || *filter.Limit == 0 {
   165  			return nil, fmt.Errorf("missing limit parameter")
   166  		}
   167  		return pdb.conn.ExecContext(ctx,
   168  			rangeDeleteTaskQry,
   169  			filter.RangeHash,
   170  			filter.TaskQueueID,
   171  			*filter.ExclusiveMaxTaskID,
   172  			*filter.Limit,
   173  		)
   174  	}
   175  	return pdb.conn.ExecContext(ctx,
   176  		deleteTaskQry,
   177  		filter.RangeHash,
   178  		filter.TaskQueueID,
   179  		*filter.TaskID,
   180  	)
   181  }
   182  
   183  // InsertIntoTaskQueues inserts one or more rows into task_queues table
   184  func (pdb *db) InsertIntoTaskQueues(
   185  	ctx context.Context,
   186  	row *sqlplugin.TaskQueuesRow,
   187  ) (sql.Result, error) {
   188  	return pdb.conn.NamedExecContext(ctx,
   189  		createTaskQueueQry,
   190  		row,
   191  	)
   192  }
   193  
   194  // UpdateTaskQueues updates a row in task_queues table
   195  func (pdb *db) UpdateTaskQueues(
   196  	ctx context.Context,
   197  	row *sqlplugin.TaskQueuesRow,
   198  ) (sql.Result, error) {
   199  	return pdb.conn.NamedExecContext(ctx,
   200  		updateTaskQueueQry,
   201  		row,
   202  	)
   203  }
   204  
   205  // SelectFromTaskQueues reads one or more rows from task_queues table
   206  func (pdb *db) SelectFromTaskQueues(
   207  	ctx context.Context,
   208  	filter sqlplugin.TaskQueuesFilter,
   209  ) ([]sqlplugin.TaskQueuesRow, error) {
   210  	switch {
   211  	case filter.TaskQueueID != nil:
   212  		if filter.RangeHashLessThanEqualTo != 0 || filter.RangeHashGreaterThanEqualTo != 0 {
   213  			return nil, serviceerror.NewInternal("shardID range not supported for specific selection")
   214  		}
   215  		return pdb.selectFromTaskQueues(ctx, filter)
   216  	case filter.RangeHashLessThanEqualTo != 0 && filter.PageSize != nil:
   217  		if filter.RangeHashLessThanEqualTo < filter.RangeHashGreaterThanEqualTo {
   218  			return nil, serviceerror.NewInternal("range of hashes bound is invalid")
   219  		}
   220  		return pdb.rangeSelectFromTaskQueues(ctx, filter)
   221  	case filter.TaskQueueIDGreaterThan != nil && filter.PageSize != nil:
   222  		return pdb.rangeSelectFromTaskQueues(ctx, filter)
   223  	default:
   224  		return nil, serviceerror.NewInternal("invalid set of query filter params")
   225  	}
   226  }
   227  
   228  func (pdb *db) selectFromTaskQueues(
   229  	ctx context.Context,
   230  	filter sqlplugin.TaskQueuesFilter,
   231  ) ([]sqlplugin.TaskQueuesRow, error) {
   232  	var err error
   233  	var row sqlplugin.TaskQueuesRow
   234  	err = pdb.conn.GetContext(ctx,
   235  		&row,
   236  		getTaskQueueQry,
   237  		filter.RangeHash,
   238  		filter.TaskQueueID,
   239  	)
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  	return []sqlplugin.TaskQueuesRow{row}, err
   244  }
   245  
   246  func (pdb *db) rangeSelectFromTaskQueues(
   247  	ctx context.Context,
   248  	filter sqlplugin.TaskQueuesFilter,
   249  ) ([]sqlplugin.TaskQueuesRow, error) {
   250  	var err error
   251  	var rows []sqlplugin.TaskQueuesRow
   252  	if filter.RangeHashLessThanEqualTo > 0 {
   253  		err = pdb.conn.SelectContext(ctx,
   254  			&rows,
   255  			listTaskQueueWithHashRangeQry,
   256  			filter.RangeHashGreaterThanEqualTo,
   257  			filter.RangeHashLessThanEqualTo,
   258  			filter.TaskQueueIDGreaterThan,
   259  			*filter.PageSize,
   260  		)
   261  	} else {
   262  		err = pdb.conn.SelectContext(ctx,
   263  			&rows,
   264  			listTaskQueueQry,
   265  			filter.RangeHash,
   266  			filter.TaskQueueIDGreaterThan,
   267  			*filter.PageSize,
   268  		)
   269  	}
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	return rows, nil
   275  }
   276  
   277  // DeleteFromTaskQueues deletes a row from task_queues table
   278  func (pdb *db) DeleteFromTaskQueues(
   279  	ctx context.Context,
   280  	filter sqlplugin.TaskQueuesFilter,
   281  ) (sql.Result, error) {
   282  	return pdb.conn.ExecContext(ctx,
   283  		deleteTaskQueueQry,
   284  		filter.RangeHash,
   285  		filter.TaskQueueID,
   286  		*filter.RangeID,
   287  	)
   288  }
   289  
   290  // LockTaskQueues locks a row in task_queues table
   291  func (pdb *db) LockTaskQueues(
   292  	ctx context.Context,
   293  	filter sqlplugin.TaskQueuesFilter,
   294  ) (int64, error) {
   295  	var rangeID int64
   296  	err := pdb.conn.GetContext(ctx,
   297  		&rangeID,
   298  		lockTaskQueueQry,
   299  		filter.RangeHash,
   300  		filter.TaskQueueID,
   301  	)
   302  	return rangeID, err
   303  }
   304  
   305  func (pdb *db) GetTaskQueueUserData(ctx context.Context, request *sqlplugin.GetTaskQueueUserDataRequest) (*sqlplugin.VersionedBlob, error) {
   306  	var row sqlplugin.VersionedBlob
   307  	err := pdb.conn.GetContext(ctx, &row, getTaskQueueUserDataQry, request.NamespaceID, request.TaskQueueName)
   308  	return &row, err
   309  }
   310  
   311  func (pdb *db) UpdateTaskQueueUserData(ctx context.Context, request *sqlplugin.UpdateTaskQueueDataRequest) error {
   312  	if request.Version == 0 {
   313  		_, err := pdb.conn.ExecContext(
   314  			ctx,
   315  			insertTaskQueueUserDataQry,
   316  			request.NamespaceID,
   317  			request.TaskQueueName,
   318  			request.Data,
   319  			request.DataEncoding)
   320  		return err
   321  	}
   322  	result, err := pdb.conn.ExecContext(
   323  		ctx,
   324  		updateTaskQueueUserDataQry,
   325  		request.Data,
   326  		request.DataEncoding,
   327  		request.Version+1,
   328  		request.NamespaceID,
   329  		request.TaskQueueName,
   330  		request.Version)
   331  	if err != nil {
   332  		return err
   333  	}
   334  	numRows, err := result.RowsAffected()
   335  	if err != nil {
   336  		return err
   337  	}
   338  	if numRows != 1 {
   339  		return &persistence.ConditionFailedError{Msg: "Expected exactly one row to be updated"}
   340  	}
   341  	return nil
   342  }
   343  
   344  func (pdb *db) AddToBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddToBuildIdToTaskQueueMapping) error {
   345  	query := addBuildIdToTaskQueueMappingQry
   346  	var params []any
   347  	for idx, buildId := range request.BuildIds {
   348  		query += fmt.Sprintf("($%d, $%d, $%d)", idx*3+1, idx*3+2, idx*3+3)
   349  		if idx < len(request.BuildIds)-1 {
   350  			query += ", "
   351  		}
   352  		params = append(params, request.NamespaceID, buildId, request.TaskQueueName)
   353  	}
   354  
   355  	_, err := pdb.conn.ExecContext(ctx, query, params...)
   356  	return err
   357  }
   358  
   359  func (pdb *db) RemoveFromBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveFromBuildIdToTaskQueueMapping) error {
   360  	query := removeBuildIdToTaskQueueMappingQry
   361  	// Golang doesn't support appending a string slice to an any slice which is essentially what we're doing here.
   362  	params := make([]any, len(request.BuildIds)+2)
   363  	params[0] = request.NamespaceID
   364  	params[1] = request.TaskQueueName
   365  	for idx, buildId := range request.BuildIds {
   366  		sep := ", "
   367  		if idx == len(request.BuildIds)-1 {
   368  			sep = ")"
   369  		}
   370  		query += fmt.Sprintf("$%d%s", idx+3, sep)
   371  		params[idx+2] = buildId
   372  	}
   373  
   374  	_, err := pdb.conn.ExecContext(ctx, query, params...)
   375  	return err
   376  }
   377  
   378  func (pdb *db) ListTaskQueueUserDataEntries(ctx context.Context, request *sqlplugin.ListTaskQueueUserDataEntriesRequest) ([]sqlplugin.TaskQueueUserDataEntry, error) {
   379  	var rows []sqlplugin.TaskQueueUserDataEntry
   380  	err := pdb.conn.SelectContext(ctx, &rows, listTaskQueueUserDataQry, request.NamespaceID, request.LastTaskQueueName, request.Limit)
   381  	return rows, err
   382  }
   383  
   384  func (pdb *db) GetTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.GetTaskQueuesByBuildIdRequest) ([]string, error) {
   385  	var rows []struct {
   386  		TaskQueueName string
   387  	}
   388  
   389  	err := pdb.conn.SelectContext(ctx, &rows, listTaskQueuesByBuildIdQry, request.NamespaceID, request.BuildID)
   390  	taskQueues := make([]string, len(rows))
   391  	for i, row := range rows {
   392  		taskQueues[i] = row.TaskQueueName
   393  	}
   394  	return taskQueues, err
   395  }
   396  
   397  func (pdb *db) CountTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.CountTaskQueuesByBuildIdRequest) (int, error) {
   398  	var count int
   399  	err := pdb.conn.GetContext(ctx, &count, countTaskQueuesByBuildIdQry, request.NamespaceID, request.BuildID)
   400  	return count, err
   401  }