github.com/unicornultrafoundation/go-u2u@v1.0.0-rc1.0.20240205080301-e74a83d3fadc/u2u/genesisstore/fileshash/reader_file.go (about)

     1  package fileshash
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  
    10  	"github.com/unicornultrafoundation/go-helios/common/bigendian"
    11  	"github.com/unicornultrafoundation/go-helios/hash"
    12  
    13  	"github.com/unicornultrafoundation/go-u2u/utils/ioread"
    14  )
    15  
    16  var (
    17  	ErrRootNotFound = errors.New("hashes root not found")
    18  	ErrRootMismatch = errors.New("hashes root mismatch")
    19  	ErrHashMismatch = errors.New("hash mismatch")
    20  	ErrTooMuchMem   = errors.New("hashed file requires too much memory")
    21  	ErrInit         = errors.New("failed to init hashfile")
    22  	ErrPieceRead    = errors.New("failed to read piece")
    23  	ErrClosed       = errors.New("closed")
    24  )
    25  
    26  type Reader struct {
    27  	backend io.Reader
    28  
    29  	size uint64
    30  	pos  uint64
    31  
    32  	pieceSize       uint64
    33  	currentPiecePos uint64
    34  	currentPiece    []byte
    35  
    36  	root   hash.Hash
    37  	hashes hash.Hashes
    38  
    39  	maxMemUsage uint64
    40  
    41  	err error
    42  }
    43  
    44  func WrapReader(backend io.Reader, maxMemUsage uint64, root hash.Hash) *Reader {
    45  	return &Reader{
    46  		backend:         backend,
    47  		pos:             0,
    48  		maxMemUsage:     maxMemUsage,
    49  		currentPiecePos: math.MaxUint64,
    50  		root:            root,
    51  	}
    52  }
    53  
    54  func (r *Reader) readHashes(n uint64) (hash.Hashes, error) {
    55  	hashes := make(hash.Hashes, n)
    56  	for i := uint64(0); i < n; i++ {
    57  		err := ioread.ReadAll(r.backend, hashes[i][:])
    58  		if err != nil {
    59  			return nil, err
    60  		}
    61  	}
    62  	return hashes, nil
    63  }
    64  
    65  func calcHash(piece []byte) hash.Hash {
    66  	hasher := sha256.New()
    67  	hasher.Write(piece)
    68  	return hash.BytesToHash(hasher.Sum(nil))
    69  }
    70  
    71  func calcHashesRoot(hashes hash.Hashes, pieceSize, size uint64) hash.Hash {
    72  	hasher := sha256.New()
    73  	hasher.Write(bigendian.Uint32ToBytes(uint32(pieceSize)))
    74  	hasher.Write(bigendian.Uint64ToBytes(size))
    75  	for _, h := range hashes {
    76  		hasher.Write(h.Bytes())
    77  	}
    78  	return hash.BytesToHash(hasher.Sum(nil))
    79  }
    80  
    81  func getPiecesNum(size, pieceSize uint64) uint64 {
    82  	if size%pieceSize != 0 {
    83  		return size/pieceSize + 1
    84  	}
    85  	return size / pieceSize
    86  }
    87  
    88  func (r *Reader) getPiecesNum(size uint64) uint64 {
    89  	return getPiecesNum(size, r.pieceSize)
    90  }
    91  
    92  func (r *Reader) getPiecePos(pos uint64) uint64 {
    93  	return pos / r.pieceSize
    94  }
    95  
    96  func (r *Reader) readNewPiece() error {
    97  	// previous piece must be fully read at this point
    98  	// ensure currentPiece has correct size
    99  	maxToRead := r.size - r.pos
   100  	if maxToRead > r.pieceSize {
   101  		maxToRead = r.pieceSize
   102  	}
   103  	if uint64(len(r.currentPiece)) > maxToRead {
   104  		r.currentPiece = r.currentPiece[:maxToRead]
   105  	}
   106  	if uint64(len(r.currentPiece)) < maxToRead {
   107  		r.currentPiece = make([]byte, maxToRead)
   108  	}
   109  	// read currentPiece
   110  	err := ioread.ReadAll(r.backend, r.currentPiece)
   111  	if err != nil {
   112  		return err
   113  	}
   114  	// verify piece hash and advance currentPiecePos
   115  	currentPiecePos := r.getPiecePos(r.pos)
   116  	if calcHash(r.currentPiece) != r.hashes[currentPiecePos] {
   117  		return ErrHashMismatch
   118  	}
   119  	r.currentPiecePos = currentPiecePos
   120  	return nil
   121  }
   122  
   123  func (r *Reader) readFromPiece(p []byte) (n int, err error) {
   124  	if r.currentPiecePos != r.getPiecePos(r.pos) {
   125  		// switch to new piece
   126  		err := r.readNewPiece()
   127  		if err != nil {
   128  			return 0, fmt.Errorf("%v: %v", ErrPieceRead, err)
   129  		}
   130  	}
   131  	maxToRead := uint64(len(r.currentPiece))
   132  	if maxToRead > uint64(len(p)) {
   133  		maxToRead = uint64(len(p))
   134  	}
   135  	posInPiece := r.pos % r.pieceSize
   136  	consumed := copy(p[:maxToRead], r.currentPiece[posInPiece:])
   137  	r.pos += uint64(consumed)
   138  	return consumed, nil
   139  }
   140  
   141  func memUsageOf(pieceSize, hashesNum uint64) uint64 {
   142  	if hashesNum > math.MaxUint32 {
   143  		return math.MaxUint64
   144  	}
   145  	return pieceSize + hashesNum*128
   146  }
   147  
   148  func (r *Reader) init() error {
   149  	buf := make([]byte, 8)
   150  	// read piece size
   151  	err := ioread.ReadAll(r.backend, buf[:4])
   152  	if err != nil {
   153  		return err
   154  	}
   155  	r.pieceSize = uint64(bigendian.BytesToUint32(buf[:4]))
   156  	// read content size
   157  	err = ioread.ReadAll(r.backend, buf)
   158  	if err != nil {
   159  		return err
   160  	}
   161  	r.size = bigendian.BytesToUint64(buf)
   162  
   163  	hashesNum := r.getPiecesNum(r.size)
   164  	if memUsageOf(r.pieceSize, hashesNum) > uint64(r.maxMemUsage) {
   165  		return ErrTooMuchMem
   166  	}
   167  	// read piece hashes
   168  	hashes, err := r.readHashes(hashesNum)
   169  	if err != nil {
   170  		return err
   171  	}
   172  	if calcHashesRoot(hashes, r.pieceSize, r.size) != r.root {
   173  		return ErrRootMismatch
   174  	}
   175  	r.hashes = hashes
   176  	return nil
   177  }
   178  
   179  func (r *Reader) read(p []byte) (n int, err error) {
   180  	if len(p) == 0 {
   181  		return 0, nil
   182  	}
   183  	if r.hashes == nil {
   184  		err := r.init()
   185  		if err != nil {
   186  			return 0, fmt.Errorf("%v: %v", ErrInit, err)
   187  		}
   188  	}
   189  	if r.pos >= r.size {
   190  		return 0, io.EOF
   191  	}
   192  	return r.readFromPiece(p)
   193  }
   194  
   195  func (r *Reader) Read(p []byte) (n int, err error) {
   196  	if r.err != nil {
   197  		return 0, r.err
   198  	}
   199  	n, err = r.read(p)
   200  	if err != nil {
   201  		r.err = err
   202  	}
   203  	return n, err
   204  }
   205  
   206  func (r *Reader) Close() error {
   207  	r.hashes = nil
   208  	r.err = ErrClosed
   209  	return r.backend.(io.Closer).Close()
   210  }