github.com/Cloud-Foundations/Dominator@v0.3.4/lib/rsync/getBlocks.go (about) 1 package rsync 2 3 import ( 4 "crypto/sha512" 5 "fmt" 6 "io" 7 8 "github.com/Cloud-Foundations/Dominator/lib/errors" 9 "github.com/Cloud-Foundations/Dominator/lib/hash" 10 proto "github.com/Cloud-Foundations/Dominator/proto/rsync" 11 ) 12 13 type measuringConn struct { 14 Conn 15 stats Stats 16 } 17 18 func getBlocks(rawConn Conn, decoder Decoder, encoder Encoder, reader io.Reader, 19 writer io.WriteSeeker, totalBytes, readerBytes uint64) (Stats, error) { 20 blockOrder := sizeToOrder(totalBytes) >> 1 21 if blockOrder < 9 { 22 blockOrder = 9 23 } else if blockOrder > 32 { 24 blockOrder = 32 25 } 26 blockSize := uint64(1 << blockOrder) 27 if reader == nil { 28 readerBytes = 0 29 } 30 numBlocks := readerBytes >> blockOrder 31 request := proto.GetBlocksRequest{ 32 BlockOrder: blockOrder, 33 NumBlocks: numBlocks, 34 } 35 conn := &measuringConn{Conn: rawConn} 36 if err := encoder.Encode(request); err != nil { 37 return Stats{}, fmt.Errorf("error encoding request: %s", err) 38 } 39 if err := conn.Flush(); err != nil { 40 return Stats{}, err 41 } 42 errChannel := make(chan error, 1) 43 go func() { errChannel <- readBlocks(writer, decoder, conn, blockOrder) }() 44 for index := uint64(0); index < numBlocks; index++ { 45 select { 46 case err := <-errChannel: 47 if err != nil { 48 return Stats{}, err 49 } 50 return Stats{}, errors.New("premature end of blocks") 51 default: 52 } 53 hasher := sha512.New() 54 if _, err := io.CopyN(hasher, reader, int64(blockSize)); err != nil { 55 return Stats{}, err 56 } 57 var hashVal hash.Hash 58 copy(hashVal[:], hasher.Sum(nil)) 59 if _, err := conn.Write(hashVal[:]); err != nil { 60 return Stats{}, err 61 } 62 if index == 0 { 63 if err := conn.Flush(); err != nil { 64 return Stats{}, err 65 } 66 } 67 } 68 if err := conn.Flush(); err != nil { 69 return Stats{}, err 70 } 71 if err := <-errChannel; err != nil { 72 return Stats{}, err 73 } 74 return conn.stats, nil 75 } 76 77 func readBlocks(writer io.WriteSeeker, decoder Decoder, reader io.Reader, 78 blockOrder uint8) error { 79 var numBytesReceived uint64 80 for { 81 var block proto.Block 82 if err := decoder.Decode(&block); err != nil { 83 return fmt.Errorf("error decoding block: %s", err) 84 } 85 if err := errors.New(block.Error); err != nil { 86 return err 87 } 88 if block.Size < 1 { 89 return nil 90 } 91 offset := int64(block.Index << blockOrder) 92 if _, err := writer.Seek(offset, io.SeekStart); err != nil { 93 return err 94 } 95 if _, err := io.CopyN(writer, reader, int64(block.Size)); err != nil { 96 return err 97 } 98 numBytesReceived += block.Size 99 } 100 } 101 102 func sizeToOrder(blockSize uint64) uint8 { 103 order := uint8(0) 104 for i := uint8(0); i < 64; i++ { 105 if 1<<i&blockSize != 0 { 106 order = i 107 } 108 } 109 return order 110 } 111 112 func (conn *measuringConn) Read(b []byte) (int, error) { 113 nRead, err := conn.Conn.Read(b) 114 conn.stats.NumRead += uint64(nRead) 115 return nRead, err 116 } 117 118 func (conn *measuringConn) Write(b []byte) (int, error) { 119 nWritten, err := conn.Conn.Write(b) 120 conn.stats.NumWritten += uint64(nWritten) 121 return nWritten, err 122 }