github.com/blend/go-sdk@v1.20220411.3/ratelimiter/copy.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package ratelimiter
     9  
    10  import (
    11  	"context"
    12  	"errors"
    13  	"io"
    14  	"time"
    15  )
    16  
    17  const (
    18  	// DefaultCopyChunkSizeBytes is the write chunk size in bytes.
    19  	DefaultCopyChunkSizeBytes = 32 * 1024
    20  )
    21  
    22  // CopyOptions are options for the throttled copy.
    23  type CopyOptions struct {
    24  	RateBytes   int64
    25  	RateQuantum time.Duration
    26  	ChunkSize   int
    27  	Buffer      []byte
    28  	OnWrite     func(int, time.Duration)
    29  }
    30  
    31  // CopyOption mutates CopyOptions.
    32  type CopyOption func(*CopyOptions)
    33  
    34  // OptCopyRateBytes sets the bytes portion of the rate.
    35  func OptCopyRateBytes(b int64) CopyOption {
    36  	return func(o *CopyOptions) { o.RateBytes = b }
    37  }
    38  
    39  // OptCopyRateQuantum sets the quantum portion of the rate.
    40  func OptCopyRateQuantum(q time.Duration) CopyOption {
    41  	return func(o *CopyOptions) { o.RateQuantum = q }
    42  }
    43  
    44  // OptCopyChunkSize sets the quantum portion of the rate.
    45  func OptCopyChunkSize(cs int) CopyOption {
    46  	return func(o *CopyOptions) { o.ChunkSize = cs }
    47  }
    48  
    49  // OptCopyBuffer sets the buffer for the copy.
    50  func OptCopyBuffer(buf []byte) CopyOption {
    51  	return func(o *CopyOptions) { o.Buffer = buf }
    52  }
    53  
    54  // OptCopyOnWrite sets the on write handler for the copy.
    55  func OptCopyOnWrite(handler func(bytesWritten int, elapsed time.Duration)) CopyOption {
    56  	return func(o *CopyOptions) { o.OnWrite = handler }
    57  }
    58  
    59  // errCopyInvalidWrite means that a write returned an impossible count.
    60  var errCopyInvalidWrite = errors.New("throttled copy; invalid write result")
    61  
    62  // errCopyInvalidChunkSize means that the user provided a < 1 chunk size.
    63  var errCopyInvalidChunkSize = errors.New("throttled copy; invalid chunk size")
    64  
    65  // errCopyInvalidOnWrite means that the user provided a nil write handler.
    66  var errCopyInvalidOnWrite = errors.New("throttled copy; invalid on write handler")
    67  
    68  // Copy copies from the src reader to the dst writer.
    69  func Copy(ctx context.Context, dst io.Writer, src io.Reader, opts ...CopyOption) (written int64, err error) {
    70  	options := CopyOptions{
    71  		RateBytes:   10 * (1 << 27), // 10gbit in bytes, or (10*(2^30))/8
    72  		RateQuantum: time.Second,
    73  		ChunkSize:   DefaultCopyChunkSizeBytes,
    74  		OnWrite:     func(_ int, _ time.Duration) {},
    75  	}
    76  	for _, opt := range opts {
    77  		opt(&options)
    78  	}
    79  
    80  	if options.ChunkSize <= 0 {
    81  		err = errCopyInvalidChunkSize
    82  		return
    83  	}
    84  	if options.OnWrite == nil {
    85  		err = errCopyInvalidOnWrite
    86  		return
    87  	}
    88  
    89  	if options.Buffer == nil {
    90  		size := options.ChunkSize
    91  		if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N {
    92  			if l.N < 1 {
    93  				size = 1
    94  			} else {
    95  				size = int(l.N)
    96  			}
    97  		}
    98  		options.Buffer = make([]byte, size)
    99  	}
   100  
   101  	var nr, nw int
   102  	var er, ew error
   103  	var ts time.Time
   104  	wait := Wait{
   105  		NumberOfActions: options.RateBytes,
   106  		Quantum:         options.RateQuantum,
   107  	}
   108  	var after *time.Timer
   109  	for {
   110  		ts = time.Now()
   111  		nr, er = src.Read(options.Buffer)
   112  		if nr > 0 {
   113  			nw, ew = dst.Write(options.Buffer[0:nr])
   114  			if nw < 0 || nr < nw {
   115  				nw = 0
   116  				if ew == nil {
   117  					ew = errCopyInvalidWrite
   118  				}
   119  			}
   120  			written += int64(nw)
   121  			if ew != nil {
   122  				err = ew
   123  				break
   124  			}
   125  			if nr != nw {
   126  				err = io.ErrShortWrite
   127  				break
   128  			}
   129  		}
   130  		if er != nil {
   131  			if er != io.EOF {
   132  				err = er
   133  			}
   134  			break
   135  		}
   136  		if err = wait.WaitTimer(ctx, int64(nw), time.Since(ts), after); err != nil {
   137  			return
   138  		}
   139  		options.OnWrite(nw, time.Since(ts))
   140  	}
   141  	return written, err
   142  }