github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/network/extpool/pool_test.go (about) 1 package extpool 2 3 import ( 4 "errors" 5 "testing" 6 7 "github.com/nspcc-dev/neo-go/pkg/core/transaction" 8 "github.com/nspcc-dev/neo-go/pkg/crypto/hash" 9 "github.com/nspcc-dev/neo-go/pkg/network/payload" 10 "github.com/nspcc-dev/neo-go/pkg/util" 11 "github.com/stretchr/testify/require" 12 ) 13 14 func TestAddGet(t *testing.T) { 15 bc := newTestChain() 16 bc.height = 10 17 18 p := New(bc, 100) 19 t.Run("invalid witness", func(t *testing.T) { 20 ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}} 21 p.testAdd(t, false, errVerification, ep) 22 }) 23 t.Run("disallowed sender", func(t *testing.T) { 24 ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x41}} 25 p.testAdd(t, false, errDisallowedSender, ep) 26 }) 27 t.Run("bad height", func(t *testing.T) { 28 ep := &payload.Extensible{ValidBlockEnd: 9} 29 p.testAdd(t, false, errInvalidHeight, ep) 30 31 ep = &payload.Extensible{ValidBlockEnd: 10} 32 p.testAdd(t, false, nil, ep) 33 }) 34 t.Run("good", func(t *testing.T) { 35 ep := &payload.Extensible{ValidBlockEnd: 100} 36 p.testAdd(t, true, nil, ep) 37 require.Equal(t, ep, p.Get(ep.Hash())) 38 39 p.testAdd(t, false, nil, ep) 40 }) 41 } 42 43 func TestCapacityLimit(t *testing.T) { 44 bc := newTestChain() 45 bc.height = 10 46 47 t.Run("invalid capacity", func(t *testing.T) { 48 require.Panics(t, func() { New(bc, 0) }) 49 }) 50 51 p := New(bc, 3) 52 53 first := &payload.Extensible{ValidBlockEnd: 11} 54 p.testAdd(t, true, nil, first) 55 56 for _, height := range []uint32{12, 13} { 57 ep := &payload.Extensible{ValidBlockEnd: height} 58 p.testAdd(t, true, nil, ep) 59 } 60 61 require.NotNil(t, p.Get(first.Hash())) 62 63 ok, err := p.Add(&payload.Extensible{ValidBlockEnd: 14}) 64 require.True(t, ok) 65 require.NoError(t, err) 66 67 require.Nil(t, p.Get(first.Hash())) 68 } 69 70 // This test checks that sender count is updated 71 // when oldest payload is removed during `Add`. 72 func TestDecreaseSenderOnEvict(t *testing.T) { 73 bc := newTestChain() 74 bc.height = 10 75 76 p := New(bc, 2) 77 senders := []util.Uint160{{1}, {2}, {3}} 78 for i := uint32(11); i < 17; i++ { 79 ep := &payload.Extensible{Sender: senders[i%3], ValidBlockEnd: i} 80 p.testAdd(t, true, nil, ep) 81 } 82 } 83 84 func TestRemoveStale(t *testing.T) { 85 bc := newTestChain() 86 bc.height = 10 87 88 p := New(bc, 100) 89 eps := []*payload.Extensible{ 90 {ValidBlockEnd: 11}, // small height 91 {ValidBlockEnd: 12}, // good 92 {Sender: util.Uint160{0x11}, ValidBlockEnd: 12}, // invalid sender 93 {Sender: util.Uint160{0x12}, ValidBlockEnd: 12}, // invalid witness 94 } 95 for i := range eps { 96 p.testAdd(t, true, nil, eps[i]) 97 } 98 bc.verifyWitness = func(u util.Uint160) bool { return u[0] != 0x12 } 99 bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 } 100 p.RemoveStale(11) 101 require.Nil(t, p.Get(eps[0].Hash())) 102 require.Equal(t, eps[1], p.Get(eps[1].Hash())) 103 require.Nil(t, p.Get(eps[2].Hash())) 104 require.Nil(t, p.Get(eps[3].Hash())) 105 } 106 107 func (p *Pool) testAdd(t *testing.T, expectedOk bool, expectedErr error, ep *payload.Extensible) { 108 ok, err := p.Add(ep) 109 if expectedErr != nil { 110 require.ErrorIs(t, err, expectedErr) 111 } else { 112 require.NoError(t, err) 113 } 114 require.Equal(t, expectedOk, ok) 115 } 116 117 type testChain struct { 118 Ledger 119 height uint32 120 verifyWitness func(util.Uint160) bool 121 isAllowed func(util.Uint160) bool 122 } 123 124 var errVerification = errors.New("verification failed") 125 126 func newTestChain() *testChain { 127 return &testChain{ 128 verifyWitness: func(u util.Uint160) bool { 129 return u[0] != 0x42 130 }, 131 isAllowed: func(u util.Uint160) bool { 132 return u[0] != 0x42 && u[0] != 0x41 133 }, 134 } 135 } 136 func (c *testChain) VerifyWitness(u util.Uint160, _ hash.Hashable, _ *transaction.Witness, _ int64) (int64, error) { 137 if !c.verifyWitness(u) { 138 return 0, errVerification 139 } 140 return 0, nil 141 } 142 func (c *testChain) IsExtensibleAllowed(u util.Uint160) bool { 143 return c.isAllowed(u) 144 } 145 func (c *testChain) BlockHeight() uint32 { return c.height }