github.com/wfusion/gofusion@v1.1.14/common/infra/asynq/x/rate/semaphore.go (about)

     1  // Package rate contains rate limiting strategies for asynq.Handler(s).
     2  package rate
     3  
     4  import (
     5  	"context"
     6  	"fmt"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/redis/go-redis/v9"
    11  
    12  	"github.com/wfusion/gofusion/common/infra/asynq"
    13  
    14  	asynqcontext "github.com/wfusion/gofusion/common/infra/asynq/pkg/context"
    15  )
    16  
    17  // NewSemaphore creates a counting Semaphore for the given scope with the given number of tokens.
    18  func NewSemaphore(rco asynq.RedisConnOpt, scope string, maxTokens int) *Semaphore {
    19  	rc, ok := rco.MakeRedisClient().(redis.UniversalClient)
    20  	if !ok {
    21  		panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco))
    22  	}
    23  
    24  	if maxTokens < 1 {
    25  		panic("rate.NewSemaphore: maxTokens cannot be less than 1")
    26  	}
    27  
    28  	if len(strings.TrimSpace(scope)) == 0 {
    29  		panic("rate.NewSemaphore: scope should not be empty")
    30  	}
    31  
    32  	return &Semaphore{
    33  		rc:        rc,
    34  		scope:     scope,
    35  		maxTokens: maxTokens,
    36  	}
    37  }
    38  
    39  // Semaphore is a distributed counting semaphore which can be used to set maxTokens across multiple asynq servers.
    40  type Semaphore struct {
    41  	rc        redis.UniversalClient
    42  	maxTokens int
    43  	scope     string
    44  }
    45  
    46  // KEYS[1] -> asynq:sema:<scope>
    47  // ARGV[1] -> max concurrency
    48  // ARGV[2] -> current time in unix time
    49  // ARGV[3] -> deadline in unix time
    50  // ARGV[4] -> task ID
    51  var acquireCmd = redis.NewScript(`
    52  redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1)
    53  local count = redis.call("ZCARD", KEYS[1])
    54  
    55  if (count < tonumber(ARGV[1])) then
    56       redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4])
    57       return 'true'
    58  else
    59       return 'false'
    60  end
    61  `)
    62  
    63  // Acquire attempts to acquire a token from the semaphore.
    64  // - Returns (true, nil), iff semaphore key exists and current value is less than maxTokens
    65  // - Returns (false, nil) when token cannot be acquired
    66  // - Returns (false, error) otherwise
    67  //
    68  // The context.Context passed to Acquire must have a deadline set,
    69  // this ensures that token is released if the job goroutine crashes and does not call Release.
    70  func (s *Semaphore) Acquire(ctx context.Context) (bool, error) {
    71  	d, ok := ctx.Deadline()
    72  	if !ok {
    73  		return false, fmt.Errorf("provided context must have a deadline")
    74  	}
    75  
    76  	taskID, ok := asynqcontext.GetTaskID(ctx)
    77  	if !ok {
    78  		return false, fmt.Errorf("provided context is missing task ID value")
    79  	}
    80  
    81  	return acquireCmd.Run(ctx, s.rc,
    82  		[]string{semaphoreKey(s.scope)},
    83  		s.maxTokens,
    84  		time.Now().Unix(),
    85  		d.Unix(),
    86  		taskID,
    87  	).Bool()
    88  }
    89  
    90  // Release will release the token on the counting semaphore.
    91  func (s *Semaphore) Release(ctx context.Context) error {
    92  	taskID, ok := asynqcontext.GetTaskID(ctx)
    93  	if !ok {
    94  		return fmt.Errorf("provided context is missing task ID value")
    95  	}
    96  
    97  	n, err := s.rc.ZRem(ctx, semaphoreKey(s.scope), taskID).Result()
    98  	if err != nil {
    99  		return fmt.Errorf("redis command failed: %v", err)
   100  	}
   101  
   102  	if n == 0 {
   103  		return fmt.Errorf("no token found for task %q", taskID)
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  // Close closes the connection to redis.
   110  func (s *Semaphore) Close() error {
   111  	return s.rc.Close()
   112  }
   113  
   114  func semaphoreKey(scope string) string {
   115  	return fmt.Sprintf("asynq:sema:%s", scope)
   116  }