go.temporal.io/server@v1.23.0/common/persistence/sql/sqlplugin/sqlite/task.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2021 Datadog, Inc.
     4  //
     5  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     6  //
     7  // Copyright (c) 2020 Uber Technologies, Inc.
     8  //
     9  // Permission is hereby granted, free of charge, to any person obtaining a copy
    10  // of this software and associated documentation files (the "Software"), to deal
    11  // in the Software without restriction, including without limitation the rights
    12  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    13  // copies of the Software, and to permit persons to whom the Software is
    14  // furnished to do so, subject to the following conditions:
    15  //
    16  // The above copyright notice and this permission notice shall be included in
    17  // all copies or substantial portions of the Software.
    18  //
    19  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    20  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    21  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    22  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    23  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    24  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    25  // THE SOFTWARE.
    26  
    27  package sqlite
    28  
    29  import (
    30  	"context"
    31  	"database/sql"
    32  	"fmt"
    33  	"strings"
    34  
    35  	"go.temporal.io/api/serviceerror"
    36  
    37  	"go.temporal.io/server/common/persistence"
    38  	"go.temporal.io/server/common/persistence/sql/sqlplugin"
    39  )
    40  
    41  const (
    42  	taskQueueCreatePart = `INTO task_queues(range_hash, task_queue_id, range_id, data, data_encoding) ` +
    43  		`VALUES (:range_hash, :task_queue_id, :range_id, :data, :data_encoding)`
    44  
    45  	// (default range ID: initialRangeID == 1)
    46  	createTaskQueueQry = `INSERT ` + taskQueueCreatePart
    47  
    48  	updateTaskQueueQry = `UPDATE task_queues SET
    49  range_id = :range_id,
    50  data = :data,
    51  data_encoding = :data_encoding
    52  WHERE
    53  range_hash = :range_hash AND
    54  task_queue_id = :task_queue_id
    55  `
    56  
    57  	listTaskQueueRowSelect = `SELECT range_hash, task_queue_id, range_id, data, data_encoding from task_queues `
    58  
    59  	listTaskQueueWithHashRangeQry = listTaskQueueRowSelect +
    60  		`WHERE range_hash >= ? AND range_hash <= ? AND task_queue_id > ? ORDER BY task_queue_id ASC LIMIT ?`
    61  
    62  	listTaskQueueQry = listTaskQueueRowSelect +
    63  		`WHERE range_hash = ? AND task_queue_id > ? ORDER BY task_queue_id ASC LIMIT ?`
    64  
    65  	getTaskQueueQry = listTaskQueueRowSelect +
    66  		`WHERE range_hash = ? AND task_queue_id = ?`
    67  
    68  	deleteTaskQueueQry = `DELETE FROM task_queues WHERE range_hash=? AND task_queue_id=? AND range_id=?`
    69  
    70  	lockTaskQueueQry = `SELECT range_id FROM task_queues ` +
    71  		`WHERE range_hash = ? AND task_queue_id = ?`
    72  	// *** Task_Queues Table Above ***
    73  
    74  	// *** Tasks Below ***
    75  	getTaskMinMaxQry = `SELECT task_id, data, data_encoding ` +
    76  		`FROM tasks ` +
    77  		`WHERE range_hash = ? AND task_queue_id = ? AND task_id >= ? AND task_id < ? ` +
    78  		` ORDER BY task_id LIMIT ?`
    79  
    80  	getTaskMinQry = `SELECT task_id, data, data_encoding ` +
    81  		`FROM tasks ` +
    82  		`WHERE range_hash = ? AND task_queue_id = ? AND task_id >= ? ORDER BY task_id LIMIT ?`
    83  
    84  	createTaskQry = `INSERT INTO ` +
    85  		`tasks(range_hash, task_queue_id, task_id, data, data_encoding) ` +
    86  		`VALUES(:range_hash, :task_queue_id, :task_id, :data, :data_encoding)`
    87  
    88  	deleteTaskQry = `DELETE FROM tasks ` +
    89  		`WHERE range_hash = ? AND task_queue_id = ? AND task_id = ?`
    90  
    91  	rangeDeleteTaskQry = `DELETE FROM tasks ` +
    92  		`WHERE range_hash = ? AND task_queue_id = ? AND task_id IN (SELECT task_id FROM
    93  		 tasks WHERE range_hash = ? AND task_queue_id = ? AND task_id < ? ` +
    94  		`ORDER BY task_queue_id,task_id LIMIT ? ) `
    95  
    96  	getTaskQueueUserDataQry = `SELECT data, data_encoding, version FROM task_queue_user_data ` +
    97  		`WHERE namespace_id = ? AND task_queue_name = ?`
    98  
    99  	updateTaskQueueUserDataQry = `UPDATE task_queue_user_data SET ` +
   100  		`data = ?, ` +
   101  		`data_encoding = ?, ` +
   102  		`version = ? ` +
   103  		`WHERE namespace_id = ? ` +
   104  		`AND task_queue_name = ? ` +
   105  		`AND version = ?`
   106  
   107  	insertTaskQueueUserDataQry = `INSERT INTO task_queue_user_data` +
   108  		`(namespace_id, task_queue_name, data, data_encoding, version) ` +
   109  		`VALUES (?, ?, ?, ?, 1)`
   110  
   111  	listTaskQueueUserDataQry = `SELECT task_queue_name, data, data_encoding, version FROM task_queue_user_data WHERE namespace_id = ? AND task_queue_name > ? LIMIT ?`
   112  
   113  	addBuildIdToTaskQueueMappingQry    = `INSERT INTO build_id_to_task_queue (namespace_id, build_id, task_queue_name) VALUES `
   114  	removeBuildIdToTaskQueueMappingQry = `DELETE FROM build_id_to_task_queue WHERE namespace_id = ? AND task_queue_name = ? AND build_id IN (`
   115  	listTaskQueuesByBuildIdQry         = `SELECT task_queue_name FROM build_id_to_task_queue WHERE namespace_id = ? AND build_id = ?`
   116  	countTaskQueuesByBuildIdQry        = `SELECT COUNT(*) FROM build_id_to_task_queue WHERE namespace_id = ? AND build_id = ?`
   117  )
   118  
   119  // InsertIntoTasks inserts one or more rows into tasks table
   120  func (mdb *db) InsertIntoTasks(
   121  	ctx context.Context,
   122  	rows []sqlplugin.TasksRow,
   123  ) (sql.Result, error) {
   124  	return mdb.conn.NamedExecContext(ctx,
   125  		createTaskQry,
   126  		rows,
   127  	)
   128  }
   129  
   130  // SelectFromTasks reads one or more rows from tasks table
   131  func (mdb *db) SelectFromTasks(
   132  	ctx context.Context,
   133  	filter sqlplugin.TasksFilter,
   134  ) ([]sqlplugin.TasksRow, error) {
   135  	var err error
   136  	var rows []sqlplugin.TasksRow
   137  	switch {
   138  	case filter.ExclusiveMaxTaskID != nil:
   139  		err = mdb.conn.SelectContext(ctx,
   140  			&rows, getTaskMinMaxQry,
   141  			filter.RangeHash,
   142  			filter.TaskQueueID,
   143  			*filter.InclusiveMinTaskID,
   144  			*filter.ExclusiveMaxTaskID,
   145  			*filter.PageSize,
   146  		)
   147  	default:
   148  		err = mdb.conn.SelectContext(ctx,
   149  			&rows, getTaskMinQry,
   150  			filter.RangeHash,
   151  			filter.TaskQueueID,
   152  			*filter.ExclusiveMaxTaskID,
   153  			*filter.PageSize,
   154  		)
   155  	}
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	return rows, nil
   160  }
   161  
   162  // DeleteFromTasks deletes one or more rows from tasks table
   163  func (mdb *db) DeleteFromTasks(
   164  	ctx context.Context,
   165  	filter sqlplugin.TasksFilter,
   166  ) (sql.Result, error) {
   167  	if filter.ExclusiveMaxTaskID != nil {
   168  		if filter.Limit == nil || *filter.Limit == 0 {
   169  			return nil, fmt.Errorf("missing limit parameter")
   170  		}
   171  		return mdb.conn.ExecContext(ctx,
   172  			rangeDeleteTaskQry,
   173  			filter.RangeHash,
   174  			filter.TaskQueueID,
   175  			filter.RangeHash,
   176  			filter.TaskQueueID,
   177  			*filter.ExclusiveMaxTaskID,
   178  			*filter.Limit,
   179  		)
   180  	}
   181  	return mdb.conn.ExecContext(ctx,
   182  		deleteTaskQry,
   183  		filter.RangeHash,
   184  		filter.TaskQueueID,
   185  		*filter.TaskID,
   186  	)
   187  }
   188  
   189  // InsertIntoTaskQueues inserts one or more rows into task_queues table
   190  func (mdb *db) InsertIntoTaskQueues(
   191  	ctx context.Context,
   192  	row *sqlplugin.TaskQueuesRow,
   193  ) (sql.Result, error) {
   194  	return mdb.conn.NamedExecContext(ctx,
   195  		createTaskQueueQry,
   196  		row,
   197  	)
   198  }
   199  
   200  // UpdateTaskQueues updates a row in task_queues table
   201  func (mdb *db) UpdateTaskQueues(
   202  	ctx context.Context,
   203  	row *sqlplugin.TaskQueuesRow,
   204  ) (sql.Result, error) {
   205  	return mdb.conn.NamedExecContext(ctx,
   206  		updateTaskQueueQry,
   207  		row,
   208  	)
   209  }
   210  
   211  // SelectFromTaskQueues reads one or more rows from task_queues table
   212  func (mdb *db) SelectFromTaskQueues(
   213  	ctx context.Context,
   214  	filter sqlplugin.TaskQueuesFilter,
   215  ) ([]sqlplugin.TaskQueuesRow, error) {
   216  	switch {
   217  	case filter.TaskQueueID != nil:
   218  		if filter.RangeHashLessThanEqualTo != 0 || filter.RangeHashGreaterThanEqualTo != 0 {
   219  			return nil, serviceerror.NewInternal("range of hashes not supported for specific selection")
   220  		}
   221  		return mdb.selectFromTaskQueues(ctx, filter)
   222  	case filter.RangeHashLessThanEqualTo != 0 && filter.PageSize != nil:
   223  		if filter.RangeHashLessThanEqualTo < filter.RangeHashGreaterThanEqualTo {
   224  			return nil, serviceerror.NewInternal("range of hashes bound is invalid")
   225  		}
   226  		return mdb.rangeSelectFromTaskQueues(ctx, filter)
   227  	case filter.TaskQueueIDGreaterThan != nil && filter.PageSize != nil:
   228  		return mdb.rangeSelectFromTaskQueues(ctx, filter)
   229  	default:
   230  		return nil, serviceerror.NewInternal("invalid set of query filter params")
   231  	}
   232  }
   233  
   234  func (mdb *db) selectFromTaskQueues(
   235  	ctx context.Context,
   236  	filter sqlplugin.TaskQueuesFilter,
   237  ) ([]sqlplugin.TaskQueuesRow, error) {
   238  	var err error
   239  	var row sqlplugin.TaskQueuesRow
   240  	err = mdb.conn.GetContext(ctx,
   241  		&row,
   242  		getTaskQueueQry,
   243  		filter.RangeHash,
   244  		filter.TaskQueueID,
   245  	)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	return []sqlplugin.TaskQueuesRow{row}, nil
   250  }
   251  
   252  func (mdb *db) rangeSelectFromTaskQueues(
   253  	ctx context.Context,
   254  	filter sqlplugin.TaskQueuesFilter,
   255  ) ([]sqlplugin.TaskQueuesRow, error) {
   256  	var err error
   257  	var rows []sqlplugin.TaskQueuesRow
   258  
   259  	if filter.RangeHashLessThanEqualTo != 0 {
   260  		err = mdb.conn.SelectContext(ctx,
   261  			&rows,
   262  			listTaskQueueWithHashRangeQry,
   263  			filter.RangeHashGreaterThanEqualTo,
   264  			filter.RangeHashLessThanEqualTo,
   265  			filter.TaskQueueIDGreaterThan,
   266  			*filter.PageSize,
   267  		)
   268  	} else {
   269  		err = mdb.conn.SelectContext(ctx,
   270  			&rows,
   271  			listTaskQueueQry,
   272  			filter.RangeHash,
   273  			filter.TaskQueueIDGreaterThan,
   274  			*filter.PageSize,
   275  		)
   276  	}
   277  	if err != nil {
   278  		return nil, err
   279  	}
   280  	return rows, nil
   281  }
   282  
   283  // DeleteFromTaskQueues deletes a row from task_queues table
   284  func (mdb *db) DeleteFromTaskQueues(
   285  	ctx context.Context,
   286  	filter sqlplugin.TaskQueuesFilter,
   287  ) (sql.Result, error) {
   288  	return mdb.conn.ExecContext(ctx,
   289  		deleteTaskQueueQry,
   290  		filter.RangeHash,
   291  		filter.TaskQueueID,
   292  		*filter.RangeID,
   293  	)
   294  }
   295  
   296  // LockTaskQueues locks a row in task_queues table
   297  func (mdb *db) LockTaskQueues(
   298  	ctx context.Context,
   299  	filter sqlplugin.TaskQueuesFilter,
   300  ) (int64, error) {
   301  	var rangeID int64
   302  	err := mdb.conn.GetContext(ctx,
   303  		&rangeID,
   304  		lockTaskQueueQry,
   305  		filter.RangeHash,
   306  		filter.TaskQueueID,
   307  	)
   308  	return rangeID, err
   309  }
   310  
   311  func (mdb *db) GetTaskQueueUserData(ctx context.Context, request *sqlplugin.GetTaskQueueUserDataRequest) (*sqlplugin.VersionedBlob, error) {
   312  	var row sqlplugin.VersionedBlob
   313  	err := mdb.conn.GetContext(ctx, &row, getTaskQueueUserDataQry, request.NamespaceID, request.TaskQueueName)
   314  	return &row, err
   315  }
   316  
   317  func (mdb *db) UpdateTaskQueueUserData(ctx context.Context, request *sqlplugin.UpdateTaskQueueDataRequest) error {
   318  	if request.Version == 0 {
   319  		_, err := mdb.conn.ExecContext(
   320  			ctx,
   321  			insertTaskQueueUserDataQry,
   322  			request.NamespaceID,
   323  			request.TaskQueueName,
   324  			request.Data,
   325  			request.DataEncoding)
   326  		return err
   327  	}
   328  	result, err := mdb.conn.ExecContext(
   329  		ctx,
   330  		updateTaskQueueUserDataQry,
   331  		request.Data,
   332  		request.DataEncoding,
   333  		request.Version+1,
   334  		request.NamespaceID,
   335  		request.TaskQueueName,
   336  		request.Version)
   337  	if err != nil {
   338  		return err
   339  	}
   340  	numRows, err := result.RowsAffected()
   341  	if err != nil {
   342  		return err
   343  	}
   344  	if numRows != 1 {
   345  		return &persistence.ConditionFailedError{Msg: "Expected exactly one row to be updated"}
   346  	}
   347  	return nil
   348  }
   349  
   350  func (mdb *db) AddToBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddToBuildIdToTaskQueueMapping) error {
   351  	query := addBuildIdToTaskQueueMappingQry
   352  	var params []any
   353  	for idx, buildId := range request.BuildIds {
   354  		if idx == len(request.BuildIds)-1 {
   355  			query += "(?, ?, ?)"
   356  		} else {
   357  			query += "(?, ?, ?), "
   358  		}
   359  		params = append(params, request.NamespaceID, buildId, request.TaskQueueName)
   360  	}
   361  
   362  	_, err := mdb.conn.ExecContext(ctx, query, params...)
   363  	return err
   364  }
   365  
   366  func (mdb *db) RemoveFromBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveFromBuildIdToTaskQueueMapping) error {
   367  	query := removeBuildIdToTaskQueueMappingQry + strings.Repeat("?, ", len(request.BuildIds)-1) + "?)"
   368  	// Golang doesn't support appending a string slice to an any slice which is essentially what we're doing here.
   369  	params := make([]any, len(request.BuildIds)+2)
   370  	params[0] = request.NamespaceID
   371  	params[1] = request.TaskQueueName
   372  	for i, buildId := range request.BuildIds {
   373  		params[i+2] = buildId
   374  	}
   375  
   376  	_, err := mdb.conn.ExecContext(ctx, query, params...)
   377  	return err
   378  }
   379  
   380  func (mdb *db) ListTaskQueueUserDataEntries(ctx context.Context, request *sqlplugin.ListTaskQueueUserDataEntriesRequest) ([]sqlplugin.TaskQueueUserDataEntry, error) {
   381  	var rows []sqlplugin.TaskQueueUserDataEntry
   382  	err := mdb.conn.SelectContext(ctx, &rows, listTaskQueueUserDataQry, request.NamespaceID, request.LastTaskQueueName, request.Limit)
   383  	return rows, err
   384  }
   385  
   386  func (mdb *db) GetTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.GetTaskQueuesByBuildIdRequest) ([]string, error) {
   387  	var rows []struct {
   388  		TaskQueueName string
   389  	}
   390  
   391  	err := mdb.conn.SelectContext(ctx, &rows, listTaskQueuesByBuildIdQry, request.NamespaceID, request.BuildID)
   392  	taskQueues := make([]string, len(rows))
   393  	for i, row := range rows {
   394  		taskQueues[i] = row.TaskQueueName
   395  	}
   396  	return taskQueues, err
   397  }
   398  
   399  func (mdb *db) CountTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.CountTaskQueuesByBuildIdRequest) (int, error) {
   400  	var count int
   401  	err := mdb.conn.GetContext(ctx, &count, countTaskQueuesByBuildIdQry, request.NamespaceID, request.BuildID)
   402  	return count, err
   403  }