github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/network/helper_test.go (about)

     1  package network
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  	"sync/atomic"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/nspcc-dev/neo-go/internal/fakechain"
    14  	"github.com/nspcc-dev/neo-go/pkg/config"
    15  	"github.com/nspcc-dev/neo-go/pkg/io"
    16  	"github.com/nspcc-dev/neo-go/pkg/network/payload"
    17  	"github.com/stretchr/testify/require"
    18  	"go.uber.org/zap/zaptest"
    19  )
    20  
    21  type testDiscovery struct {
    22  	sync.Mutex
    23  	bad          []string
    24  	connected    []string
    25  	unregistered []string
    26  	backfill     []string
    27  }
    28  
    29  func newTestDiscovery([]string, time.Duration, Transporter) Discoverer { return new(testDiscovery) }
    30  
    31  func (d *testDiscovery) BackFill(addrs ...string) {
    32  	d.Lock()
    33  	defer d.Unlock()
    34  	d.backfill = append(d.backfill, addrs...)
    35  }
    36  func (d *testDiscovery) PoolCount() int { return 0 }
    37  func (d *testDiscovery) RegisterSelf(p AddressablePeer) {
    38  	d.Lock()
    39  	defer d.Unlock()
    40  	d.bad = append(d.bad, p.ConnectionAddr())
    41  }
    42  func (d *testDiscovery) GetFanOut() int {
    43  	d.Lock()
    44  	defer d.Unlock()
    45  	return (len(d.connected) + len(d.backfill)) * 2 / 3
    46  }
    47  func (d *testDiscovery) NetworkSize() int {
    48  	d.Lock()
    49  	defer d.Unlock()
    50  	return len(d.connected) + len(d.backfill)
    51  }
    52  func (d *testDiscovery) RegisterGood(AddressablePeer) {}
    53  func (d *testDiscovery) RegisterConnected(p AddressablePeer) {
    54  	d.Lock()
    55  	defer d.Unlock()
    56  	d.connected = append(d.connected, p.ConnectionAddr())
    57  }
    58  func (d *testDiscovery) UnregisterConnected(p AddressablePeer, force bool) {
    59  	d.Lock()
    60  	defer d.Unlock()
    61  	d.unregistered = append(d.unregistered, p.ConnectionAddr())
    62  }
    63  func (d *testDiscovery) UnconnectedPeers() []string {
    64  	d.Lock()
    65  	defer d.Unlock()
    66  	return d.unregistered
    67  }
    68  func (d *testDiscovery) RequestRemote(n int) {}
    69  func (d *testDiscovery) BadPeers() []string {
    70  	d.Lock()
    71  	defer d.Unlock()
    72  	return d.bad
    73  }
    74  func (d *testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} }
    75  
    76  var defaultMessageHandler = func(t *testing.T, msg *Message) {}
    77  
    78  type localPeer struct {
    79  	netaddr        net.TCPAddr
    80  	server         *Server
    81  	version        *payload.Version
    82  	lastBlockIndex uint32
    83  	handshaked     int32 // TODO: use atomic.Bool after #2626.
    84  	isFullNode     bool
    85  	t              *testing.T
    86  	messageHandler func(t *testing.T, msg *Message)
    87  	pingSent       int
    88  	getAddrSent    int
    89  	droppedWith    atomic.Value
    90  }
    91  
    92  func newLocalPeer(t *testing.T, s *Server) *localPeer {
    93  	naddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0")
    94  	return &localPeer{
    95  		t:              t,
    96  		server:         s,
    97  		netaddr:        *naddr,
    98  		messageHandler: defaultMessageHandler,
    99  	}
   100  }
   101  
   102  func (p *localPeer) ConnectionAddr() string {
   103  	return p.netaddr.String()
   104  }
   105  func (p *localPeer) RemoteAddr() net.Addr {
   106  	return &p.netaddr
   107  }
   108  func (p *localPeer) PeerAddr() net.Addr {
   109  	return &p.netaddr
   110  }
   111  func (p *localPeer) StartProtocol() {}
   112  func (p *localPeer) Disconnect(err error) {
   113  	if p.droppedWith.Load() == nil {
   114  		p.droppedWith.Store(err)
   115  	}
   116  	fmt.Println("peer dropped:", err)
   117  	p.server.unregister <- peerDrop{p, err}
   118  }
   119  
   120  func (p *localPeer) BroadcastPacket(_ context.Context, m []byte) error {
   121  	if len(m) == 0 {
   122  		return errors.New("empty msg")
   123  	}
   124  	msg := &Message{}
   125  	r := io.NewBinReaderFromBuf(m)
   126  	for r.Len() > 0 {
   127  		err := msg.Decode(r)
   128  		if err == nil {
   129  			p.messageHandler(p.t, msg)
   130  		}
   131  	}
   132  	return nil
   133  }
   134  func (p *localPeer) EnqueueP2PMessage(msg *Message) error {
   135  	return p.EnqueueHPMessage(msg)
   136  }
   137  func (p *localPeer) EnqueueP2PPacket(m []byte) error {
   138  	return p.BroadcastPacket(context.TODO(), m)
   139  }
   140  func (p *localPeer) BroadcastHPPacket(ctx context.Context, m []byte) error {
   141  	return p.BroadcastPacket(ctx, m)
   142  }
   143  func (p *localPeer) EnqueueHPMessage(msg *Message) error {
   144  	p.messageHandler(p.t, msg)
   145  	return nil
   146  }
   147  func (p *localPeer) EnqueueHPPacket(m []byte) error {
   148  	return p.BroadcastPacket(context.TODO(), m)
   149  }
   150  func (p *localPeer) Version() *payload.Version {
   151  	return p.version
   152  }
   153  func (p *localPeer) LastBlockIndex() uint32 {
   154  	return p.lastBlockIndex
   155  }
   156  func (p *localPeer) HandleVersion(v *payload.Version) error {
   157  	p.version = v
   158  	return nil
   159  }
   160  func (p *localPeer) SendVersion() error {
   161  	m, err := p.server.getVersionMsg(nil)
   162  	if err != nil {
   163  		return err
   164  	}
   165  	_ = p.EnqueueHPMessage(m)
   166  	return nil
   167  }
   168  func (p *localPeer) SendVersionAck(m *Message) error {
   169  	_ = p.EnqueueHPMessage(m)
   170  	return nil
   171  }
   172  func (p *localPeer) HandleVersionAck() error {
   173  	atomic.StoreInt32(&p.handshaked, 1)
   174  	return nil
   175  }
   176  func (p *localPeer) SetPingTimer() {
   177  	p.pingSent++
   178  }
   179  func (p *localPeer) HandlePing(ping *payload.Ping) error {
   180  	p.lastBlockIndex = ping.LastBlockIndex
   181  	return nil
   182  }
   183  
   184  func (p *localPeer) HandlePong(pong *payload.Ping) error {
   185  	p.lastBlockIndex = pong.LastBlockIndex
   186  	p.pingSent--
   187  	return nil
   188  }
   189  
   190  func (p *localPeer) Handshaked() bool {
   191  	return atomic.LoadInt32(&p.handshaked) != 0
   192  }
   193  
   194  func (p *localPeer) IsFullNode() bool {
   195  	return p.isFullNode
   196  }
   197  
   198  func (p *localPeer) AddGetAddrSent() {
   199  	p.getAddrSent++
   200  }
   201  func (p *localPeer) CanProcessAddr() bool {
   202  	p.getAddrSent--
   203  	return p.getAddrSent >= 0
   204  }
   205  
   206  func newTestServer(t *testing.T, serverConfig ServerConfig) *Server {
   207  	return newTestServerWithCustomCfg(t, serverConfig, nil)
   208  }
   209  
   210  func newTestServerWithCustomCfg(t *testing.T, serverConfig ServerConfig, protocolCfg func(*config.Blockchain)) *Server {
   211  	if len(serverConfig.Addresses) == 0 {
   212  		// Normally it will be done by ApplicationConfiguration.GetAddresses().
   213  		serverConfig.Addresses = []config.AnnounceableAddress{{Address: ":0"}}
   214  	}
   215  	s, err := newServerFromConstructors(serverConfig, fakechain.NewFakeChainWithCustomCfg(protocolCfg), new(fakechain.FakeStateSync), zaptest.NewLogger(t),
   216  		newFakeTransp, newTestDiscovery)
   217  	require.NoError(t, err)
   218  	return s
   219  }