github.com/mholt/caddy-l4@v0.0.0-20241104153248-ec8fae209322/modules/l4throttle/throttle.go (about)

     1  // Copyright 2020 Matthew Holt
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package l4throttle
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"net"
    21  	"strconv"
    22  	"time"
    23  
    24  	"github.com/caddyserver/caddy/v2"
    25  	"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
    26  	"go.uber.org/zap"
    27  	"golang.org/x/time/rate"
    28  
    29  	"github.com/mholt/caddy-l4/layer4"
    30  )
    31  
    32  func init() {
    33  	caddy.RegisterModule(&Handler{})
    34  }
    35  
    36  // Handler throttles connections using leaky bucket rate limiting.
    37  type Handler struct {
    38  	// The number of bytes to read per second, per connection.
    39  	ReadBytesPerSecond float64 `json:"read_bytes_per_second,omitempty"`
    40  
    41  	// The maximum number of bytes to read at once (rate permitting) per connection.
    42  	// If a rate is specified, burst must be greater than zero; default is same as
    43  	// the rate (truncated to integer).
    44  	ReadBurstSize int `json:"read_burst_size,omitempty"`
    45  
    46  	// The number of bytes to read per second, across all connections ("per handler").
    47  	TotalReadBytesPerSecond float64 `json:"total_read_bytes_per_second,omitempty"`
    48  
    49  	// The maximum number of bytes to read at once (rate permitting) across all
    50  	// connections ("per handler"). If a rate is specified, burst must be greater
    51  	// than zero; default is same as the rate (truncated to integer).
    52  	TotalReadBurstSize int `json:"total_read_burst_size,omitempty"`
    53  
    54  	// Delay before initial read on each connection.
    55  	Latency caddy.Duration `json:"latency,omitempty"`
    56  
    57  	logger       *zap.Logger
    58  	totalLimiter *rate.Limiter
    59  }
    60  
    61  // CaddyModule returns the Caddy module information.
    62  func (*Handler) CaddyModule() caddy.ModuleInfo {
    63  	return caddy.ModuleInfo{
    64  		ID:  "layer4.handlers.throttle",
    65  		New: func() caddy.Module { return new(Handler) },
    66  	}
    67  }
    68  
    69  // Provision sets up the handler.
    70  func (h *Handler) Provision(ctx caddy.Context) error {
    71  	h.logger = ctx.Logger(h)
    72  	if h.ReadBytesPerSecond < 0 {
    73  		return fmt.Errorf("bytes per second must be at least 0: %f", h.ReadBytesPerSecond)
    74  	}
    75  	if h.ReadBytesPerSecond > 0 && h.ReadBurstSize == 0 {
    76  		h.ReadBurstSize = int(h.ReadBytesPerSecond) + 1
    77  	}
    78  	if h.TotalReadBytesPerSecond < 0 {
    79  		return fmt.Errorf("total bytes per second must be at least 0: %f", h.TotalReadBytesPerSecond)
    80  	}
    81  	if h.TotalReadBytesPerSecond > 0 && h.TotalReadBurstSize == 0 {
    82  		h.TotalReadBurstSize = int(h.TotalReadBytesPerSecond) + 1
    83  	}
    84  	if h.ReadBurstSize < 0 {
    85  		return fmt.Errorf("burst size must be greater than 0: %d", h.ReadBurstSize)
    86  	}
    87  	if h.TotalReadBurstSize < 0 {
    88  		return fmt.Errorf("total burst size must be greater than 0: %d", h.TotalReadBurstSize)
    89  	}
    90  	if h.TotalReadBytesPerSecond > 0 || h.TotalReadBurstSize > 0 {
    91  		h.totalLimiter = rate.NewLimiter(rate.Limit(h.TotalReadBytesPerSecond), h.TotalReadBurstSize)
    92  	}
    93  	return nil
    94  }
    95  
    96  // Handle handles the connection.
    97  func (h *Handler) Handle(cx *layer4.Connection, next layer4.Handler) error {
    98  	var localLimiter *rate.Limiter
    99  	if h.ReadBytesPerSecond > 0 || h.ReadBurstSize > 0 {
   100  		localLimiter = rate.NewLimiter(rate.Limit(h.ReadBytesPerSecond), h.ReadBurstSize)
   101  	}
   102  	cx.Conn = throttledConn{
   103  		Conn:         cx.Conn,
   104  		ctx:          cx.Context,
   105  		logger:       h.logger.Named("conn"),
   106  		totalLimiter: h.totalLimiter,
   107  		localLimiter: localLimiter,
   108  	}
   109  	if h.Latency > 0 {
   110  		timer := time.NewTimer(time.Duration(h.Latency))
   111  		select {
   112  		case <-timer.C:
   113  		case <-cx.Context.Done():
   114  			return context.Canceled
   115  		}
   116  	}
   117  	return next.Handle(cx)
   118  }
   119  
   120  // UnmarshalCaddyfile sets up the Handler from Caddyfile tokens. Syntax:
   121  //
   122  //	throttle {
   123  //		latency <duration>
   124  //		read_burst_size <int>
   125  //		read_bytes_per_second <float>
   126  //		total_read_burst_size <int>
   127  //		total_read_bytes_per_second <float>
   128  //	}
   129  func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
   130  	_, wrapper := d.Next(), d.Val() // consume wrapper name
   131  
   132  	// No same-line options are supported
   133  	if d.CountRemainingArgs() > 0 {
   134  		return d.ArgErr()
   135  	}
   136  
   137  	var hasLatency, hasReadBurstSize, hasReadBytesPerSecond, hasTotalReadBurstSize, hasTotalReadBytesPerSecond bool
   138  	for nesting := d.Nesting(); d.NextBlock(nesting); {
   139  		optionName := d.Val()
   140  		switch optionName {
   141  		case "latency":
   142  			if hasLatency {
   143  				return d.Errf("duplicate %s option '%s'", wrapper, optionName)
   144  			}
   145  			if d.CountRemainingArgs() != 1 {
   146  				return d.ArgErr()
   147  			}
   148  			d.NextArg() // consume option value
   149  			dur, err := caddy.ParseDuration(d.Val())
   150  			if err != nil {
   151  				return d.Errf("parsing %s option '%s' duration: %v", wrapper, optionName, err)
   152  			}
   153  			h.Latency, hasLatency = caddy.Duration(dur), true
   154  		case "read_burst_size":
   155  			if hasReadBurstSize {
   156  				return d.Errf("duplicate %s option '%s'", wrapper, optionName)
   157  			}
   158  			if d.CountRemainingArgs() != 1 {
   159  				return d.ArgErr()
   160  			}
   161  			d.NextArg() // consume option value
   162  			val, err := strconv.ParseInt(d.Val(), 10, 32)
   163  			if err != nil {
   164  				return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err)
   165  			}
   166  			h.ReadBurstSize, hasReadBurstSize = int(val), true
   167  		case "read_bytes_per_second":
   168  			if hasReadBytesPerSecond {
   169  				return d.Errf("duplicate %s option '%s'", wrapper, optionName)
   170  			}
   171  			if d.CountRemainingArgs() != 1 {
   172  				return d.ArgErr()
   173  			}
   174  			d.NextArg() // consume option value
   175  			val, err := strconv.ParseFloat(d.Val(), 64)
   176  			if err != nil {
   177  				return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err)
   178  			}
   179  			h.ReadBytesPerSecond, hasReadBytesPerSecond = val, true
   180  		case "total_read_burst_size":
   181  			if hasTotalReadBurstSize {
   182  				return d.Errf("duplicate %s option '%s'", wrapper, optionName)
   183  			}
   184  			if d.CountRemainingArgs() != 1 {
   185  				return d.ArgErr()
   186  			}
   187  			d.NextArg() // consume option value
   188  			val, err := strconv.ParseInt(d.Val(), 10, 32)
   189  			if err != nil {
   190  				return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err)
   191  			}
   192  			h.TotalReadBurstSize, hasTotalReadBurstSize = int(val), true
   193  		case "total_read_bytes_per_second":
   194  			if hasTotalReadBytesPerSecond {
   195  				return d.Errf("duplicate %s option '%s'", wrapper, optionName)
   196  			}
   197  			if d.CountRemainingArgs() != 1 {
   198  				return d.ArgErr()
   199  			}
   200  			d.NextArg() // consume option value
   201  			val, err := strconv.ParseFloat(d.Val(), 64)
   202  			if err != nil {
   203  				return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err)
   204  			}
   205  			h.TotalReadBytesPerSecond, hasTotalReadBytesPerSecond = val, true
   206  		default:
   207  			return d.ArgErr()
   208  		}
   209  
   210  		// No nested blocks are supported
   211  		if d.NextBlock(nesting + 1) {
   212  			return d.Errf("malformed %s option '%s': blocks are not supported", wrapper, optionName)
   213  		}
   214  	}
   215  
   216  	return nil
   217  }
   218  
   219  type throttledConn struct {
   220  	net.Conn
   221  	ctx                        context.Context
   222  	logger                     *zap.Logger
   223  	totalLimiter, localLimiter *rate.Limiter
   224  }
   225  
   226  func (tc throttledConn) Read(p []byte) (int, error) {
   227  	// The rate limiters will not let us wait for more than their burst
   228  	// size, so the max we can read in each iteration is the minimum of
   229  	// len(p) and both limiters' burst sizes.
   230  	batchSize := len(p)
   231  	if tc.totalLimiter != nil {
   232  		if burstSize := tc.totalLimiter.Burst(); batchSize > burstSize {
   233  			batchSize = burstSize
   234  		}
   235  	}
   236  	if tc.localLimiter != nil {
   237  		if burstSize := tc.localLimiter.Burst(); batchSize > burstSize {
   238  			batchSize = burstSize
   239  		}
   240  	}
   241  
   242  	if tc.totalLimiter != nil {
   243  		err := tc.totalLimiter.WaitN(tc.ctx, batchSize)
   244  		if err != nil {
   245  			return 0, fmt.Errorf("waiting for total limiter: %v", err)
   246  		}
   247  	}
   248  	if tc.localLimiter != nil {
   249  		err := tc.localLimiter.WaitN(tc.ctx, batchSize)
   250  		if err != nil {
   251  			return 0, fmt.Errorf("waiting for local limiter: %v", err)
   252  		}
   253  	}
   254  
   255  	n, err := tc.Conn.Read(p[:batchSize])
   256  
   257  	tc.logger.Debug("read",
   258  		zap.String("remote", tc.RemoteAddr().String()),
   259  		zap.Int("batch_size", batchSize),
   260  		zap.Int("bytes_read", n),
   261  		zap.Error(err))
   262  
   263  	return n, err
   264  }
   265  
   266  // Interface guards
   267  var (
   268  	_ caddy.Provisioner     = (*Handler)(nil)
   269  	_ caddyfile.Unmarshaler = (*Handler)(nil)
   270  	_ layer4.NextHandler    = (*Handler)(nil)
   271  )