github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/iolimiter/iolimiter.go (about)

     1  package iolimiter
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"time"
     8  
     9  	"golang.org/x/time/rate"
    10  )
    11  
    12  const burstLimit = 1000 * 1000 * 1000
    13  
    14  type Reader struct {
    15  	r       io.Reader
    16  	limiter *rate.Limiter
    17  	ctx     context.Context
    18  }
    19  
    20  type Writer struct {
    21  	w       io.Writer
    22  	limiter *rate.Limiter
    23  	ctx     context.Context
    24  }
    25  
    26  type conn struct {
    27  	net.Conn
    28  	r            io.Reader
    29  	w            io.Writer
    30  	readLimiter  *rate.Limiter
    31  	writeLimiter *rate.Limiter
    32  	ctx          context.Context
    33  }
    34  
    35  //NewtRateLimitConn sets rate limit (bytes/sec) to the Conn read and write.
    36  func NewtConn(c net.Conn, bytesPerSec float64) net.Conn {
    37  	s := &conn{
    38  		Conn: c,
    39  		r:    c,
    40  		w:    c,
    41  		ctx:  context.Background(),
    42  	}
    43  	s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
    44  	s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
    45  	s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
    46  	s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
    47  	return s
    48  }
    49  
    50  //NewtRateLimitReaderConn sets rate limit (bytes/sec) to the Conn read.
    51  func NewReaderConn(c net.Conn, bytesPerSec float64) net.Conn {
    52  	s := &conn{
    53  		Conn: c,
    54  		r:    c,
    55  		w:    c,
    56  		ctx:  context.Background(),
    57  	}
    58  	s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
    59  	s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
    60  	return s
    61  }
    62  
    63  //NewtRateLimitWriterConn sets rate limit (bytes/sec) to the Conn write.
    64  func NewWriterConn(c net.Conn, bytesPerSec float64) net.Conn {
    65  	s := &conn{
    66  		Conn: c,
    67  		r:    c,
    68  		w:    c,
    69  		ctx:  context.Background(),
    70  	}
    71  	s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
    72  	s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
    73  	return s
    74  }
    75  
    76  // Read reads bytes into p.
    77  func (s *conn) Read(p []byte) (int, error) {
    78  	if s.readLimiter == nil {
    79  		return s.r.Read(p)
    80  	}
    81  	n, err := s.r.Read(p)
    82  	if err != nil {
    83  		return n, err
    84  	}
    85  	if err := s.readLimiter.WaitN(s.ctx, n); err != nil {
    86  		return n, err
    87  	}
    88  	return n, nil
    89  }
    90  
    91  // Write writes bytes from p.
    92  func (s *conn) Write(p []byte) (int, error) {
    93  	if s.writeLimiter == nil {
    94  		return s.w.Write(p)
    95  	}
    96  	n, err := s.w.Write(p)
    97  	if err != nil {
    98  		return n, err
    99  	}
   100  	if err := s.writeLimiter.WaitN(s.ctx, n); err != nil {
   101  		return n, err
   102  	}
   103  	return n, err
   104  }
   105  func (s *conn) Close() error {
   106  	if s.Conn != nil {
   107  		e := s.Conn.Close()
   108  		s.Conn = nil
   109  		s.r = nil
   110  		s.w = nil
   111  		s.readLimiter = nil
   112  		s.writeLimiter = nil
   113  		s.ctx = nil
   114  		return e
   115  	}
   116  	return nil
   117  }
   118  
   119  // NewReader returns a reader that implements io.Reader with rate limiting.
   120  func NewReader(r io.Reader) *Reader {
   121  	return &Reader{
   122  		r:   r,
   123  		ctx: context.Background(),
   124  	}
   125  }
   126  
   127  // NewReaderWithContext returns a reader that implements io.Reader with rate limiting.
   128  func NewReaderWithContext(r io.Reader, ctx context.Context) *Reader {
   129  	return &Reader{
   130  		r:   r,
   131  		ctx: ctx,
   132  	}
   133  }
   134  
   135  // NewWriter returns a writer that implements io.Writer with rate limiting.
   136  func NewWriter(w io.Writer) *Writer {
   137  	return &Writer{
   138  		w:   w,
   139  		ctx: context.Background(),
   140  	}
   141  }
   142  
   143  // NewWriterWithContext returns a writer that implements io.Writer with rate limiting.
   144  func NewWriterWithContext(w io.Writer, ctx context.Context) *Writer {
   145  	return &Writer{
   146  		w:   w,
   147  		ctx: ctx,
   148  	}
   149  }
   150  
   151  // SetRateLimit sets rate limit (bytes/sec) to the reader.
   152  func (s *Reader) SetRateLimit(bytesPerSec float64) {
   153  	s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
   154  	s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
   155  }
   156  
   157  // Read reads bytes into p.
   158  func (s *Reader) Read(p []byte) (int, error) {
   159  	if s.limiter == nil {
   160  		return s.r.Read(p)
   161  	}
   162  	n, err := s.r.Read(p)
   163  	if err != nil {
   164  		return n, err
   165  	}
   166  	if err := s.limiter.WaitN(s.ctx, n); err != nil {
   167  		return n, err
   168  	}
   169  	return n, nil
   170  }
   171  
   172  // SetRateLimit sets rate limit (bytes/sec) to the writer.
   173  func (s *Writer) SetRateLimit(bytesPerSec float64) {
   174  	s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
   175  	s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
   176  }
   177  
   178  // Write writes bytes from p.
   179  func (s *Writer) Write(p []byte) (int, error) {
   180  	if s.limiter == nil {
   181  		return s.w.Write(p)
   182  	}
   183  	n, err := s.w.Write(p)
   184  	if err != nil {
   185  		return n, err
   186  	}
   187  	if err := s.limiter.WaitN(s.ctx, n); err != nil {
   188  		return n, err
   189  	}
   190  	return n, err
   191  }