github.com/insolar/vanilla@v0.0.0-20201023172447-248fdf805322/iokit/throttle.go (about)

     1  // Copyright 2020 Insolar Network Ltd.
     2  // All rights reserved.
     3  // This material is licensed under the Insolar License version 1.0,
     4  // available at https://github.com/insolar/assured-ledger/blob/master/LICENSE.md.
     5  
     6  package iokit
     7  
     8  import (
     9  	"io"
    10  	"math"
    11  
    12  	"github.com/insolar/vanilla/throw"
    13  )
    14  
    15  type RateLimiter interface {
    16  	TakeQuota(max int64) int64
    17  }
    18  
    19  func RateLimitReader(r io.Reader, q RateLimiter) *RateLimitedReader {
    20  	switch {
    21  	case r == nil:
    22  		panic(throw.IllegalValue())
    23  	case q == nil:
    24  		panic(throw.IllegalValue())
    25  	}
    26  	return &RateLimitedReader{r, q}
    27  }
    28  
    29  func RateLimitWriter(w io.Writer, q RateLimiter) *RateLimitedWriter {
    30  	switch {
    31  	case w == nil:
    32  		panic(throw.IllegalValue())
    33  	case q == nil:
    34  		panic(throw.IllegalValue())
    35  	}
    36  	return &RateLimitedWriter{w, q}
    37  }
    38  
    39  var _ io.Reader = RateLimitedReader{}
    40  var _ io.WriterTo = RateLimitedReader{}
    41  
    42  type RateLimitedReader struct {
    43  	R io.Reader
    44  	Q RateLimiter
    45  }
    46  
    47  func (r RateLimitedReader) WriteTo(w io.Writer) (int64, error) {
    48  	return RateLimitedCopy(w, r.R, r.Q)
    49  }
    50  
    51  func (r RateLimitedReader) Read(p []byte) (int, error) {
    52  	return RateLimitedByteCopy(r.R.Read, p, r.Q)
    53  }
    54  
    55  func (r RateLimitedReader) Close() error {
    56  	if c, ok := r.R.(io.Closer); ok {
    57  		return c.Close()
    58  	}
    59  	return nil
    60  }
    61  
    62  var _ io.Writer = RateLimitedWriter{}
    63  var _ io.ReaderFrom = RateLimitedWriter{}
    64  
    65  type RateLimitedWriter struct {
    66  	W io.Writer
    67  	Q RateLimiter
    68  }
    69  
    70  func (r RateLimitedWriter) ReadFrom(rd io.Reader) (int64, error) {
    71  	return RateLimitedCopy(r.W, rd, r.Q)
    72  }
    73  
    74  func (r RateLimitedWriter) Write(p []byte) (int, error) {
    75  	return RateLimitedByteCopy(r.W.Write, p, r.Q)
    76  }
    77  
    78  func (r RateLimitedWriter) Close() error {
    79  	if c, ok := r.W.(io.Closer); ok {
    80  		return c.Close()
    81  	}
    82  	return nil
    83  }
    84  
    85  
    86  /****************************/
    87  
    88  const rateLimitBlockMin = 4096
    89  const rateLimitBlockMax = 32768 // don't set too high to avoid exhaustion of rate limiter
    90  
    91  type quotaSizer struct {
    92  	v int64
    93  }
    94  
    95  func (p *quotaSizer) estimate() int64 {
    96  	switch {
    97  	case p.v < rateLimitBlockMin:
    98  		return rateLimitBlockMin
    99  	case p.v > rateLimitBlockMax:
   100  		return rateLimitBlockMax
   101  	}
   102  	return p.v
   103  }
   104  
   105  func (p *quotaSizer) update(n int64) {
   106  	if p.v < n {
   107  		p.v = n
   108  	} else {
   109  		p.v -= n
   110  	}
   111  }
   112  
   113  func RateLimitedCopy(writer io.Writer, reader io.Reader, q RateLimiter) (int64, error) {
   114  	if q == nil {
   115  		return io.Copy(writer, reader)
   116  	}
   117  
   118  	// io.LimitedReader is reused for multiple calls to io.Copy
   119  	sizer := quotaSizer{}
   120  	limited := io.LimitedReader{}
   121  	limited.R, sizer.v = traverseLimitReaders(reader, math.MaxInt64)
   122  	innerW, maxW := traverseLimitWriters(writer, math.MaxInt64)
   123  
   124  	switch {
   125  	case maxW == 0 || sizer.v == 0:
   126  		return 0, nil
   127  	case maxW < sizer.v:
   128  		sizer.v = maxW
   129  	case sizer.v == math.MaxInt64:
   130  		sizer.v = 0
   131  	}
   132  
   133  	for total := int64(0); ; {
   134  		limited.N += q.TakeQuota(sizer.estimate())
   135  		for limited.N > 0 {
   136  			// io.Copy will use io.ReaderFrom when available
   137  			n, err := io.Copy(innerW, &limited)
   138  			total += n
   139  			if err != nil {
   140  				updatedLimitReaders(reader, total)
   141  				updatedLimitWriters(writer, total)
   142  				return total, err
   143  			}
   144  			sizer.update(n)
   145  		}
   146  	}
   147  }
   148  
   149  func RateLimitedByteCopy(fn func([]byte) (int, error), p []byte, q RateLimiter) (int, error) {
   150  	if q == nil {
   151  		return fn(p)
   152  	}
   153  
   154  	total := 0
   155  	for len(p) > 0 {
   156  		for q := q.TakeQuota(int64(len(p))); q > 0; {
   157  			n, err := fn(p[:q])
   158  			if err != nil {
   159  				return total + n, err
   160  			}
   161  			total += n
   162  			p = p[n:]
   163  			q -= int64(n)
   164  		}
   165  	}
   166  	return total, nil
   167  }
   168  
   169  func RateLimitedBySize(n int64, q RateLimiter) error {
   170  	if q == nil {
   171  		return nil
   172  	}
   173  	for ; n > 0; n -= q.TakeQuota(n) {
   174  	}
   175  	return nil
   176  }