github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/protocols/protocol_test.go (about)

     1  package protocols
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/neatio-net/neatio/network/p2p"
    11  	"github.com/neatio-net/neatio/network/p2p/discover"
    12  	"github.com/neatio-net/neatio/network/p2p/simulations/adapters"
    13  	p2ptest "github.com/neatio-net/neatio/network/p2p/testing"
    14  )
    15  
    16  type hs0 struct {
    17  	C uint
    18  }
    19  
    20  type kill struct {
    21  	C discover.NodeID
    22  }
    23  
    24  type drop struct {
    25  }
    26  
    27  type protoHandshake struct {
    28  	Version   uint
    29  	NetworkID string
    30  }
    31  
    32  func checkProtoHandshake(testVersion uint, testNetworkID string) func(interface{}) error {
    33  	return func(rhs interface{}) error {
    34  		remote := rhs.(*protoHandshake)
    35  		if remote.NetworkID != testNetworkID {
    36  			return fmt.Errorf("%s (!= %s)", remote.NetworkID, testNetworkID)
    37  		}
    38  
    39  		if remote.Version != testVersion {
    40  			return fmt.Errorf("%d (!= %d)", remote.Version, testVersion)
    41  		}
    42  		return nil
    43  	}
    44  }
    45  
    46  func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) error {
    47  	spec := &Spec{
    48  		Name:       "test",
    49  		Version:    42,
    50  		MaxMsgSize: 10 * 1024,
    51  		Messages: []interface{}{
    52  			protoHandshake{},
    53  			hs0{},
    54  			kill{},
    55  			drop{},
    56  		},
    57  	}
    58  	return func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
    59  		peer := NewPeer(p, rw, spec)
    60  
    61  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    62  		defer cancel()
    63  		phs := &protoHandshake{42, "420"}
    64  		hsCheck := checkProtoHandshake(phs.Version, phs.NetworkID)
    65  		_, err := peer.Handshake(ctx, phs, hsCheck)
    66  		if err != nil {
    67  			return err
    68  		}
    69  
    70  		lhs := &hs0{42}
    71  
    72  		hs, err := peer.Handshake(ctx, lhs, nil)
    73  		if err != nil {
    74  			return err
    75  		}
    76  
    77  		if rmhs := hs.(*hs0); rmhs.C > lhs.C {
    78  			return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C)
    79  		}
    80  
    81  		handle := func(msg interface{}) error {
    82  			switch msg := msg.(type) {
    83  
    84  			case *protoHandshake:
    85  				return errors.New("duplicate handshake")
    86  
    87  			case *hs0:
    88  				rhs := msg
    89  				if rhs.C > lhs.C {
    90  					return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C)
    91  				}
    92  				lhs.C += rhs.C
    93  				return peer.Send(lhs)
    94  
    95  			case *kill:
    96  
    97  				id := msg.C
    98  				pp.Get(id).Drop(errors.New("killed"))
    99  				return nil
   100  
   101  			case *drop:
   102  
   103  				return errors.New("dropped")
   104  
   105  			default:
   106  				return fmt.Errorf("unknown message type: %T", msg)
   107  			}
   108  		}
   109  
   110  		pp.Add(peer)
   111  		defer pp.Remove(peer)
   112  		return peer.Run(handle)
   113  	}
   114  }
   115  
   116  func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester {
   117  	conf := adapters.RandomNodeConfig()
   118  	return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp))
   119  }
   120  
   121  func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange {
   122  
   123  	return []p2ptest.Exchange{
   124  		{
   125  			Expects: []p2ptest.Expect{
   126  				{
   127  					Code: 0,
   128  					Msg:  &protoHandshake{42, "420"},
   129  					Peer: id,
   130  				},
   131  			},
   132  		},
   133  		{
   134  			Triggers: []p2ptest.Trigger{
   135  				{
   136  					Code: 0,
   137  					Msg:  proto,
   138  					Peer: id,
   139  				},
   140  			},
   141  		},
   142  	}
   143  }
   144  
   145  func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) {
   146  	pp := p2ptest.NewTestPeerPool()
   147  	s := protocolTester(t, pp)
   148  
   149  	id := s.IDs[0]
   150  	if err := s.TestExchanges(protoHandshakeExchange(id, proto)...); err != nil {
   151  		t.Fatal(err)
   152  	}
   153  	var disconnects []*p2ptest.Disconnect
   154  	for i, err := range errs {
   155  		disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err})
   156  	}
   157  	if err := s.TestDisconnected(disconnects...); err != nil {
   158  		t.Fatal(err)
   159  	}
   160  }
   161  
   162  func TestProtoHandshakeVersionMismatch(t *testing.T) {
   163  	runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error()))
   164  }
   165  
   166  func TestProtoHandshakeNetworkIDMismatch(t *testing.T) {
   167  	runProtoHandshake(t, &protoHandshake{42, "421"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 421 (!= 420)").Error()))
   168  }
   169  
   170  func TestProtoHandshakeSuccess(t *testing.T) {
   171  	runProtoHandshake(t, &protoHandshake{42, "420"})
   172  }
   173  
   174  func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange {
   175  
   176  	return []p2ptest.Exchange{
   177  		{
   178  			Expects: []p2ptest.Expect{
   179  				{
   180  					Code: 1,
   181  					Msg:  &hs0{42},
   182  					Peer: id,
   183  				},
   184  			},
   185  		},
   186  		{
   187  			Triggers: []p2ptest.Trigger{
   188  				{
   189  					Code: 1,
   190  					Msg:  &hs0{resp},
   191  					Peer: id,
   192  				},
   193  			},
   194  		},
   195  	}
   196  }
   197  
   198  func runModuleHandshake(t *testing.T, resp uint, errs ...error) {
   199  	pp := p2ptest.NewTestPeerPool()
   200  	s := protocolTester(t, pp)
   201  	id := s.IDs[0]
   202  	if err := s.TestExchanges(protoHandshakeExchange(id, &protoHandshake{42, "420"})...); err != nil {
   203  		t.Fatal(err)
   204  	}
   205  	if err := s.TestExchanges(moduleHandshakeExchange(id, resp)...); err != nil {
   206  		t.Fatal(err)
   207  	}
   208  	var disconnects []*p2ptest.Disconnect
   209  	for i, err := range errs {
   210  		disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err})
   211  	}
   212  	if err := s.TestDisconnected(disconnects...); err != nil {
   213  		t.Fatal(err)
   214  	}
   215  }
   216  
   217  func TestModuleHandshakeError(t *testing.T) {
   218  	runModuleHandshake(t, 43, fmt.Errorf("handshake mismatch remote 43 > local 42"))
   219  }
   220  
   221  func TestModuleHandshakeSuccess(t *testing.T) {
   222  	runModuleHandshake(t, 42)
   223  }
   224  
   225  func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange {
   226  
   227  	return []p2ptest.Exchange{
   228  		{
   229  			Label: "primary handshake",
   230  			Expects: []p2ptest.Expect{
   231  				{
   232  					Code: 0,
   233  					Msg:  &protoHandshake{42, "420"},
   234  					Peer: a,
   235  				},
   236  				{
   237  					Code: 0,
   238  					Msg:  &protoHandshake{42, "420"},
   239  					Peer: b,
   240  				},
   241  			},
   242  		},
   243  		{
   244  			Label: "module handshake",
   245  			Triggers: []p2ptest.Trigger{
   246  				{
   247  					Code: 0,
   248  					Msg:  &protoHandshake{42, "420"},
   249  					Peer: a,
   250  				},
   251  				{
   252  					Code: 0,
   253  					Msg:  &protoHandshake{42, "420"},
   254  					Peer: b,
   255  				},
   256  			},
   257  			Expects: []p2ptest.Expect{
   258  				{
   259  					Code: 1,
   260  					Msg:  &hs0{42},
   261  					Peer: a,
   262  				},
   263  				{
   264  					Code: 1,
   265  					Msg:  &hs0{42},
   266  					Peer: b,
   267  				},
   268  			},
   269  		},
   270  
   271  		{Label: "alternative module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{41}, Peer: a},
   272  			{Code: 1, Msg: &hs0{41}, Peer: b}}},
   273  		{Label: "repeated module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{1}, Peer: a}}},
   274  		{Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{{Code: 1, Msg: &hs0{43}, Peer: a}}}}
   275  }
   276  
   277  func runMultiplePeers(t *testing.T, peer int, errs ...error) {
   278  	pp := p2ptest.NewTestPeerPool()
   279  	s := protocolTester(t, pp)
   280  
   281  	if err := s.TestExchanges(testMultiPeerSetup(s.IDs[0], s.IDs[1])...); err != nil {
   282  		t.Fatal(err)
   283  	}
   284  
   285  	tick := time.NewTicker(10 * time.Millisecond)
   286  	timeout := time.NewTimer(1 * time.Second)
   287  WAIT:
   288  	for {
   289  		select {
   290  		case <-tick.C:
   291  			if pp.Has(s.IDs[0]) {
   292  				break WAIT
   293  			}
   294  		case <-timeout.C:
   295  			t.Fatal("timeout")
   296  		}
   297  	}
   298  	if !pp.Has(s.IDs[1]) {
   299  		t.Fatalf("missing peer test-1: %v (%v)", pp, s.IDs)
   300  	}
   301  
   302  	err := s.TestExchanges(p2ptest.Exchange{
   303  		Triggers: []p2ptest.Trigger{
   304  			{
   305  				Code: 2,
   306  				Msg:  &kill{s.IDs[peer]},
   307  				Peer: s.IDs[0],
   308  			},
   309  		},
   310  	})
   311  
   312  	if err != nil {
   313  		t.Fatal(err)
   314  	}
   315  
   316  	err = s.TestExchanges(p2ptest.Exchange{
   317  		Triggers: []p2ptest.Trigger{
   318  			{
   319  				Code: 3,
   320  				Msg:  &drop{},
   321  				Peer: s.IDs[(peer+1)%2],
   322  			},
   323  		},
   324  	})
   325  
   326  	if err != nil {
   327  		t.Fatal(err)
   328  	}
   329  
   330  	var disconnects []*p2ptest.Disconnect
   331  	for i, err := range errs {
   332  		disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err})
   333  	}
   334  	if err := s.TestDisconnected(disconnects...); err != nil {
   335  		t.Fatal(err)
   336  	}
   337  
   338  	if pp.Has(s.IDs[peer]) {
   339  		t.Fatalf("peer test-%v not dropped: %v (%v)", peer, pp, s.IDs)
   340  	}
   341  
   342  }
   343  
   344  func TestMultiplePeersDropSelf(t *testing.T) {
   345  	runMultiplePeers(t, 0,
   346  		fmt.Errorf("subprotocol error"),
   347  		fmt.Errorf("Message handler error: (msg code 3): dropped"),
   348  	)
   349  }
   350  
   351  func TestMultiplePeersDropOther(t *testing.T) {
   352  	runMultiplePeers(t, 1,
   353  		fmt.Errorf("Message handler error: (msg code 3): dropped"),
   354  		fmt.Errorf("subprotocol error"),
   355  	)
   356  }