github.com/square/finch@v0.0.0-20240412205204-6530c03e2b96/limit/rate.go (about)

     1  // Copyright 2023 Block, Inc.
     2  
     3  package limit
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  
     9  	gorate "golang.org/x/time/rate"
    10  
    11  	"github.com/square/finch"
    12  )
    13  
    14  type Rate interface {
    15  	Adjust(byte)
    16  	Current() (byte, string)
    17  	Allow() <-chan bool
    18  	Stop()
    19  }
    20  
    21  type rate struct {
    22  	c        chan bool
    23  	n        uint
    24  	rl       *gorate.Limiter
    25  	stopChan chan struct{}
    26  }
    27  
    28  var _ Rate = &rate{}
    29  
    30  func NewRate(perSecond uint) Rate {
    31  	if perSecond == 0 {
    32  		return nil
    33  	}
    34  	finch.Debug("new rate: %d/s", perSecond)
    35  	lm := &rate{
    36  		rl:       gorate.NewLimiter(gorate.Limit(perSecond), 1),
    37  		c:        make(chan bool, 1),
    38  		stopChan: make(chan struct{}),
    39  	}
    40  	go lm.run()
    41  	return lm
    42  }
    43  
    44  func (lm *rate) Adjust(p byte) {
    45  }
    46  
    47  func (lm *rate) Current() (p byte, s string) {
    48  	return 0, ""
    49  }
    50  
    51  func (lm *rate) Stop() {
    52  }
    53  
    54  func (lm *rate) Allow() <-chan bool {
    55  	return lm.c
    56  }
    57  
    58  func (lm *rate) run() {
    59  	var err error
    60  	for {
    61  		err = lm.rl.Wait(context.Background())
    62  		if err != nil {
    63  			// burst limit exceeded?
    64  			continue
    65  		}
    66  		select {
    67  		case lm.c <- true:
    68  		case <-lm.stopChan:
    69  			return
    70  		default:
    71  			// dropped
    72  		}
    73  	}
    74  }
    75  
    76  // --------------------------------------------------------------------------
    77  
    78  type and struct {
    79  	c chan bool
    80  	n uint
    81  	a Rate
    82  	b Rate
    83  }
    84  
    85  var _ Rate = &and{}
    86  
    87  // And makes a Rate limiter that allows execution when both a and b allow it.
    88  // This is used to combine QPS and TPS rate limits to keep clients at or below
    89  // both rates.
    90  func And(a, b Rate) Rate {
    91  	if a == nil && b == nil {
    92  		return nil
    93  	}
    94  	if a == nil && b != nil {
    95  		return b
    96  	}
    97  	if a != nil && b == nil {
    98  		return a
    99  	}
   100  	lm := &and{
   101  		a: a,
   102  		b: b,
   103  		c: make(chan bool, 1),
   104  	}
   105  	go lm.run()
   106  	return lm
   107  }
   108  
   109  func (lm *and) Allow() <-chan bool {
   110  	return lm.c
   111  }
   112  
   113  func (lm *and) N(_ uint) {
   114  }
   115  
   116  func (lm *and) Adjust(p byte) {
   117  	lm.a.Adjust(p)
   118  	lm.b.Adjust(p)
   119  }
   120  
   121  func (lm *and) Current() (p byte, s string) {
   122  	p1, s1 := lm.a.Current()
   123  	p2, s2 := lm.a.Current()
   124  	if p1 != p2 {
   125  		panic(fmt.Sprintf("lm.A %d != lm.B %d", p1, p2))
   126  	}
   127  	return p1, s1 + " and " + s2
   128  }
   129  
   130  func (lm *and) Stop() {
   131  	lm.a.Stop()
   132  	lm.b.Stop()
   133  }
   134  
   135  func (lm *and) run() {
   136  	a := false
   137  	b := false
   138  	for {
   139  		select {
   140  		case <-lm.a.Allow():
   141  			a = true
   142  		case <-lm.b.Allow():
   143  			b = true
   144  		}
   145  		if a && b {
   146  			select {
   147  			case lm.c <- true:
   148  			default:
   149  				// dropped
   150  			}
   151  			a = false
   152  			b = false
   153  		}
   154  	}
   155  }