github.com/unicornultrafoundation/go-u2u@v1.0.0-rc1.0.20240205080301-e74a83d3fadc/u2u/genesisstore/fileshash/write_file.go (about) 1 package fileshash 2 3 import ( 4 "crypto/sha256" 5 "errors" 6 hasher "hash" 7 "io" 8 9 "github.com/unicornultrafoundation/go-helios/common/bigendian" 10 "github.com/unicornultrafoundation/go-helios/hash" 11 12 "github.com/unicornultrafoundation/go-u2u/utils/ioread" 13 ) 14 15 type TmpWriter interface { 16 io.ReadWriteSeeker 17 io.Closer 18 Drop() error 19 } 20 21 type tmpWriter struct { 22 TmpWriter 23 h hasher.Hash 24 } 25 26 type Writer struct { 27 backend io.Writer 28 29 openTmp func(int) TmpWriter 30 tmps []tmpWriter 31 tmpReadPos uint64 32 33 size uint64 34 35 pieceSize uint64 36 } 37 38 func WrapWriter(backend io.Writer, pieceSize uint64, openTmp func(int) TmpWriter) *Writer { 39 return &Writer{ 40 backend: backend, 41 openTmp: openTmp, 42 pieceSize: pieceSize, 43 } 44 } 45 46 func (w *Writer) writeIntoTmp(p []byte) error { 47 if len(p) == 0 { 48 return nil 49 } 50 if w.size/w.pieceSize >= uint64(len(w.tmps)) { 51 tmpI := len(w.tmps) 52 f := w.openTmp(len(w.tmps)) 53 if tmpI > 0 { 54 err := w.tmps[tmpI-1].Close() 55 if err != nil { 56 return err 57 } 58 w.tmps[tmpI-1].TmpWriter = nil 59 } 60 w.tmps = append(w.tmps, tmpWriter{ 61 TmpWriter: f, 62 h: sha256.New(), 63 }) 64 } 65 currentPosInTmp := w.size % w.pieceSize 66 maxToWrite := w.pieceSize - currentPosInTmp 67 if maxToWrite > uint64(len(p)) { 68 maxToWrite = uint64(len(p)) 69 } 70 n, err := w.tmps[len(w.tmps)-1].Write(p[:maxToWrite]) 71 w.tmps[len(w.tmps)-1].h.Write(p[:maxToWrite]) 72 w.size += uint64(n) 73 if err != nil { 74 return err 75 } 76 return w.writeIntoTmp(p[maxToWrite:]) 77 } 78 79 func (w *Writer) resetTmpReads() error { 80 for _, tmp := range w.tmps { 81 if tmp.TmpWriter != nil { 82 _, err := tmp.Seek(0, io.SeekStart) 83 if err != nil { 84 return err 85 } 86 } 87 } 88 w.tmpReadPos = 0 89 return nil 90 } 91 92 func (w *Writer) readFromTmp(p []byte, destructive bool) error { 93 if len(p) == 0 { 94 return nil 95 } 96 tmpI := w.tmpReadPos / w.pieceSize 97 if tmpI > uint64(len(w.tmps)) { 98 return errors.New("all tmp files are consumed") 99 } 100 if w.tmps[tmpI].TmpWriter == nil { 101 w.tmps[tmpI].TmpWriter = w.openTmp(int(tmpI)) 102 } 103 currentPosInTmp := w.tmpReadPos % w.pieceSize 104 maxToRead := w.pieceSize - currentPosInTmp 105 if maxToRead > uint64(len(p)) { 106 maxToRead = uint64(len(p)) 107 } 108 err := ioread.ReadAll(w.tmps[tmpI], p[:maxToRead]) 109 if err != nil { 110 return err 111 } 112 w.tmpReadPos += maxToRead 113 if w.tmpReadPos%w.pieceSize == 0 { 114 _ = w.tmps[tmpI].Close() 115 if destructive { 116 _ = w.tmps[tmpI].Drop() 117 } 118 w.tmps[tmpI].TmpWriter = nil 119 } 120 return w.readFromTmp(p[maxToRead:], destructive) 121 } 122 123 func (w *Writer) Write(p []byte) (n int, err error) { 124 oldSize := w.size 125 err = w.writeIntoTmp(p) 126 n = int(w.size - oldSize) 127 return 128 } 129 130 func (w *Writer) readFromTmpPieceByPiece(destructive bool, fn func([]byte) error) error { 131 err := w.resetTmpReads() 132 if err != nil { 133 return err 134 } 135 piece := make([]byte, w.pieceSize) 136 for pos := uint64(0); pos < w.size; pos += w.pieceSize { 137 end := pos + w.pieceSize 138 if end > w.size { 139 end = w.size 140 } 141 err := w.readFromTmp(piece[:end-pos], destructive) 142 if err != nil { 143 return err 144 } 145 err = fn(piece[:end-pos]) 146 if err != nil { 147 return err 148 } 149 } 150 return nil 151 } 152 153 func (w *Writer) Root() hash.Hash { 154 hashes := hash.Hashes{} 155 for _, tmp := range w.tmps { 156 h := hash.BytesToHash(tmp.h.Sum(nil)) 157 hashes = append(hashes, h) 158 } 159 return calcHashesRoot(hashes, w.pieceSize, w.size) 160 } 161 162 func (w *Writer) Flush() (hash.Hash, error) { 163 // write piece 164 _, err := w.backend.Write(bigendian.Uint32ToBytes(uint32(w.pieceSize))) 165 if err != nil { 166 return hash.Hash{}, err 167 } 168 // write size 169 _, err = w.backend.Write(bigendian.Uint64ToBytes(w.size)) 170 if err != nil { 171 return hash.Hash{}, err 172 } 173 // write piece hashes 174 hashes := hash.Hashes{} 175 for _, tmp := range w.tmps { 176 h := hash.BytesToHash(tmp.h.Sum(nil)) 177 hashes = append(hashes, h) 178 _, err = w.backend.Write(h[:]) 179 if err != nil { 180 return hash.Hash{}, err 181 } 182 } 183 root := calcHashesRoot(hashes, w.pieceSize, w.size) 184 // write data and drop tmp files 185 return root, w.readFromTmpPieceByPiece(true, func(piece []byte) error { 186 _, err := w.backend.Write(piece) 187 return err 188 }) 189 }