github.com/cavaliergopher/grab/v3@v3.0.1/transfer.go (about)

     1  package grab
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/cavaliergopher/grab/v3/pkg/bps"
    10  )
    11  
    12  type transfer struct {
    13  	n     int64 // must be 64bit aligned on 386
    14  	ctx   context.Context
    15  	gauge bps.Gauge
    16  	lim   RateLimiter
    17  	w     io.Writer
    18  	r     io.Reader
    19  	b     []byte
    20  }
    21  
    22  func newTransfer(ctx context.Context, lim RateLimiter, dst io.Writer, src io.Reader, buf []byte) *transfer {
    23  	return &transfer{
    24  		ctx:   ctx,
    25  		gauge: bps.NewSMA(6), // five second moving average sampling every second
    26  		lim:   lim,
    27  		w:     dst,
    28  		r:     src,
    29  		b:     buf,
    30  	}
    31  }
    32  
    33  // copy behaves similarly to io.CopyBuffer except that it checks for cancelation
    34  // of the given context.Context, reports progress in a thread-safe manner and
    35  // tracks the transfer rate.
    36  func (c *transfer) copy() (written int64, err error) {
    37  	// maintain a bps gauge in another goroutine
    38  	ctx, cancel := context.WithCancel(c.ctx)
    39  	defer cancel()
    40  	go bps.Watch(ctx, c.gauge, c.N, time.Second)
    41  
    42  	// start the transfer
    43  	if c.b == nil {
    44  		c.b = make([]byte, 32*1024)
    45  	}
    46  	for {
    47  		select {
    48  		case <-c.ctx.Done():
    49  			err = c.ctx.Err()
    50  			return
    51  		default:
    52  			// keep working
    53  		}
    54  		nr, er := c.r.Read(c.b)
    55  		if nr > 0 {
    56  			nw, ew := c.w.Write(c.b[0:nr])
    57  			if nw > 0 {
    58  				written += int64(nw)
    59  				atomic.StoreInt64(&c.n, written)
    60  			}
    61  			if ew != nil {
    62  				err = ew
    63  				break
    64  			}
    65  			if nr != nw {
    66  				err = io.ErrShortWrite
    67  				break
    68  			}
    69  			// wait for rate limiter
    70  			if c.lim != nil {
    71  				err = c.lim.WaitN(c.ctx, nr)
    72  				if err != nil {
    73  					return
    74  				}
    75  			}
    76  		}
    77  		if er != nil {
    78  			if er != io.EOF {
    79  				err = er
    80  			}
    81  			break
    82  		}
    83  	}
    84  	return written, err
    85  }
    86  
    87  // N returns the number of bytes transferred.
    88  func (c *transfer) N() (n int64) {
    89  	if c == nil {
    90  		return 0
    91  	}
    92  	n = atomic.LoadInt64(&c.n)
    93  	return
    94  }
    95  
    96  // BPS returns the current bytes per second transfer rate using a simple moving
    97  // average.
    98  func (c *transfer) BPS() (bps float64) {
    99  	if c == nil || c.gauge == nil {
   100  		return 0
   101  	}
   102  	return c.gauge.BPS()
   103  }