github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/sync/rate_limiter_test.go (about) 1 package sync 2 3 import ( 4 "context" 5 "sync" 6 "testing" 7 "time" 8 9 "github.com/libp2p/go-libp2p-core/network" 10 "github.com/libp2p/go-libp2p-core/protocol" 11 "github.com/prysmaticlabs/prysm/beacon-chain/p2p" 12 mockp2p "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing" 13 p2ptypes "github.com/prysmaticlabs/prysm/beacon-chain/p2p/types" 14 "github.com/prysmaticlabs/prysm/shared/testutil" 15 "github.com/prysmaticlabs/prysm/shared/testutil/assert" 16 "github.com/prysmaticlabs/prysm/shared/testutil/require" 17 ) 18 19 func TestNewRateLimiter(t *testing.T) { 20 rlimiter := newRateLimiter(mockp2p.NewTestP2P(t)) 21 assert.Equal(t, len(rlimiter.limiterMap), 7, "correct number of topics not registered") 22 } 23 24 func TestNewRateLimiter_FreeCorrectly(t *testing.T) { 25 rlimiter := newRateLimiter(mockp2p.NewTestP2P(t)) 26 rlimiter.free() 27 assert.Equal(t, len(rlimiter.limiterMap), 0, "rate limiter not freed correctly") 28 29 } 30 31 func TestRateLimiter_ExceedCapacity(t *testing.T) { 32 p1 := mockp2p.NewTestP2P(t) 33 p2 := mockp2p.NewTestP2P(t) 34 p1.Connect(p2) 35 rlimiter := newRateLimiter(p1) 36 37 // BlockByRange 38 topic := p2p.RPCBlocksByRangeTopicV1 + p1.Encoding().ProtocolSuffix() 39 40 wg := sync.WaitGroup{} 41 p2.BHost.SetStreamHandler(protocol.ID(topic), func(stream network.Stream) { 42 defer wg.Done() 43 code, errMsg, err := readStatusCodeNoDeadline(stream, p2.Encoding()) 44 require.NoError(t, err, "could not read incoming stream") 45 assert.Equal(t, responseCodeInvalidRequest, code, "not equal response codes") 46 assert.Equal(t, p2ptypes.ErrRateLimited.Error(), errMsg, "not equal errors") 47 }) 48 wg.Add(1) 49 stream, err := p1.BHost.NewStream(context.Background(), p2.PeerID(), protocol.ID(topic)) 50 require.NoError(t, err, "could not create stream") 51 52 err = rlimiter.validateRequest(stream, 64) 53 require.NoError(t, err, "could not validate incoming request") 54 55 // Attempt to create an error, rate limit and lead to disconnect 56 err = rlimiter.validateRequest(stream, 1000) 57 require.NotNil(t, err, "could not get error from leaky bucket") 58 59 require.NoError(t, stream.Close(), "could not close stream") 60 61 if testutil.WaitTimeout(&wg, 1*time.Second) { 62 t.Fatal("Did not receive stream within 1 sec") 63 } 64 } 65 66 func TestRateLimiter_ExceedRawCapacity(t *testing.T) { 67 p1 := mockp2p.NewTestP2P(t) 68 p2 := mockp2p.NewTestP2P(t) 69 p1.Connect(p2) 70 p1.Peers().Add(nil, p2.PeerID(), p2.BHost.Addrs()[0], network.DirOutbound) 71 72 rlimiter := newRateLimiter(p1) 73 74 // BlockByRange 75 topic := p2p.RPCBlocksByRangeTopicV1 + p1.Encoding().ProtocolSuffix() 76 77 wg := sync.WaitGroup{} 78 p2.BHost.SetStreamHandler(protocol.ID(topic), func(stream network.Stream) { 79 defer wg.Done() 80 code, errMsg, err := readStatusCodeNoDeadline(stream, p2.Encoding()) 81 require.NoError(t, err, "could not read incoming stream") 82 assert.Equal(t, responseCodeInvalidRequest, code, "not equal response codes") 83 assert.Equal(t, p2ptypes.ErrRateLimited.Error(), errMsg, "not equal errors") 84 }) 85 wg.Add(1) 86 stream, err := p1.BHost.NewStream(context.Background(), p2.PeerID(), protocol.ID(topic)) 87 require.NoError(t, err, "could not create stream") 88 89 for i := 0; i < 2*defaultBurstLimit; i++ { 90 err = rlimiter.validateRawRpcRequest(stream) 91 rlimiter.addRawStream(stream) 92 require.NoError(t, err, "could not validate incoming request") 93 } 94 // Triggers rate limit error on burst. 95 assert.ErrorContains(t, p2ptypes.ErrRateLimited.Error(), rlimiter.validateRawRpcRequest(stream)) 96 97 // Make Peer bad. 98 for i := 0; i < defaultBurstLimit; i++ { 99 assert.ErrorContains(t, p2ptypes.ErrRateLimited.Error(), rlimiter.validateRawRpcRequest(stream)) 100 } 101 assert.Equal(t, true, p1.Peers().IsBad(p2.PeerID()), "peer is not marked as a bad peer") 102 require.NoError(t, stream.Close(), "could not close stream") 103 104 if testutil.WaitTimeout(&wg, 1*time.Second) { 105 t.Fatal("Did not receive stream within 1 sec") 106 } 107 } 108 109 func Test_limiter_retrieveCollector_requiresLock(t *testing.T) { 110 l := limiter{} 111 _, err := l.retrieveCollector("") 112 require.ErrorContains(t, "caller must hold read/write lock", err) 113 }