github.com/songzhibin97/gkit@v1.2.13/distributed/backend/backend_db/db.go (about)

     1  package backend_db
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/songzhibin97/gkit/distributed/task"
    10  	"gorm.io/driver/mysql"
    11  	"gorm.io/driver/postgres"
    12  
    13  	"gorm.io/gorm"
    14  
    15  	"github.com/songzhibin97/gkit/distributed/backend"
    16  )
    17  
    18  // BackendSQLDB 支持mysql&pgsql
    19  type BackendSQLDB struct {
    20  	// gClient db客户端
    21  	gClient *gorm.DB
    22  	// resultExpire 数据过期时间
    23  	// -1 代表永不过期
    24  	// 0 会设置默认过期时间
    25  	// 单位为ns
    26  	resultExpire int64
    27  }
    28  
    29  // SetResultExpire 设置结果超时时间
    30  func (b *BackendSQLDB) SetResultExpire(expire int64) {
    31  	b.resultExpire = expire
    32  }
    33  
    34  func (b *BackendSQLDB) GroupTakeOver(groupID string, name string, taskIDs ...string) error {
    35  	group := task.InitGroupMeta(groupID, name, b.resultExpire, taskIDs...)
    36  	return b.gClient.Create(group).Error
    37  }
    38  
    39  func (b *BackendSQLDB) GroupCompleted(groupID string) (bool, error) {
    40  	group, err := b.getGroup(groupID)
    41  	if err != nil {
    42  		return false, err
    43  	}
    44  	status, err := b.getTaskStatus(group.TaskIDs)
    45  	if err != nil {
    46  		return false, err
    47  	}
    48  	ln := 0
    49  	for _, t := range status {
    50  		if !t.IsCompleted() {
    51  			return false, nil
    52  		}
    53  		ln++
    54  	}
    55  	return len(group.TaskIDs) == ln, nil
    56  }
    57  
    58  func (b *BackendSQLDB) getGroup(groupID string) (*task.GroupMeta, error) {
    59  	var group task.GroupMeta
    60  	err := b.gClient.Model(&task.GroupMeta{}).Where("id = ?", groupID).First(&group).Error
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	return &group, nil
    65  }
    66  
    67  func (b *BackendSQLDB) getTaskStatus(taskIDs []string) ([]*task.Status, error) {
    68  	statusList := make([]*task.Status, 0, len(taskIDs))
    69  	err := b.gClient.Where("id in ?", taskIDs).Find(&statusList).Error
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	return statusList, nil
    74  }
    75  
    76  func (b *BackendSQLDB) GroupTaskStatus(groupID string) ([]*task.Status, error) {
    77  	group, err := b.getGroup(groupID)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	return b.getTaskStatus(group.TaskIDs)
    82  }
    83  
    84  func (b *BackendSQLDB) TriggerCompleted(groupID string) (bool, error) {
    85  	result := b.gClient.Debug().Model(&task.GroupMeta{}).Where("id = ? and `lock` = false", groupID).Update("`lock`", true)
    86  	if result.Error != nil {
    87  		return false, result.Error
    88  	}
    89  	return result.RowsAffected != 0, nil
    90  }
    91  
    92  func (b *BackendSQLDB) SetStatePending(signature *task.Signature) error {
    93  	var status task.Status
    94  	err := b.gClient.Where("id = ?", signature.ID).First(&status).Error
    95  	if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
    96  		// 创建
    97  		status = task.Status{
    98  			TaskID:   signature.ID,
    99  			GroupID:  signature.GroupID,
   100  			Name:     signature.Name,
   101  			Status:   task.StatePending,
   102  			CreateAt: time.Now(),
   103  		}
   104  		return b.gClient.Create(&status).Error
   105  	}
   106  	if err != nil {
   107  		return err
   108  	}
   109  	// 更新
   110  	return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Update("status", task.StatePending).Error
   111  }
   112  
   113  func (b *BackendSQLDB) SetStateReceived(signature *task.Signature) error {
   114  	var status task.Status
   115  	err := b.gClient.Where("id = ?", signature.ID).First(&status).Error
   116  	if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
   117  		// 创建
   118  		status = task.Status{
   119  			TaskID:   signature.ID,
   120  			GroupID:  signature.GroupID,
   121  			Name:     signature.Name,
   122  			Status:   task.StateReceived,
   123  			CreateAt: time.Now(),
   124  		}
   125  		return b.gClient.Create(&status).Error
   126  	}
   127  	if err != nil {
   128  		return err
   129  	}
   130  
   131  	return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Update("status", task.StateReceived).Error
   132  }
   133  
   134  func (b *BackendSQLDB) SetStateStarted(signature *task.Signature) error {
   135  	var status task.Status
   136  	err := b.gClient.Where("id = ?", signature.ID).First(&status).Error
   137  	if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
   138  		// 创建
   139  		status = task.Status{
   140  			TaskID:   signature.ID,
   141  			GroupID:  signature.GroupID,
   142  			Name:     signature.Name,
   143  			Status:   task.StateStarted,
   144  			CreateAt: time.Now(),
   145  		}
   146  		return b.gClient.Create(&status).Error
   147  	}
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Update("status", task.StateStarted).Error
   153  }
   154  
   155  func (b *BackendSQLDB) SetStateRetry(t *task.Signature) error {
   156  	var status task.Status
   157  	err := b.gClient.Where("id = ?", t.ID).First(&status).Error
   158  	if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
   159  		// 创建
   160  		status = task.Status{
   161  			TaskID:   t.ID,
   162  			GroupID:  t.GroupID,
   163  			Name:     t.Name,
   164  			Status:   task.StateRetry,
   165  			CreateAt: time.Now(),
   166  		}
   167  		return b.gClient.Create(&status).Error
   168  	}
   169  	if err != nil {
   170  		return err
   171  	}
   172  
   173  	return b.gClient.Model(&task.Status{}).Where("id = ?", t.ID).Update("status", task.StateRetry).Error
   174  }
   175  
   176  func (b *BackendSQLDB) SetStateSuccess(signature *task.Signature, results []*task.Result) error {
   177  	var status task.Status
   178  	err := b.gClient.Where("id = ?", signature.ID).First(&status).Error
   179  	if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
   180  		// 创建
   181  		status = task.Status{
   182  			TaskID:   signature.ID,
   183  			GroupID:  signature.GroupID,
   184  			Name:     signature.Name,
   185  			Status:   task.StateSuccess,
   186  			Results:  task.Results(results),
   187  			CreateAt: time.Now(),
   188  		}
   189  		return b.gClient.Create(&status).Error
   190  	}
   191  	if err != nil {
   192  		return err
   193  	}
   194  
   195  	return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Updates(map[string]interface{}{"status": task.StateSuccess, "results": task.Results(results)}).Error
   196  }
   197  
   198  func (b *BackendSQLDB) SetStateFailure(signature *task.Signature, err string) error {
   199  	var status task.Status
   200  	_err := b.gClient.Where("id = ?", signature.ID).First(&status).Error
   201  	if _err != nil && errors.Is(_err, gorm.ErrRecordNotFound) {
   202  		// 创建
   203  		status = task.Status{
   204  			TaskID:   signature.ID,
   205  			GroupID:  signature.GroupID,
   206  			Name:     signature.Name,
   207  			Status:   task.StateFailure,
   208  			Error:    err,
   209  			CreateAt: time.Now(),
   210  		}
   211  		return b.gClient.Create(&status).Error
   212  	}
   213  	if _err != nil {
   214  		return _err
   215  	}
   216  
   217  	return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Updates(map[string]interface{}{"status": task.StateFailure, "error": err}).Error
   218  }
   219  
   220  func (b *BackendSQLDB) GetStatus(taskID string) (*task.Status, error) {
   221  	var status task.Status
   222  	err := b.gClient.Where("id = ?", taskID).First(&status).Error
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	return &status, nil
   227  }
   228  
   229  func (b *BackendSQLDB) ResetTask(taskIDs ...string) error {
   230  	return b.gClient.Where("id in ?", taskIDs).Delete(&task.Status{}).Error
   231  }
   232  
   233  func (b *BackendSQLDB) ResetGroup(groupIDs ...string) error {
   234  	return b.gClient.Where("id in ?", groupIDs).Delete(&task.GroupMeta{}).Error
   235  }
   236  
   237  func (b *BackendSQLDB) autoMigrate() error {
   238  	return b.gClient.AutoMigrate(
   239  		task.GroupMeta{},
   240  		task.Status{},
   241  	)
   242  }
   243  
   244  func NewBackendSQLDB(db *sql.DB, resultExpire int64, dbType string, config *gorm.Config) backend.Backend {
   245  	if config == nil {
   246  		config = &gorm.Config{}
   247  	}
   248  	var (
   249  		gdb *gorm.DB
   250  		err error
   251  	)
   252  	switch strings.ToLower(dbType) {
   253  	case "mysql":
   254  		gdb, err = gorm.Open(mysql.New(mysql.Config{Conn: db}), config)
   255  	case "pgsql":
   256  		gdb, err = gorm.Open(postgres.New(postgres.Config{Conn: db}), config)
   257  	default:
   258  		panic("dbType not supported")
   259  	}
   260  	if err != nil {
   261  		panic(err)
   262  	}
   263  	b := BackendSQLDB{
   264  		gClient:      gdb,
   265  		resultExpire: resultExpire,
   266  	}
   267  	_ = b.autoMigrate()
   268  	return &b
   269  }