github.com/songzhibin97/gkit@v1.2.13/distributed/backend/backend_redis/redis.go (about)

     1  package backend_redis
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"reflect"
     8  	"time"
     9  
    10  	json "github.com/json-iterator/go"
    11  
    12  	"github.com/go-redsync/redsync/v4/redis/goredis/v8"
    13  	"github.com/songzhibin97/gkit/distributed/task"
    14  
    15  	"github.com/go-redis/redis/v8"
    16  	"github.com/go-redsync/redsync/v4"
    17  	"github.com/songzhibin97/gkit/distributed/backend"
    18  )
    19  
    20  var defaultResultExpire int64 = 3600
    21  
    22  var ErrType = errors.New("err type")
    23  
    24  type BackendRedis struct {
    25  	// client redis客户端
    26  	client redis.UniversalClient
    27  	// lock 分布式锁
    28  	lock *redsync.Redsync
    29  	// resultExpire 数据过期时间
    30  	// -1 代表永不过期
    31  	// 0 会设置默认过期时间
    32  	// 单位为ns
    33  	resultExpire int64
    34  }
    35  
    36  func NewBackendRedis(client redis.UniversalClient, resultExpire int64) backend.Backend {
    37  	b := &BackendRedis{
    38  		client:       client,
    39  		lock:         redsync.New(goredis.NewPool(client)),
    40  		resultExpire: resultExpire,
    41  	}
    42  	if b.resultExpire == 0 {
    43  		b.resultExpire = defaultResultExpire
    44  	}
    45  	return b
    46  }
    47  
    48  // SetResultExpire 设置结果超时时间
    49  func (b *BackendRedis) SetResultExpire(expire int64) {
    50  	b.resultExpire = expire
    51  }
    52  
    53  func (b *BackendRedis) GroupTakeOver(groupID string, name string, taskIDs ...string) error {
    54  	group := task.InitGroupMeta(groupID, name, b.resultExpire, taskIDs...)
    55  	body, err := json.Marshal(group)
    56  	if err != nil {
    57  		return err
    58  	}
    59  	expire := b.resultExpire
    60  	if expire < 0 {
    61  		expire = 0
    62  	}
    63  	// 避免接管任务记录被覆盖
    64  	var ok bool
    65  	for !ok {
    66  		ok, err = b.client.SetNX(context.Background(), groupID, body, time.Duration(expire)).Result()
    67  		if err != nil {
    68  			return err
    69  		}
    70  		if !ok {
    71  			time.Sleep(time.Second)
    72  		}
    73  	}
    74  	return nil
    75  }
    76  
    77  func (b *BackendRedis) GroupCompleted(groupID string) (bool, error) {
    78  	list, err := b.GroupTaskStatus(groupID)
    79  	if err != nil {
    80  		return false, err
    81  	}
    82  	for _, status := range list {
    83  		if !status.IsCompleted() {
    84  			return false, nil
    85  		}
    86  	}
    87  	return true, nil
    88  }
    89  
    90  func (b *BackendRedis) GroupTaskStatus(groupID string) ([]*task.Status, error) {
    91  	var ret []*task.Status
    92  	// 同一个groupID 可能接管多个任务
    93  	// 拿到所有的key
    94  	var taskIDs []string
    95  	groups, err := b.shouldAndBind(&task.GroupMeta{}, groupID)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	_groups := groups.([]interface{})
   100  	if len(_groups) == 0 {
   101  		return nil, nil
   102  	}
   103  
   104  	for _, group := range _groups {
   105  		_group, ok := group.(*task.GroupMeta)
   106  		if !ok {
   107  			return nil, ErrType
   108  		}
   109  		for _, id := range _group.TaskIDs {
   110  			taskIDs = append(taskIDs, id)
   111  		}
   112  	}
   113  	statusList, err := b.shouldAndBind(&task.Status{}, taskIDs...)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	_statusList := statusList.([]interface{})
   119  	for _, status := range _statusList {
   120  		_status, ok := status.(*task.Status)
   121  		if !ok {
   122  			return nil, ErrType
   123  		}
   124  		ret = append(ret, _status)
   125  	}
   126  	return ret, nil
   127  }
   128  
   129  func (b *BackendRedis) TriggerCompleted(groupID string) (bool, error) {
   130  	// 分布式锁
   131  	l := b.lock.NewMutex("TriggerCompletedMutex" + groupID)
   132  	if err := l.Lock(); err != nil {
   133  		return false, err
   134  	}
   135  	defer l.Unlock()
   136  	group, err := b.getGroup(groupID)
   137  	if err != nil {
   138  		return false, err
   139  	}
   140  	if group.TriggerCompleted {
   141  		return false, nil
   142  	}
   143  	group.TriggerCompleted = true
   144  	body, _ := json.Marshal(group)
   145  	expire := b.resultExpire
   146  	if expire < 0 {
   147  		expire = 0
   148  	}
   149  	err = b.client.Set(context.Background(), groupID, body, time.Duration(expire)).Err()
   150  	if err != nil {
   151  		return false, err
   152  	}
   153  	return true, nil
   154  }
   155  
   156  func (b *BackendRedis) SetStatePending(signature *task.Signature) error {
   157  	return b.updateStatus(task.NewPendingState(signature))
   158  }
   159  
   160  func (b *BackendRedis) SetStateReceived(signature *task.Signature) error {
   161  	dst := task.NewReceivedState(signature)
   162  	b.migrate(dst)
   163  	return b.updateStatus(dst)
   164  }
   165  
   166  func (b *BackendRedis) SetStateStarted(signature *task.Signature) error {
   167  	dst := task.NewStartedState(signature)
   168  	b.migrate(dst)
   169  	return b.updateStatus(dst)
   170  }
   171  
   172  func (b *BackendRedis) SetStateRetry(signature *task.Signature) error {
   173  	dst := task.NewRetryState(signature)
   174  	b.migrate(dst)
   175  	return b.updateStatus(dst)
   176  }
   177  
   178  func (b *BackendRedis) SetStateSuccess(signature *task.Signature, results []*task.Result) error {
   179  	dst := task.NewSuccessState(signature, results...)
   180  	b.migrate(dst)
   181  	return b.updateStatus(dst)
   182  }
   183  
   184  func (b *BackendRedis) SetStateFailure(signature *task.Signature, err string) error {
   185  	dst := task.NewFailureState(signature, err)
   186  	b.migrate(dst)
   187  	return b.updateStatus(dst)
   188  }
   189  
   190  func (b *BackendRedis) GetStatus(taskID string) (*task.Status, error) {
   191  	return b.getStatus(taskID)
   192  }
   193  
   194  func (b *BackendRedis) ResetTask(taskIDs ...string) error {
   195  	if len(taskIDs) == 0 {
   196  		return nil
   197  	}
   198  	return b.client.Del(context.Background(), taskIDs...).Err()
   199  }
   200  
   201  func (b *BackendRedis) ResetGroup(groupIDs ...string) error {
   202  	if len(groupIDs) == 0 {
   203  		return nil
   204  	}
   205  	return b.client.Del(context.Background(), groupIDs...).Err()
   206  }
   207  
   208  // shouldAndBind 批量获取对应key的group信息
   209  // obj interface must ptr
   210  func (b *BackendRedis) shouldAndBind(dst interface{}, keys ...string) (interface{}, error) {
   211  	var src []interface{}
   212  	results, err := b.client.Pipelined(context.Background(), func(pipeline redis.Pipeliner) error {
   213  		for _, key := range keys {
   214  			pipeline.Get(context.Background(), key)
   215  		}
   216  		return nil
   217  	})
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  	for _, result := range results {
   222  		stringCmd, ok := result.(*redis.StringCmd)
   223  		if !ok {
   224  			continue
   225  		}
   226  		body, err := stringCmd.Bytes()
   227  		if err != nil {
   228  			return nil, err
   229  		}
   230  		obj := reflect.New(reflect.TypeOf(dst).Elem()).Interface()
   231  		err = json.Unmarshal(body, obj)
   232  		if err != nil {
   233  			return nil, err
   234  		}
   235  		src = append(src, obj)
   236  	}
   237  	return src, nil
   238  }
   239  
   240  // getGroup 获取组详情
   241  func (b *BackendRedis) getGroup(groupID string) (*task.GroupMeta, error) {
   242  	body, err := b.client.Get(context.Background(), groupID).Bytes()
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  	var group task.GroupMeta
   247  	err = json.Unmarshal(body, &group)
   248  	return &group, err
   249  }
   250  
   251  // updateStatus 更新状态
   252  func (b *BackendRedis) updateStatus(status *task.Status) error {
   253  	body, err := json.Marshal(status)
   254  	if err != nil {
   255  		return err
   256  	}
   257  	expire := b.resultExpire
   258  	if expire < 0 {
   259  		expire = 0
   260  	}
   261  	_, err = b.client.Set(context.Background(), status.TaskID, body, time.Duration(expire)).Result()
   262  	return err
   263  }
   264  
   265  // getStatus 获取任务状态
   266  func (b *BackendRedis) getStatus(taskID string) (*task.Status, error) {
   267  	body, err := b.client.Get(context.Background(), taskID).Bytes()
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  	var status task.Status
   272  	decoder := json.NewDecoder(bytes.NewReader(body))
   273  	decoder.UseNumber()
   274  	if err := decoder.Decode(&status); err != nil {
   275  		return nil, err
   276  	}
   277  
   278  	return &status, nil
   279  }
   280  
   281  func (b *BackendRedis) migrate(dst *task.Status) {
   282  	src, err := b.getStatus(dst.TaskID)
   283  	if err == nil {
   284  		dst.CreateAt = src.CreateAt
   285  		dst.Name = src.Name
   286  	}
   287  }