github.com/mckael/restic@v0.8.3/internal/limiter/static_limiter.go (about)

     1  package limiter
     2  
     3  import (
     4  	"io"
     5  	"net/http"
     6  
     7  	"github.com/juju/ratelimit"
     8  )
     9  
    10  type staticLimiter struct {
    11  	upstream   *ratelimit.Bucket
    12  	downstream *ratelimit.Bucket
    13  }
    14  
    15  // NewStaticLimiter constructs a Limiter with a fixed (static) upload and
    16  // download rate cap
    17  func NewStaticLimiter(uploadKb, downloadKb int) Limiter {
    18  	var (
    19  		upstreamBucket   *ratelimit.Bucket
    20  		downstreamBucket *ratelimit.Bucket
    21  	)
    22  
    23  	if uploadKb > 0 {
    24  		upstreamBucket = ratelimit.NewBucketWithRate(toByteRate(uploadKb), int64(toByteRate(uploadKb)))
    25  	}
    26  
    27  	if downloadKb > 0 {
    28  		downstreamBucket = ratelimit.NewBucketWithRate(toByteRate(downloadKb), int64(toByteRate(downloadKb)))
    29  	}
    30  
    31  	return staticLimiter{
    32  		upstream:   upstreamBucket,
    33  		downstream: downstreamBucket,
    34  	}
    35  }
    36  
    37  func (l staticLimiter) Upstream(r io.Reader) io.Reader {
    38  	return l.limit(r, l.upstream)
    39  }
    40  
    41  func (l staticLimiter) Downstream(r io.Reader) io.Reader {
    42  	return l.limit(r, l.downstream)
    43  }
    44  
    45  type roundTripper func(*http.Request) (*http.Response, error)
    46  
    47  func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    48  	return rt(req)
    49  }
    50  
    51  func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*http.Response, error) {
    52  	if req.Body != nil {
    53  		req.Body = limitedReadCloser{
    54  			limited:  l.Upstream(req.Body),
    55  			original: req.Body,
    56  		}
    57  	}
    58  
    59  	res, err := rt.RoundTrip(req)
    60  
    61  	if res != nil && res.Body != nil {
    62  		res.Body = limitedReadCloser{
    63  			limited:  l.Downstream(res.Body),
    64  			original: res.Body,
    65  		}
    66  	}
    67  
    68  	return res, err
    69  }
    70  
    71  // Transport returns an HTTP transport limited with the limiter l.
    72  func (l staticLimiter) Transport(rt http.RoundTripper) http.RoundTripper {
    73  	return roundTripper(func(req *http.Request) (*http.Response, error) {
    74  		return l.roundTripper(rt, req)
    75  	})
    76  }
    77  
    78  func (l staticLimiter) limit(r io.Reader, b *ratelimit.Bucket) io.Reader {
    79  	if b == nil {
    80  		return r
    81  	}
    82  	return ratelimit.Reader(r, b)
    83  }
    84  
    85  func toByteRate(val int) float64 {
    86  	return float64(val) * 1024.
    87  }