github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/neatptc/protocol_test.go (about)

     1  package neatptc
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/neatlab/neatio/chain/core/types"
    10  	"github.com/neatlab/neatio/neatptc/downloader"
    11  	"github.com/neatlab/neatio/network/p2p"
    12  	"github.com/neatlab/neatio/utilities/common"
    13  	"github.com/neatlab/neatio/utilities/crypto"
    14  	"github.com/neatlab/neatio/utilities/rlp"
    15  )
    16  
    17  func init() {
    18  
    19  }
    20  
    21  var testAccount, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
    22  
    23  func TestStatusMsgErrors62(t *testing.T) { testStatusMsgErrors(t, 62) }
    24  func TestStatusMsgErrors63(t *testing.T) { testStatusMsgErrors(t, 63) }
    25  
    26  func testStatusMsgErrors(t *testing.T, protocol int) {
    27  	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
    28  	var (
    29  		genesis = pm.blockchain.Genesis()
    30  		head    = pm.blockchain.CurrentHeader()
    31  		td      = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
    32  	)
    33  	defer pm.Stop()
    34  
    35  	tests := []struct {
    36  		code      uint64
    37  		data      interface{}
    38  		wantError error
    39  	}{
    40  		{
    41  			code: TxMsg, data: []interface{}{},
    42  			wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"),
    43  		},
    44  		{
    45  			code: StatusMsg, data: statusData{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash()},
    46  			wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", protocol),
    47  		},
    48  		{
    49  			code: StatusMsg, data: statusData{uint32(protocol), 999, td, head.Hash(), genesis.Hash()},
    50  			wantError: errResp(ErrNetworkIdMismatch, "999 (!= 1)"),
    51  		},
    52  		{
    53  			code: StatusMsg, data: statusData{uint32(protocol), DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}},
    54  			wantError: errResp(ErrGenesisBlockMismatch, "0300000000000000 (!= %x)", genesis.Hash().Bytes()[:8]),
    55  		},
    56  	}
    57  
    58  	for i, test := range tests {
    59  		p, errc := newTestPeer("peer", protocol, pm, false)
    60  
    61  		go p2p.Send(p.app, test.code, test.data)
    62  
    63  		select {
    64  		case err := <-errc:
    65  			if err == nil {
    66  				t.Errorf("test %d: protocol returned nil error, want %q", i, test.wantError)
    67  			} else if err.Error() != test.wantError.Error() {
    68  				t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.wantError)
    69  			}
    70  		case <-time.After(2 * time.Second):
    71  			t.Errorf("protocol did not shut down within 2 seconds")
    72  		}
    73  		p.close()
    74  	}
    75  }
    76  
    77  func TestRecvTransactions62(t *testing.T) { testRecvTransactions(t, 62) }
    78  func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) }
    79  
    80  func testRecvTransactions(t *testing.T, protocol int) {
    81  	txAdded := make(chan []*types.Transaction)
    82  	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded)
    83  	pm.acceptTxs = 1
    84  	p, _ := newTestPeer("peer", protocol, pm, true)
    85  	defer pm.Stop()
    86  	defer p.close()
    87  
    88  	tx := newTestTransaction(testAccount, 0, 0)
    89  	if err := p2p.Send(p.app, TxMsg, []interface{}{tx}); err != nil {
    90  		t.Fatalf("send error: %v", err)
    91  	}
    92  	select {
    93  	case added := <-txAdded:
    94  		if len(added) != 1 {
    95  			t.Errorf("wrong number of added transactions: got %d, want 1", len(added))
    96  		} else if added[0].Hash() != tx.Hash() {
    97  			t.Errorf("added wrong tx hash: got %v, want %v", added[0].Hash(), tx.Hash())
    98  		}
    99  	case <-time.After(2 * time.Second):
   100  		t.Errorf("no TxPreEvent received within 2 seconds")
   101  	}
   102  }
   103  
   104  func TestSendTransactions62(t *testing.T) { testSendTransactions(t, 62) }
   105  func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) }
   106  
   107  func testSendTransactions(t *testing.T, protocol int) {
   108  	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
   109  	defer pm.Stop()
   110  
   111  	const txsize = txsyncPackSize / 10
   112  	alltxs := make([]*types.Transaction, 100)
   113  	for nonce := range alltxs {
   114  		alltxs[nonce] = newTestTransaction(testAccount, uint64(nonce), txsize)
   115  	}
   116  	pm.txpool.AddRemotes(alltxs)
   117  
   118  	var wg sync.WaitGroup
   119  	checktxs := func(p *testPeer) {
   120  		defer wg.Done()
   121  		defer p.close()
   122  		seen := make(map[common.Hash]bool)
   123  		for _, tx := range alltxs {
   124  			seen[tx.Hash()] = false
   125  		}
   126  		for n := 0; n < len(alltxs) && !t.Failed(); {
   127  			var txs []*types.Transaction
   128  			msg, err := p.app.ReadMsg()
   129  			if err != nil {
   130  				t.Errorf("%v: read error: %v", p.Peer, err)
   131  			} else if msg.Code != TxMsg {
   132  				t.Errorf("%v: got code %d, want TxMsg", p.Peer, msg.Code)
   133  			}
   134  			if err := msg.Decode(&txs); err != nil {
   135  				t.Errorf("%v: %v", p.Peer, err)
   136  			}
   137  			for _, tx := range txs {
   138  				hash := tx.Hash()
   139  				seentx, want := seen[hash]
   140  				if seentx {
   141  					t.Errorf("%v: got tx more than once: %x", p.Peer, hash)
   142  				}
   143  				if !want {
   144  					t.Errorf("%v: got unexpected tx: %x", p.Peer, hash)
   145  				}
   146  				seen[hash] = true
   147  				n++
   148  			}
   149  		}
   150  	}
   151  	for i := 0; i < 3; i++ {
   152  		p, _ := newTestPeer(fmt.Sprintf("peer #%d", i), protocol, pm, true)
   153  		wg.Add(1)
   154  		go checktxs(p)
   155  	}
   156  	wg.Wait()
   157  }
   158  
   159  func TestGetBlockHeadersDataEncodeDecode(t *testing.T) {
   160  
   161  	var hash common.Hash
   162  	for i := range hash {
   163  		hash[i] = byte(i)
   164  	}
   165  
   166  	tests := []struct {
   167  		packet *getBlockHeadersData
   168  		fail   bool
   169  	}{
   170  
   171  		{fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}}},
   172  		{fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}}},
   173  
   174  		{fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}, Amount: 314, Skip: 1, Reverse: true}},
   175  		{fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: 314, Skip: 1, Reverse: true}},
   176  
   177  		{fail: true, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash, Number: 314}}},
   178  	}
   179  
   180  	for i, tt := range tests {
   181  		bytes, err := rlp.EncodeToBytes(tt.packet)
   182  		if err != nil && !tt.fail {
   183  			t.Fatalf("test %d: failed to encode packet: %v", i, err)
   184  		} else if err == nil && tt.fail {
   185  			t.Fatalf("test %d: encode should have failed", i)
   186  		}
   187  		if !tt.fail {
   188  			packet := new(getBlockHeadersData)
   189  			if err := rlp.DecodeBytes(bytes, packet); err != nil {
   190  				t.Fatalf("test %d: failed to decode packet: %v", i, err)
   191  			}
   192  			if packet.Origin.Hash != tt.packet.Origin.Hash || packet.Origin.Number != tt.packet.Origin.Number || packet.Amount != tt.packet.Amount ||
   193  				packet.Skip != tt.packet.Skip || packet.Reverse != tt.packet.Reverse {
   194  				t.Fatalf("test %d: encode decode mismatch: have %+v, want %+v", i, packet, tt.packet)
   195  			}
   196  		}
   197  	}
   198  }