github.com/psexton/git-lfs@v2.1.1-0.20170517224304-289a18b2bc53+incompatible/tools/iotools.go (about)

     1  package tools
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/hex"
     7  	"hash"
     8  	"io"
     9  	"io/ioutil"
    10  	"os"
    11  
    12  	"github.com/git-lfs/git-lfs/errors"
    13  	"github.com/git-lfs/git-lfs/progress"
    14  )
    15  
    16  const (
    17  	// memoryBufferLimit is the number of bytes to buffer in memory before
    18  	// spooling the contents of an `io.Reader` in `Spool()` to a temporary
    19  	// file on disk.
    20  	memoryBufferLimit = 1024
    21  )
    22  
    23  // CopyWithCallback copies reader to writer while performing a progress callback
    24  func CopyWithCallback(writer io.Writer, reader io.Reader, totalSize int64, cb progress.CopyCallback) (int64, error) {
    25  	if success, _ := CloneFile(writer, reader); success {
    26  		if cb != nil {
    27  			cb(totalSize, totalSize, 0)
    28  		}
    29  		return totalSize, nil
    30  	}
    31  	if cb == nil {
    32  		return io.Copy(writer, reader)
    33  	}
    34  
    35  	cbReader := &progress.CallbackReader{
    36  		C:         cb,
    37  		TotalSize: totalSize,
    38  		Reader:    reader,
    39  	}
    40  	return io.Copy(writer, cbReader)
    41  }
    42  
    43  // Get a new Hash instance of the type used to hash LFS content
    44  func NewLfsContentHash() hash.Hash {
    45  	return sha256.New()
    46  }
    47  
    48  // HashingReader wraps a reader and calculates the hash of the data as it is read
    49  type HashingReader struct {
    50  	reader io.Reader
    51  	hasher hash.Hash
    52  }
    53  
    54  func NewHashingReader(r io.Reader) *HashingReader {
    55  	return &HashingReader{r, NewLfsContentHash()}
    56  }
    57  
    58  func NewHashingReaderPreloadHash(r io.Reader, hash hash.Hash) *HashingReader {
    59  	return &HashingReader{r, hash}
    60  }
    61  
    62  func (r *HashingReader) Hash() string {
    63  	return hex.EncodeToString(r.hasher.Sum(nil))
    64  }
    65  
    66  func (r *HashingReader) Read(b []byte) (int, error) {
    67  	w, err := r.reader.Read(b)
    68  	if err == nil || err == io.EOF {
    69  		_, e := r.hasher.Write(b[0:w])
    70  		if e != nil && err == nil {
    71  			return w, e
    72  		}
    73  	}
    74  
    75  	return w, err
    76  }
    77  
    78  // RetriableReader wraps a error response of reader as RetriableError()
    79  type RetriableReader struct {
    80  	reader io.Reader
    81  }
    82  
    83  func NewRetriableReader(r io.Reader) io.Reader {
    84  	return &RetriableReader{r}
    85  }
    86  
    87  func (r *RetriableReader) Read(b []byte) (int, error) {
    88  	n, err := r.reader.Read(b)
    89  
    90  	// EOF is a successful response as it is used to signal a graceful end
    91  	// of input c.f. https://git.io/v6riQ
    92  	//
    93  	// Otherwise, if the error is non-nil and already retriable (in the
    94  	// case that the underlying reader `r.reader` is itself a
    95  	// `*RetriableReader`, return the error wholesale:
    96  	if err == nil || err == io.EOF || errors.IsRetriableError(err) {
    97  		return n, err
    98  	}
    99  
   100  	return n, errors.NewRetriableError(err)
   101  }
   102  
   103  // Spool spools the contents from 'from' to 'to' by buffering the entire
   104  // contents of 'from' into a temprorary file created in the directory "dir".
   105  // That buffer is held in memory until the file grows to larger than
   106  // 'memoryBufferLimit`, then the remaining contents are spooled to disk.
   107  //
   108  // The temporary file is cleaned up after the copy is complete.
   109  //
   110  // The number of bytes written to "to", as well as any error encountered are
   111  // returned.
   112  func Spool(to io.Writer, from io.Reader, dir string) (n int64, err error) {
   113  	// First, buffer up to `memoryBufferLimit` in memory.
   114  	buf := make([]byte, memoryBufferLimit)
   115  	if bn, err := from.Read(buf); err != nil && err != io.EOF {
   116  		return int64(bn), err
   117  	} else {
   118  		buf = buf[:bn]
   119  	}
   120  
   121  	var spool io.Reader = bytes.NewReader(buf)
   122  	if err != io.EOF {
   123  		// If we weren't at the end of the stream, create a temporary
   124  		// file, and spool the remaining contents there.
   125  		tmp, err := ioutil.TempFile(dir, "")
   126  		if err != nil {
   127  			return 0, errors.Wrap(err, "spool tmp")
   128  		}
   129  		defer os.Remove(tmp.Name())
   130  
   131  		if n, err = io.Copy(tmp, from); err != nil {
   132  			return n, errors.Wrap(err, "unable to spool")
   133  		}
   134  
   135  		if _, err = tmp.Seek(0, io.SeekStart); err != nil {
   136  			return 0, errors.Wrap(err, "unable to seek")
   137  		}
   138  
   139  		// The spooled contents will now be the concatenation of the
   140  		// contents we stored in memory, then the remainder of the
   141  		// contents on disk.
   142  		spool = io.MultiReader(spool, tmp)
   143  	}
   144  
   145  	return io.Copy(to, spool)
   146  }