github.com/Cloud-Foundations/Dominator@v0.3.4/lib/rsync/getBlocks.go (about)

     1  package rsync
     2  
     3  import (
     4  	"crypto/sha512"
     5  	"fmt"
     6  	"io"
     7  
     8  	"github.com/Cloud-Foundations/Dominator/lib/errors"
     9  	"github.com/Cloud-Foundations/Dominator/lib/hash"
    10  	proto "github.com/Cloud-Foundations/Dominator/proto/rsync"
    11  )
    12  
    13  type measuringConn struct {
    14  	Conn
    15  	stats Stats
    16  }
    17  
    18  func getBlocks(rawConn Conn, decoder Decoder, encoder Encoder, reader io.Reader,
    19  	writer io.WriteSeeker, totalBytes, readerBytes uint64) (Stats, error) {
    20  	blockOrder := sizeToOrder(totalBytes) >> 1
    21  	if blockOrder < 9 {
    22  		blockOrder = 9
    23  	} else if blockOrder > 32 {
    24  		blockOrder = 32
    25  	}
    26  	blockSize := uint64(1 << blockOrder)
    27  	if reader == nil {
    28  		readerBytes = 0
    29  	}
    30  	numBlocks := readerBytes >> blockOrder
    31  	request := proto.GetBlocksRequest{
    32  		BlockOrder: blockOrder,
    33  		NumBlocks:  numBlocks,
    34  	}
    35  	conn := &measuringConn{Conn: rawConn}
    36  	if err := encoder.Encode(request); err != nil {
    37  		return Stats{}, fmt.Errorf("error encoding request: %s", err)
    38  	}
    39  	if err := conn.Flush(); err != nil {
    40  		return Stats{}, err
    41  	}
    42  	errChannel := make(chan error, 1)
    43  	go func() { errChannel <- readBlocks(writer, decoder, conn, blockOrder) }()
    44  	for index := uint64(0); index < numBlocks; index++ {
    45  		select {
    46  		case err := <-errChannel:
    47  			if err != nil {
    48  				return Stats{}, err
    49  			}
    50  			return Stats{}, errors.New("premature end of blocks")
    51  		default:
    52  		}
    53  		hasher := sha512.New()
    54  		if _, err := io.CopyN(hasher, reader, int64(blockSize)); err != nil {
    55  			return Stats{}, err
    56  		}
    57  		var hashVal hash.Hash
    58  		copy(hashVal[:], hasher.Sum(nil))
    59  		if _, err := conn.Write(hashVal[:]); err != nil {
    60  			return Stats{}, err
    61  		}
    62  		if index == 0 {
    63  			if err := conn.Flush(); err != nil {
    64  				return Stats{}, err
    65  			}
    66  		}
    67  	}
    68  	if err := conn.Flush(); err != nil {
    69  		return Stats{}, err
    70  	}
    71  	if err := <-errChannel; err != nil {
    72  		return Stats{}, err
    73  	}
    74  	return conn.stats, nil
    75  }
    76  
    77  func readBlocks(writer io.WriteSeeker, decoder Decoder, reader io.Reader,
    78  	blockOrder uint8) error {
    79  	var numBytesReceived uint64
    80  	for {
    81  		var block proto.Block
    82  		if err := decoder.Decode(&block); err != nil {
    83  			return fmt.Errorf("error decoding block: %s", err)
    84  		}
    85  		if err := errors.New(block.Error); err != nil {
    86  			return err
    87  		}
    88  		if block.Size < 1 {
    89  			return nil
    90  		}
    91  		offset := int64(block.Index << blockOrder)
    92  		if _, err := writer.Seek(offset, io.SeekStart); err != nil {
    93  			return err
    94  		}
    95  		if _, err := io.CopyN(writer, reader, int64(block.Size)); err != nil {
    96  			return err
    97  		}
    98  		numBytesReceived += block.Size
    99  	}
   100  }
   101  
   102  func sizeToOrder(blockSize uint64) uint8 {
   103  	order := uint8(0)
   104  	for i := uint8(0); i < 64; i++ {
   105  		if 1<<i&blockSize != 0 {
   106  			order = i
   107  		}
   108  	}
   109  	return order
   110  }
   111  
   112  func (conn *measuringConn) Read(b []byte) (int, error) {
   113  	nRead, err := conn.Conn.Read(b)
   114  	conn.stats.NumRead += uint64(nRead)
   115  	return nRead, err
   116  }
   117  
   118  func (conn *measuringConn) Write(b []byte) (int, error) {
   119  	nWritten, err := conn.Conn.Write(b)
   120  	conn.stats.NumWritten += uint64(nWritten)
   121  	return nWritten, err
   122  }