github.com/matrixorigin/matrixone@v0.7.0/pkg/taskservice/mysql_task_storage.go (about) 1 // Copyright 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package taskservice 16 17 import ( 18 "context" 19 "database/sql" 20 "encoding/json" 21 "fmt" 22 "os" 23 "strings" 24 25 "github.com/go-sql-driver/mysql" 26 "github.com/matrixorigin/matrixone/pkg/common/moerr" 27 "github.com/matrixorigin/matrixone/pkg/pb/task" 28 "go.uber.org/multierr" 29 ) 30 31 var ( 32 createDatabase = `create database if not exists %s` 33 createTables = map[string]string{ 34 "sys_async_task": `create table if not exists %s.sys_async_task ( 35 task_id int primary key auto_increment, 36 task_metadata_id varchar(50) unique not null, 37 task_metadata_executor int, 38 task_metadata_context blob, 39 task_metadata_option varchar(1000), 40 task_parent_id varchar(50), 41 task_status int, 42 task_runner varchar(50), 43 task_epoch int, 44 last_heartbeat bigint, 45 result_code int null, 46 error_msg varchar(1000) null, 47 create_at bigint, 48 end_at bigint)`, 49 "sys_cron_task": `create table if not exists %s.sys_cron_task ( 50 cron_task_id int primary key auto_increment, 51 task_metadata_id varchar(50) unique not null, 52 task_metadata_executor int, 53 task_metadata_context blob, 54 task_metadata_option varchar(1000), 55 cron_expr varchar(100) not null, 56 next_time bigint, 57 trigger_times int, 58 create_at bigint, 59 update_at bigint)`, 60 } 61 62 insertAsyncTask = `insert into %s.sys_async_task( 63 task_metadata_id, 64 task_metadata_executor, 65 task_metadata_context, 66 task_metadata_option, 67 task_parent_id, 68 task_status, 69 task_runner, 70 task_epoch, 71 last_heartbeat, 72 create_at, 73 end_at) values ` 74 75 updateAsyncTask = `update %s.sys_async_task set 76 task_metadata_executor=?, 77 task_metadata_context=?, 78 task_metadata_option=?, 79 task_parent_id=?, 80 task_status=?, 81 task_runner=?, 82 task_epoch=?, 83 last_heartbeat=?, 84 result_code=?, 85 error_msg=?, 86 create_at=?, 87 end_at=? where task_id=?` 88 89 selectAsyncTask = `select 90 task_id, 91 task_metadata_id, 92 task_metadata_executor, 93 task_metadata_context, 94 task_metadata_option, 95 task_parent_id, 96 task_status, 97 task_runner, 98 task_epoch, 99 last_heartbeat, 100 result_code, 101 error_msg, 102 create_at, 103 end_at 104 from %s.sys_async_task` 105 106 insertCronTask = `insert into %s.sys_cron_task ( 107 task_metadata_id, 108 task_metadata_executor, 109 task_metadata_context, 110 task_metadata_option, 111 cron_expr, 112 next_time, 113 trigger_times, 114 create_at, 115 update_at 116 ) values ` 117 118 selectCronTask = `select 119 cron_task_id, 120 task_metadata_id, 121 task_metadata_executor, 122 task_metadata_context, 123 task_metadata_option, 124 cron_expr, 125 next_time, 126 trigger_times, 127 create_at, 128 update_at 129 from %s.sys_cron_task` 130 131 updateCronTask = `update %s.sys_cron_task set 132 task_metadata_executor=?, 133 task_metadata_context=?, 134 task_metadata_option=?, 135 cron_expr=?, 136 next_time=?, 137 trigger_times=?, 138 create_at=?, 139 update_at=? where cron_task_id=?` 140 141 countTaskId = `select count(task_metadata_id) from %s.sys_async_task where task_metadata_id=?` 142 143 getTriggerTimes = `select trigger_times from %s.sys_cron_task where task_metadata_id=?` 144 145 deleteTask = `delete from %s.sys_async_task where ` 146 ) 147 148 var ( 149 forceNewConn = "async_task_force_new_connection" 150 ) 151 152 type mysqlTaskStorage struct { 153 dsn string 154 dbname string 155 db *sql.DB 156 forceNewConn bool 157 } 158 159 func NewMysqlTaskStorage(dsn, dbname string) (TaskStorage, error) { 160 db, err := sql.Open("mysql", dsn) 161 if err != nil { 162 return nil, err 163 } 164 165 db.SetMaxOpenConns(5) 166 db.SetMaxIdleConns(1) 167 168 _, ok := os.LookupEnv(forceNewConn) 169 return &mysqlTaskStorage{ 170 dsn: dsn, 171 db: db, 172 dbname: dbname, 173 forceNewConn: ok, 174 }, nil 175 } 176 177 func (m *mysqlTaskStorage) Close() error { 178 return m.db.Close() 179 } 180 181 func (m *mysqlTaskStorage) Add(ctx context.Context, tasks ...task.Task) (int, error) { 182 if taskFrameworkDisabled() { 183 return 0, nil 184 } 185 186 if len(tasks) == 0 { 187 return 0, nil 188 } 189 190 db, release, err := m.getDB() 191 if err != nil { 192 return 0, err 193 } 194 defer func() { 195 _ = release() 196 }() 197 198 conn, err := db.Conn(ctx) 199 if err != nil { 200 return 0, err 201 } 202 defer func() { 203 _ = conn.Close() 204 }() 205 206 sqlStr := fmt.Sprintf(insertAsyncTask, m.dbname) 207 vals := make([]any, 0, len(tasks)*13) 208 209 for _, t := range tasks { 210 j, err := json.Marshal(t.Metadata.Options) 211 if err != nil { 212 return 0, err 213 } 214 215 sqlStr += "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)," 216 vals = append(vals, t.Metadata.ID, 217 t.Metadata.Executor, 218 t.Metadata.Context, 219 string(j), 220 t.ParentTaskID, 221 t.Status, 222 t.TaskRunner, 223 t.Epoch, 224 t.LastHeartbeat, 225 t.CreateAt, 226 t.CompletedAt, 227 ) 228 } 229 230 if sqlStr == fmt.Sprintf(insertAsyncTask, m.dbname) { 231 return 0, nil 232 } 233 sqlStr = sqlStr[0 : len(sqlStr)-1] 234 stmt, err := conn.PrepareContext(ctx, sqlStr) 235 if err != nil { 236 return 0, err 237 } 238 exec, err := stmt.Exec(vals...) 239 if err != nil { 240 dup, err := removeDuplicateTasks(err, tasks) 241 if err != nil { 242 return 0, err 243 } 244 add, err := m.Add(ctx, dup...) 245 if err != nil { 246 return add, err 247 } 248 return add, nil 249 } 250 affected, err := exec.RowsAffected() 251 if err != nil { 252 return 0, err 253 } 254 255 return int(affected), nil 256 } 257 258 func (m *mysqlTaskStorage) Update(ctx context.Context, tasks []task.Task, condition ...Condition) (int, error) { 259 if taskFrameworkDisabled() { 260 return 0, nil 261 } 262 263 if len(tasks) == 0 { 264 return 0, nil 265 } 266 267 db, release, err := m.getDB() 268 if err != nil { 269 return 0, err 270 } 271 defer func() { 272 _ = release() 273 }() 274 275 conn, err := db.Conn(ctx) 276 if err != nil { 277 return 0, err 278 } 279 defer func() { 280 _ = conn.Close() 281 }() 282 283 c := conditions{} 284 for _, cond := range condition { 285 cond(&c) 286 } 287 where := buildWhereClause(c) 288 289 tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) 290 if err != nil { 291 return 0, err 292 } 293 294 var update string 295 if where != "" { 296 update = fmt.Sprintf(updateAsyncTask, m.dbname) + " and " + where 297 } else { 298 update = fmt.Sprintf(updateAsyncTask, m.dbname) 299 } 300 n := 0 301 for _, t := range tasks { 302 err := func() error { 303 execResult := &task.ExecuteResult{} 304 if t.ExecuteResult != nil { 305 execResult.Code = t.ExecuteResult.Code 306 execResult.Error = t.ExecuteResult.Error 307 } 308 309 j, err := json.Marshal(t.Metadata.Options) 310 if err != nil { 311 return err 312 } 313 314 exec, err := tx.ExecContext(ctx, update, 315 t.Metadata.Executor, 316 t.Metadata.Context, 317 string(j), 318 t.ParentTaskID, 319 t.Status, 320 t.TaskRunner, 321 t.Epoch, 322 t.LastHeartbeat, 323 execResult.Code, 324 execResult.Error, 325 t.CreateAt, 326 t.CompletedAt, 327 t.ID, 328 ) 329 if err != nil { 330 return err 331 } 332 affected, err := exec.RowsAffected() 333 if err != nil { 334 return nil 335 } 336 n += int(affected) 337 return nil 338 }() 339 if err != nil { 340 if e := tx.Rollback(); e != nil { 341 return 0, e 342 } 343 return 0, err 344 } 345 } 346 if err = tx.Commit(); err != nil { 347 return 0, err 348 } 349 return n, nil 350 } 351 352 func (m *mysqlTaskStorage) Delete(ctx context.Context, condition ...Condition) (int, error) { 353 if taskFrameworkDisabled() { 354 return 0, nil 355 } 356 357 db, release, err := m.getDB() 358 if err != nil { 359 return 0, err 360 } 361 defer func() { 362 _ = release() 363 }() 364 365 conn, err := db.Conn(ctx) 366 if err != nil { 367 return 0, err 368 } 369 defer func() { 370 _ = conn.Close() 371 }() 372 373 c := conditions{} 374 for _, cond := range condition { 375 cond(&c) 376 } 377 where := buildWhereClause(c) 378 379 exec, err := conn.ExecContext(ctx, fmt.Sprintf(deleteTask, m.dbname)+where) 380 if err != nil { 381 return 0, err 382 } 383 affected, err := exec.RowsAffected() 384 if err != nil { 385 panic(err) 386 } 387 return int(affected), nil 388 } 389 390 func (m *mysqlTaskStorage) Query(ctx context.Context, condition ...Condition) ([]task.Task, error) { 391 if taskFrameworkDisabled() { 392 return nil, nil 393 } 394 395 db, release, err := m.getDB() 396 if err != nil { 397 return nil, err 398 } 399 defer func() { 400 _ = release() 401 }() 402 403 conn, err := db.Conn(ctx) 404 if err != nil { 405 return nil, err 406 } 407 defer func() { 408 _ = conn.Close() 409 }() 410 411 c := conditions{} 412 for _, cond := range condition { 413 cond(&c) 414 } 415 416 where := buildWhereClause(c) 417 var query string 418 if where != "" { 419 query = fmt.Sprintf(selectAsyncTask, m.dbname) + " where " + where 420 } else { 421 query = fmt.Sprintf(selectAsyncTask, m.dbname) 422 } 423 query += buildOrderByClause(c) 424 query += buildLimitClause(c) 425 426 rows, err := conn.QueryContext(ctx, query) 427 if err != nil { 428 return nil, err 429 } 430 defer func() { 431 _ = rows.Close() 432 }() 433 434 tasks := make([]task.Task, 0) 435 for rows.Next() { 436 var t task.Task 437 var codeOption sql.NullInt32 438 var msgOption sql.NullString 439 var options string 440 if err := rows.Scan( 441 &t.ID, 442 &t.Metadata.ID, 443 &t.Metadata.Executor, 444 &t.Metadata.Context, 445 &options, 446 &t.ParentTaskID, 447 &t.Status, 448 &t.TaskRunner, 449 &t.Epoch, 450 &t.LastHeartbeat, 451 &codeOption, 452 &msgOption, 453 &t.CreateAt, 454 &t.CompletedAt, 455 ); err != nil { 456 return nil, err 457 } 458 if err := json.Unmarshal([]byte(options), &t.Metadata.Options); err != nil { 459 return nil, err 460 } 461 462 if codeOption.Valid { 463 t.ExecuteResult = &task.ExecuteResult{} 464 code, err := codeOption.Value() 465 if err != nil { 466 return nil, err 467 } 468 t.ExecuteResult.Code = task.ResultCode(code.(int64)) 469 470 msg, err := msgOption.Value() 471 if err != nil { 472 return nil, err 473 } 474 t.ExecuteResult.Error = msg.(string) 475 } 476 477 tasks = append(tasks, t) 478 } 479 if err := rows.Err(); err != nil { 480 return tasks, err 481 } 482 return tasks, nil 483 } 484 485 func (m *mysqlTaskStorage) AddCronTask(ctx context.Context, cronTask ...task.CronTask) (int, error) { 486 if taskFrameworkDisabled() { 487 return 0, nil 488 } 489 490 if len(cronTask) == 0 { 491 return 0, nil 492 } 493 494 db, release, err := m.getDB() 495 if err != nil { 496 return 0, err 497 } 498 defer func() { 499 _ = release() 500 }() 501 502 conn, err := db.Conn(ctx) 503 if err != nil { 504 return 0, err 505 } 506 defer func() { 507 _ = conn.Close() 508 }() 509 510 sqlStr := fmt.Sprintf(insertCronTask, m.dbname) 511 vals := make([]any, 0) 512 for _, t := range cronTask { 513 sqlStr += "(?, ?, ?, ?, ?, ?, ?, ?, ?)," 514 515 j, err := json.Marshal(t.Metadata.Options) 516 if err != nil { 517 return 0, err 518 } 519 520 vals = append(vals, 521 t.Metadata.ID, 522 t.Metadata.Executor, 523 t.Metadata.Context, 524 string(j), 525 t.CronExpr, 526 t.NextTime, 527 t.TriggerTimes, 528 t.CreateAt, 529 t.UpdateAt, 530 ) 531 } 532 if sqlStr == fmt.Sprintf(insertCronTask, m.dbname) { 533 return 0, nil 534 } 535 sqlStr = sqlStr[0 : len(sqlStr)-1] 536 stmt, err := conn.PrepareContext(ctx, sqlStr) 537 if err != nil { 538 return 0, err 539 } 540 exec, err := stmt.Exec(vals...) 541 if err != nil { 542 dup, err := removeDuplicateCronTasks(err, cronTask) 543 if err != nil { 544 return 0, err 545 } 546 add, err := m.AddCronTask(ctx, dup...) 547 if err != nil { 548 return add, err 549 } 550 return add, nil 551 } 552 affected, err := exec.RowsAffected() 553 if err != nil { 554 return 0, err 555 } 556 return int(affected), nil 557 } 558 559 func (m *mysqlTaskStorage) QueryCronTask(ctx context.Context) ([]task.CronTask, error) { 560 if taskFrameworkDisabled() { 561 return nil, nil 562 } 563 564 db, release, err := m.getDB() 565 if err != nil { 566 return nil, err 567 } 568 defer func() { 569 _ = release() 570 }() 571 572 conn, err := db.Conn(ctx) 573 if err != nil { 574 return nil, err 575 } 576 defer func() { 577 _ = conn.Close() 578 }() 579 580 rows, err := conn.QueryContext(ctx, fmt.Sprintf(selectCronTask, m.dbname)) 581 defer func(rows *sql.Rows) { 582 if rows == nil { 583 return 584 } 585 _ = rows.Close() 586 }(rows) 587 if err != nil { 588 return nil, err 589 } 590 591 tasks := make([]task.CronTask, 0) 592 593 for rows.Next() { 594 var t task.CronTask 595 var options string 596 err := rows.Scan( 597 &t.ID, 598 &t.Metadata.ID, 599 &t.Metadata.Executor, 600 &t.Metadata.Context, 601 &options, 602 &t.CronExpr, 603 &t.NextTime, 604 &t.TriggerTimes, 605 &t.CreateAt, 606 &t.UpdateAt, 607 ) 608 if err != nil { 609 return nil, err 610 } 611 if err := json.Unmarshal([]byte(options), &t.Metadata.Options); err != nil { 612 return nil, err 613 } 614 615 tasks = append(tasks, t) 616 } 617 if err := rows.Err(); err != nil { 618 return tasks, err 619 } 620 621 return tasks, nil 622 } 623 624 func (m *mysqlTaskStorage) UpdateCronTask(ctx context.Context, cronTask task.CronTask, t task.Task) (int, error) { 625 if taskFrameworkDisabled() { 626 return 0, nil 627 } 628 629 db, release, err := m.getDB() 630 if err != nil { 631 return 0, err 632 } 633 defer func() { 634 _ = release() 635 }() 636 637 conn, err := db.Conn(ctx) 638 if err != nil { 639 return 0, err 640 } 641 defer func() { 642 _ = conn.Close() 643 }() 644 645 ok, err := m.taskExists(ctx, conn, t.Metadata.ID) 646 if err != nil || ok { 647 return 0, err 648 } 649 650 triggerTimes, err := m.getTriggerTimes(ctx, conn, cronTask.Metadata.ID) 651 if err == sql.ErrNoRows || triggerTimes != cronTask.TriggerTimes-1 { 652 return 0, nil 653 } 654 655 tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) 656 if err != nil { 657 return 0, err 658 } 659 defer func(tx *sql.Tx) { 660 _ = tx.Rollback() 661 }(tx) 662 stmt, err := tx.Prepare(fmt.Sprintf(insertAsyncTask, m.dbname) + "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") 663 if err != nil { 664 return 0, err 665 } 666 667 j, err := json.Marshal(t.Metadata.Options) 668 if err != nil { 669 return 0, err 670 } 671 update, err := stmt.Exec( 672 t.Metadata.ID, 673 t.Metadata.Executor, 674 t.Metadata.Context, 675 string(j), 676 t.ParentTaskID, 677 t.Status, 678 t.TaskRunner, 679 t.Epoch, 680 t.LastHeartbeat, 681 t.CreateAt, 682 t.CompletedAt) 683 if err != nil { 684 return 0, err 685 } 686 exec, err := tx.Exec(fmt.Sprintf(updateCronTask, m.dbname), 687 cronTask.Metadata.Executor, 688 cronTask.Metadata.Context, 689 string(j), 690 cronTask.CronExpr, 691 cronTask.NextTime, 692 cronTask.TriggerTimes, 693 cronTask.CreateAt, 694 cronTask.UpdateAt, 695 cronTask.ID) 696 if err != nil { 697 return 0, err 698 } 699 if err := tx.Commit(); err != nil { 700 return 0, err 701 } 702 affected1, err := exec.RowsAffected() 703 if err != nil { 704 return 0, err 705 } 706 affected2, err := update.RowsAffected() 707 if err != nil { 708 return 0, err 709 } 710 711 return int(affected2) + int(affected1), nil 712 } 713 714 func (m *mysqlTaskStorage) taskExists(ctx context.Context, conn *sql.Conn, taskMetadataID string) (bool, error) { 715 var count int32 716 if err := conn.QueryRowContext(ctx, fmt.Sprintf(countTaskId, m.dbname), taskMetadataID).Scan(&count); err != nil { 717 return false, err 718 } 719 return count != 0, nil 720 } 721 722 func (m *mysqlTaskStorage) getTriggerTimes(ctx context.Context, conn *sql.Conn, taskMetadataID string) (uint64, error) { 723 var triggerTimes uint64 724 err := conn.QueryRowContext(ctx, fmt.Sprintf(getTriggerTimes, m.dbname), taskMetadataID).Scan(&triggerTimes) 725 if err != nil { 726 if err == sql.ErrNoRows { 727 return 0, nil 728 } 729 return 0, err 730 } 731 return triggerTimes, nil 732 } 733 734 func buildWhereClause(c conditions) string { 735 var clause string 736 737 if c.hasTaskIDCond { 738 clause = fmt.Sprintf("task_id%s%d", OpName[c.taskIDOp], c.taskID) 739 } 740 741 if c.hasTaskRunnerCond { 742 if clause != "" { 743 clause += " AND " 744 } 745 clause += fmt.Sprintf("task_runner%s'%s'", OpName[c.taskRunnerOp], c.taskRunner) 746 } 747 748 if c.hasTaskStatusCond { 749 if clause != "" { 750 clause += " AND " 751 } 752 clause += fmt.Sprintf("task_status%s%d", OpName[c.taskStatusOp], c.taskStatus) 753 } 754 755 if c.hasTaskEpochCond { 756 if clause != "" { 757 clause += " AND " 758 } 759 clause += fmt.Sprintf("task_epoch%s%d", OpName[c.taskEpochOp], c.taskEpoch) 760 } 761 762 if c.hasTaskParentIDCond { 763 if clause != "" { 764 clause += " AND " 765 } 766 clause += fmt.Sprintf("task_parent_id%s'%s'", OpName[c.taskParentTaskIDOp], c.taskParentTaskID) 767 } 768 769 if c.hasTaskExecutorCond { 770 if clause != "" { 771 clause += " AND " 772 } 773 clause += fmt.Sprintf("task_metadata_executor%s%d", OpName[c.taskExecutorOp], c.taskExecutor) 774 } 775 776 return clause 777 } 778 779 func (m *mysqlTaskStorage) getDB() (*sql.DB, func() error, error) { 780 if !m.forceNewConn { 781 if err := m.useDB(m.db); err != nil { 782 return nil, nil, err 783 } 784 return m.db, func() error { return nil }, nil 785 } 786 787 db, err := sql.Open("mysql", m.dsn) 788 if err != nil { 789 return nil, nil, err 790 } 791 792 if err = m.useDB(db); err != nil { 793 return nil, nil, multierr.Append(err, db.Close()) 794 } 795 796 return db, func() error { return db.Close() }, nil 797 } 798 799 func (m *mysqlTaskStorage) useDB(db *sql.DB) error { 800 if err := db.Ping(); err != nil { 801 return errNotReady 802 } 803 for _, err := db.Exec("use " + m.dbname); err != nil; _, err = db.Exec("use " + m.dbname) { 804 me, ok := err.(*mysql.MySQLError) 805 if !ok || me.Number != moerr.ER_BAD_DB_ERROR { 806 return err 807 } 808 if _, err = db.Exec(fmt.Sprintf(createDatabase, m.dbname)); err != nil { 809 return multierr.Append(err, db.Close()) 810 } 811 } 812 rows, err := db.Query("show tables") 813 if err != nil { 814 return err 815 } 816 817 tables := make(map[string]struct{}, len(createTables)) 818 for rows.Next() { 819 var table string 820 if err := rows.Scan(&table); err != nil { 821 return err 822 } 823 tables[table] = struct{}{} 824 } 825 826 for table, createSql := range createTables { 827 if _, ok := tables[table]; !ok { 828 if _, err = db.Exec(fmt.Sprintf(createSql, m.dbname)); err != nil { 829 return multierr.Append(err, db.Close()) 830 } 831 } 832 } 833 return nil 834 } 835 836 func buildLimitClause(c conditions) string { 837 if c.limit != 0 { 838 return fmt.Sprintf(" limit %d", c.limit) 839 } 840 return "" 841 } 842 843 func buildOrderByClause(c conditions) string { 844 if c.orderByDesc { 845 return " order by task_id desc" 846 } 847 return " order by task_id" 848 } 849 850 func removeDuplicateTasks(err error, tasks []task.Task) ([]task.Task, error) { 851 me, ok := err.(*mysql.MySQLError) 852 if !ok { 853 return nil, err 854 } 855 if me.Number != moerr.ER_DUP_ENTRY { 856 return nil, err 857 } 858 b := tasks[:0] 859 for _, t := range tasks { 860 if !strings.Contains(me.Message, t.Metadata.ID) { 861 b = append(b, t) 862 } 863 } 864 return b, nil 865 } 866 867 func removeDuplicateCronTasks(err error, tasks []task.CronTask) ([]task.CronTask, error) { 868 me, ok := err.(*mysql.MySQLError) 869 if !ok { 870 return nil, err 871 } 872 if me.Number != moerr.ER_DUP_ENTRY { 873 return nil, err 874 } 875 b := tasks[:0] 876 for _, t := range tasks { 877 if !strings.Contains(me.Message, t.Metadata.ID) { 878 b = append(b, t) 879 } 880 } 881 return b, nil 882 }