github.com/wfusion/gofusion@v1.1.14/common/infra/asynq/x/rate/semaphore.go (about) 1 // Package rate contains rate limiting strategies for asynq.Handler(s). 2 package rate 3 4 import ( 5 "context" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/redis/go-redis/v9" 11 12 "github.com/wfusion/gofusion/common/infra/asynq" 13 14 asynqcontext "github.com/wfusion/gofusion/common/infra/asynq/pkg/context" 15 ) 16 17 // NewSemaphore creates a counting Semaphore for the given scope with the given number of tokens. 18 func NewSemaphore(rco asynq.RedisConnOpt, scope string, maxTokens int) *Semaphore { 19 rc, ok := rco.MakeRedisClient().(redis.UniversalClient) 20 if !ok { 21 panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco)) 22 } 23 24 if maxTokens < 1 { 25 panic("rate.NewSemaphore: maxTokens cannot be less than 1") 26 } 27 28 if len(strings.TrimSpace(scope)) == 0 { 29 panic("rate.NewSemaphore: scope should not be empty") 30 } 31 32 return &Semaphore{ 33 rc: rc, 34 scope: scope, 35 maxTokens: maxTokens, 36 } 37 } 38 39 // Semaphore is a distributed counting semaphore which can be used to set maxTokens across multiple asynq servers. 40 type Semaphore struct { 41 rc redis.UniversalClient 42 maxTokens int 43 scope string 44 } 45 46 // KEYS[1] -> asynq:sema:<scope> 47 // ARGV[1] -> max concurrency 48 // ARGV[2] -> current time in unix time 49 // ARGV[3] -> deadline in unix time 50 // ARGV[4] -> task ID 51 var acquireCmd = redis.NewScript(` 52 redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1) 53 local count = redis.call("ZCARD", KEYS[1]) 54 55 if (count < tonumber(ARGV[1])) then 56 redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4]) 57 return 'true' 58 else 59 return 'false' 60 end 61 `) 62 63 // Acquire attempts to acquire a token from the semaphore. 64 // - Returns (true, nil), iff semaphore key exists and current value is less than maxTokens 65 // - Returns (false, nil) when token cannot be acquired 66 // - Returns (false, error) otherwise 67 // 68 // The context.Context passed to Acquire must have a deadline set, 69 // this ensures that token is released if the job goroutine crashes and does not call Release. 70 func (s *Semaphore) Acquire(ctx context.Context) (bool, error) { 71 d, ok := ctx.Deadline() 72 if !ok { 73 return false, fmt.Errorf("provided context must have a deadline") 74 } 75 76 taskID, ok := asynqcontext.GetTaskID(ctx) 77 if !ok { 78 return false, fmt.Errorf("provided context is missing task ID value") 79 } 80 81 return acquireCmd.Run(ctx, s.rc, 82 []string{semaphoreKey(s.scope)}, 83 s.maxTokens, 84 time.Now().Unix(), 85 d.Unix(), 86 taskID, 87 ).Bool() 88 } 89 90 // Release will release the token on the counting semaphore. 91 func (s *Semaphore) Release(ctx context.Context) error { 92 taskID, ok := asynqcontext.GetTaskID(ctx) 93 if !ok { 94 return fmt.Errorf("provided context is missing task ID value") 95 } 96 97 n, err := s.rc.ZRem(ctx, semaphoreKey(s.scope), taskID).Result() 98 if err != nil { 99 return fmt.Errorf("redis command failed: %v", err) 100 } 101 102 if n == 0 { 103 return fmt.Errorf("no token found for task %q", taskID) 104 } 105 106 return nil 107 } 108 109 // Close closes the connection to redis. 110 func (s *Semaphore) Close() error { 111 return s.rc.Close() 112 } 113 114 func semaphoreKey(scope string) string { 115 return fmt.Sprintf("asynq:sema:%s", scope) 116 }