github.com/matrixorigin/matrixone@v1.2.0/pkg/taskservice/mysql_task_storage.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package taskservice
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"encoding/json"
    21  	"errors"
    22  	"fmt"
    23  	"os"
    24  	"strings"
    25  
    26  	"github.com/go-sql-driver/mysql"
    27  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    28  	"github.com/matrixorigin/matrixone/pkg/common/util"
    29  	"github.com/matrixorigin/matrixone/pkg/pb/task"
    30  )
    31  
    32  var (
    33  	insertAsyncTask = `insert into %s.sys_async_task(
    34                             task_metadata_id,
    35                             task_metadata_executor,
    36                             task_metadata_context,
    37                             task_metadata_option,
    38                             task_parent_id,
    39                             task_status,
    40                             task_runner,
    41                             task_epoch,
    42                             last_heartbeat,
    43                             create_at,
    44                             end_at) values `
    45  
    46  	updateAsyncTask = `update %s.sys_async_task set 
    47  							task_metadata_executor=?,
    48  							task_metadata_context=?,
    49  							task_metadata_option=?,
    50  							task_parent_id=?,
    51  							task_status=?,
    52  							task_runner=?,
    53  							task_epoch=?,
    54  							last_heartbeat=?,
    55  							result_code=?,
    56  							error_msg=?,
    57  							create_at=?,
    58  							end_at=? where task_id=?`
    59  
    60  	selectAsyncTask = `select 
    61      						task_id,
    62  							task_metadata_id,
    63  							task_metadata_executor,
    64  							task_metadata_context,
    65  							task_metadata_option,
    66  							task_parent_id,
    67  							task_status,
    68  							task_runner,
    69  							task_epoch,
    70  							last_heartbeat,
    71  							result_code,
    72  							error_msg,
    73  							create_at,
    74  							end_at 
    75  						from %s.sys_async_task where 1=1`
    76  
    77  	insertCronTask = `insert into %s.sys_cron_task (
    78                             task_metadata_id,
    79  						   task_metadata_executor,
    80                             task_metadata_context,
    81                             task_metadata_option,
    82                             cron_expr,
    83                             next_time,
    84                             trigger_times,
    85                             create_at,
    86                             update_at
    87                      ) values `
    88  
    89  	selectCronTask = `select 
    90      						cron_task_id,
    91      						task_metadata_id,
    92      						task_metadata_executor,
    93      						task_metadata_context,
    94      						task_metadata_option,
    95      						cron_expr,
    96      						next_time,
    97      						trigger_times,
    98      						create_at,
    99      						update_at
   100  						from %s.sys_cron_task where 1=1`
   101  
   102  	updateCronTask = `update %s.sys_cron_task set 
   103  							task_metadata_executor=?,
   104      						task_metadata_context=?,
   105      						task_metadata_option=?,
   106      						cron_expr=?,
   107      						next_time=?,
   108      						trigger_times=?,
   109      						create_at=?,
   110      						update_at=? where cron_task_id=?`
   111  
   112  	countTaskId = `select count(task_metadata_id) from %s.sys_async_task where task_metadata_id=?`
   113  
   114  	getTriggerTimes = `select trigger_times from %s.sys_cron_task where task_metadata_id=?`
   115  
   116  	deleteAsyncTask = `delete from %s.sys_async_task where 1=1`
   117  
   118  	insertDaemonTask = `insert into %s.sys_daemon_task (
   119                        task_metadata_id,
   120  							task_metadata_executor,
   121  							task_metadata_context,
   122                             task_metadata_option,
   123                             account_id,
   124                             account,
   125                             task_type,
   126                             task_status,
   127                             create_at,
   128                             update_at,
   129                             details
   130                      ) values `
   131  
   132  	updateDaemonTask = `update %s.sys_daemon_task set
   133  							task_metadata_executor=?,
   134  							task_metadata_context=?,
   135  							task_metadata_option=?,
   136  							task_type=?,
   137  							task_status=?,
   138  							task_runner=?,
   139  							last_heartbeat=?,
   140  							update_at=?,
   141  							end_at=?,
   142                              last_run=?,
   143                              details=? where task_id=?`
   144  
   145  	heartbeatDaemonTask = `update %s.sys_daemon_task set
   146  							last_heartbeat=? where task_id=?`
   147  
   148  	deleteDaemonTask = `delete from %s.sys_daemon_task where 1=1`
   149  
   150  	selectDaemonTask = `select
   151  							task_id,
   152  							task_metadata_id,
   153  							task_metadata_executor,
   154  							task_metadata_context,
   155  							task_metadata_option,
   156  							account_id,
   157  							account,
   158  							task_type,
   159  							task_runner,
   160  							task_status,
   161  							last_heartbeat,
   162  							create_at,
   163  							update_at,
   164  							end_at,
   165  							last_run,
   166  							details
   167  						from %s.sys_daemon_task where 1=1`
   168  )
   169  
   170  var (
   171  	forceNewConn = "async_task_force_new_connection"
   172  )
   173  
   174  type mysqlTaskStorage struct {
   175  	dsn          string
   176  	dbname       string
   177  	db           *sql.DB
   178  	forceNewConn bool
   179  }
   180  
   181  func NewMysqlTaskStorage(dsn, dbname string) (TaskStorage, error) {
   182  	db, err := sql.Open("mysql", dsn)
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  
   187  	db.SetMaxOpenConns(5)
   188  	db.SetMaxIdleConns(1)
   189  
   190  	_, ok := os.LookupEnv(forceNewConn)
   191  	return &mysqlTaskStorage{
   192  		dsn:          dsn,
   193  		db:           db,
   194  		dbname:       dbname,
   195  		forceNewConn: ok,
   196  	}, nil
   197  }
   198  
   199  func (m *mysqlTaskStorage) Close() error {
   200  	return m.db.Close()
   201  }
   202  
   203  func (m *mysqlTaskStorage) AddAsyncTask(ctx context.Context, tasks ...task.AsyncTask) (int, error) {
   204  	if taskFrameworkDisabled() {
   205  		return 0, nil
   206  	}
   207  
   208  	if len(tasks) == 0 {
   209  		return 0, nil
   210  	}
   211  
   212  	db, release, err := m.getDB()
   213  	if err != nil {
   214  		return 0, err
   215  	}
   216  	defer func() {
   217  		_ = release()
   218  	}()
   219  
   220  	conn, err := db.Conn(ctx)
   221  	if err != nil {
   222  		return 0, err
   223  	}
   224  	defer func() {
   225  		_ = conn.Close()
   226  	}()
   227  
   228  	sqlStr := fmt.Sprintf(insertAsyncTask, m.dbname)
   229  	vals := make([]any, 0, len(tasks)*13)
   230  
   231  	for _, t := range tasks {
   232  		j, err := json.Marshal(t.Metadata.Options)
   233  		if err != nil {
   234  			return 0, err
   235  		}
   236  
   237  		sqlStr += "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),"
   238  		vals = append(vals, t.Metadata.ID,
   239  			t.Metadata.Executor,
   240  			t.Metadata.Context,
   241  			string(j),
   242  			t.ParentTaskID,
   243  			t.Status,
   244  			t.TaskRunner,
   245  			t.Epoch,
   246  			t.LastHeartbeat,
   247  			t.CreateAt,
   248  			t.CompletedAt,
   249  		)
   250  	}
   251  
   252  	if sqlStr == fmt.Sprintf(insertAsyncTask, m.dbname) {
   253  		return 0, nil
   254  	}
   255  	sqlStr = sqlStr[0 : len(sqlStr)-1]
   256  	stmt, err := conn.PrepareContext(ctx, sqlStr)
   257  	if err != nil {
   258  		return 0, err
   259  	}
   260  	defer stmt.Close()
   261  	exec, err := stmt.ExecContext(ctx, vals...)
   262  	if err != nil {
   263  		dup, err := removeDuplicateAsyncTasks(err, tasks)
   264  		if err != nil {
   265  			return 0, err
   266  		}
   267  		add, err := m.AddAsyncTask(ctx, dup...)
   268  		if err != nil {
   269  			return add, err
   270  		}
   271  		return add, nil
   272  	}
   273  	affected, err := exec.RowsAffected()
   274  	if err != nil {
   275  		return 0, err
   276  	}
   277  
   278  	return int(affected), nil
   279  }
   280  
   281  func (m *mysqlTaskStorage) UpdateAsyncTask(ctx context.Context, tasks []task.AsyncTask, condition ...Condition) (int, error) {
   282  	if taskFrameworkDisabled() {
   283  		return 0, nil
   284  	}
   285  
   286  	if len(tasks) == 0 {
   287  		return 0, nil
   288  	}
   289  
   290  	db, release, err := m.getDB()
   291  	if err != nil {
   292  		return 0, err
   293  	}
   294  	defer func() {
   295  		_ = release()
   296  	}()
   297  
   298  	conn, err := db.Conn(ctx)
   299  	if err != nil {
   300  		return 0, err
   301  	}
   302  	defer func() {
   303  		_ = conn.Close()
   304  	}()
   305  
   306  	tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
   307  	if err != nil {
   308  		return 0, err
   309  	}
   310  
   311  	c := newConditions(condition...)
   312  	updateSql := fmt.Sprintf(updateAsyncTask, m.dbname) + buildWhereClause(c)
   313  	n := 0
   314  	for _, t := range tasks {
   315  		err := func() error {
   316  			execResult := &task.ExecuteResult{
   317  				Code: task.ResultCode_Success,
   318  			}
   319  			if t.ExecuteResult != nil {
   320  				execResult.Code = t.ExecuteResult.Code
   321  				execResult.Error = t.ExecuteResult.Error
   322  			}
   323  
   324  			j, err := json.Marshal(t.Metadata.Options)
   325  			if err != nil {
   326  				return err
   327  			}
   328  
   329  			prepare, err := tx.PrepareContext(ctx, updateSql)
   330  			if err != nil {
   331  				return err
   332  			}
   333  			defer prepare.Close()
   334  
   335  			exec, err := prepare.ExecContext(ctx,
   336  				t.Metadata.Executor,
   337  				t.Metadata.Context,
   338  				string(j),
   339  				t.ParentTaskID,
   340  				t.Status,
   341  				t.TaskRunner,
   342  				t.Epoch,
   343  				t.LastHeartbeat,
   344  				execResult.Code,
   345  				execResult.Error,
   346  				t.CreateAt,
   347  				t.CompletedAt,
   348  				t.ID,
   349  			)
   350  			if err != nil {
   351  				return err
   352  			}
   353  			affected, err := exec.RowsAffected()
   354  			if err != nil {
   355  				return nil
   356  			}
   357  			n += int(affected)
   358  			return nil
   359  		}()
   360  		if err != nil {
   361  			if e := tx.Rollback(); e != nil {
   362  				return 0, errors.Join(e, err)
   363  			}
   364  			return 0, err
   365  		}
   366  	}
   367  	if err = tx.Commit(); err != nil {
   368  		return 0, err
   369  	}
   370  	return n, nil
   371  }
   372  
   373  func (m *mysqlTaskStorage) DeleteAsyncTask(ctx context.Context, condition ...Condition) (int, error) {
   374  	if taskFrameworkDisabled() {
   375  		return 0, nil
   376  	}
   377  
   378  	db, release, err := m.getDB()
   379  	if err != nil {
   380  		return 0, err
   381  	}
   382  	defer func() {
   383  		_ = release()
   384  	}()
   385  
   386  	conn, err := db.Conn(ctx)
   387  	if err != nil {
   388  		return 0, err
   389  	}
   390  	defer func() {
   391  		_ = conn.Close()
   392  	}()
   393  
   394  	c := newConditions(condition...)
   395  	deleteSql := fmt.Sprintf(deleteAsyncTask, m.dbname) + buildWhereClause(c)
   396  	exec, err := conn.ExecContext(ctx, deleteSql)
   397  	if err != nil {
   398  		return 0, err
   399  	}
   400  	affected, err := exec.RowsAffected()
   401  	if err != nil {
   402  		panic(err)
   403  	}
   404  	return int(affected), nil
   405  }
   406  
   407  func (m *mysqlTaskStorage) QueryAsyncTask(ctx context.Context, condition ...Condition) ([]task.AsyncTask, error) {
   408  	if taskFrameworkDisabled() {
   409  		return nil, nil
   410  	}
   411  
   412  	db, release, err := m.getDB()
   413  	if err != nil {
   414  		return nil, err
   415  	}
   416  	defer func() {
   417  		_ = release()
   418  	}()
   419  
   420  	conn, err := db.Conn(ctx)
   421  	if err != nil {
   422  		return nil, err
   423  	}
   424  	defer func() {
   425  		_ = conn.Close()
   426  	}()
   427  
   428  	c := newConditions(condition...)
   429  
   430  	query := fmt.Sprintf(selectAsyncTask, m.dbname) + buildWhereClause(c)
   431  	query += buildOrderByClause(c)
   432  	query += buildLimitClause(c)
   433  
   434  	rows, err := conn.QueryContext(ctx, query)
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  	defer func() {
   439  		_ = rows.Close()
   440  	}()
   441  
   442  	tasks := make([]task.AsyncTask, 0)
   443  	for rows.Next() {
   444  		var t task.AsyncTask
   445  		var codeOption sql.NullInt32
   446  		var msgOption sql.NullString
   447  		var options string
   448  		if err := rows.Scan(
   449  			&t.ID,
   450  			&t.Metadata.ID,
   451  			&t.Metadata.Executor,
   452  			&t.Metadata.Context,
   453  			&options,
   454  			&t.ParentTaskID,
   455  			&t.Status,
   456  			&t.TaskRunner,
   457  			&t.Epoch,
   458  			&t.LastHeartbeat,
   459  			&codeOption,
   460  			&msgOption,
   461  			&t.CreateAt,
   462  			&t.CompletedAt,
   463  		); err != nil {
   464  			return nil, err
   465  		}
   466  		if err := json.Unmarshal([]byte(options), &t.Metadata.Options); err != nil {
   467  			return nil, err
   468  		}
   469  
   470  		if codeOption.Valid {
   471  			t.ExecuteResult = &task.ExecuteResult{}
   472  			code, err := codeOption.Value()
   473  			if err != nil {
   474  				return nil, err
   475  			}
   476  			t.ExecuteResult.Code = task.ResultCode(code.(int64))
   477  
   478  			msg, err := msgOption.Value()
   479  			if err != nil {
   480  				return nil, err
   481  			}
   482  			t.ExecuteResult.Error = msg.(string)
   483  		}
   484  
   485  		tasks = append(tasks, t)
   486  	}
   487  	if err := rows.Err(); err != nil {
   488  		return tasks, err
   489  	}
   490  	return tasks, nil
   491  }
   492  
   493  func (m *mysqlTaskStorage) AddCronTask(ctx context.Context, cronTask ...task.CronTask) (int, error) {
   494  	if taskFrameworkDisabled() {
   495  		return 0, nil
   496  	}
   497  
   498  	if len(cronTask) == 0 {
   499  		return 0, nil
   500  	}
   501  
   502  	db, release, err := m.getDB()
   503  	if err != nil {
   504  		return 0, err
   505  	}
   506  	defer func() {
   507  		_ = release()
   508  	}()
   509  
   510  	conn, err := db.Conn(ctx)
   511  	if err != nil {
   512  		return 0, err
   513  	}
   514  	defer func() {
   515  		_ = conn.Close()
   516  	}()
   517  
   518  	sqlStr := fmt.Sprintf(insertCronTask, m.dbname)
   519  	vals := make([]any, 0)
   520  	for _, t := range cronTask {
   521  		sqlStr += "(?, ?, ?, ?, ?, ?, ?, ?, ?),"
   522  
   523  		j, err := json.Marshal(t.Metadata.Options)
   524  		if err != nil {
   525  			return 0, err
   526  		}
   527  
   528  		vals = append(vals,
   529  			t.Metadata.ID,
   530  			t.Metadata.Executor,
   531  			t.Metadata.Context,
   532  			string(j),
   533  			t.CronExpr,
   534  			t.NextTime,
   535  			t.TriggerTimes,
   536  			t.CreateAt,
   537  			t.UpdateAt,
   538  		)
   539  	}
   540  	sqlStr = sqlStr[0 : len(sqlStr)-1]
   541  	stmt, err := conn.PrepareContext(ctx, sqlStr)
   542  	if err != nil {
   543  		return 0, err
   544  	}
   545  	defer stmt.Close()
   546  	exec, err := stmt.Exec(vals...)
   547  	if err != nil {
   548  		dup, err := removeDuplicateCronTasks(err, cronTask)
   549  		if err != nil {
   550  			return 0, err
   551  		}
   552  		add, err := m.AddCronTask(ctx, dup...)
   553  		if err != nil {
   554  			return add, err
   555  		}
   556  		return add, nil
   557  	}
   558  	affected, err := exec.RowsAffected()
   559  	if err != nil {
   560  		return 0, err
   561  	}
   562  	return int(affected), nil
   563  }
   564  
   565  func (m *mysqlTaskStorage) QueryCronTask(ctx context.Context, condition ...Condition) (tasks []task.CronTask, err error) {
   566  	if taskFrameworkDisabled() {
   567  		return nil, nil
   568  	}
   569  
   570  	db, release, err := m.getDB()
   571  	if err != nil {
   572  		return nil, err
   573  	}
   574  	defer func() {
   575  		_ = release()
   576  	}()
   577  
   578  	conn, err := db.Conn(ctx)
   579  	if err != nil {
   580  		return nil, err
   581  	}
   582  	defer func() {
   583  		_ = conn.Close()
   584  	}()
   585  
   586  	c := newConditions(condition...)
   587  	query := fmt.Sprintf(selectCronTask, m.dbname) + buildWhereClause(c)
   588  	rows, err := conn.QueryContext(ctx, query)
   589  	defer func(rows *sql.Rows) {
   590  		if rows == nil {
   591  			return
   592  		}
   593  		_ = rows.Close()
   594  	}(rows)
   595  	if err != nil {
   596  		return nil, err
   597  	}
   598  	defer func() {
   599  		err = errors.Join(err, rows.Close(), rows.Err())
   600  	}()
   601  
   602  	tasks = make([]task.CronTask, 0)
   603  
   604  	for rows.Next() {
   605  		var t task.CronTask
   606  		var options string
   607  		err := rows.Scan(
   608  			&t.ID,
   609  			&t.Metadata.ID,
   610  			&t.Metadata.Executor,
   611  			&t.Metadata.Context,
   612  			&options,
   613  			&t.CronExpr,
   614  			&t.NextTime,
   615  			&t.TriggerTimes,
   616  			&t.CreateAt,
   617  			&t.UpdateAt,
   618  		)
   619  		if err != nil {
   620  			return nil, err
   621  		}
   622  		if err := json.Unmarshal([]byte(options), &t.Metadata.Options); err != nil {
   623  			return nil, err
   624  		}
   625  
   626  		tasks = append(tasks, t)
   627  	}
   628  
   629  	return tasks, nil
   630  }
   631  
   632  func (m *mysqlTaskStorage) UpdateCronTask(ctx context.Context, cronTask task.CronTask, asyncTask task.AsyncTask) (int, error) {
   633  	if taskFrameworkDisabled() {
   634  		return 0, nil
   635  	}
   636  
   637  	db, release, err := m.getDB()
   638  	if err != nil {
   639  		return 0, err
   640  	}
   641  	defer func() {
   642  		_ = release()
   643  	}()
   644  
   645  	conn, err := db.Conn(ctx)
   646  	if err != nil {
   647  		return 0, err
   648  	}
   649  	defer func() {
   650  		_ = conn.Close()
   651  	}()
   652  
   653  	ok, err := m.taskExists(ctx, conn, asyncTask.Metadata.ID)
   654  	if err != nil || ok {
   655  		return 0, err
   656  	}
   657  
   658  	triggerTimes, err := m.getTriggerTimes(ctx, conn, cronTask.Metadata.ID)
   659  	if errors.Is(err, sql.ErrNoRows) || triggerTimes != cronTask.TriggerTimes-1 {
   660  		return 0, moerr.NewInternalError(ctx, "cron task trigger times not match")
   661  	}
   662  
   663  	tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
   664  	if err != nil {
   665  		return 0, err
   666  	}
   667  	defer func(tx *sql.Tx) {
   668  		_ = tx.Rollback()
   669  	}(tx)
   670  
   671  	preInsert, err := tx.Prepare(fmt.Sprintf(insertAsyncTask, m.dbname) + "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
   672  	if err != nil {
   673  		return 0, err
   674  	}
   675  	defer preInsert.Close()
   676  
   677  	preUpdate, err := tx.Prepare(fmt.Sprintf(updateCronTask, m.dbname))
   678  	if err != nil {
   679  		return 0, err
   680  	}
   681  	defer preUpdate.Close()
   682  
   683  	j, err := json.Marshal(asyncTask.Metadata.Options)
   684  	if err != nil {
   685  		return 0, err
   686  	}
   687  	inserted, err := preInsert.Exec(
   688  		asyncTask.Metadata.ID,
   689  		asyncTask.Metadata.Executor,
   690  		asyncTask.Metadata.Context,
   691  		util.UnsafeBytesToString(j),
   692  		asyncTask.ParentTaskID,
   693  		asyncTask.Status,
   694  		asyncTask.TaskRunner,
   695  		asyncTask.Epoch,
   696  		asyncTask.LastHeartbeat,
   697  		asyncTask.CreateAt,
   698  		asyncTask.CompletedAt)
   699  	if err != nil {
   700  		return 0, err
   701  	}
   702  
   703  	updated, err := preUpdate.Exec(
   704  		cronTask.Metadata.Executor,
   705  		cronTask.Metadata.Context,
   706  		util.UnsafeBytesToString(j),
   707  		cronTask.CronExpr,
   708  		cronTask.NextTime,
   709  		cronTask.TriggerTimes,
   710  		cronTask.CreateAt,
   711  		cronTask.UpdateAt,
   712  		cronTask.ID)
   713  	if err != nil {
   714  		return 0, err
   715  	}
   716  	if err := tx.Commit(); err != nil {
   717  		return 0, err
   718  	}
   719  	affected1, err := updated.RowsAffected()
   720  	if err != nil {
   721  		return 0, err
   722  	}
   723  	affected2, err := inserted.RowsAffected()
   724  	if err != nil {
   725  		return 0, err
   726  	}
   727  
   728  	return int(affected2) + int(affected1), nil
   729  }
   730  
   731  func (m *mysqlTaskStorage) taskExists(ctx context.Context, conn *sql.Conn, taskMetadataID string) (bool, error) {
   732  	var count int32
   733  	if err := conn.QueryRowContext(ctx, fmt.Sprintf(countTaskId, m.dbname), taskMetadataID).Scan(&count); err != nil {
   734  		return false, err
   735  	}
   736  	return count != 0, nil
   737  }
   738  
   739  func (m *mysqlTaskStorage) getTriggerTimes(ctx context.Context, conn *sql.Conn, taskMetadataID string) (uint64, error) {
   740  	var triggerTimes uint64
   741  	err := conn.QueryRowContext(ctx, fmt.Sprintf(getTriggerTimes, m.dbname), taskMetadataID).Scan(&triggerTimes)
   742  	if err != nil {
   743  		return 0, err
   744  	}
   745  	return triggerTimes, nil
   746  }
   747  
   748  func buildWhereClause(c *conditions) string {
   749  	var clauseBuilder strings.Builder
   750  
   751  	for code := range whereConditionCodes {
   752  		if cond, ok := (*c)[code]; ok {
   753  			clauseBuilder.WriteString(" AND ")
   754  			clauseBuilder.WriteString(cond.sql())
   755  		}
   756  	}
   757  
   758  	return clauseBuilder.String()
   759  }
   760  
   761  func (m *mysqlTaskStorage) getDB() (*sql.DB, func() error, error) {
   762  	if !m.forceNewConn {
   763  		if err := m.useDB(m.db); err != nil {
   764  			return nil, nil, err
   765  		}
   766  		return m.db, func() error { return nil }, nil
   767  	}
   768  
   769  	db, err := sql.Open("mysql", m.dsn)
   770  	if err != nil {
   771  		return nil, nil, err
   772  	}
   773  
   774  	if err = m.useDB(db); err != nil {
   775  		return nil, nil, errors.Join(err, db.Close())
   776  	}
   777  
   778  	return db, db.Close, nil
   779  }
   780  
   781  func (m *mysqlTaskStorage) useDB(db *sql.DB) (err error) {
   782  	if err := db.Ping(); err != nil {
   783  		return errors.Join(err, ErrNotReady)
   784  	}
   785  	if _, err := db.Exec("use " + m.dbname); err != nil {
   786  		return errors.Join(err, db.Close())
   787  	}
   788  	if _, err := db.Exec("set session disable_txn_trace=1"); err != nil {
   789  		return errors.Join(err, db.Close())
   790  	}
   791  	return nil
   792  }
   793  
   794  func buildLimitClause(c *conditions) string {
   795  	if cond, ok := (*c)[CondLimit]; ok {
   796  		return cond.sql()
   797  	}
   798  	return ""
   799  }
   800  
   801  func buildOrderByClause(c *conditions) string {
   802  	if cond, ok := (*c)[CondOrderByDesc]; ok {
   803  		return cond.sql()
   804  	}
   805  	return " order by task_id"
   806  }
   807  
   808  func removeDuplicateAsyncTasks(err error, tasks []task.AsyncTask) ([]task.AsyncTask, error) {
   809  	var me *mysql.MySQLError
   810  	if ok := errors.As(err, &me); !ok {
   811  		return nil, err
   812  	}
   813  	if me.Number != moerr.ER_DUP_ENTRY {
   814  		return nil, err
   815  	}
   816  	b := tasks[:0]
   817  	for _, t := range tasks {
   818  		if !strings.Contains(me.Message, t.Metadata.ID) {
   819  			b = append(b, t)
   820  		}
   821  	}
   822  	return b, nil
   823  }
   824  
   825  func removeDuplicateCronTasks(err error, tasks []task.CronTask) ([]task.CronTask, error) {
   826  	var me *mysql.MySQLError
   827  	if ok := errors.As(err, &me); !ok {
   828  		return nil, err
   829  	}
   830  	if me.Number != moerr.ER_DUP_ENTRY {
   831  		return nil, err
   832  	}
   833  	b := tasks[:0]
   834  	for _, t := range tasks {
   835  		if !strings.Contains(me.Message, t.Metadata.ID) {
   836  			b = append(b, t)
   837  		}
   838  	}
   839  	return b, nil
   840  }
   841  
   842  func removeDuplicateDaemonTasks(err error, tasks []task.DaemonTask) ([]task.DaemonTask, error) {
   843  	var me *mysql.MySQLError
   844  	if ok := errors.As(err, &me); !ok {
   845  		return nil, err
   846  	}
   847  	if me.Number != moerr.ER_DUP_ENTRY {
   848  		return nil, err
   849  	}
   850  	b := tasks[:0]
   851  	for _, t := range tasks {
   852  		if !strings.Contains(me.Message, t.Metadata.ID) {
   853  			b = append(b, t)
   854  		}
   855  	}
   856  	return b, nil
   857  }
   858  
   859  func (m *mysqlTaskStorage) AddDaemonTask(ctx context.Context, tasks ...task.DaemonTask) (int, error) {
   860  	if taskFrameworkDisabled() {
   861  		return 0, nil
   862  	}
   863  	if len(tasks) == 0 {
   864  		return 0, nil
   865  	}
   866  	db, release, err := m.getDB()
   867  	if err != nil {
   868  		return 0, err
   869  	}
   870  	defer func() {
   871  		_ = release()
   872  	}()
   873  	conn, err := db.Conn(ctx)
   874  	if err != nil {
   875  		return 0, err
   876  	}
   877  	defer func() {
   878  		_ = conn.Close()
   879  	}()
   880  	sqlStr := fmt.Sprintf(insertDaemonTask, m.dbname)
   881  	values := make([]any, 0)
   882  	for _, t := range tasks {
   883  		sqlStr += "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),"
   884  
   885  		j, err := json.Marshal(t.Metadata.Options)
   886  		if err != nil {
   887  			return 0, err
   888  		}
   889  		details, err := t.Details.Marshal()
   890  		if err != nil {
   891  			return 0, err
   892  		}
   893  
   894  		values = append(values,
   895  			t.Metadata.ID,
   896  			t.Metadata.Executor,
   897  			t.Metadata.Context,
   898  			string(j),
   899  			t.Details.AccountID,
   900  			t.Details.Account,
   901  			t.TaskType.String(),
   902  			t.TaskStatus,
   903  			t.CreateAt,
   904  			t.UpdateAt,
   905  			details,
   906  		)
   907  	}
   908  	if sqlStr == fmt.Sprintf(insertDaemonTask, m.dbname) {
   909  		return 0, nil
   910  	}
   911  	sqlStr = sqlStr[0 : len(sqlStr)-1]
   912  	stmt, err := conn.PrepareContext(ctx, sqlStr)
   913  	if err != nil {
   914  		return 0, err
   915  	}
   916  	defer stmt.Close()
   917  	exec, err := stmt.Exec(values...)
   918  	if err != nil {
   919  		dup, err := removeDuplicateDaemonTasks(err, tasks)
   920  		if err != nil {
   921  			return 0, err
   922  		}
   923  		add, err := m.AddDaemonTask(ctx, dup...)
   924  		if err != nil {
   925  			return add, err
   926  		}
   927  		return add, nil
   928  	}
   929  	affected, err := exec.RowsAffected()
   930  	if err != nil {
   931  		return 0, err
   932  	}
   933  	return int(affected), nil
   934  }
   935  
   936  func (m *mysqlTaskStorage) UpdateDaemonTask(ctx context.Context, tasks []task.DaemonTask, condition ...Condition) (int, error) {
   937  	if taskFrameworkDisabled() {
   938  		return 0, nil
   939  	}
   940  
   941  	if len(tasks) == 0 {
   942  		return 0, nil
   943  	}
   944  
   945  	db, release, err := m.getDB()
   946  	if err != nil {
   947  		return 0, err
   948  	}
   949  	defer func() {
   950  		_ = release()
   951  	}()
   952  
   953  	conn, err := db.Conn(ctx)
   954  	if err != nil {
   955  		return 0, err
   956  	}
   957  	defer func() {
   958  		_ = conn.Close()
   959  	}()
   960  
   961  	tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
   962  	if err != nil {
   963  		return 0, err
   964  	}
   965  
   966  	c := newConditions(condition...)
   967  	updateSql := fmt.Sprintf(updateDaemonTask, m.dbname) + buildDaemonTaskWhereClause(c)
   968  	n := 0
   969  	for _, t := range tasks {
   970  		err := func() error {
   971  			j, err := json.Marshal(t.Metadata.Options)
   972  			if err != nil {
   973  				return err
   974  			}
   975  			details, err := t.Details.Marshal()
   976  			if err != nil {
   977  				return err
   978  			}
   979  
   980  			prepare, err := tx.PrepareContext(ctx, updateSql)
   981  			if err != nil {
   982  				return err
   983  			}
   984  			defer prepare.Close()
   985  
   986  			var lastHeartbeat, updateAt, endAt, lastRun any
   987  			if !t.LastHeartbeat.IsZero() {
   988  				lastHeartbeat = t.LastHeartbeat
   989  			}
   990  			if !t.UpdateAt.IsZero() {
   991  				updateAt = t.UpdateAt
   992  			}
   993  			if !t.EndAt.IsZero() {
   994  				endAt = t.EndAt
   995  			}
   996  			if !t.LastRun.IsZero() {
   997  				lastRun = t.LastRun
   998  			}
   999  
  1000  			exec, err := prepare.ExecContext(ctx,
  1001  				t.Metadata.Executor,
  1002  				t.Metadata.Context,
  1003  				string(j),
  1004  				t.TaskType.String(),
  1005  				t.TaskStatus,
  1006  				t.TaskRunner,
  1007  				lastHeartbeat,
  1008  				updateAt,
  1009  				endAt,
  1010  				lastRun,
  1011  				details,
  1012  				t.ID,
  1013  			)
  1014  			if err != nil {
  1015  				return err
  1016  			}
  1017  			affected, err := exec.RowsAffected()
  1018  			if err != nil {
  1019  				return nil
  1020  			}
  1021  			n += int(affected)
  1022  			return nil
  1023  		}()
  1024  		if err != nil {
  1025  			if e := tx.Rollback(); e != nil {
  1026  				return 0, errors.Join(e, err)
  1027  			}
  1028  			return 0, err
  1029  		}
  1030  	}
  1031  	if err = tx.Commit(); err != nil {
  1032  		return 0, err
  1033  	}
  1034  	return n, nil
  1035  }
  1036  
  1037  func (m *mysqlTaskStorage) DeleteDaemonTask(ctx context.Context, condition ...Condition) (int, error) {
  1038  	if taskFrameworkDisabled() {
  1039  		return 0, nil
  1040  	}
  1041  
  1042  	db, release, err := m.getDB()
  1043  	if err != nil {
  1044  		return 0, err
  1045  	}
  1046  	defer func() {
  1047  		_ = release()
  1048  	}()
  1049  
  1050  	conn, err := db.Conn(ctx)
  1051  	if err != nil {
  1052  		return 0, err
  1053  	}
  1054  	defer func() {
  1055  		_ = conn.Close()
  1056  	}()
  1057  
  1058  	c := newConditions(condition...)
  1059  	deleteSql := fmt.Sprintf(deleteDaemonTask, m.dbname) + buildDaemonTaskWhereClause(c)
  1060  	exec, err := conn.ExecContext(ctx, deleteSql)
  1061  	if err != nil {
  1062  		return 0, err
  1063  	}
  1064  	affected, err := exec.RowsAffected()
  1065  	if err != nil {
  1066  		panic(err)
  1067  	}
  1068  	return int(affected), nil
  1069  }
  1070  
  1071  func (m *mysqlTaskStorage) QueryDaemonTask(ctx context.Context, condition ...Condition) ([]task.DaemonTask, error) {
  1072  	if taskFrameworkDisabled() {
  1073  		return nil, nil
  1074  	}
  1075  
  1076  	db, release, err := m.getDB()
  1077  	if err != nil {
  1078  		return nil, err
  1079  	}
  1080  	defer func() {
  1081  		_ = release()
  1082  	}()
  1083  
  1084  	conn, err := db.Conn(ctx)
  1085  	if err != nil {
  1086  		return nil, err
  1087  	}
  1088  	defer func() {
  1089  		_ = conn.Close()
  1090  	}()
  1091  
  1092  	c := newConditions(condition...)
  1093  	query := fmt.Sprintf(selectDaemonTask, m.dbname) + buildDaemonTaskWhereClause(c)
  1094  	query += buildOrderByClause(c)
  1095  	query += buildLimitClause(c)
  1096  
  1097  	rows, err := conn.QueryContext(ctx, query)
  1098  	if err != nil {
  1099  		return nil, err
  1100  	}
  1101  	defer func() {
  1102  		_ = rows.Close()
  1103  	}()
  1104  
  1105  	tasks := make([]task.DaemonTask, 0)
  1106  	for rows.Next() {
  1107  		var t task.DaemonTask
  1108  		var options, taskType string
  1109  		var runner sql.NullString
  1110  		var lastHeartbeat, createAt, updateAt, endAt, lastRun sql.NullTime
  1111  		if err := rows.Scan(
  1112  			&t.ID,
  1113  			&t.Metadata.ID,
  1114  			&t.Metadata.Executor,
  1115  			&t.Metadata.Context,
  1116  			&options,
  1117  			&t.AccountID,
  1118  			&t.Account,
  1119  			&taskType,
  1120  			&runner,
  1121  			&t.TaskStatus,
  1122  			&lastHeartbeat,
  1123  			&createAt,
  1124  			&updateAt,
  1125  			&endAt,
  1126  			&lastRun,
  1127  			&t.Details,
  1128  		); err != nil {
  1129  			return nil, err
  1130  		}
  1131  		if err := json.Unmarshal([]byte(options), &t.Metadata.Options); err != nil {
  1132  			return nil, err
  1133  		}
  1134  
  1135  		typ, ok := task.TaskType_value[taskType]
  1136  		if !ok {
  1137  			typ = int32(task.TaskType_TypeUnknown)
  1138  		}
  1139  		t.TaskType = task.TaskType(typ)
  1140  
  1141  		t.TaskRunner = runner.String
  1142  		t.LastHeartbeat = lastHeartbeat.Time
  1143  		t.CreateAt = createAt.Time
  1144  		t.UpdateAt = updateAt.Time
  1145  		t.EndAt = endAt.Time
  1146  		t.LastRun = lastRun.Time
  1147  
  1148  		tasks = append(tasks, t)
  1149  	}
  1150  	if err := rows.Err(); err != nil {
  1151  		return tasks, err
  1152  	}
  1153  	return tasks, nil
  1154  }
  1155  
  1156  func (m *mysqlTaskStorage) HeartbeatDaemonTask(ctx context.Context, tasks []task.DaemonTask) (int, error) {
  1157  	if taskFrameworkDisabled() {
  1158  		return 0, nil
  1159  	}
  1160  
  1161  	if len(tasks) == 0 {
  1162  		return 0, nil
  1163  	}
  1164  
  1165  	db, release, err := m.getDB()
  1166  	if err != nil {
  1167  		return 0, err
  1168  	}
  1169  	defer func() {
  1170  		_ = release()
  1171  	}()
  1172  
  1173  	conn, err := db.Conn(ctx)
  1174  	if err != nil {
  1175  		return 0, err
  1176  	}
  1177  	defer func() {
  1178  		_ = conn.Close()
  1179  	}()
  1180  
  1181  	tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
  1182  	if err != nil {
  1183  		return 0, err
  1184  	}
  1185  
  1186  	update := fmt.Sprintf(heartbeatDaemonTask, m.dbname)
  1187  	n := 0
  1188  	for _, t := range tasks {
  1189  		err := func() error {
  1190  			prepare, err := tx.PrepareContext(ctx, update)
  1191  			if err != nil {
  1192  				return err
  1193  			}
  1194  			defer prepare.Close()
  1195  
  1196  			var lastHeartbeat any
  1197  			if !t.LastHeartbeat.IsZero() {
  1198  				lastHeartbeat = t.LastHeartbeat
  1199  			}
  1200  
  1201  			exec, err := prepare.ExecContext(ctx,
  1202  				lastHeartbeat,
  1203  				t.ID,
  1204  			)
  1205  			if err != nil {
  1206  				return err
  1207  			}
  1208  			affected, err := exec.RowsAffected()
  1209  			if err != nil {
  1210  				return nil
  1211  			}
  1212  			n += int(affected)
  1213  			return nil
  1214  		}()
  1215  		if err != nil {
  1216  			if e := tx.Rollback(); e != nil {
  1217  				return 0, errors.Join(e, err)
  1218  			}
  1219  			return 0, err
  1220  		}
  1221  	}
  1222  	if err = tx.Commit(); err != nil {
  1223  		return 0, err
  1224  	}
  1225  	return n, nil
  1226  }
  1227  
  1228  func buildDaemonTaskWhereClause(c *conditions) string {
  1229  	var clauseBuilder strings.Builder
  1230  
  1231  	for cond := range daemonWhereConditionCodes {
  1232  		if cond, ok := (*c)[cond]; ok {
  1233  			clauseBuilder.WriteString(" AND ")
  1234  			clauseBuilder.WriteString(cond.sql())
  1235  		}
  1236  	}
  1237  
  1238  	return clauseBuilder.String()
  1239  }