github.com/wfusion/gofusion@v1.1.14/common/infra/asynq/pkg/testutil/testutil.go (about)

     1  // Copyright 2020 Kentaro Hibino. All rights reserved.
     2  // Use of this source code is governed by a MIT license
     3  // that can be found in the LICENSE file.
     4  
     5  // Package testutil defines test helpers for asynq and its internal packages.
     6  package testutil
     7  
     8  import (
     9  	"context"
    10  	"math"
    11  	"sort"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/google/go-cmp/cmp"
    16  	"github.com/google/go-cmp/cmp/cmpopts"
    17  	"github.com/google/uuid"
    18  	"github.com/redis/go-redis/v9"
    19  
    20  	"github.com/wfusion/gofusion/common/infra/asynq/pkg/base"
    21  	"github.com/wfusion/gofusion/common/infra/asynq/pkg/timeutil"
    22  	"github.com/wfusion/gofusion/common/utils/serialize/json"
    23  )
    24  
    25  // EquateInt64Approx returns a Comparer option that treats int64 values
    26  // to be equal if they are within the given margin.
    27  func EquateInt64Approx(margin int64) cmp.Option {
    28  	return cmp.Comparer(func(a, b int64) bool {
    29  		return math.Abs(float64(a-b)) <= float64(margin)
    30  	})
    31  }
    32  
    33  // SortMsgOpt is a cmp.Option to sort base.TaskMessage for comparing slice of task messages.
    34  var SortMsgOpt = cmp.Transformer("SortTaskMessages", func(in []*base.TaskMessage) []*base.TaskMessage {
    35  	out := append([]*base.TaskMessage(nil), in...) // Copy input to avoid mutating it
    36  	sort.Slice(out, func(i, j int) bool {
    37  		return out[i].ID < out[j].ID
    38  	})
    39  	return out
    40  })
    41  
    42  // SortZSetEntryOpt is an cmp.Option to sort ZSetEntry for comparing slice of zset entries.
    43  var SortZSetEntryOpt = cmp.Transformer("SortZSetEntries", func(in []base.Z) []base.Z {
    44  	out := append([]base.Z(nil), in...) // Copy input to avoid mutating it
    45  	sort.Slice(out, func(i, j int) bool {
    46  		return out[i].Message.ID < out[j].Message.ID
    47  	})
    48  	return out
    49  })
    50  
    51  // SortServerInfoOpt is a cmp.Option to sort base.ServerInfo for comparing slice of process info.
    52  var SortServerInfoOpt = cmp.Transformer("SortServerInfo", func(in []*base.ServerInfo) []*base.ServerInfo {
    53  	out := append([]*base.ServerInfo(nil), in...) // Copy input to avoid mutating it
    54  	sort.Slice(out, func(i, j int) bool {
    55  		if out[i].Host != out[j].Host {
    56  			return out[i].Host < out[j].Host
    57  		}
    58  		return out[i].PID < out[j].PID
    59  	})
    60  	return out
    61  })
    62  
    63  // SortWorkerInfoOpt is a cmp.Option to sort base.WorkerInfo for comparing slice of worker info.
    64  var SortWorkerInfoOpt = cmp.Transformer("SortWorkerInfo", func(in []*base.WorkerInfo) []*base.WorkerInfo {
    65  	out := append([]*base.WorkerInfo(nil), in...) // Copy input to avoid mutating it
    66  	sort.Slice(out, func(i, j int) bool {
    67  		return out[i].ID < out[j].ID
    68  	})
    69  	return out
    70  })
    71  
    72  // SortSchedulerEntryOpt is a cmp.Option to sort base.SchedulerEntry for comparing slice of entries.
    73  var SortSchedulerEntryOpt = cmp.Transformer("SortSchedulerEntry", func(in []*base.SchedulerEntry) []*base.SchedulerEntry {
    74  	out := append([]*base.SchedulerEntry(nil), in...) // Copy input to avoid mutating it
    75  	sort.Slice(out, func(i, j int) bool {
    76  		return out[i].Spec < out[j].Spec
    77  	})
    78  	return out
    79  })
    80  
    81  // SortSchedulerEnqueueEventOpt is a cmp.Option to sort base.SchedulerEnqueueEvent for comparing slice of events.
    82  var SortSchedulerEnqueueEventOpt = cmp.Transformer("SortSchedulerEnqueueEvent", func(in []*base.SchedulerEnqueueEvent) []*base.SchedulerEnqueueEvent {
    83  	out := append([]*base.SchedulerEnqueueEvent(nil), in...)
    84  	sort.Slice(out, func(i, j int) bool {
    85  		return out[i].EnqueuedAt.Unix() < out[j].EnqueuedAt.Unix()
    86  	})
    87  	return out
    88  })
    89  
    90  // SortStringSliceOpt is a cmp.Option to sort string slice.
    91  var SortStringSliceOpt = cmp.Transformer("SortStringSlice", func(in []string) []string {
    92  	out := append([]string(nil), in...)
    93  	sort.Strings(out)
    94  	return out
    95  })
    96  
    97  var SortRedisZSetEntryOpt = cmp.Transformer("SortZSetEntries", func(in []redis.Z) []redis.Z {
    98  	out := append([]redis.Z(nil), in...) // Copy input to avoid mutating it
    99  	sort.Slice(out, func(i, j int) bool {
   100  		// TODO: If member is a comparable type (int, string, etc) compare by the member
   101  		// Use generic comparable type here once update to go1.18
   102  		if _, ok := out[i].Member.(string); ok {
   103  			// If member is a string, compare the member
   104  			return out[i].Member.(string) < out[j].Member.(string)
   105  		}
   106  		return out[i].Score < out[j].Score
   107  	})
   108  	return out
   109  })
   110  
   111  // IgnoreIDOpt is an cmp.Option to ignore ID field in task messages when comparing.
   112  var IgnoreIDOpt = cmpopts.IgnoreFields(base.TaskMessage{}, "ID")
   113  
   114  // NewTaskMessage returns a new instance of TaskMessage given a task type and payload.
   115  func NewTaskMessage(taskType string, payload []byte) *base.TaskMessage {
   116  	return NewTaskMessageWithQueue(taskType, payload, base.DefaultQueueName)
   117  }
   118  
   119  // NewTaskMessageWithQueue returns a new instance of TaskMessage given a
   120  // task type, payload and queue name.
   121  func NewTaskMessageWithQueue(taskType string, payload []byte, qname string) *base.TaskMessage {
   122  	return &base.TaskMessage{
   123  		ID:       uuid.NewString(),
   124  		Type:     taskType,
   125  		Queue:    qname,
   126  		Retry:    25,
   127  		Payload:  payload,
   128  		Timeout:  1800, // default timeout of 30 mins
   129  		Deadline: 0,    // no deadline
   130  	}
   131  }
   132  
   133  // NewLeaseWithClock returns a new lease with the given expiration time and clock.
   134  func NewLeaseWithClock(expirationTime time.Time, clock timeutil.Clock) *base.Lease {
   135  	l := base.NewLease(expirationTime)
   136  	l.Clock = clock
   137  	return l
   138  }
   139  
   140  // JSON serializes the given key-value pairs into stream of bytes in JSON.
   141  func JSON(kv map[string]any) []byte {
   142  	b, err := json.Marshal(kv)
   143  	if err != nil {
   144  		panic(err)
   145  	}
   146  	return b
   147  }
   148  
   149  // TaskMessageAfterRetry returns an updated copy of t after retry.
   150  // It increments retry count and sets the error message and last_failed_at time.
   151  func TaskMessageAfterRetry(t base.TaskMessage, errMsg string, failedAt time.Time) *base.TaskMessage {
   152  	t.Retried = t.Retried + 1
   153  	t.ErrorMsg = errMsg
   154  	t.LastFailedAt = failedAt.Unix()
   155  	return &t
   156  }
   157  
   158  // TaskMessageWithError returns an updated copy of t with the given error message.
   159  func TaskMessageWithError(t base.TaskMessage, errMsg string, failedAt time.Time) *base.TaskMessage {
   160  	t.ErrorMsg = errMsg
   161  	t.LastFailedAt = failedAt.Unix()
   162  	return &t
   163  }
   164  
   165  // TaskMessageWithCompletedAt returns an updated copy of t after completion.
   166  func TaskMessageWithCompletedAt(t base.TaskMessage, completedAt time.Time) *base.TaskMessage {
   167  	t.CompletedAt = completedAt.Unix()
   168  	return &t
   169  }
   170  
   171  // MustMarshal marshals given task message and returns a json string.
   172  // Calling test will fail if marshaling errors out.
   173  func MustMarshal(tb testing.TB, msg *base.TaskMessage) string {
   174  	tb.Helper()
   175  	data, err := base.EncodeMessage(msg)
   176  	if err != nil {
   177  		tb.Fatal(err)
   178  	}
   179  	return string(data)
   180  }
   181  
   182  // MustUnmarshal unmarshals given string into task message struct.
   183  // Calling test will fail if unmarshaling errors out.
   184  func MustUnmarshal(tb testing.TB, data string) *base.TaskMessage {
   185  	tb.Helper()
   186  	msg, err := base.DecodeMessage([]byte(data))
   187  	if err != nil {
   188  		tb.Fatal(err)
   189  	}
   190  	return msg
   191  }
   192  
   193  // FlushDB deletes all the keys of the currently selected DB.
   194  func FlushDB(tb testing.TB, r redis.UniversalClient) {
   195  	tb.Helper()
   196  	switch r := r.(type) {
   197  	case *redis.Client:
   198  		if err := r.FlushDB(context.Background()).Err(); err != nil {
   199  			tb.Fatal(err)
   200  		}
   201  	case *redis.ClusterClient:
   202  		err := r.ForEachMaster(context.Background(), func(ctx context.Context, c *redis.Client) error {
   203  			if err := c.FlushAll(ctx).Err(); err != nil {
   204  				return err
   205  			}
   206  			return nil
   207  		})
   208  		if err != nil {
   209  			tb.Fatal(err)
   210  		}
   211  	}
   212  }
   213  
   214  // SeedPendingQueue initializes the specified queue with the given messages.
   215  func SeedPendingQueue(tb testing.TB, r redis.UniversalClient, msgs []*base.TaskMessage, qname string) {
   216  	tb.Helper()
   217  	r.SAdd(context.Background(), base.AllQueues, qname)
   218  	seedRedisList(tb, r, base.PendingKey(qname), msgs, base.TaskStatePending)
   219  }
   220  
   221  // SeedActiveQueue initializes the active queue with the given messages.
   222  func SeedActiveQueue(tb testing.TB, r redis.UniversalClient, msgs []*base.TaskMessage, qname string) {
   223  	tb.Helper()
   224  	r.SAdd(context.Background(), base.AllQueues, qname)
   225  	seedRedisList(tb, r, base.ActiveKey(qname), msgs, base.TaskStateActive)
   226  }
   227  
   228  // SeedScheduledQueue initializes the scheduled queue with the given messages.
   229  func SeedScheduledQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) {
   230  	tb.Helper()
   231  	r.SAdd(context.Background(), base.AllQueues, qname)
   232  	seedRedisZSet(tb, r, base.ScheduledKey(qname), entries, base.TaskStateScheduled)
   233  }
   234  
   235  // SeedRetryQueue initializes the retry queue with the given messages.
   236  func SeedRetryQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) {
   237  	tb.Helper()
   238  	r.SAdd(context.Background(), base.AllQueues, qname)
   239  	seedRedisZSet(tb, r, base.RetryKey(qname), entries, base.TaskStateRetry)
   240  }
   241  
   242  // SeedArchivedQueue initializes the archived queue with the given messages.
   243  func SeedArchivedQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) {
   244  	tb.Helper()
   245  	r.SAdd(context.Background(), base.AllQueues, qname)
   246  	seedRedisZSet(tb, r, base.ArchivedKey(qname), entries, base.TaskStateArchived)
   247  }
   248  
   249  // SeedLease initializes the lease set with the given entries.
   250  func SeedLease(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) {
   251  	tb.Helper()
   252  	r.SAdd(context.Background(), base.AllQueues, qname)
   253  	seedRedisZSet(tb, r, base.LeaseKey(qname), entries, base.TaskStateActive)
   254  }
   255  
   256  // SeedCompletedQueue initializes the completed set with the given entries.
   257  func SeedCompletedQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) {
   258  	tb.Helper()
   259  	r.SAdd(context.Background(), base.AllQueues, qname)
   260  	seedRedisZSet(tb, r, base.CompletedKey(qname), entries, base.TaskStateCompleted)
   261  }
   262  
   263  // SeedGroup initializes the group with the given entries.
   264  func SeedGroup(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname, gname string) {
   265  	tb.Helper()
   266  	ctx := context.Background()
   267  	r.SAdd(ctx, base.AllQueues, qname)
   268  	r.SAdd(ctx, base.AllGroups(qname), gname)
   269  	seedRedisZSet(tb, r, base.GroupKey(qname, gname), entries, base.TaskStateAggregating)
   270  }
   271  
   272  func SeedAggregationSet(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname, gname, setID string) {
   273  	tb.Helper()
   274  	r.SAdd(context.Background(), base.AllQueues, qname)
   275  	seedRedisZSet(tb, r, base.AggregationSetKey(qname, gname, setID), entries, base.TaskStateAggregating)
   276  }
   277  
   278  // SeedAllPendingQueues initializes all of the specified queues with the given messages.
   279  //
   280  // pending maps a queue name to a list of messages.
   281  func SeedAllPendingQueues(tb testing.TB, r redis.UniversalClient, pending map[string][]*base.TaskMessage) {
   282  	tb.Helper()
   283  	for q, msgs := range pending {
   284  		SeedPendingQueue(tb, r, msgs, q)
   285  	}
   286  }
   287  
   288  // SeedAllActiveQueues initializes all of the specified active queues with the given messages.
   289  func SeedAllActiveQueues(tb testing.TB, r redis.UniversalClient, active map[string][]*base.TaskMessage) {
   290  	tb.Helper()
   291  	for q, msgs := range active {
   292  		SeedActiveQueue(tb, r, msgs, q)
   293  	}
   294  }
   295  
   296  // SeedAllScheduledQueues initializes all of the specified scheduled queues with the given entries.
   297  func SeedAllScheduledQueues(tb testing.TB, r redis.UniversalClient, scheduled map[string][]base.Z) {
   298  	tb.Helper()
   299  	for q, entries := range scheduled {
   300  		SeedScheduledQueue(tb, r, entries, q)
   301  	}
   302  }
   303  
   304  // SeedAllRetryQueues initializes all of the specified retry queues with the given entries.
   305  func SeedAllRetryQueues(tb testing.TB, r redis.UniversalClient, retry map[string][]base.Z) {
   306  	tb.Helper()
   307  	for q, entries := range retry {
   308  		SeedRetryQueue(tb, r, entries, q)
   309  	}
   310  }
   311  
   312  // SeedAllArchivedQueues initializes all of the specified archived queues with the given entries.
   313  func SeedAllArchivedQueues(tb testing.TB, r redis.UniversalClient, archived map[string][]base.Z) {
   314  	tb.Helper()
   315  	for q, entries := range archived {
   316  		SeedArchivedQueue(tb, r, entries, q)
   317  	}
   318  }
   319  
   320  // SeedAllLease initializes all of the lease sets with the given entries.
   321  func SeedAllLease(tb testing.TB, r redis.UniversalClient, lease map[string][]base.Z) {
   322  	tb.Helper()
   323  	for q, entries := range lease {
   324  		SeedLease(tb, r, entries, q)
   325  	}
   326  }
   327  
   328  // SeedAllCompletedQueues initializes all of the completed queues with the given entries.
   329  func SeedAllCompletedQueues(tb testing.TB, r redis.UniversalClient, completed map[string][]base.Z) {
   330  	tb.Helper()
   331  	for q, entries := range completed {
   332  		SeedCompletedQueue(tb, r, entries, q)
   333  	}
   334  }
   335  
   336  // SeedAllGroups initializes all groups in all queues.
   337  // The map maps queue names to group names which maps to a list of task messages and the time it was
   338  // added to the group.
   339  func SeedAllGroups(tb testing.TB, r redis.UniversalClient, groups map[string]map[string][]base.Z) {
   340  	tb.Helper()
   341  	for qname, g := range groups {
   342  		for gname, entries := range g {
   343  			SeedGroup(tb, r, entries, qname, gname)
   344  		}
   345  	}
   346  }
   347  
   348  func seedRedisList(tb testing.TB, c redis.UniversalClient, key string,
   349  	msgs []*base.TaskMessage, state base.TaskState) {
   350  	tb.Helper()
   351  	for _, msg := range msgs {
   352  		encoded := MustMarshal(tb, msg)
   353  		if err := c.LPush(context.Background(), key, msg.ID).Err(); err != nil {
   354  			tb.Fatal(err)
   355  		}
   356  		taskKey := base.TaskKey(msg.Queue, msg.ID)
   357  		data := map[string]any{
   358  			"msg":        encoded,
   359  			"state":      state.String(),
   360  			"unique_key": msg.UniqueKey,
   361  			"group":      msg.GroupKey,
   362  		}
   363  		if err := c.HSet(context.Background(), taskKey, data).Err(); err != nil {
   364  			tb.Fatal(err)
   365  		}
   366  		if len(msg.UniqueKey) > 0 {
   367  			err := c.SetNX(context.Background(), msg.UniqueKey, msg.ID, 1*time.Minute).Err()
   368  			if err != nil {
   369  				tb.Fatalf("Failed to set unique lock in redis: %v", err)
   370  			}
   371  		}
   372  	}
   373  }
   374  
   375  func seedRedisZSet(tb testing.TB, c redis.UniversalClient, key string,
   376  	items []base.Z, state base.TaskState) {
   377  	tb.Helper()
   378  	for _, item := range items {
   379  		msg := item.Message
   380  		encoded := MustMarshal(tb, msg)
   381  		z := redis.Z{Member: msg.ID, Score: float64(item.Score)}
   382  		if err := c.ZAdd(context.Background(), key, z).Err(); err != nil {
   383  			tb.Fatal(err)
   384  		}
   385  		taskKey := base.TaskKey(msg.Queue, msg.ID)
   386  		data := map[string]any{
   387  			"msg":        encoded,
   388  			"state":      state.String(),
   389  			"unique_key": msg.UniqueKey,
   390  			"group":      msg.GroupKey,
   391  		}
   392  		if err := c.HSet(context.Background(), taskKey, data).Err(); err != nil {
   393  			tb.Fatal(err)
   394  		}
   395  		if len(msg.UniqueKey) > 0 {
   396  			err := c.SetNX(context.Background(), msg.UniqueKey, msg.ID, 1*time.Minute).Err()
   397  			if err != nil {
   398  				tb.Fatalf("Failed to set unique lock in redis: %v", err)
   399  			}
   400  		}
   401  	}
   402  }
   403  
   404  // GetPendingMessages returns all pending messages in the given queue.
   405  // It also asserts the state field of the task.
   406  func GetPendingMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage {
   407  	tb.Helper()
   408  	return getMessagesFromList(tb, r, qname, base.PendingKey, base.TaskStatePending)
   409  }
   410  
   411  // GetActiveMessages returns all active messages in the given queue.
   412  // It also asserts the state field of the task.
   413  func GetActiveMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage {
   414  	tb.Helper()
   415  	return getMessagesFromList(tb, r, qname, base.ActiveKey, base.TaskStateActive)
   416  }
   417  
   418  // GetScheduledMessages returns all scheduled task messages in the given queue.
   419  // It also asserts the state field of the task.
   420  func GetScheduledMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage {
   421  	tb.Helper()
   422  	return getMessagesFromZSet(tb, r, qname, base.ScheduledKey, base.TaskStateScheduled)
   423  }
   424  
   425  // GetRetryMessages returns all retry messages in the given queue.
   426  // It also asserts the state field of the task.
   427  func GetRetryMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage {
   428  	tb.Helper()
   429  	return getMessagesFromZSet(tb, r, qname, base.RetryKey, base.TaskStateRetry)
   430  }
   431  
   432  // GetArchivedMessages returns all archived messages in the given queue.
   433  // It also asserts the state field of the task.
   434  func GetArchivedMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage {
   435  	tb.Helper()
   436  	return getMessagesFromZSet(tb, r, qname, base.ArchivedKey, base.TaskStateArchived)
   437  }
   438  
   439  // GetCompletedMessages returns all completed task messages in the given queue.
   440  // It also asserts the state field of the task.
   441  func GetCompletedMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage {
   442  	tb.Helper()
   443  	return getMessagesFromZSet(tb, r, qname, base.CompletedKey, base.TaskStateCompleted)
   444  }
   445  
   446  // GetScheduledEntries returns all scheduled messages and its score in the given queue.
   447  // It also asserts the state field of the task.
   448  func GetScheduledEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z {
   449  	tb.Helper()
   450  	return getMessagesFromZSetWithScores(tb, r, qname, base.ScheduledKey, base.TaskStateScheduled)
   451  }
   452  
   453  // GetRetryEntries returns all retry messages and its score in the given queue.
   454  // It also asserts the state field of the task.
   455  func GetRetryEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z {
   456  	tb.Helper()
   457  	return getMessagesFromZSetWithScores(tb, r, qname, base.RetryKey, base.TaskStateRetry)
   458  }
   459  
   460  // GetArchivedEntries returns all archived messages and its score in the given queue.
   461  // It also asserts the state field of the task.
   462  func GetArchivedEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z {
   463  	tb.Helper()
   464  	return getMessagesFromZSetWithScores(tb, r, qname, base.ArchivedKey, base.TaskStateArchived)
   465  }
   466  
   467  // GetLeaseEntries returns all task IDs and its score in the lease set for the given queue.
   468  // It also asserts the state field of the task.
   469  func GetLeaseEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z {
   470  	tb.Helper()
   471  	return getMessagesFromZSetWithScores(tb, r, qname, base.LeaseKey, base.TaskStateActive)
   472  }
   473  
   474  // GetCompletedEntries returns all completed messages and its score in the given queue.
   475  // It also asserts the state field of the task.
   476  func GetCompletedEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z {
   477  	tb.Helper()
   478  	return getMessagesFromZSetWithScores(tb, r, qname, base.CompletedKey, base.TaskStateCompleted)
   479  }
   480  
   481  // GetGroupEntries returns all scheduled messages and its score in the given queue.
   482  // It also asserts the state field of the task.
   483  func GetGroupEntries(tb testing.TB, r redis.UniversalClient, qname, groupKey string) []base.Z {
   484  	tb.Helper()
   485  	return getMessagesFromZSetWithScores(tb, r, qname,
   486  		func(qname string) string { return base.GroupKey(qname, groupKey) }, base.TaskStateAggregating)
   487  }
   488  
   489  // Retrieves all messages stored under `keyFn(qname)` key in redis list.
   490  func getMessagesFromList(tb testing.TB, r redis.UniversalClient, qname string,
   491  	keyFn func(qname string) string, state base.TaskState) []*base.TaskMessage {
   492  	tb.Helper()
   493  	ids := r.LRange(context.Background(), keyFn(qname), 0, -1).Val()
   494  	var msgs []*base.TaskMessage
   495  	for _, id := range ids {
   496  		taskKey := base.TaskKey(qname, id)
   497  		data := r.HGet(context.Background(), taskKey, "msg").Val()
   498  		msgs = append(msgs, MustUnmarshal(tb, data))
   499  		if gotState := r.HGet(context.Background(), taskKey, "state").Val(); gotState != state.String() {
   500  			tb.Errorf("task (id=%q) is in %q state, want %v", id, gotState, state)
   501  		}
   502  	}
   503  	return msgs
   504  }
   505  
   506  // Retrieves all messages stored under `keyFn(qname)` key in redis zset (sorted-set).
   507  func getMessagesFromZSet(tb testing.TB, r redis.UniversalClient, qname string,
   508  	keyFn func(qname string) string, state base.TaskState) []*base.TaskMessage {
   509  	tb.Helper()
   510  	ids := r.ZRange(context.Background(), keyFn(qname), 0, -1).Val()
   511  	var msgs []*base.TaskMessage
   512  	for _, id := range ids {
   513  		taskKey := base.TaskKey(qname, id)
   514  		msg := r.HGet(context.Background(), taskKey, "msg").Val()
   515  		msgs = append(msgs, MustUnmarshal(tb, msg))
   516  		if gotState := r.HGet(context.Background(), taskKey, "state").Val(); gotState != state.String() {
   517  			tb.Errorf("task (id=%q) is in %q state, want %v", id, gotState, state)
   518  		}
   519  	}
   520  	return msgs
   521  }
   522  
   523  // Retrieves all messages along with their scores stored under `keyFn(qname)` key in redis zset (sorted-set).
   524  func getMessagesFromZSetWithScores(tb testing.TB, r redis.UniversalClient,
   525  	qname string, keyFn func(qname string) string, state base.TaskState) []base.Z {
   526  	tb.Helper()
   527  	zs := r.ZRangeWithScores(context.Background(), keyFn(qname), 0, -1).Val()
   528  	var res []base.Z
   529  	for _, z := range zs {
   530  		taskID := z.Member.(string)
   531  		taskKey := base.TaskKey(qname, taskID)
   532  		msg := r.HGet(context.Background(), taskKey, "msg").Val()
   533  		res = append(res, base.Z{Message: MustUnmarshal(tb, msg), Score: int64(z.Score)})
   534  		if gotState := r.HGet(context.Background(), taskKey, "state").Val(); gotState != state.String() {
   535  			tb.Errorf("task (id=%q) is in %q state, want %v", taskID, gotState, state)
   536  		}
   537  	}
   538  	return res
   539  }
   540  
   541  // TaskSeedData holds the data required to seed tasks under the task key in test.
   542  type TaskSeedData struct {
   543  	Msg          *base.TaskMessage
   544  	State        base.TaskState
   545  	PendingSince time.Time
   546  }
   547  
   548  func SeedTasks(tb testing.TB, r redis.UniversalClient, taskData []*TaskSeedData) {
   549  	for _, data := range taskData {
   550  		msg := data.Msg
   551  		ctx := context.Background()
   552  		key := base.TaskKey(msg.Queue, msg.ID)
   553  		v := map[string]any{
   554  			"msg":        MustMarshal(tb, msg),
   555  			"state":      data.State.String(),
   556  			"unique_key": msg.UniqueKey,
   557  			"group":      msg.GroupKey,
   558  		}
   559  		if !data.PendingSince.IsZero() {
   560  			v["pending_since"] = data.PendingSince.Unix()
   561  		}
   562  		if err := r.HSet(ctx, key, v).Err(); err != nil {
   563  			tb.Fatalf("Failed to write task data in redis: %v", err)
   564  		}
   565  		if len(msg.UniqueKey) > 0 {
   566  			err := r.SetNX(ctx, msg.UniqueKey, msg.ID, 1*time.Minute).Err()
   567  			if err != nil {
   568  				tb.Fatalf("Failed to set unique lock in redis: %v", err)
   569  			}
   570  		}
   571  	}
   572  }
   573  
   574  func SeedRedisZSets(tb testing.TB, r redis.UniversalClient, zsets map[string][]redis.Z) {
   575  	for key, zs := range zsets {
   576  		// FIXME: How come we can't simply do ZAdd(ctx, key, zs...) here?
   577  		for _, z := range zs {
   578  			if err := r.ZAdd(context.Background(), key, z).Err(); err != nil {
   579  				tb.Fatalf("Failed to seed zset (key=%q): %v", key, err)
   580  			}
   581  		}
   582  	}
   583  }
   584  
   585  func SeedRedisSets(tb testing.TB, r redis.UniversalClient, sets map[string][]string) {
   586  	for key, set := range sets {
   587  		SeedRedisSet(tb, r, key, set)
   588  	}
   589  }
   590  
   591  func SeedRedisSet(tb testing.TB, r redis.UniversalClient, key string, members []string) {
   592  	for _, mem := range members {
   593  		if err := r.SAdd(context.Background(), key, mem).Err(); err != nil {
   594  			tb.Fatalf("Failed to seed set (key=%q): %v", key, err)
   595  		}
   596  	}
   597  }
   598  
   599  func SeedRedisLists(tb testing.TB, r redis.UniversalClient, lists map[string][]string) {
   600  	for key, vals := range lists {
   601  		for _, v := range vals {
   602  			if err := r.LPush(context.Background(), key, v).Err(); err != nil {
   603  				tb.Fatalf("Failed to seed list (key=%q): %v", key, err)
   604  			}
   605  		}
   606  	}
   607  }
   608  
   609  func AssertRedisLists(t *testing.T, r redis.UniversalClient, wantLists map[string][]string) {
   610  	for key, want := range wantLists {
   611  		got, err := r.LRange(context.Background(), key, 0, -1).Result()
   612  		if err != nil {
   613  			t.Fatalf("Failed to read list (key=%q): %v", key, err)
   614  		}
   615  		if diff := cmp.Diff(want, got, SortStringSliceOpt); diff != "" {
   616  			t.Errorf("mismatch found in list (key=%q): (-want,+got)\n%s", key, diff)
   617  		}
   618  	}
   619  }
   620  
   621  func AssertRedisSets(t *testing.T, r redis.UniversalClient, wantSets map[string][]string) {
   622  	for key, want := range wantSets {
   623  		got, err := r.SMembers(context.Background(), key).Result()
   624  		if err != nil {
   625  			t.Fatalf("Failed to read set (key=%q): %v", key, err)
   626  		}
   627  		if diff := cmp.Diff(want, got, SortStringSliceOpt); diff != "" {
   628  			t.Errorf("mismatch found in set (key=%q): (-want,+got)\n%s", key, diff)
   629  		}
   630  	}
   631  }
   632  
   633  func AssertRedisZSets(t *testing.T, r redis.UniversalClient, wantZSets map[string][]redis.Z) {
   634  	for key, want := range wantZSets {
   635  		got, err := r.ZRangeWithScores(context.Background(), key, 0, -1).Result()
   636  		if err != nil {
   637  			t.Fatalf("Failed to read zset (key=%q): %v", key, err)
   638  		}
   639  		if diff := cmp.Diff(want, got, SortRedisZSetEntryOpt); diff != "" {
   640  			t.Errorf("mismatch found in zset (key=%q): (-want,+got)\n%s", key, diff)
   641  		}
   642  	}
   643  }