storj.io/uplink@v1.13.0/private/storage/streams/peek.go (about)

     1  // Copyright (C) 2019 Storj Labs, Inc.
     2  // See LICENSE for copying information.
     3  
     4  package streams
     5  
     6  import (
     7  	"errors"
     8  	"io"
     9  
    10  	"github.com/zeebo/errs"
    11  )
    12  
    13  // PeekThresholdReader allows a check to see if the size of a given reader
    14  // exceeds the maximum inline segment size or not.
    15  type PeekThresholdReader struct {
    16  	r              io.Reader
    17  	thresholdBuf   []byte
    18  	thresholdErr   error
    19  	isLargerCalled bool
    20  	readCalled     bool
    21  }
    22  
    23  // NewPeekThresholdReader creates a new instance of PeekThresholdReader.
    24  func NewPeekThresholdReader(r io.Reader) (pt *PeekThresholdReader) {
    25  	return &PeekThresholdReader{r: r}
    26  }
    27  
    28  // Read initially reads bytes from the internal buffer, then continues
    29  // reading from the wrapped data reader. The number of bytes read `n`
    30  // is returned.
    31  func (pt *PeekThresholdReader) Read(p []byte) (n int, err error) {
    32  	pt.readCalled = true
    33  
    34  	if len(pt.thresholdBuf) > 0 || pt.thresholdErr != nil {
    35  		n = copy(p, pt.thresholdBuf)
    36  		pt.thresholdBuf = pt.thresholdBuf[n:]
    37  		if len(pt.thresholdBuf) == 0 {
    38  			err := pt.thresholdErr
    39  			pt.thresholdErr = nil
    40  			return n, err
    41  		}
    42  		return n, nil
    43  	}
    44  
    45  	return pt.r.Read(p)
    46  }
    47  
    48  // IsLargerThan returns a bool to determine whether a reader's size
    49  // is larger than the given threshold or not.
    50  func (pt *PeekThresholdReader) IsLargerThan(thresholdSize int) (bool, error) {
    51  	if pt.isLargerCalled {
    52  		return false, errs.New("IsLargerThan can't be called more than once")
    53  	}
    54  	if pt.readCalled {
    55  		return false, errs.New("IsLargerThan can't be called after Read has been called")
    56  	}
    57  	pt.isLargerCalled = true
    58  	buf := make([]byte, thresholdSize+1)
    59  	n, err := io.ReadFull(pt.r, buf)
    60  	pt.thresholdBuf = buf[:n]
    61  	pt.thresholdErr = err
    62  	if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
    63  		if errors.Is(err, io.ErrUnexpectedEOF) {
    64  			pt.thresholdErr = io.EOF
    65  		}
    66  		return false, nil
    67  	}
    68  	if err != nil {
    69  		return false, err
    70  	}
    71  	return true, nil
    72  }