github.com/songzhibin97/gkit@v1.2.13/distributed/backend/backend_redis/redis.go (about) 1 package backend_redis 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "reflect" 8 "time" 9 10 json "github.com/json-iterator/go" 11 12 "github.com/go-redsync/redsync/v4/redis/goredis/v8" 13 "github.com/songzhibin97/gkit/distributed/task" 14 15 "github.com/go-redis/redis/v8" 16 "github.com/go-redsync/redsync/v4" 17 "github.com/songzhibin97/gkit/distributed/backend" 18 ) 19 20 var defaultResultExpire int64 = 3600 21 22 var ErrType = errors.New("err type") 23 24 type BackendRedis struct { 25 // client redis客户端 26 client redis.UniversalClient 27 // lock 分布式锁 28 lock *redsync.Redsync 29 // resultExpire 数据过期时间 30 // -1 代表永不过期 31 // 0 会设置默认过期时间 32 // 单位为ns 33 resultExpire int64 34 } 35 36 func NewBackendRedis(client redis.UniversalClient, resultExpire int64) backend.Backend { 37 b := &BackendRedis{ 38 client: client, 39 lock: redsync.New(goredis.NewPool(client)), 40 resultExpire: resultExpire, 41 } 42 if b.resultExpire == 0 { 43 b.resultExpire = defaultResultExpire 44 } 45 return b 46 } 47 48 // SetResultExpire 设置结果超时时间 49 func (b *BackendRedis) SetResultExpire(expire int64) { 50 b.resultExpire = expire 51 } 52 53 func (b *BackendRedis) GroupTakeOver(groupID string, name string, taskIDs ...string) error { 54 group := task.InitGroupMeta(groupID, name, b.resultExpire, taskIDs...) 55 body, err := json.Marshal(group) 56 if err != nil { 57 return err 58 } 59 expire := b.resultExpire 60 if expire < 0 { 61 expire = 0 62 } 63 // 避免接管任务记录被覆盖 64 var ok bool 65 for !ok { 66 ok, err = b.client.SetNX(context.Background(), groupID, body, time.Duration(expire)).Result() 67 if err != nil { 68 return err 69 } 70 if !ok { 71 time.Sleep(time.Second) 72 } 73 } 74 return nil 75 } 76 77 func (b *BackendRedis) GroupCompleted(groupID string) (bool, error) { 78 list, err := b.GroupTaskStatus(groupID) 79 if err != nil { 80 return false, err 81 } 82 for _, status := range list { 83 if !status.IsCompleted() { 84 return false, nil 85 } 86 } 87 return true, nil 88 } 89 90 func (b *BackendRedis) GroupTaskStatus(groupID string) ([]*task.Status, error) { 91 var ret []*task.Status 92 // 同一个groupID 可能接管多个任务 93 // 拿到所有的key 94 var taskIDs []string 95 groups, err := b.shouldAndBind(&task.GroupMeta{}, groupID) 96 if err != nil { 97 return nil, err 98 } 99 _groups := groups.([]interface{}) 100 if len(_groups) == 0 { 101 return nil, nil 102 } 103 104 for _, group := range _groups { 105 _group, ok := group.(*task.GroupMeta) 106 if !ok { 107 return nil, ErrType 108 } 109 for _, id := range _group.TaskIDs { 110 taskIDs = append(taskIDs, id) 111 } 112 } 113 statusList, err := b.shouldAndBind(&task.Status{}, taskIDs...) 114 if err != nil { 115 return nil, err 116 } 117 118 _statusList := statusList.([]interface{}) 119 for _, status := range _statusList { 120 _status, ok := status.(*task.Status) 121 if !ok { 122 return nil, ErrType 123 } 124 ret = append(ret, _status) 125 } 126 return ret, nil 127 } 128 129 func (b *BackendRedis) TriggerCompleted(groupID string) (bool, error) { 130 // 分布式锁 131 l := b.lock.NewMutex("TriggerCompletedMutex" + groupID) 132 if err := l.Lock(); err != nil { 133 return false, err 134 } 135 defer l.Unlock() 136 group, err := b.getGroup(groupID) 137 if err != nil { 138 return false, err 139 } 140 if group.TriggerCompleted { 141 return false, nil 142 } 143 group.TriggerCompleted = true 144 body, _ := json.Marshal(group) 145 expire := b.resultExpire 146 if expire < 0 { 147 expire = 0 148 } 149 err = b.client.Set(context.Background(), groupID, body, time.Duration(expire)).Err() 150 if err != nil { 151 return false, err 152 } 153 return true, nil 154 } 155 156 func (b *BackendRedis) SetStatePending(signature *task.Signature) error { 157 return b.updateStatus(task.NewPendingState(signature)) 158 } 159 160 func (b *BackendRedis) SetStateReceived(signature *task.Signature) error { 161 dst := task.NewReceivedState(signature) 162 b.migrate(dst) 163 return b.updateStatus(dst) 164 } 165 166 func (b *BackendRedis) SetStateStarted(signature *task.Signature) error { 167 dst := task.NewStartedState(signature) 168 b.migrate(dst) 169 return b.updateStatus(dst) 170 } 171 172 func (b *BackendRedis) SetStateRetry(signature *task.Signature) error { 173 dst := task.NewRetryState(signature) 174 b.migrate(dst) 175 return b.updateStatus(dst) 176 } 177 178 func (b *BackendRedis) SetStateSuccess(signature *task.Signature, results []*task.Result) error { 179 dst := task.NewSuccessState(signature, results...) 180 b.migrate(dst) 181 return b.updateStatus(dst) 182 } 183 184 func (b *BackendRedis) SetStateFailure(signature *task.Signature, err string) error { 185 dst := task.NewFailureState(signature, err) 186 b.migrate(dst) 187 return b.updateStatus(dst) 188 } 189 190 func (b *BackendRedis) GetStatus(taskID string) (*task.Status, error) { 191 return b.getStatus(taskID) 192 } 193 194 func (b *BackendRedis) ResetTask(taskIDs ...string) error { 195 if len(taskIDs) == 0 { 196 return nil 197 } 198 return b.client.Del(context.Background(), taskIDs...).Err() 199 } 200 201 func (b *BackendRedis) ResetGroup(groupIDs ...string) error { 202 if len(groupIDs) == 0 { 203 return nil 204 } 205 return b.client.Del(context.Background(), groupIDs...).Err() 206 } 207 208 // shouldAndBind 批量获取对应key的group信息 209 // obj interface must ptr 210 func (b *BackendRedis) shouldAndBind(dst interface{}, keys ...string) (interface{}, error) { 211 var src []interface{} 212 results, err := b.client.Pipelined(context.Background(), func(pipeline redis.Pipeliner) error { 213 for _, key := range keys { 214 pipeline.Get(context.Background(), key) 215 } 216 return nil 217 }) 218 if err != nil { 219 return nil, err 220 } 221 for _, result := range results { 222 stringCmd, ok := result.(*redis.StringCmd) 223 if !ok { 224 continue 225 } 226 body, err := stringCmd.Bytes() 227 if err != nil { 228 return nil, err 229 } 230 obj := reflect.New(reflect.TypeOf(dst).Elem()).Interface() 231 err = json.Unmarshal(body, obj) 232 if err != nil { 233 return nil, err 234 } 235 src = append(src, obj) 236 } 237 return src, nil 238 } 239 240 // getGroup 获取组详情 241 func (b *BackendRedis) getGroup(groupID string) (*task.GroupMeta, error) { 242 body, err := b.client.Get(context.Background(), groupID).Bytes() 243 if err != nil { 244 return nil, err 245 } 246 var group task.GroupMeta 247 err = json.Unmarshal(body, &group) 248 return &group, err 249 } 250 251 // updateStatus 更新状态 252 func (b *BackendRedis) updateStatus(status *task.Status) error { 253 body, err := json.Marshal(status) 254 if err != nil { 255 return err 256 } 257 expire := b.resultExpire 258 if expire < 0 { 259 expire = 0 260 } 261 _, err = b.client.Set(context.Background(), status.TaskID, body, time.Duration(expire)).Result() 262 return err 263 } 264 265 // getStatus 获取任务状态 266 func (b *BackendRedis) getStatus(taskID string) (*task.Status, error) { 267 body, err := b.client.Get(context.Background(), taskID).Bytes() 268 if err != nil { 269 return nil, err 270 } 271 var status task.Status 272 decoder := json.NewDecoder(bytes.NewReader(body)) 273 decoder.UseNumber() 274 if err := decoder.Decode(&status); err != nil { 275 return nil, err 276 } 277 278 return &status, nil 279 } 280 281 func (b *BackendRedis) migrate(dst *task.Status) { 282 src, err := b.getStatus(dst.TaskID) 283 if err == nil { 284 dst.CreateAt = src.CreateAt 285 dst.Name = src.Name 286 } 287 }