github.com/koko1123/flow-go-1@v0.29.6/network/p2p/connection/connection_gater_test.go (about)

     1  package connection_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/libp2p/go-libp2p/core/control"
    11  	"github.com/libp2p/go-libp2p/core/network"
    12  	"github.com/libp2p/go-libp2p/core/peer"
    13  	"github.com/libp2p/go-libp2p/p2p/net/swarm"
    14  	"github.com/stretchr/testify/mock"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/koko1123/flow-go-1/model/flow"
    18  	"github.com/koko1123/flow-go-1/module/irrecoverable"
    19  	"github.com/koko1123/flow-go-1/network/channels"
    20  	"github.com/koko1123/flow-go-1/network/internal/p2pfixtures"
    21  	"github.com/koko1123/flow-go-1/network/internal/testutils"
    22  	"github.com/koko1123/flow-go-1/network/p2p"
    23  	mockp2p "github.com/koko1123/flow-go-1/network/p2p/mock"
    24  	p2ptest "github.com/koko1123/flow-go-1/network/p2p/test"
    25  	"github.com/koko1123/flow-go-1/utils/unittest"
    26  )
    27  
    28  // TestConnectionGating tests node allow listing by peer ID.
    29  func TestConnectionGating(t *testing.T) {
    30  	ctx, cancel := context.WithCancel(context.Background())
    31  	signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx)
    32  
    33  	sporkID := unittest.IdentifierFixture()
    34  
    35  	// create 2 nodes
    36  	node1Peers := unittest.NewProtectedMap[peer.ID, struct{}]()
    37  	node1, node1Id := p2ptest.NodeFixture(
    38  		t,
    39  		sporkID,
    40  		t.Name(),
    41  		p2ptest.WithConnectionGater(testutils.NewConnectionGater(func(p peer.ID) error {
    42  			if !node1Peers.Has(p) {
    43  				return fmt.Errorf("id not found: %s", p.String())
    44  			}
    45  			return nil
    46  		})))
    47  
    48  	node2Peers := unittest.NewProtectedMap[peer.ID, struct{}]()
    49  	node2, node2Id := p2ptest.NodeFixture(
    50  		t,
    51  		sporkID,
    52  		t.Name(),
    53  		p2ptest.WithConnectionGater(testutils.NewConnectionGater(func(p peer.ID) error {
    54  			if !node2Peers.Has(p) {
    55  				return fmt.Errorf("id not found: %s", p.String())
    56  			}
    57  			return nil
    58  		})))
    59  
    60  	nodes := []p2p.LibP2PNode{node1, node2}
    61  	ids := flow.IdentityList{&node1Id, &node2Id}
    62  	p2ptest.StartNodes(t, signalerCtx, nodes, 100*time.Millisecond)
    63  	defer p2ptest.StopNodes(t, nodes, cancel, 100*time.Millisecond)
    64  
    65  	p2pfixtures.AddNodesToEachOthersPeerStore(t, nodes, ids)
    66  
    67  	t.Run("outbound connection to a disallowed node is rejected", func(t *testing.T) {
    68  		// although nodes have each other addresses, they are not in the allow-lists of each other.
    69  		// so they should not be able to connect to each other.
    70  		p2pfixtures.EnsureNoStreamCreationBetweenGroups(t, ctx, []p2p.LibP2PNode{node1}, []p2p.LibP2PNode{node2}, func(t *testing.T, err error) {
    71  			require.True(t, errors.Is(err, swarm.ErrGaterDisallowedConnection))
    72  		})
    73  	})
    74  
    75  	t.Run("inbound connection from an allowed node is rejected", func(t *testing.T) {
    76  		// for an inbound connection to be established both nodes should be in each other's allow-lists.
    77  		// the connection gater on the dialing node is checking the allow-list upon dialing.
    78  		// the connection gater on the listening node is checking the allow-list upon accepting the connection.
    79  
    80  		// add node2 to node1's allow list, but not the other way around.
    81  		node1Peers.Add(node2.Host().ID(), struct{}{})
    82  
    83  		// now node2 should be able to connect to node1.
    84  		// from node1 -> node2 shouldn't work
    85  		p2pfixtures.EnsureNoStreamCreation(t, ctx, []p2p.LibP2PNode{node1}, []p2p.LibP2PNode{node2})
    86  
    87  		// however, from node2 -> node1 should also NOT work, since node 1 is not in node2's allow list for dialing!
    88  		p2pfixtures.EnsureNoStreamCreation(t, ctx, []p2p.LibP2PNode{node2}, []p2p.LibP2PNode{node1})
    89  	})
    90  
    91  	t.Run("outbound connection to an approved node is allowed", func(t *testing.T) {
    92  		// adding both nodes to each other's allow lists.
    93  		node1Peers.Add(node2.Host().ID(), struct{}{})
    94  		node2Peers.Add(node1.Host().ID(), struct{}{})
    95  
    96  		// now both nodes should be able to connect to each other.
    97  		p2ptest.EnsureStreamCreationInBothDirections(t, ctx, []p2p.LibP2PNode{node1, node2})
    98  	})
    99  }
   100  
   101  // TestConnectionGater_InterceptUpgrade tests the connection gater only upgrades the connections to the allow-listed peers.
   102  // Upgrading a connection means that the connection is the last phase of the connection establishment process.
   103  // It means that the connection is ready to be used for sending and receiving messages.
   104  // It checks that no disallowed peer can upgrade the connection.
   105  func TestConnectionGater_InterceptUpgrade(t *testing.T) {
   106  	ctx, cancel := context.WithCancel(context.Background())
   107  	signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx)
   108  	sporkId := unittest.IdentifierFixture()
   109  	defer cancel()
   110  
   111  	count := 5
   112  	nodes := make([]p2p.LibP2PNode, 0, count)
   113  	inbounds := make([]chan string, 0, count)
   114  
   115  	disallowedPeerIds := unittest.NewProtectedMap[peer.ID, struct{}]()
   116  	allPeerIds := make(peer.IDSlice, 0, count)
   117  
   118  	connectionGater := mockp2p.NewConnectionGater(t)
   119  	for i := 0; i < count; i++ {
   120  		handler, inbound := p2ptest.StreamHandlerFixture(t)
   121  		node, _ := p2ptest.NodeFixture(
   122  			t,
   123  			sporkId,
   124  			t.Name(),
   125  			p2ptest.WithRole(flow.RoleConsensus),
   126  			p2ptest.WithDefaultStreamHandler(handler),
   127  			// enable peer manager, with a 1-second refresh rate, and connection pruning enabled.
   128  			p2ptest.WithPeerManagerEnabled(true, 1*time.Second, func() peer.IDSlice {
   129  				list := make(peer.IDSlice, 0)
   130  				for _, pid := range allPeerIds {
   131  					if !disallowedPeerIds.Has(pid) {
   132  						list = append(list, pid)
   133  					}
   134  				}
   135  				return list
   136  			}),
   137  			p2ptest.WithConnectionGater(connectionGater))
   138  
   139  		nodes = append(nodes, node)
   140  		allPeerIds = append(allPeerIds, node.Host().ID())
   141  		inbounds = append(inbounds, inbound)
   142  	}
   143  
   144  	connectionGater.On("InterceptSecured", mock.Anything, mock.Anything, mock.Anything).
   145  		Return(func(_ network.Direction, p peer.ID, _ network.ConnMultiaddrs) bool {
   146  			return !disallowedPeerIds.Has(p)
   147  		})
   148  
   149  	connectionGater.On("InterceptPeerDial", mock.Anything).Return(func(p peer.ID) bool {
   150  		return !disallowedPeerIds.Has(p)
   151  	})
   152  
   153  	// we don't inspect connections during "accept" and "dial" phases as the peer IDs are not available at those phases.
   154  	connectionGater.On("InterceptAddrDial", mock.Anything, mock.Anything).Return(true)
   155  	connectionGater.On("InterceptAccept", mock.Anything).Return(true)
   156  
   157  	// adds first node to disallowed list
   158  	disallowedPeerIds.Add(nodes[0].Host().ID(), struct{}{})
   159  
   160  	// starts the nodes
   161  	p2ptest.StartNodes(t, signalerCtx, nodes, 1*time.Second)
   162  	defer p2ptest.StopNodes(t, nodes, cancel, 1*time.Second)
   163  
   164  	ensureCommunicationSilenceAmongGroups(t, ctx, sporkId, nodes[:1], nodes[1:])
   165  
   166  	// Checks that only the allowed nodes can establish an upgradable connection.
   167  	// We intentionally mock this after checking for communication silence.
   168  	// As no connection to/from a disallowed node should ever reach the upgradable connection stage.
   169  	connectionGater.On("InterceptUpgraded", mock.Anything).Run(func(args mock.Arguments) {
   170  		conn, ok := args.Get(0).(network.Conn)
   171  		require.True(t, ok)
   172  
   173  		remote := conn.RemotePeer()
   174  		require.False(t, disallowedPeerIds.Has(remote))
   175  
   176  		local := conn.LocalPeer()
   177  		require.False(t, disallowedPeerIds.Has(local))
   178  	}).Return(true, control.DisconnectReason(0))
   179  
   180  	ensureCommunicationOverAllProtocols(t, ctx, sporkId, nodes[1:], inbounds[1:])
   181  }
   182  
   183  // TestConnectionGater_Disallow_Integration tests that when a peer is disallowed, it is disconnected from all other peers, and
   184  // cannot connect, exchange unicast, or pubsub messages to any other peers.
   185  // It also checked that the allowed peers can still communicate with each other.
   186  func TestConnectionGater_Disallow_Integration(t *testing.T) {
   187  	ctx, cancel := context.WithCancel(context.Background())
   188  	signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx)
   189  	sporkId := unittest.IdentifierFixture()
   190  	defer cancel()
   191  
   192  	count := 5
   193  	nodes := make([]p2p.LibP2PNode, 0, 5)
   194  	ids := flow.IdentityList{}
   195  	inbounds := make([]chan string, 0, 5)
   196  
   197  	disallowedList := unittest.NewProtectedMap[*flow.Identity, struct{}]()
   198  
   199  	for i := 0; i < count; i++ {
   200  		handler, inbound := p2ptest.StreamHandlerFixture(t)
   201  		node, id := p2ptest.NodeFixture(
   202  			t,
   203  			sporkId,
   204  			t.Name(),
   205  			p2ptest.WithRole(flow.RoleConsensus),
   206  			p2ptest.WithDefaultStreamHandler(handler),
   207  			// enable peer manager, with a 1-second refresh rate, and connection pruning enabled.
   208  			p2ptest.WithPeerManagerEnabled(true, 1*time.Second, func() peer.IDSlice {
   209  				list := make(peer.IDSlice, 0)
   210  				for _, id := range ids {
   211  					if disallowedList.Has(id) {
   212  						continue
   213  					}
   214  
   215  					pid, err := unittest.PeerIDFromFlowID(id)
   216  					require.NoError(t, err)
   217  
   218  					list = append(list, pid)
   219  				}
   220  				return list
   221  			}),
   222  			p2ptest.WithConnectionGater(testutils.NewConnectionGater(func(pid peer.ID) error {
   223  				return disallowedList.ForEach(func(id *flow.Identity, _ struct{}) error {
   224  					bid, err := unittest.PeerIDFromFlowID(id)
   225  					require.NoError(t, err)
   226  					if bid == pid {
   227  						return fmt.Errorf("disallow-listed")
   228  					}
   229  					return nil
   230  				})
   231  			})))
   232  
   233  		nodes = append(nodes, node)
   234  		ids = append(ids, &id)
   235  		inbounds = append(inbounds, inbound)
   236  	}
   237  
   238  	p2ptest.StartNodes(t, signalerCtx, nodes, 1*time.Second)
   239  	defer p2ptest.StopNodes(t, nodes, cancel, 1*time.Second)
   240  
   241  	p2ptest.LetNodesDiscoverEachOther(t, ctx, nodes, ids)
   242  
   243  	// ensures that all nodes are connected to each other, and they can exchange messages over the pubsub and unicast.
   244  	ensureCommunicationOverAllProtocols(t, ctx, sporkId, nodes, inbounds)
   245  
   246  	// now we add one of the nodes (the last node) to the disallow-list.
   247  	disallowedList.Add(ids[len(ids)-1], struct{}{})
   248  	// let peer manager prune the connections to the disallow-listed node.
   249  	time.Sleep(1 * time.Second)
   250  	// ensures no connection, unicast, or pubsub going to or coming from the disallow-listed node.
   251  	ensureCommunicationSilenceAmongGroups(t, ctx, sporkId, nodes[:count-1], nodes[count-1:])
   252  
   253  	// now we add another node (the second last node) to the disallowed list.
   254  	disallowedList.Add(ids[len(ids)-2], struct{}{})
   255  	// let peer manager prune the connections to the disallow-listed node.
   256  	time.Sleep(1 * time.Second)
   257  	// ensures no connection, unicast, or pubsub going to and coming from the disallow-listed nodes.
   258  	ensureCommunicationSilenceAmongGroups(t, ctx, sporkId, nodes[:count-2], nodes[count-2:])
   259  	// ensures that all nodes are other non-disallow-listed nodes can exchange messages over the pubsub and unicast.
   260  	ensureCommunicationOverAllProtocols(t, ctx, sporkId, nodes[:count-2], inbounds[:count-2])
   261  }
   262  
   263  // ensureCommunicationSilenceAmongGroups ensures no connection, unicast, or pubsub going to or coming from between the two groups of nodes.
   264  func ensureCommunicationSilenceAmongGroups(t *testing.T, ctx context.Context, sporkId flow.Identifier, groupA []p2p.LibP2PNode, groupB []p2p.LibP2PNode) {
   265  	// ensures no connection, unicast, or pubsub going to the disallow-listed nodes
   266  	p2pfixtures.EnsureNotConnectedBetweenGroups(t, ctx, groupA, groupB)
   267  	p2pfixtures.EnsureNoPubsubExchangeBetweenGroups(t, ctx, groupA, groupB, func() (interface{}, channels.Topic) {
   268  		blockTopic := channels.TopicFromChannel(channels.PushBlocks, sporkId)
   269  		return unittest.ProposalFixture(), blockTopic
   270  	})
   271  	p2pfixtures.EnsureNoStreamCreationBetweenGroups(t, ctx, groupA, groupB)
   272  }
   273  
   274  // ensureCommunicationOverAllProtocols ensures that all nodes are connected to each other, and they can exchange messages over the pubsub and unicast.
   275  func ensureCommunicationOverAllProtocols(t *testing.T, ctx context.Context, sporkId flow.Identifier, nodes []p2p.LibP2PNode, inbounds []chan string) {
   276  	p2ptest.EnsureConnected(t, ctx, nodes)
   277  	p2ptest.EnsurePubsubMessageExchange(t, ctx, nodes, func() (interface{}, channels.Topic) {
   278  		blockTopic := channels.TopicFromChannel(channels.PushBlocks, sporkId)
   279  		return unittest.ProposalFixture(), blockTopic
   280  	})
   281  	p2pfixtures.EnsureMessageExchangeOverUnicast(t, ctx, nodes, inbounds, p2pfixtures.LongStringMessageFactoryFixture(t))
   282  }