github.com/amazechain/amc@v0.1.3/internal/sync/rpc_ping.go (about)

     1  package sync
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/amazechain/amc/api/protocol/sync_pb"
     8  	ssztype "github.com/amazechain/amc/common/types/ssz"
     9  	"github.com/amazechain/amc/internal/p2p"
    10  	p2ptypes "github.com/amazechain/amc/internal/p2p/types"
    11  	"time"
    12  
    13  	libp2pcore "github.com/libp2p/go-libp2p/core"
    14  	"github.com/libp2p/go-libp2p/core/peer"
    15  )
    16  
    17  // pingHandler reads the incoming ping rpc message from the peer.
    18  func (s *Service) pingHandler(_ context.Context, msg interface{}, stream libp2pcore.Stream) error {
    19  	SetRPCStreamDeadlines(stream)
    20  
    21  	m, ok := msg.(*ssztype.SSZUint64)
    22  	if !ok {
    23  		return fmt.Errorf("wrong message type for ping, got %T, wanted *uint64", msg)
    24  	}
    25  	if err := s.rateLimiter.validateRequest(stream, 1); err != nil {
    26  		return err
    27  	}
    28  	s.rateLimiter.add(stream, 1)
    29  	valid, err := s.validateSequenceNum(*m, stream.Conn().RemotePeer())
    30  	if err != nil {
    31  		// Descore peer for giving us a bad sequence number.
    32  		if errors.Is(err, p2ptypes.ErrInvalidSequenceNum) {
    33  			s.cfg.p2p.Peers().Scorers().BadResponsesScorer().Increment(stream.Conn().RemotePeer())
    34  			s.writeErrorResponseToStream(responseCodeInvalidRequest, p2ptypes.ErrInvalidSequenceNum.Error(), stream)
    35  		}
    36  		return err
    37  	}
    38  	if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil {
    39  		return err
    40  	}
    41  	sq := s.cfg.p2p.GetPing()
    42  	if _, err := s.cfg.p2p.Encoding().EncodeWithMaxLength(stream, sq); err != nil {
    43  		return err
    44  	}
    45  
    46  	closeStream(stream)
    47  
    48  	if valid {
    49  		s.cfg.p2p.Peers().SetPing(stream.Conn().RemotePeer(), &sync_pb.Ping{SeqNumber: uint64(*m)})
    50  		// If the sequence number was valid we're done.
    51  		return nil
    52  	}
    53  
    54  	return nil
    55  }
    56  
    57  func (s *Service) sendPingRequest(ctx context.Context, id peer.ID) error {
    58  	ctx, cancel := context.WithTimeout(ctx, respTimeout)
    59  	defer cancel()
    60  
    61  	pingReq := ssztype.SSZUint64(s.cfg.p2p.GetPing().SeqNumber)
    62  	topic, err := p2p.TopicFromMessage(p2p.PingMessageName)
    63  	if err != nil {
    64  		return err
    65  	}
    66  	stream, err := s.cfg.p2p.Send(ctx, &pingReq, topic, id)
    67  	if err != nil {
    68  		return err
    69  	}
    70  	currentTime := time.Now()
    71  	defer closeStream(stream)
    72  
    73  	code, errMsg, err := ReadStatusCode(stream, s.cfg.p2p.Encoding())
    74  	if err != nil {
    75  		return err
    76  	}
    77  	// Records the latency of the ping request for that peer.
    78  	s.cfg.p2p.Host().Peerstore().RecordLatency(id, time.Now().Sub(currentTime))
    79  
    80  	if code != 0 {
    81  		s.cfg.p2p.Peers().Scorers().BadResponsesScorer().Increment(stream.Conn().RemotePeer())
    82  		return errors.New(errMsg)
    83  	}
    84  	pingResponse := new(ssztype.SSZUint64)
    85  	if err := s.cfg.p2p.Encoding().DecodeWithMaxLength(stream, pingResponse); err != nil {
    86  		return err
    87  	}
    88  	valid, err := s.validateSequenceNum(*pingResponse, stream.Conn().RemotePeer())
    89  	if err != nil {
    90  		// Descore peer for giving us a bad sequence number.
    91  		if errors.Is(err, p2ptypes.ErrInvalidSequenceNum) {
    92  			s.cfg.p2p.Peers().Scorers().BadResponsesScorer().Increment(stream.Conn().RemotePeer())
    93  		}
    94  		return err
    95  	}
    96  	if valid {
    97  		return nil
    98  	}
    99  	s.cfg.p2p.Peers().SetPing(stream.Conn().RemotePeer(), &sync_pb.Ping{SeqNumber: uint64(*pingResponse)})
   100  	return nil
   101  }
   102  
   103  // validates the peer's sequence number.
   104  func (s *Service) validateSequenceNum(seq ssztype.SSZUint64, id peer.ID) (bool, error) {
   105  	md, err := s.cfg.p2p.Peers().GetPing(id)
   106  	if err != nil {
   107  		return false, err
   108  	}
   109  	//
   110  	if md == nil {
   111  		return true, nil
   112  	}
   113  	// Return error on invalid sequence number.
   114  	if md.GetSeqNumber() > uint64(seq) {
   115  		return false, p2ptypes.ErrInvalidSequenceNum
   116  	}
   117  	return md.GetSeqNumber() <= uint64(seq), nil
   118  }