github.com/mholt/caddy-l4@v0.0.0-20241104153248-ec8fae209322/modules/l4throttle/throttle.go (about) 1 // Copyright 2020 Matthew Holt 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package l4throttle 16 17 import ( 18 "context" 19 "fmt" 20 "net" 21 "strconv" 22 "time" 23 24 "github.com/caddyserver/caddy/v2" 25 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" 26 "go.uber.org/zap" 27 "golang.org/x/time/rate" 28 29 "github.com/mholt/caddy-l4/layer4" 30 ) 31 32 func init() { 33 caddy.RegisterModule(&Handler{}) 34 } 35 36 // Handler throttles connections using leaky bucket rate limiting. 37 type Handler struct { 38 // The number of bytes to read per second, per connection. 39 ReadBytesPerSecond float64 `json:"read_bytes_per_second,omitempty"` 40 41 // The maximum number of bytes to read at once (rate permitting) per connection. 42 // If a rate is specified, burst must be greater than zero; default is same as 43 // the rate (truncated to integer). 44 ReadBurstSize int `json:"read_burst_size,omitempty"` 45 46 // The number of bytes to read per second, across all connections ("per handler"). 47 TotalReadBytesPerSecond float64 `json:"total_read_bytes_per_second,omitempty"` 48 49 // The maximum number of bytes to read at once (rate permitting) across all 50 // connections ("per handler"). If a rate is specified, burst must be greater 51 // than zero; default is same as the rate (truncated to integer). 52 TotalReadBurstSize int `json:"total_read_burst_size,omitempty"` 53 54 // Delay before initial read on each connection. 55 Latency caddy.Duration `json:"latency,omitempty"` 56 57 logger *zap.Logger 58 totalLimiter *rate.Limiter 59 } 60 61 // CaddyModule returns the Caddy module information. 62 func (*Handler) CaddyModule() caddy.ModuleInfo { 63 return caddy.ModuleInfo{ 64 ID: "layer4.handlers.throttle", 65 New: func() caddy.Module { return new(Handler) }, 66 } 67 } 68 69 // Provision sets up the handler. 70 func (h *Handler) Provision(ctx caddy.Context) error { 71 h.logger = ctx.Logger(h) 72 if h.ReadBytesPerSecond < 0 { 73 return fmt.Errorf("bytes per second must be at least 0: %f", h.ReadBytesPerSecond) 74 } 75 if h.ReadBytesPerSecond > 0 && h.ReadBurstSize == 0 { 76 h.ReadBurstSize = int(h.ReadBytesPerSecond) + 1 77 } 78 if h.TotalReadBytesPerSecond < 0 { 79 return fmt.Errorf("total bytes per second must be at least 0: %f", h.TotalReadBytesPerSecond) 80 } 81 if h.TotalReadBytesPerSecond > 0 && h.TotalReadBurstSize == 0 { 82 h.TotalReadBurstSize = int(h.TotalReadBytesPerSecond) + 1 83 } 84 if h.ReadBurstSize < 0 { 85 return fmt.Errorf("burst size must be greater than 0: %d", h.ReadBurstSize) 86 } 87 if h.TotalReadBurstSize < 0 { 88 return fmt.Errorf("total burst size must be greater than 0: %d", h.TotalReadBurstSize) 89 } 90 if h.TotalReadBytesPerSecond > 0 || h.TotalReadBurstSize > 0 { 91 h.totalLimiter = rate.NewLimiter(rate.Limit(h.TotalReadBytesPerSecond), h.TotalReadBurstSize) 92 } 93 return nil 94 } 95 96 // Handle handles the connection. 97 func (h *Handler) Handle(cx *layer4.Connection, next layer4.Handler) error { 98 var localLimiter *rate.Limiter 99 if h.ReadBytesPerSecond > 0 || h.ReadBurstSize > 0 { 100 localLimiter = rate.NewLimiter(rate.Limit(h.ReadBytesPerSecond), h.ReadBurstSize) 101 } 102 cx.Conn = throttledConn{ 103 Conn: cx.Conn, 104 ctx: cx.Context, 105 logger: h.logger.Named("conn"), 106 totalLimiter: h.totalLimiter, 107 localLimiter: localLimiter, 108 } 109 if h.Latency > 0 { 110 timer := time.NewTimer(time.Duration(h.Latency)) 111 select { 112 case <-timer.C: 113 case <-cx.Context.Done(): 114 return context.Canceled 115 } 116 } 117 return next.Handle(cx) 118 } 119 120 // UnmarshalCaddyfile sets up the Handler from Caddyfile tokens. Syntax: 121 // 122 // throttle { 123 // latency <duration> 124 // read_burst_size <int> 125 // read_bytes_per_second <float> 126 // total_read_burst_size <int> 127 // total_read_bytes_per_second <float> 128 // } 129 func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { 130 _, wrapper := d.Next(), d.Val() // consume wrapper name 131 132 // No same-line options are supported 133 if d.CountRemainingArgs() > 0 { 134 return d.ArgErr() 135 } 136 137 var hasLatency, hasReadBurstSize, hasReadBytesPerSecond, hasTotalReadBurstSize, hasTotalReadBytesPerSecond bool 138 for nesting := d.Nesting(); d.NextBlock(nesting); { 139 optionName := d.Val() 140 switch optionName { 141 case "latency": 142 if hasLatency { 143 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 144 } 145 if d.CountRemainingArgs() != 1 { 146 return d.ArgErr() 147 } 148 d.NextArg() // consume option value 149 dur, err := caddy.ParseDuration(d.Val()) 150 if err != nil { 151 return d.Errf("parsing %s option '%s' duration: %v", wrapper, optionName, err) 152 } 153 h.Latency, hasLatency = caddy.Duration(dur), true 154 case "read_burst_size": 155 if hasReadBurstSize { 156 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 157 } 158 if d.CountRemainingArgs() != 1 { 159 return d.ArgErr() 160 } 161 d.NextArg() // consume option value 162 val, err := strconv.ParseInt(d.Val(), 10, 32) 163 if err != nil { 164 return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err) 165 } 166 h.ReadBurstSize, hasReadBurstSize = int(val), true 167 case "read_bytes_per_second": 168 if hasReadBytesPerSecond { 169 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 170 } 171 if d.CountRemainingArgs() != 1 { 172 return d.ArgErr() 173 } 174 d.NextArg() // consume option value 175 val, err := strconv.ParseFloat(d.Val(), 64) 176 if err != nil { 177 return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err) 178 } 179 h.ReadBytesPerSecond, hasReadBytesPerSecond = val, true 180 case "total_read_burst_size": 181 if hasTotalReadBurstSize { 182 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 183 } 184 if d.CountRemainingArgs() != 1 { 185 return d.ArgErr() 186 } 187 d.NextArg() // consume option value 188 val, err := strconv.ParseInt(d.Val(), 10, 32) 189 if err != nil { 190 return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err) 191 } 192 h.TotalReadBurstSize, hasTotalReadBurstSize = int(val), true 193 case "total_read_bytes_per_second": 194 if hasTotalReadBytesPerSecond { 195 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 196 } 197 if d.CountRemainingArgs() != 1 { 198 return d.ArgErr() 199 } 200 d.NextArg() // consume option value 201 val, err := strconv.ParseFloat(d.Val(), 64) 202 if err != nil { 203 return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err) 204 } 205 h.TotalReadBytesPerSecond, hasTotalReadBytesPerSecond = val, true 206 default: 207 return d.ArgErr() 208 } 209 210 // No nested blocks are supported 211 if d.NextBlock(nesting + 1) { 212 return d.Errf("malformed %s option '%s': blocks are not supported", wrapper, optionName) 213 } 214 } 215 216 return nil 217 } 218 219 type throttledConn struct { 220 net.Conn 221 ctx context.Context 222 logger *zap.Logger 223 totalLimiter, localLimiter *rate.Limiter 224 } 225 226 func (tc throttledConn) Read(p []byte) (int, error) { 227 // The rate limiters will not let us wait for more than their burst 228 // size, so the max we can read in each iteration is the minimum of 229 // len(p) and both limiters' burst sizes. 230 batchSize := len(p) 231 if tc.totalLimiter != nil { 232 if burstSize := tc.totalLimiter.Burst(); batchSize > burstSize { 233 batchSize = burstSize 234 } 235 } 236 if tc.localLimiter != nil { 237 if burstSize := tc.localLimiter.Burst(); batchSize > burstSize { 238 batchSize = burstSize 239 } 240 } 241 242 if tc.totalLimiter != nil { 243 err := tc.totalLimiter.WaitN(tc.ctx, batchSize) 244 if err != nil { 245 return 0, fmt.Errorf("waiting for total limiter: %v", err) 246 } 247 } 248 if tc.localLimiter != nil { 249 err := tc.localLimiter.WaitN(tc.ctx, batchSize) 250 if err != nil { 251 return 0, fmt.Errorf("waiting for local limiter: %v", err) 252 } 253 } 254 255 n, err := tc.Conn.Read(p[:batchSize]) 256 257 tc.logger.Debug("read", 258 zap.String("remote", tc.RemoteAddr().String()), 259 zap.Int("batch_size", batchSize), 260 zap.Int("bytes_read", n), 261 zap.Error(err)) 262 263 return n, err 264 } 265 266 // Interface guards 267 var ( 268 _ caddy.Provisioner = (*Handler)(nil) 269 _ caddyfile.Unmarshaler = (*Handler)(nil) 270 _ layer4.NextHandler = (*Handler)(nil) 271 )