github.com/ethersphere/bee/v2@v2.2.0/pkg/p2p/libp2p/protocols_test.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package libp2p_test
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"sync"
    11  	"sync/atomic"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/ethersphere/bee/v2/pkg/p2p"
    16  	"github.com/ethersphere/bee/v2/pkg/p2p/libp2p"
    17  	"github.com/ethersphere/bee/v2/pkg/spinlock"
    18  	libp2pm "github.com/libp2p/go-libp2p"
    19  	"github.com/libp2p/go-libp2p/core/host"
    20  	protocol "github.com/libp2p/go-libp2p/core/protocol"
    21  	bhost "github.com/libp2p/go-libp2p/p2p/host/basic"
    22  	swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
    23  	"github.com/multiformats/go-multistream"
    24  )
    25  
    26  func TestNewStream(t *testing.T) {
    27  	t.Parallel()
    28  
    29  	ctx, cancel := context.WithCancel(context.Background())
    30  	defer cancel()
    31  
    32  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
    33  		FullNode: true,
    34  	}})
    35  
    36  	s2, _ := newService(t, 1, libp2pServiceOpts{})
    37  
    38  	if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
    39  		return nil
    40  	})); err != nil {
    41  		t.Fatal(err)
    42  	}
    43  
    44  	addr := serviceUnderlayAddress(t, s1)
    45  
    46  	if _, err := s2.Connect(ctx, addr); err != nil {
    47  		t.Fatal(err)
    48  	}
    49  
    50  	stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  	if err := stream.Close(); err != nil {
    55  		t.Fatal(err)
    56  	}
    57  }
    58  
    59  // TestNewStream_OnlyFull tests that the handler gets the full
    60  // node information communicated correctly.
    61  func TestNewStream_OnlyFull(t *testing.T) {
    62  	t.Parallel()
    63  
    64  	ctx, cancel := context.WithCancel(context.Background())
    65  	defer cancel()
    66  
    67  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
    68  		FullNode: true,
    69  	}})
    70  
    71  	s2, _ := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
    72  		FullNode: true,
    73  	}})
    74  
    75  	if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
    76  		if !p.FullNode {
    77  			t.Error("expected full node")
    78  		}
    79  		return nil
    80  	})); err != nil {
    81  		t.Fatal(err)
    82  	}
    83  
    84  	addr := serviceUnderlayAddress(t, s1)
    85  
    86  	if _, err := s2.Connect(ctx, addr); err != nil {
    87  		t.Fatal(err)
    88  	}
    89  
    90  	stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
    91  	if err != nil {
    92  		t.Fatal(err)
    93  	}
    94  	if err := stream.Close(); err != nil {
    95  		t.Fatal(err)
    96  	}
    97  }
    98  
    99  // TestNewStream_Mixed tests that the handler gets the full
   100  // node information communicated correctly for light node
   101  func TestNewStream_Mixed(t *testing.T) {
   102  	t.Parallel()
   103  
   104  	ctx, cancel := context.WithCancel(context.Background())
   105  	defer cancel()
   106  
   107  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
   108  		FullNode: true,
   109  	}})
   110  
   111  	s2, _ := newService(t, 1, libp2pServiceOpts{})
   112  
   113  	if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
   114  		if p.FullNode {
   115  			t.Error("expected light node")
   116  		}
   117  		return nil
   118  	})); err != nil {
   119  		t.Fatal(err)
   120  	}
   121  
   122  	addr := serviceUnderlayAddress(t, s1)
   123  
   124  	if _, err := s2.Connect(ctx, addr); err != nil {
   125  		t.Fatal(err)
   126  	}
   127  
   128  	stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  	if err := stream.Close(); err != nil {
   133  		t.Fatal(err)
   134  	}
   135  }
   136  
   137  // TestNewStreamMulti is a regression test to see that we trigger
   138  // the right handler when multiple streams are registered under
   139  // a single protocol.
   140  func TestNewStreamMulti(t *testing.T) {
   141  	t.Parallel()
   142  
   143  	ctx, cancel := context.WithCancel(context.Background())
   144  	defer cancel()
   145  
   146  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
   147  		FullNode: true,
   148  	}})
   149  
   150  	var (
   151  		h1calls, h2calls int32
   152  		h1               = func(_ context.Context, p p2p.Peer, s p2p.Stream) error {
   153  			defer s.Close()
   154  			_ = atomic.AddInt32(&h1calls, 1)
   155  			return nil
   156  		}
   157  		h2 = func(_ context.Context, p p2p.Peer, s p2p.Stream) error {
   158  			defer s.Close()
   159  			_ = atomic.AddInt32(&h2calls, 1)
   160  			return nil
   161  		}
   162  	)
   163  	s2, _ := newService(t, 1, libp2pServiceOpts{})
   164  
   165  	if err := s1.AddProtocol(newTestMultiProtocol(h1, h2)); err != nil {
   166  		t.Fatal(err)
   167  	}
   168  
   169  	addr := serviceUnderlayAddress(t, s1)
   170  
   171  	if _, err := s2.Connect(ctx, addr); err != nil {
   172  		t.Fatal(err)
   173  	}
   174  
   175  	stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
   176  	if err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	if err := stream.FullClose(); err != nil {
   180  		t.Fatal(err)
   181  	}
   182  	if atomic.LoadInt32(&h1calls) != 1 {
   183  		t.Fatal("handler should have been called but wasn't")
   184  	}
   185  	if atomic.LoadInt32(&h2calls) > 0 {
   186  		t.Fatal("handler should not have been called")
   187  	}
   188  }
   189  
   190  func TestNewStream_errNotSupported(t *testing.T) {
   191  	t.Parallel()
   192  
   193  	ctx, cancel := context.WithCancel(context.Background())
   194  	defer cancel()
   195  
   196  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
   197  		FullNode: true,
   198  	}})
   199  
   200  	s2, _ := newService(t, 1, libp2pServiceOpts{})
   201  
   202  	addr := serviceUnderlayAddress(t, s1)
   203  
   204  	// connect nodes
   205  	if _, err := s2.Connect(ctx, addr); err != nil {
   206  		t.Fatal(err)
   207  	}
   208  
   209  	// test for missing protocol
   210  	_, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
   211  	expectErrNotSupported(t, err)
   212  
   213  	// add protocol
   214  	if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
   215  		return nil
   216  	})); err != nil {
   217  		t.Fatal(err)
   218  	}
   219  
   220  	// test for incorrect protocol name
   221  	_, err = s2.NewStream(ctx, overlay1, nil, testProtocolName+"invalid", testProtocolVersion, testStreamName)
   222  	expectErrNotSupported(t, err)
   223  
   224  	// test for incorrect stream name
   225  	_, err = s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName+"invalid")
   226  	expectErrNotSupported(t, err)
   227  }
   228  
   229  func TestNewStream_semanticVersioning(t *testing.T) {
   230  	t.Parallel()
   231  
   232  	ctx, cancel := context.WithCancel(context.Background())
   233  	defer cancel()
   234  
   235  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
   236  		FullNode: true,
   237  	}})
   238  
   239  	s2, _ := newService(t, 1, libp2pServiceOpts{})
   240  
   241  	addr := serviceUnderlayAddress(t, s1)
   242  
   243  	if _, err := s2.Connect(ctx, addr); err != nil {
   244  		t.Fatal(err)
   245  	}
   246  
   247  	if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
   248  		return nil
   249  	})); err != nil {
   250  		t.Fatal(err)
   251  	}
   252  
   253  	for _, tc := range []struct {
   254  		version   string
   255  		supported bool
   256  	}{
   257  		{version: "0", supported: false},
   258  		{version: "1", supported: false},
   259  		{version: "2", supported: false},
   260  		{version: "3", supported: false},
   261  		{version: "4", supported: false},
   262  		{version: "a", supported: false},
   263  		{version: "invalid", supported: false},
   264  		{version: "0.0.0", supported: false},
   265  		{version: "0.1.0", supported: false},
   266  		{version: "1.0.0", supported: false},
   267  		{version: "2.0.0", supported: true},
   268  		{version: "2.2.0", supported: true},
   269  		{version: "2.3.0", supported: true},
   270  		{version: "2.3.1", supported: true},
   271  		{version: "2.3.4", supported: true},
   272  		{version: "2.3.5", supported: true},
   273  		{version: "2.3.5-beta", supported: true},
   274  		{version: "2.3.5+beta", supported: true},
   275  		{version: "2.3.6", supported: true},
   276  		{version: "2.3.6-beta", supported: true},
   277  		{version: "2.3.6+beta", supported: true},
   278  		{version: "2.4.0", supported: false},
   279  		{version: "3.0.0", supported: false},
   280  	} {
   281  		_, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, tc.version, testStreamName)
   282  		if tc.supported {
   283  			if err != nil {
   284  				t.Fatal(err)
   285  			}
   286  		} else {
   287  			expectErrNotSupported(t, err)
   288  		}
   289  	}
   290  }
   291  
   292  func TestDisconnectError(t *testing.T) {
   293  	t.Parallel()
   294  
   295  	ctx, cancel := context.WithCancel(context.Background())
   296  	defer cancel()
   297  
   298  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
   299  		FullNode: true,
   300  	}})
   301  
   302  	s2, overlay2 := newService(t, 1, libp2pServiceOpts{})
   303  
   304  	if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
   305  		return p2p.NewDisconnectError(errors.New("test error"))
   306  	})); err != nil {
   307  		t.Fatal(err)
   308  	}
   309  
   310  	addr := serviceUnderlayAddress(t, s1)
   311  
   312  	if _, err := s2.Connect(ctx, addr); err != nil {
   313  		t.Fatal(err)
   314  	}
   315  
   316  	expectPeers(t, s1, overlay2)
   317  
   318  	// error is not checked as opening a new stream should cause disconnect from s1 which is async and can make errors in newStream function
   319  	// it is important to validate that disconnect will happen after NewStream()
   320  	_, _ = s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
   321  	expectPeersEventually(t, s1)
   322  }
   323  
   324  func TestConnectDisconnectEvents(t *testing.T) {
   325  	t.Parallel()
   326  
   327  	ctx, cancel := context.WithCancel(context.Background())
   328  	defer cancel()
   329  
   330  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
   331  		FullNode: true,
   332  	}})
   333  
   334  	s2, _ := newService(t, 1, libp2pServiceOpts{})
   335  	testProtocol := newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
   336  		return nil
   337  	})
   338  
   339  	cinCount, coutCount, dinCount, doutCount := 0, 0, 0, 0
   340  	var countMU sync.Mutex
   341  
   342  	testProtocol.ConnectIn = func(c context.Context, p p2p.Peer) error {
   343  		countMU.Lock()
   344  		cinCount++
   345  		countMU.Unlock()
   346  		return nil
   347  	}
   348  
   349  	testProtocol.ConnectOut = func(c context.Context, p p2p.Peer) error {
   350  		countMU.Lock()
   351  		coutCount++
   352  		countMU.Unlock()
   353  		return nil
   354  	}
   355  
   356  	testProtocol.DisconnectIn = func(p p2p.Peer) error {
   357  		countMU.Lock()
   358  		dinCount++
   359  		countMU.Unlock()
   360  		return nil
   361  	}
   362  
   363  	testProtocol.DisconnectOut = func(p p2p.Peer) error {
   364  		countMU.Lock()
   365  		doutCount++
   366  		countMU.Unlock()
   367  		return nil
   368  	}
   369  
   370  	if err := s1.AddProtocol(testProtocol); err != nil {
   371  		t.Fatal(err)
   372  	}
   373  
   374  	if err := s2.AddProtocol(testProtocol); err != nil {
   375  		t.Fatal(err)
   376  	}
   377  
   378  	addr := serviceUnderlayAddress(t, s1)
   379  
   380  	if _, err := s2.Connect(ctx, addr); err != nil {
   381  		t.Fatal(err)
   382  	}
   383  
   384  	expectCounter(t, &cinCount, 1, &countMU)
   385  	expectCounter(t, &coutCount, 1, &countMU)
   386  	expectCounter(t, &dinCount, 0, &countMU)
   387  	expectCounter(t, &doutCount, 0, &countMU)
   388  
   389  	if err := s2.Disconnect(overlay1, "test disconnect"); err != nil {
   390  		t.Fatal(err)
   391  	}
   392  
   393  	cinCount = 0
   394  	coutCount = 0
   395  
   396  	expectCounter(t, &cinCount, 0, &countMU)
   397  	expectCounter(t, &coutCount, 0, &countMU)
   398  	expectCounter(t, &dinCount, 1, &countMU)
   399  	expectCounter(t, &doutCount, 1, &countMU)
   400  
   401  }
   402  
   403  func TestPing(t *testing.T) {
   404  	t.Parallel()
   405  
   406  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   407  	defer cancel()
   408  
   409  	s1, _ := newService(t, 1, libp2pServiceOpts{
   410  		libp2pOpts: libp2p.WithHostFactory(
   411  			func(...libp2pm.Option) (host.Host, error) {
   412  				return bhost.NewHost(swarmt.GenSwarm(t), &bhost.HostOpts{EnablePing: true})
   413  			},
   414  		),
   415  	})
   416  
   417  	s2, _ := newService(t, 1, libp2pServiceOpts{
   418  		libp2pOpts: libp2p.WithHostFactory(
   419  			func(...libp2pm.Option) (host.Host, error) {
   420  				host, err := bhost.NewHost(swarmt.GenSwarm(t), &bhost.HostOpts{EnablePing: true})
   421  				if err != nil {
   422  					t.Fatalf("start host: %v", err)
   423  				}
   424  				host.Start()
   425  				return host, nil
   426  			},
   427  		),
   428  	})
   429  
   430  	addr := serviceUnderlayAddress(t, s1)
   431  
   432  	if _, err := s2.Ping(ctx, addr); err != nil {
   433  		t.Fatal(err)
   434  	}
   435  }
   436  
   437  const (
   438  	testProtocolName     = "testing"
   439  	testProtocolVersion  = "2.3.4"
   440  	testStreamName       = "messages"
   441  	testSecondStreamName = "cookies"
   442  )
   443  
   444  func newTestProtocol(h p2p.HandlerFunc) p2p.ProtocolSpec {
   445  	return p2p.ProtocolSpec{
   446  		Name:    testProtocolName,
   447  		Version: testProtocolVersion,
   448  		StreamSpecs: []p2p.StreamSpec{
   449  			{
   450  				Name:    testStreamName,
   451  				Handler: h,
   452  			},
   453  		},
   454  	}
   455  }
   456  
   457  func newTestMultiProtocol(h1, h2 p2p.HandlerFunc) p2p.ProtocolSpec {
   458  	return p2p.ProtocolSpec{
   459  		Name:    testProtocolName,
   460  		Version: testProtocolVersion,
   461  		StreamSpecs: []p2p.StreamSpec{
   462  			{
   463  				Name:    testStreamName,
   464  				Handler: h1,
   465  			},
   466  			{
   467  				Name:    testSecondStreamName,
   468  				Handler: h2,
   469  			},
   470  		},
   471  	}
   472  }
   473  
   474  func expectErrNotSupported(t *testing.T, err error) {
   475  	t.Helper()
   476  	if e := (*p2p.IncompatibleStreamError)(nil); !errors.As(err, &e) {
   477  		t.Fatalf("got error %v, want %T", err, e)
   478  	}
   479  	var e2 multistream.ErrNotSupported[protocol.ID]
   480  	if !errors.As(err, &e2) {
   481  		t.Fatalf("got error %v, want %v", err, &e2)
   482  	}
   483  }
   484  
   485  func expectCounter(t *testing.T, c *int, expected int, mtx *sync.Mutex) {
   486  	t.Helper()
   487  
   488  	err := spinlock.Wait(time.Second, func() bool {
   489  		mtx.Lock()
   490  		defer mtx.Unlock()
   491  		return *c == expected
   492  	})
   493  	if err != nil {
   494  		t.Fatal("timed out waiting for counter to be set")
   495  	}
   496  }