github.com/aergoio/aergo@v1.3.1/p2p/raftsupport/snapshotsender.go (about)

     1  /*
     2   * @file
     3   * @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package raftsupport
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"github.com/aergoio/aergo-lib/log"
    14  	"github.com/aergoio/aergo/consensus"
    15  	"github.com/aergoio/aergo/p2p/p2pcommon"
    16  	"github.com/aergoio/aergo/p2p/p2putil"
    17  	"github.com/aergoio/aergo/types"
    18  	pioutil "github.com/aergoio/etcd/pkg/ioutil"
    19  	"github.com/aergoio/etcd/raft"
    20  	"github.com/aergoio/etcd/raft/raftpb"
    21  	"github.com/aergoio/etcd/snap"
    22  	"github.com/golang/protobuf/proto"
    23  	core "github.com/libp2p/go-libp2p-core"
    24  	"io"
    25  	"time"
    26  )
    27  
    28  type snapshotSender struct {
    29  	logger   *log.Logger
    30  	nt       p2pcommon.NetworkTransport
    31  	rAcc     consensus.AergoRaftAccessor
    32  	stopChan chan interface{}
    33  
    34  	peer p2pcommon.RemotePeer
    35  }
    36  
    37  func newSnapshotSender(logger *log.Logger, nt p2pcommon.NetworkTransport, rAcc consensus.AergoRaftAccessor, peer p2pcommon.RemotePeer) *snapshotSender {
    38  	return &snapshotSender{logger: logger, nt: nt, rAcc: rAcc, stopChan: make(chan interface{}), peer:peer}
    39  }
    40  
    41  func (s *snapshotSender) Send(snapMsg *snap.Message) {
    42  	peer := s.peer
    43  
    44  	// 1. connect to target peer with snap protocol
    45  	stream, err := s.getSnapshotStream(peer.Meta())
    46  	if err != nil {
    47  		s.logger.Warn().Str(p2putil.LogPeerName, peer.Name()).Err(err).Msg("failed to send snapshot")
    48  		s.rAcc.ReportUnreachable(peer.ID())
    49  		s.rAcc.ReportSnapshot(peer.ID(), raft.SnapshotFailure)
    50  		snapMsg.CloseWithError(errUnreachableMember)
    51  		return
    52  	}
    53  
    54  	//
    55  	m := snapMsg.Message
    56  	body := s.createSnapBody(*snapMsg)
    57  	defer body.Close()
    58  
    59  	s.logger.Info().Uint64("index", m.Snapshot.Metadata.Index).Str(p2putil.LogPeerName, peer.Name()).Msg("start to send database snapshot")
    60  
    61  	// send bytes to target peer
    62  	err = s.pushBMsg(body, stream)
    63  	defer snapMsg.CloseWithError(err)
    64  	if err != nil {
    65  		s.logger.Warn().Uint64("index", m.Snapshot.Metadata.Index).Str(p2putil.LogPeerName, peer.Name()).Err(err).Msg("database snapshot failed to be sent out")
    66  
    67  		// errMemberRemoved is a critical error since a removed member should
    68  		// always be stopped. So we use reportCriticalError to report it to errorc.
    69  		//if err == errMemberRemoved {
    70  		//	reportCriticalError(err, s.errorc)
    71  		//}
    72  
    73  		// TODO set peer status not healthy
    74  		s.rAcc.ReportUnreachable(peer.ID())
    75  		// report SnapshotFailure to raft state machine. After raft state
    76  		// machine knows about it, it would pause a while and retry sending
    77  		// new snapshot message.
    78  		s.rAcc.ReportSnapshot(peer.ID(), raft.SnapshotFailure)
    79  		//sentFailures.WithLabelValues(to).Inc()
    80  		//snapshotSendFailures.WithLabelValues(to).Inc()
    81  		return
    82  	}
    83  
    84  	s.rAcc.ReportSnapshot(peer.ID(), raft.SnapshotFinish)
    85  	s.logger.Info().Uint64("index", m.Snapshot.Metadata.Index).Str(p2putil.LogPeerName, peer.Name()).Msg("database snapshot [index: %d, to: %s] sent out successfully")
    86  
    87  	//sentBytes.WithLabelValues(to).Add(float64(merged.TotalSize))
    88  	//snapshotSend.WithLabelValues(to).Inc()
    89  	//snapshotSendSeconds.WithLabelValues(to).Observe(time.Since(start).Seconds())
    90  
    91  }
    92  func (s *snapshotSender) pushBMsg(body io.Reader, to io.ReadWriteCloser) error {
    93  	//ctx, cancel := context.WithCancel(context.Background())
    94  	//defer cancel()
    95  
    96  	const (
    97  		WholeTimeLimit  = time.Hour * 24 * 30 // just make indefinitely long term.
    98  		ProcessingLimit = time.Minute * 20    // receiving peer should complete and response within after receiving whole snapshot data.
    99  	)
   100  
   101  	wErr := make(chan error, 1)
   102  	rResult := make(chan error, 1)
   103  	t := time.NewTimer(WholeTimeLimit)
   104  	// write snapshot bytes
   105  	go func() {
   106  		_, err := io.Copy(to, body)
   107  		if err != nil {
   108  			wErr <- err
   109  		}
   110  		// renew timer if timer is not expired yet.
   111  		if !t.Stop() {
   112  			<-t.C
   113  		}
   114  		t.Reset(ProcessingLimit)
   115  	}()
   116  
   117  	// read response of receiver
   118  	go func() {
   119  		resp, err := readWireHSResp(to)
   120  		if err == nil {
   121  			if resp.Status == types.ResultStatus_OK {
   122  				err = nil
   123  			} else {
   124  				err = fmt.Errorf("error code: %v, msg: %s", resp.Status.String(), resp.Message)
   125  			}
   126  		}
   127  		rResult <- err
   128  	}()
   129  
   130  	select {
   131  	case <-s.stopChan:
   132  		return errors.New("stopped")
   133  	case <-t.C:
   134  		return errors.New("timeout")
   135  	case r := <-wErr:
   136  		return r
   137  	case r := <-rResult:
   138  		return r
   139  	}
   140  }
   141  
   142  func (s *snapshotSender) getSnapshotStream(meta p2pcommon.PeerMeta) (core.Stream, error) {
   143  	// try connect peer with possible versions
   144  	stream, err := s.nt.GetOrCreateStream(meta, p2pcommon.RaftSnapSubAddr)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	return stream, nil
   149  }
   150  
   151  func readWireHSResp(rd io.Reader) (resp types.SnapshotResponse, err error) {
   152  	bytebuf := make([]byte, SnapRespHeaderLength)
   153  	readn, err := p2putil.ReadToLen(rd, bytebuf)
   154  	if err != nil {
   155  		return
   156  	}
   157  	if readn != SnapRespHeaderLength {
   158  		err = fmt.Errorf("wrong header length")
   159  		return
   160  	}
   161  
   162  	respLen := binary.BigEndian.Uint32(bytebuf)
   163  	bodyBuf := make([]byte, respLen)
   164  	readn, err = p2putil.ReadToLen(rd, bodyBuf)
   165  	if err != nil {
   166  		return
   167  	}
   168  	if readn != int(respLen) {
   169  		err = fmt.Errorf("wrong body length")
   170  		return
   171  	}
   172  
   173  	err = proto.Unmarshal(bodyBuf, &resp)
   174  	return
   175  }
   176  func (s *snapshotSender) createSnapBody(merged snap.Message) io.ReadCloser {
   177  	buf := new(bytes.Buffer)
   178  	enc := &RaftMsgEncoder{w: buf}
   179  	// encode raft message
   180  	if err := enc.Encode(&merged.Message); err != nil {
   181  		s.logger.Panic().Err(err).Msg("encode raft message error")
   182  	}
   183  
   184  	return &pioutil.ReaderAndCloser{
   185  		Reader: io.MultiReader(buf, merged.ReadCloser),
   186  		Closer: merged.ReadCloser,
   187  	}
   188  }
   189  
   190  // RaftMsgEncoder is encode raftpb.Message itt result will be same as rafthttp.messageEncoder
   191  type RaftMsgEncoder struct {
   192  	w io.Writer
   193  }
   194  
   195  func (enc *RaftMsgEncoder) Encode(m *raftpb.Message) error {
   196  	if err := binary.Write(enc.w, binary.BigEndian, uint64(m.Size())); err != nil {
   197  		return err
   198  	}
   199  	bytes, err := p2putil.MarshalMessageBody(m)
   200  	if err != nil {
   201  		return err
   202  	}
   203  	_, err = enc.w.Write(bytes)
   204  	return err
   205  }
   206  
   207  type RaftMsgDecoder struct {
   208  	r io.Reader
   209  }
   210  
   211  var (
   212  	readBytesLimit     uint64 = 512 * 1024 * 1024 // 512 MB
   213  	ErrExceedSizeLimit        = errors.New("raftsupport: error limit exceeded")
   214  )
   215  
   216  func (dec *RaftMsgDecoder) Decode() (raftpb.Message, error) {
   217  	return dec.DecodeLimit(readBytesLimit)
   218  }
   219  
   220  func (dec *RaftMsgDecoder) DecodeLimit(numBytes uint64) (raftpb.Message, error) {
   221  	var m raftpb.Message
   222  	var l uint64
   223  	if err := binary.Read(dec.r, binary.BigEndian, &l); err != nil {
   224  		return m, err
   225  	}
   226  	if l > numBytes {
   227  		return m, ErrExceedSizeLimit
   228  	}
   229  	buf := make([]byte, int(l))
   230  	if _, err := io.ReadFull(dec.r, buf); err != nil {
   231  		return m, err
   232  	}
   233  	return m, m.Unmarshal(buf)
   234  }