github.com/square/finch@v0.0.0-20240412205204-6530c03e2b96/limit/rate.go (about) 1 // Copyright 2023 Block, Inc. 2 3 package limit 4 5 import ( 6 "context" 7 "fmt" 8 9 gorate "golang.org/x/time/rate" 10 11 "github.com/square/finch" 12 ) 13 14 type Rate interface { 15 Adjust(byte) 16 Current() (byte, string) 17 Allow() <-chan bool 18 Stop() 19 } 20 21 type rate struct { 22 c chan bool 23 n uint 24 rl *gorate.Limiter 25 stopChan chan struct{} 26 } 27 28 var _ Rate = &rate{} 29 30 func NewRate(perSecond uint) Rate { 31 if perSecond == 0 { 32 return nil 33 } 34 finch.Debug("new rate: %d/s", perSecond) 35 lm := &rate{ 36 rl: gorate.NewLimiter(gorate.Limit(perSecond), 1), 37 c: make(chan bool, 1), 38 stopChan: make(chan struct{}), 39 } 40 go lm.run() 41 return lm 42 } 43 44 func (lm *rate) Adjust(p byte) { 45 } 46 47 func (lm *rate) Current() (p byte, s string) { 48 return 0, "" 49 } 50 51 func (lm *rate) Stop() { 52 } 53 54 func (lm *rate) Allow() <-chan bool { 55 return lm.c 56 } 57 58 func (lm *rate) run() { 59 var err error 60 for { 61 err = lm.rl.Wait(context.Background()) 62 if err != nil { 63 // burst limit exceeded? 64 continue 65 } 66 select { 67 case lm.c <- true: 68 case <-lm.stopChan: 69 return 70 default: 71 // dropped 72 } 73 } 74 } 75 76 // -------------------------------------------------------------------------- 77 78 type and struct { 79 c chan bool 80 n uint 81 a Rate 82 b Rate 83 } 84 85 var _ Rate = &and{} 86 87 // And makes a Rate limiter that allows execution when both a and b allow it. 88 // This is used to combine QPS and TPS rate limits to keep clients at or below 89 // both rates. 90 func And(a, b Rate) Rate { 91 if a == nil && b == nil { 92 return nil 93 } 94 if a == nil && b != nil { 95 return b 96 } 97 if a != nil && b == nil { 98 return a 99 } 100 lm := &and{ 101 a: a, 102 b: b, 103 c: make(chan bool, 1), 104 } 105 go lm.run() 106 return lm 107 } 108 109 func (lm *and) Allow() <-chan bool { 110 return lm.c 111 } 112 113 func (lm *and) N(_ uint) { 114 } 115 116 func (lm *and) Adjust(p byte) { 117 lm.a.Adjust(p) 118 lm.b.Adjust(p) 119 } 120 121 func (lm *and) Current() (p byte, s string) { 122 p1, s1 := lm.a.Current() 123 p2, s2 := lm.a.Current() 124 if p1 != p2 { 125 panic(fmt.Sprintf("lm.A %d != lm.B %d", p1, p2)) 126 } 127 return p1, s1 + " and " + s2 128 } 129 130 func (lm *and) Stop() { 131 lm.a.Stop() 132 lm.b.Stop() 133 } 134 135 func (lm *and) run() { 136 a := false 137 b := false 138 for { 139 select { 140 case <-lm.a.Allow(): 141 a = true 142 case <-lm.b.Allow(): 143 b = true 144 } 145 if a && b { 146 select { 147 case lm.c <- true: 148 default: 149 // dropped 150 } 151 a = false 152 b = false 153 } 154 } 155 }