github.com/unicornultrafoundation/go-u2u@v1.0.0-rc1.0.20240205080301-e74a83d3fadc/u2u/genesisstore/fileshash/write_file.go (about)

     1  package fileshash
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"errors"
     6  	hasher "hash"
     7  	"io"
     8  
     9  	"github.com/unicornultrafoundation/go-helios/common/bigendian"
    10  	"github.com/unicornultrafoundation/go-helios/hash"
    11  
    12  	"github.com/unicornultrafoundation/go-u2u/utils/ioread"
    13  )
    14  
    15  type TmpWriter interface {
    16  	io.ReadWriteSeeker
    17  	io.Closer
    18  	Drop() error
    19  }
    20  
    21  type tmpWriter struct {
    22  	TmpWriter
    23  	h hasher.Hash
    24  }
    25  
    26  type Writer struct {
    27  	backend io.Writer
    28  
    29  	openTmp    func(int) TmpWriter
    30  	tmps       []tmpWriter
    31  	tmpReadPos uint64
    32  
    33  	size uint64
    34  
    35  	pieceSize uint64
    36  }
    37  
    38  func WrapWriter(backend io.Writer, pieceSize uint64, openTmp func(int) TmpWriter) *Writer {
    39  	return &Writer{
    40  		backend:   backend,
    41  		openTmp:   openTmp,
    42  		pieceSize: pieceSize,
    43  	}
    44  }
    45  
    46  func (w *Writer) writeIntoTmp(p []byte) error {
    47  	if len(p) == 0 {
    48  		return nil
    49  	}
    50  	if w.size/w.pieceSize >= uint64(len(w.tmps)) {
    51  		tmpI := len(w.tmps)
    52  		f := w.openTmp(len(w.tmps))
    53  		if tmpI > 0 {
    54  			err := w.tmps[tmpI-1].Close()
    55  			if err != nil {
    56  				return err
    57  			}
    58  			w.tmps[tmpI-1].TmpWriter = nil
    59  		}
    60  		w.tmps = append(w.tmps, tmpWriter{
    61  			TmpWriter: f,
    62  			h:         sha256.New(),
    63  		})
    64  	}
    65  	currentPosInTmp := w.size % w.pieceSize
    66  	maxToWrite := w.pieceSize - currentPosInTmp
    67  	if maxToWrite > uint64(len(p)) {
    68  		maxToWrite = uint64(len(p))
    69  	}
    70  	n, err := w.tmps[len(w.tmps)-1].Write(p[:maxToWrite])
    71  	w.tmps[len(w.tmps)-1].h.Write(p[:maxToWrite])
    72  	w.size += uint64(n)
    73  	if err != nil {
    74  		return err
    75  	}
    76  	return w.writeIntoTmp(p[maxToWrite:])
    77  }
    78  
    79  func (w *Writer) resetTmpReads() error {
    80  	for _, tmp := range w.tmps {
    81  		if tmp.TmpWriter != nil {
    82  			_, err := tmp.Seek(0, io.SeekStart)
    83  			if err != nil {
    84  				return err
    85  			}
    86  		}
    87  	}
    88  	w.tmpReadPos = 0
    89  	return nil
    90  }
    91  
    92  func (w *Writer) readFromTmp(p []byte, destructive bool) error {
    93  	if len(p) == 0 {
    94  		return nil
    95  	}
    96  	tmpI := w.tmpReadPos / w.pieceSize
    97  	if tmpI > uint64(len(w.tmps)) {
    98  		return errors.New("all tmp files are consumed")
    99  	}
   100  	if w.tmps[tmpI].TmpWriter == nil {
   101  		w.tmps[tmpI].TmpWriter = w.openTmp(int(tmpI))
   102  	}
   103  	currentPosInTmp := w.tmpReadPos % w.pieceSize
   104  	maxToRead := w.pieceSize - currentPosInTmp
   105  	if maxToRead > uint64(len(p)) {
   106  		maxToRead = uint64(len(p))
   107  	}
   108  	err := ioread.ReadAll(w.tmps[tmpI], p[:maxToRead])
   109  	if err != nil {
   110  		return err
   111  	}
   112  	w.tmpReadPos += maxToRead
   113  	if w.tmpReadPos%w.pieceSize == 0 {
   114  		_ = w.tmps[tmpI].Close()
   115  		if destructive {
   116  			_ = w.tmps[tmpI].Drop()
   117  		}
   118  		w.tmps[tmpI].TmpWriter = nil
   119  	}
   120  	return w.readFromTmp(p[maxToRead:], destructive)
   121  }
   122  
   123  func (w *Writer) Write(p []byte) (n int, err error) {
   124  	oldSize := w.size
   125  	err = w.writeIntoTmp(p)
   126  	n = int(w.size - oldSize)
   127  	return
   128  }
   129  
   130  func (w *Writer) readFromTmpPieceByPiece(destructive bool, fn func([]byte) error) error {
   131  	err := w.resetTmpReads()
   132  	if err != nil {
   133  		return err
   134  	}
   135  	piece := make([]byte, w.pieceSize)
   136  	for pos := uint64(0); pos < w.size; pos += w.pieceSize {
   137  		end := pos + w.pieceSize
   138  		if end > w.size {
   139  			end = w.size
   140  		}
   141  		err := w.readFromTmp(piece[:end-pos], destructive)
   142  		if err != nil {
   143  			return err
   144  		}
   145  		err = fn(piece[:end-pos])
   146  		if err != nil {
   147  			return err
   148  		}
   149  	}
   150  	return nil
   151  }
   152  
   153  func (w *Writer) Root() hash.Hash {
   154  	hashes := hash.Hashes{}
   155  	for _, tmp := range w.tmps {
   156  		h := hash.BytesToHash(tmp.h.Sum(nil))
   157  		hashes = append(hashes, h)
   158  	}
   159  	return calcHashesRoot(hashes, w.pieceSize, w.size)
   160  }
   161  
   162  func (w *Writer) Flush() (hash.Hash, error) {
   163  	// write piece
   164  	_, err := w.backend.Write(bigendian.Uint32ToBytes(uint32(w.pieceSize)))
   165  	if err != nil {
   166  		return hash.Hash{}, err
   167  	}
   168  	// write size
   169  	_, err = w.backend.Write(bigendian.Uint64ToBytes(w.size))
   170  	if err != nil {
   171  		return hash.Hash{}, err
   172  	}
   173  	// write piece hashes
   174  	hashes := hash.Hashes{}
   175  	for _, tmp := range w.tmps {
   176  		h := hash.BytesToHash(tmp.h.Sum(nil))
   177  		hashes = append(hashes, h)
   178  		_, err = w.backend.Write(h[:])
   179  		if err != nil {
   180  			return hash.Hash{}, err
   181  		}
   182  	}
   183  	root := calcHashesRoot(hashes, w.pieceSize, w.size)
   184  	// write data and drop tmp files
   185  	return root, w.readFromTmpPieceByPiece(true, func(piece []byte) error {
   186  		_, err := w.backend.Write(piece)
   187  		return err
   188  	})
   189  }