github.com/songzhibin97/gkit@v1.2.13/distributed/backend/backend_db/db.go (about) 1 package backend_db 2 3 import ( 4 "database/sql" 5 "errors" 6 "strings" 7 "time" 8 9 "github.com/songzhibin97/gkit/distributed/task" 10 "gorm.io/driver/mysql" 11 "gorm.io/driver/postgres" 12 13 "gorm.io/gorm" 14 15 "github.com/songzhibin97/gkit/distributed/backend" 16 ) 17 18 // BackendSQLDB 支持mysql&pgsql 19 type BackendSQLDB struct { 20 // gClient db客户端 21 gClient *gorm.DB 22 // resultExpire 数据过期时间 23 // -1 代表永不过期 24 // 0 会设置默认过期时间 25 // 单位为ns 26 resultExpire int64 27 } 28 29 // SetResultExpire 设置结果超时时间 30 func (b *BackendSQLDB) SetResultExpire(expire int64) { 31 b.resultExpire = expire 32 } 33 34 func (b *BackendSQLDB) GroupTakeOver(groupID string, name string, taskIDs ...string) error { 35 group := task.InitGroupMeta(groupID, name, b.resultExpire, taskIDs...) 36 return b.gClient.Create(group).Error 37 } 38 39 func (b *BackendSQLDB) GroupCompleted(groupID string) (bool, error) { 40 group, err := b.getGroup(groupID) 41 if err != nil { 42 return false, err 43 } 44 status, err := b.getTaskStatus(group.TaskIDs) 45 if err != nil { 46 return false, err 47 } 48 ln := 0 49 for _, t := range status { 50 if !t.IsCompleted() { 51 return false, nil 52 } 53 ln++ 54 } 55 return len(group.TaskIDs) == ln, nil 56 } 57 58 func (b *BackendSQLDB) getGroup(groupID string) (*task.GroupMeta, error) { 59 var group task.GroupMeta 60 err := b.gClient.Model(&task.GroupMeta{}).Where("id = ?", groupID).First(&group).Error 61 if err != nil { 62 return nil, err 63 } 64 return &group, nil 65 } 66 67 func (b *BackendSQLDB) getTaskStatus(taskIDs []string) ([]*task.Status, error) { 68 statusList := make([]*task.Status, 0, len(taskIDs)) 69 err := b.gClient.Where("id in ?", taskIDs).Find(&statusList).Error 70 if err != nil { 71 return nil, err 72 } 73 return statusList, nil 74 } 75 76 func (b *BackendSQLDB) GroupTaskStatus(groupID string) ([]*task.Status, error) { 77 group, err := b.getGroup(groupID) 78 if err != nil { 79 return nil, err 80 } 81 return b.getTaskStatus(group.TaskIDs) 82 } 83 84 func (b *BackendSQLDB) TriggerCompleted(groupID string) (bool, error) { 85 result := b.gClient.Debug().Model(&task.GroupMeta{}).Where("id = ? and `lock` = false", groupID).Update("`lock`", true) 86 if result.Error != nil { 87 return false, result.Error 88 } 89 return result.RowsAffected != 0, nil 90 } 91 92 func (b *BackendSQLDB) SetStatePending(signature *task.Signature) error { 93 var status task.Status 94 err := b.gClient.Where("id = ?", signature.ID).First(&status).Error 95 if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { 96 // 创建 97 status = task.Status{ 98 TaskID: signature.ID, 99 GroupID: signature.GroupID, 100 Name: signature.Name, 101 Status: task.StatePending, 102 CreateAt: time.Now(), 103 } 104 return b.gClient.Create(&status).Error 105 } 106 if err != nil { 107 return err 108 } 109 // 更新 110 return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Update("status", task.StatePending).Error 111 } 112 113 func (b *BackendSQLDB) SetStateReceived(signature *task.Signature) error { 114 var status task.Status 115 err := b.gClient.Where("id = ?", signature.ID).First(&status).Error 116 if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { 117 // 创建 118 status = task.Status{ 119 TaskID: signature.ID, 120 GroupID: signature.GroupID, 121 Name: signature.Name, 122 Status: task.StateReceived, 123 CreateAt: time.Now(), 124 } 125 return b.gClient.Create(&status).Error 126 } 127 if err != nil { 128 return err 129 } 130 131 return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Update("status", task.StateReceived).Error 132 } 133 134 func (b *BackendSQLDB) SetStateStarted(signature *task.Signature) error { 135 var status task.Status 136 err := b.gClient.Where("id = ?", signature.ID).First(&status).Error 137 if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { 138 // 创建 139 status = task.Status{ 140 TaskID: signature.ID, 141 GroupID: signature.GroupID, 142 Name: signature.Name, 143 Status: task.StateStarted, 144 CreateAt: time.Now(), 145 } 146 return b.gClient.Create(&status).Error 147 } 148 if err != nil { 149 return err 150 } 151 152 return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Update("status", task.StateStarted).Error 153 } 154 155 func (b *BackendSQLDB) SetStateRetry(t *task.Signature) error { 156 var status task.Status 157 err := b.gClient.Where("id = ?", t.ID).First(&status).Error 158 if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { 159 // 创建 160 status = task.Status{ 161 TaskID: t.ID, 162 GroupID: t.GroupID, 163 Name: t.Name, 164 Status: task.StateRetry, 165 CreateAt: time.Now(), 166 } 167 return b.gClient.Create(&status).Error 168 } 169 if err != nil { 170 return err 171 } 172 173 return b.gClient.Model(&task.Status{}).Where("id = ?", t.ID).Update("status", task.StateRetry).Error 174 } 175 176 func (b *BackendSQLDB) SetStateSuccess(signature *task.Signature, results []*task.Result) error { 177 var status task.Status 178 err := b.gClient.Where("id = ?", signature.ID).First(&status).Error 179 if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { 180 // 创建 181 status = task.Status{ 182 TaskID: signature.ID, 183 GroupID: signature.GroupID, 184 Name: signature.Name, 185 Status: task.StateSuccess, 186 Results: task.Results(results), 187 CreateAt: time.Now(), 188 } 189 return b.gClient.Create(&status).Error 190 } 191 if err != nil { 192 return err 193 } 194 195 return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Updates(map[string]interface{}{"status": task.StateSuccess, "results": task.Results(results)}).Error 196 } 197 198 func (b *BackendSQLDB) SetStateFailure(signature *task.Signature, err string) error { 199 var status task.Status 200 _err := b.gClient.Where("id = ?", signature.ID).First(&status).Error 201 if _err != nil && errors.Is(_err, gorm.ErrRecordNotFound) { 202 // 创建 203 status = task.Status{ 204 TaskID: signature.ID, 205 GroupID: signature.GroupID, 206 Name: signature.Name, 207 Status: task.StateFailure, 208 Error: err, 209 CreateAt: time.Now(), 210 } 211 return b.gClient.Create(&status).Error 212 } 213 if _err != nil { 214 return _err 215 } 216 217 return b.gClient.Model(&task.Status{}).Where("id = ?", signature.ID).Updates(map[string]interface{}{"status": task.StateFailure, "error": err}).Error 218 } 219 220 func (b *BackendSQLDB) GetStatus(taskID string) (*task.Status, error) { 221 var status task.Status 222 err := b.gClient.Where("id = ?", taskID).First(&status).Error 223 if err != nil { 224 return nil, err 225 } 226 return &status, nil 227 } 228 229 func (b *BackendSQLDB) ResetTask(taskIDs ...string) error { 230 return b.gClient.Where("id in ?", taskIDs).Delete(&task.Status{}).Error 231 } 232 233 func (b *BackendSQLDB) ResetGroup(groupIDs ...string) error { 234 return b.gClient.Where("id in ?", groupIDs).Delete(&task.GroupMeta{}).Error 235 } 236 237 func (b *BackendSQLDB) autoMigrate() error { 238 return b.gClient.AutoMigrate( 239 task.GroupMeta{}, 240 task.Status{}, 241 ) 242 } 243 244 func NewBackendSQLDB(db *sql.DB, resultExpire int64, dbType string, config *gorm.Config) backend.Backend { 245 if config == nil { 246 config = &gorm.Config{} 247 } 248 var ( 249 gdb *gorm.DB 250 err error 251 ) 252 switch strings.ToLower(dbType) { 253 case "mysql": 254 gdb, err = gorm.Open(mysql.New(mysql.Config{Conn: db}), config) 255 case "pgsql": 256 gdb, err = gorm.Open(postgres.New(postgres.Config{Conn: db}), config) 257 default: 258 panic("dbType not supported") 259 } 260 if err != nil { 261 panic(err) 262 } 263 b := BackendSQLDB{ 264 gClient: gdb, 265 resultExpire: resultExpire, 266 } 267 _ = b.autoMigrate() 268 return &b 269 }