github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/network/extpool/pool_test.go (about)

     1  package extpool
     2  
     3  import (
     4  	"errors"
     5  	"testing"
     6  
     7  	"github.com/nspcc-dev/neo-go/pkg/core/transaction"
     8  	"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
     9  	"github.com/nspcc-dev/neo-go/pkg/network/payload"
    10  	"github.com/nspcc-dev/neo-go/pkg/util"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func TestAddGet(t *testing.T) {
    15  	bc := newTestChain()
    16  	bc.height = 10
    17  
    18  	p := New(bc, 100)
    19  	t.Run("invalid witness", func(t *testing.T) {
    20  		ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}}
    21  		p.testAdd(t, false, errVerification, ep)
    22  	})
    23  	t.Run("disallowed sender", func(t *testing.T) {
    24  		ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x41}}
    25  		p.testAdd(t, false, errDisallowedSender, ep)
    26  	})
    27  	t.Run("bad height", func(t *testing.T) {
    28  		ep := &payload.Extensible{ValidBlockEnd: 9}
    29  		p.testAdd(t, false, errInvalidHeight, ep)
    30  
    31  		ep = &payload.Extensible{ValidBlockEnd: 10}
    32  		p.testAdd(t, false, nil, ep)
    33  	})
    34  	t.Run("good", func(t *testing.T) {
    35  		ep := &payload.Extensible{ValidBlockEnd: 100}
    36  		p.testAdd(t, true, nil, ep)
    37  		require.Equal(t, ep, p.Get(ep.Hash()))
    38  
    39  		p.testAdd(t, false, nil, ep)
    40  	})
    41  }
    42  
    43  func TestCapacityLimit(t *testing.T) {
    44  	bc := newTestChain()
    45  	bc.height = 10
    46  
    47  	t.Run("invalid capacity", func(t *testing.T) {
    48  		require.Panics(t, func() { New(bc, 0) })
    49  	})
    50  
    51  	p := New(bc, 3)
    52  
    53  	first := &payload.Extensible{ValidBlockEnd: 11}
    54  	p.testAdd(t, true, nil, first)
    55  
    56  	for _, height := range []uint32{12, 13} {
    57  		ep := &payload.Extensible{ValidBlockEnd: height}
    58  		p.testAdd(t, true, nil, ep)
    59  	}
    60  
    61  	require.NotNil(t, p.Get(first.Hash()))
    62  
    63  	ok, err := p.Add(&payload.Extensible{ValidBlockEnd: 14})
    64  	require.True(t, ok)
    65  	require.NoError(t, err)
    66  
    67  	require.Nil(t, p.Get(first.Hash()))
    68  }
    69  
    70  // This test checks that sender count is updated
    71  // when oldest payload is removed during `Add`.
    72  func TestDecreaseSenderOnEvict(t *testing.T) {
    73  	bc := newTestChain()
    74  	bc.height = 10
    75  
    76  	p := New(bc, 2)
    77  	senders := []util.Uint160{{1}, {2}, {3}}
    78  	for i := uint32(11); i < 17; i++ {
    79  		ep := &payload.Extensible{Sender: senders[i%3], ValidBlockEnd: i}
    80  		p.testAdd(t, true, nil, ep)
    81  	}
    82  }
    83  
    84  func TestRemoveStale(t *testing.T) {
    85  	bc := newTestChain()
    86  	bc.height = 10
    87  
    88  	p := New(bc, 100)
    89  	eps := []*payload.Extensible{
    90  		{ValidBlockEnd: 11},                             // small height
    91  		{ValidBlockEnd: 12},                             // good
    92  		{Sender: util.Uint160{0x11}, ValidBlockEnd: 12}, // invalid sender
    93  		{Sender: util.Uint160{0x12}, ValidBlockEnd: 12}, // invalid witness
    94  	}
    95  	for i := range eps {
    96  		p.testAdd(t, true, nil, eps[i])
    97  	}
    98  	bc.verifyWitness = func(u util.Uint160) bool { return u[0] != 0x12 }
    99  	bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 }
   100  	p.RemoveStale(11)
   101  	require.Nil(t, p.Get(eps[0].Hash()))
   102  	require.Equal(t, eps[1], p.Get(eps[1].Hash()))
   103  	require.Nil(t, p.Get(eps[2].Hash()))
   104  	require.Nil(t, p.Get(eps[3].Hash()))
   105  }
   106  
   107  func (p *Pool) testAdd(t *testing.T, expectedOk bool, expectedErr error, ep *payload.Extensible) {
   108  	ok, err := p.Add(ep)
   109  	if expectedErr != nil {
   110  		require.ErrorIs(t, err, expectedErr)
   111  	} else {
   112  		require.NoError(t, err)
   113  	}
   114  	require.Equal(t, expectedOk, ok)
   115  }
   116  
   117  type testChain struct {
   118  	Ledger
   119  	height        uint32
   120  	verifyWitness func(util.Uint160) bool
   121  	isAllowed     func(util.Uint160) bool
   122  }
   123  
   124  var errVerification = errors.New("verification failed")
   125  
   126  func newTestChain() *testChain {
   127  	return &testChain{
   128  		verifyWitness: func(u util.Uint160) bool {
   129  			return u[0] != 0x42
   130  		},
   131  		isAllowed: func(u util.Uint160) bool {
   132  			return u[0] != 0x42 && u[0] != 0x41
   133  		},
   134  	}
   135  }
   136  func (c *testChain) VerifyWitness(u util.Uint160, _ hash.Hashable, _ *transaction.Witness, _ int64) (int64, error) {
   137  	if !c.verifyWitness(u) {
   138  		return 0, errVerification
   139  	}
   140  	return 0, nil
   141  }
   142  func (c *testChain) IsExtensibleAllowed(u util.Uint160) bool {
   143  	return c.isAllowed(u)
   144  }
   145  func (c *testChain) BlockHeight() uint32 { return c.height }