github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ext/dload/throttler.go (about)

     1  // Package dload implements functionality to download resources into AIS cluster from external source.
     2  /*
     3   * Copyright (c) 2018-2022, NVIDIA CORPORATION. All rights reserved.
     4   */
     5  package dload
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"io"
    11  	"time"
    12  
    13  	"github.com/NVIDIA/aistore/cmn/cos"
    14  )
    15  
    16  var errThrottlerStopped = errors.New("throttler has been stopped")
    17  
    18  type (
    19  	throttler struct {
    20  		sema    *cos.Semaphore
    21  		emptyCh chan struct{} // Empty, closed channel (set only if `sema == nil`).
    22  
    23  		maxBytesPerMinute int
    24  		capacityCh        chan int
    25  		giveBackCh        chan int
    26  		ticker            *time.Ticker
    27  		stopCh            *cos.StopCh
    28  	}
    29  
    30  	throughputThrottler interface {
    31  		acquireAllowance(ctx context.Context, n int) error
    32  	}
    33  
    34  	throttledReader struct {
    35  		t   throughputThrottler
    36  		ctx context.Context
    37  		r   io.ReadCloser
    38  	}
    39  )
    40  
    41  func (t *throttler) init(limits Limits) {
    42  	if limits.Connections > 0 {
    43  		t.sema = cos.NewSemaphore(limits.Connections)
    44  	} else {
    45  		t.emptyCh = make(chan struct{})
    46  		close(t.emptyCh)
    47  	}
    48  	if limits.BytesPerHour > 0 {
    49  		t.initThroughputThrottling(limits.BytesPerHour / 60)
    50  	}
    51  }
    52  
    53  func (t *throttler) initThroughputThrottling(maxBytesPerMinute int) {
    54  	t.maxBytesPerMinute = maxBytesPerMinute
    55  	t.capacityCh = make(chan int, 1)
    56  	t.giveBackCh = make(chan int, 1)
    57  	t.ticker = time.NewTicker(time.Minute)
    58  	t.stopCh = cos.NewStopCh()
    59  
    60  	go func() {
    61  		t.do()
    62  		if t.ticker != nil {
    63  			t.ticker.Stop()
    64  			close(t.capacityCh)
    65  		}
    66  	}()
    67  }
    68  
    69  // LOOP-INVARIANT: `t.capacityCh` has 1 element and `t.giveBackCh` has 0 elements.
    70  // LOOP-INVARIANT: `t.capacityCh` and `t.giveBackCh` can't have size > 0 at the same time.
    71  // Readers start to compete for resources on `t.capacityCh`.
    72  func (t *throttler) do() {
    73  	t.capacityCh <- t.maxBytesPerMinute
    74  
    75  	for {
    76  		select {
    77  		case <-t.stopCh.Listen():
    78  			return
    79  		case leftover := <-t.giveBackCh:
    80  			// Reader took value from `t.capacityCh` and put it to `t.giveBackCh`.
    81  			// `t.capacityCh` has 0 elements and `t.giveBackCh` has 0 elements
    82  			// (we've just read from `t.giveBackCh`).
    83  			if leftover > 0 {
    84  				select {
    85  				// By default put leftover to capacity channel.
    86  				case t.capacityCh <- leftover:
    87  					break
    88  				// But if time has passed, put a big chunk.
    89  				case <-t.ticker.C:
    90  					t.capacityCh <- t.maxBytesPerMinute
    91  				}
    92  			} else {
    93  				// Readers are faster than bandwidth, throttle here.
    94  				select {
    95  				case <-t.ticker.C:
    96  					t.capacityCh <- t.maxBytesPerMinute
    97  				case <-t.stopCh.Listen():
    98  					return
    99  				}
   100  			}
   101  			// Regardless of chosen if-branch we put 1 element to `t.capacityCh`.
   102  		}
   103  	}
   104  }
   105  
   106  func (t *throttler) tryAcquire() <-chan struct{} {
   107  	if t.sema == nil {
   108  		return t.emptyCh
   109  	}
   110  	return t.sema.TryAcquire()
   111  }
   112  
   113  func (t *throttler) release() {
   114  	if t.sema == nil {
   115  		return
   116  	}
   117  	t.sema.Release()
   118  }
   119  
   120  func (t *throttler) wrapReader(ctx context.Context, r io.ReadCloser) io.ReadCloser {
   121  	if t.maxBytesPerMinute == 0 {
   122  		return r
   123  	}
   124  	return &throttledReader{
   125  		t:   t,
   126  		ctx: ctx,
   127  		r:   r,
   128  	}
   129  }
   130  
   131  func (t *throttler) stop() {
   132  	if t.ticker != nil {
   133  		t.ticker.Stop()
   134  		t.ticker = nil
   135  	}
   136  	if t.stopCh != nil {
   137  		t.stopCh.Close()
   138  	}
   139  }
   140  
   141  func (t *throttler) giveBack(leftoverSize int) {
   142  	// Never waits. If we took value from `t.capacityCh`, it means that
   143  	// `t.giveBack` was empty.
   144  	t.giveBackCh <- leftoverSize
   145  }
   146  
   147  func (t *throttler) acquireAllowance(ctx context.Context, n int) error {
   148  	select {
   149  	case size, ok := <-t.capacityCh:
   150  		if !ok {
   151  			return errThrottlerStopped
   152  		}
   153  		t.giveBack(size - n)
   154  		return nil
   155  	case <-ctx.Done():
   156  		return context.Canceled
   157  	}
   158  }
   159  
   160  func (tr *throttledReader) Read(p []byte) (n int, err error) {
   161  	if err := tr.t.acquireAllowance(tr.ctx, len(p)); err != nil {
   162  		return 0, err
   163  	}
   164  	return tr.r.Read(p)
   165  }
   166  
   167  func (tr *throttledReader) Close() (err error) {
   168  	return tr.r.Close()
   169  }