github.com/ethersphere/bee/v2@v2.2.0/pkg/p2p/libp2p/internal/handshake/handshake_test.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package handshake_test
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"errors"
    11  	"fmt"
    12  	"testing"
    13  
    14  	"github.com/ethereum/go-ethereum/common"
    15  	"github.com/ethersphere/bee/v2/pkg/bzz"
    16  	"github.com/ethersphere/bee/v2/pkg/crypto"
    17  	"github.com/ethersphere/bee/v2/pkg/log"
    18  	"github.com/ethersphere/bee/v2/pkg/p2p"
    19  	"github.com/ethersphere/bee/v2/pkg/p2p/libp2p/internal/handshake"
    20  	"github.com/ethersphere/bee/v2/pkg/p2p/libp2p/internal/handshake/mock"
    21  	"github.com/ethersphere/bee/v2/pkg/p2p/libp2p/internal/handshake/pb"
    22  	"github.com/ethersphere/bee/v2/pkg/p2p/protobuf"
    23  
    24  	libp2ppeer "github.com/libp2p/go-libp2p/core/peer"
    25  	ma "github.com/multiformats/go-multiaddr"
    26  )
    27  
    28  //nolint:paralleltest
    29  func TestHandshake(t *testing.T) {
    30  	const (
    31  		testWelcomeMessage = "HelloWorld"
    32  	)
    33  
    34  	logger := log.Noop
    35  	networkID := uint64(3)
    36  
    37  	node1ma, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/1634/p2p/16Uiu2HAkx8ULY8cTXhdVAcMmLcH9AsTKz6uBQ7DPLKRjMLgBVYkA")
    38  	if err != nil {
    39  		t.Fatal(err)
    40  	}
    41  	node2ma, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/1634/p2p/16Uiu2HAkx8ULY8cTXhdVAcMmLcH9AsTKz6uBQ7DPLKRjMLgBVYkS")
    42  	if err != nil {
    43  		t.Fatal(err)
    44  	}
    45  	node1maBinary, err := node1ma.MarshalBinary()
    46  	if err != nil {
    47  		t.Fatal(err)
    48  	}
    49  	node2maBinary, err := node2ma.MarshalBinary()
    50  	if err != nil {
    51  		t.Fatal(err)
    52  	}
    53  	node1AddrInfo, err := libp2ppeer.AddrInfoFromP2pAddr(node1ma)
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  	node2AddrInfo, err := libp2ppeer.AddrInfoFromP2pAddr(node2ma)
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  
    62  	privateKey1, err := crypto.GenerateSecp256k1Key()
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	privateKey2, err := crypto.GenerateSecp256k1Key()
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  
    71  	nonce := common.HexToHash("0x1").Bytes()
    72  
    73  	signer1 := crypto.NewDefaultSigner(privateKey1)
    74  	signer2 := crypto.NewDefaultSigner(privateKey2)
    75  	addr, err := crypto.NewOverlayAddress(privateKey1.PublicKey, networkID, nonce)
    76  	if err != nil {
    77  		t.Fatal(err)
    78  	}
    79  	node1BzzAddress, err := bzz.NewAddress(signer1, node1ma, addr, networkID, nonce)
    80  	if err != nil {
    81  		t.Fatal(err)
    82  	}
    83  	addr2, err := crypto.NewOverlayAddress(privateKey2.PublicKey, networkID, nonce)
    84  	if err != nil {
    85  		t.Fatal(err)
    86  	}
    87  	node2BzzAddress, err := bzz.NewAddress(signer2, node2ma, addr2, networkID, nonce)
    88  	if err != nil {
    89  		t.Fatal(err)
    90  	}
    91  
    92  	node1Info := handshake.Info{
    93  		BzzAddress: node1BzzAddress,
    94  		FullNode:   true,
    95  	}
    96  	node2Info := handshake.Info{
    97  		BzzAddress: node2BzzAddress,
    98  		FullNode:   true,
    99  	}
   100  
   101  	aaddresser := &AdvertisableAddresserMock{}
   102  
   103  	handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nonce, testWelcomeMessage, true, node1AddrInfo.ID, logger)
   104  	if err != nil {
   105  		t.Fatal(err)
   106  	}
   107  
   108  	t.Run("Handshake - OK", func(t *testing.T) {
   109  		var buffer1 bytes.Buffer
   110  		var buffer2 bytes.Buffer
   111  		stream1 := mock.NewStream(&buffer1, &buffer2)
   112  		stream2 := mock.NewStream(&buffer2, &buffer1)
   113  
   114  		w, r := protobuf.NewWriterAndReader(stream2)
   115  		if err := w.WriteMsg(&pb.SynAck{
   116  			Syn: &pb.Syn{
   117  				ObservedUnderlay: node1maBinary,
   118  			},
   119  			Ack: &pb.Ack{
   120  				Address: &pb.BzzAddress{
   121  					Underlay:  node2maBinary,
   122  					Overlay:   node2BzzAddress.Overlay.Bytes(),
   123  					Signature: node2BzzAddress.Signature,
   124  				},
   125  				NetworkID:      networkID,
   126  				FullNode:       true,
   127  				Nonce:          nonce,
   128  				WelcomeMessage: testWelcomeMessage,
   129  			},
   130  		}); err != nil {
   131  			t.Fatal(err)
   132  		}
   133  		if err != nil {
   134  			t.Fatal(err)
   135  		}
   136  
   137  		res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   138  		if err != nil {
   139  			t.Fatal(err)
   140  		}
   141  
   142  		testInfo(t, *res, node2Info)
   143  
   144  		var syn pb.Syn
   145  		if err := r.ReadMsg(&syn); err != nil {
   146  			t.Fatal(err)
   147  		}
   148  
   149  		if !bytes.Equal(syn.ObservedUnderlay, node2maBinary) {
   150  			t.Fatal("bad syn")
   151  		}
   152  
   153  		var ack pb.Ack
   154  		if err := r.ReadMsg(&ack); err != nil {
   155  			t.Fatal(err)
   156  		}
   157  
   158  		if !bytes.Equal(ack.Address.Overlay, node1BzzAddress.Overlay.Bytes()) {
   159  			t.Fatal("bad ack - overlay")
   160  		}
   161  		if !bytes.Equal(ack.Address.Underlay, node1maBinary) {
   162  			t.Fatal("bad ack - underlay")
   163  		}
   164  		if !bytes.Equal(ack.Address.Signature, node1BzzAddress.Signature) {
   165  			t.Fatal("bad ack - signature")
   166  		}
   167  		if ack.NetworkID != networkID {
   168  			t.Fatal("bad ack - networkID")
   169  		}
   170  		if ack.FullNode != true {
   171  			t.Fatal("bad ack - full node")
   172  		}
   173  
   174  		if ack.WelcomeMessage != testWelcomeMessage {
   175  			t.Fatalf("Bad ack welcome message: want %s, got %s", testWelcomeMessage, ack.WelcomeMessage)
   176  		}
   177  	})
   178  
   179  	t.Run("Handshake - picker error", func(t *testing.T) {
   180  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nonce, "", true, node1AddrInfo.ID, logger)
   181  		if err != nil {
   182  			t.Fatal(err)
   183  		}
   184  
   185  		handshakeService.SetPicker(mockPicker(func(p p2p.Peer) bool { return false }))
   186  
   187  		var buffer1 bytes.Buffer
   188  		var buffer2 bytes.Buffer
   189  		stream1 := mock.NewStream(&buffer1, &buffer2)
   190  		stream2 := mock.NewStream(&buffer2, &buffer1)
   191  
   192  		w := protobuf.NewWriter(stream2)
   193  		if err := w.WriteMsg(&pb.Syn{
   194  			ObservedUnderlay: node1maBinary,
   195  		}); err != nil {
   196  			t.Fatal(err)
   197  		}
   198  
   199  		if err := w.WriteMsg(&pb.Ack{
   200  			Address: &pb.BzzAddress{
   201  				Underlay:  node2maBinary,
   202  				Overlay:   node2BzzAddress.Overlay.Bytes(),
   203  				Signature: node2BzzAddress.Signature,
   204  			},
   205  			NetworkID: networkID,
   206  			Nonce:     nonce,
   207  			FullNode:  true,
   208  		}); err != nil {
   209  			t.Fatal(err)
   210  		}
   211  
   212  		_, err = handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   213  		expectedErr := handshake.ErrPicker
   214  		if !errors.Is(err, expectedErr) {
   215  			t.Fatal("expected:", expectedErr, "got:", err)
   216  		}
   217  	})
   218  
   219  	t.Run("Handshake - welcome message too long", func(t *testing.T) {
   220  		const LongMessage = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi consectetur urna ut lorem sollicitudin posuere. Donec sagittis laoreet sapien."
   221  
   222  		expectedErr := handshake.ErrWelcomeMessageLength
   223  		_, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, LongMessage, true, node1AddrInfo.ID, logger)
   224  		if err == nil || err.Error() != expectedErr.Error() {
   225  			t.Fatal("expected:", expectedErr, "got:", err)
   226  		}
   227  	})
   228  
   229  	t.Run("Handshake - dynamic welcome message too long", func(t *testing.T) {
   230  		const LongMessage = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi consectetur urna ut lorem sollicitudin posuere. Donec sagittis laoreet sapien."
   231  
   232  		expectedErr := handshake.ErrWelcomeMessageLength
   233  		err := handshakeService.SetWelcomeMessage(LongMessage)
   234  		if err == nil || err.Error() != expectedErr.Error() {
   235  			t.Fatal("expected:", expectedErr, "got:", err)
   236  		}
   237  	})
   238  
   239  	t.Run("Handshake - set welcome message", func(t *testing.T) {
   240  		const TestMessage = "Hi im the new test message"
   241  
   242  		err := handshakeService.SetWelcomeMessage(TestMessage)
   243  		if err != nil {
   244  			t.Fatal("Got error:", err)
   245  		}
   246  		got := handshakeService.GetWelcomeMessage()
   247  		if got != TestMessage {
   248  			t.Fatal("expected:", TestMessage, ", got:", got)
   249  		}
   250  	})
   251  
   252  	t.Run("Handshake - Syn write error", func(t *testing.T) {
   253  		testErr := errors.New("test error")
   254  		expectedErr := fmt.Errorf("write syn message: %w", testErr)
   255  		stream := &mock.Stream{}
   256  		stream.SetWriteErr(testErr, 0)
   257  		res, err := handshakeService.Handshake(context.Background(), stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   258  		if err == nil || err.Error() != expectedErr.Error() {
   259  			t.Fatal("expected:", expectedErr, "got:", err)
   260  		}
   261  
   262  		if res != nil {
   263  			t.Fatal("handshake returned non-nil res")
   264  		}
   265  	})
   266  
   267  	t.Run("Handshake - Syn read error", func(t *testing.T) {
   268  		testErr := errors.New("test error")
   269  		expectedErr := fmt.Errorf("read synack message: %w", testErr)
   270  		stream := mock.NewStream(nil, &bytes.Buffer{})
   271  		stream.SetReadErr(testErr, 0)
   272  		res, err := handshakeService.Handshake(context.Background(), stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   273  		if err == nil || err.Error() != expectedErr.Error() {
   274  			t.Fatal("expected:", expectedErr, "got:", err)
   275  		}
   276  
   277  		if res != nil {
   278  			t.Fatal("handshake returned non-nil res")
   279  		}
   280  	})
   281  
   282  	t.Run("Handshake - ack write error", func(t *testing.T) {
   283  		testErr := errors.New("test error")
   284  		expectedErr := fmt.Errorf("write ack message: %w", testErr)
   285  		var buffer1 bytes.Buffer
   286  		var buffer2 bytes.Buffer
   287  		stream1 := mock.NewStream(&buffer1, &buffer2)
   288  		stream1.SetWriteErr(testErr, 1)
   289  		stream2 := mock.NewStream(&buffer2, &buffer1)
   290  
   291  		w := protobuf.NewWriter(stream2)
   292  		if err := w.WriteMsg(&pb.SynAck{
   293  			Syn: &pb.Syn{
   294  				ObservedUnderlay: node1maBinary,
   295  			},
   296  			Ack: &pb.Ack{
   297  				Address: &pb.BzzAddress{
   298  					Underlay:  node2maBinary,
   299  					Overlay:   node2BzzAddress.Overlay.Bytes(),
   300  					Signature: node2BzzAddress.Signature,
   301  				},
   302  				Nonce:     nonce,
   303  				NetworkID: networkID,
   304  				FullNode:  true,
   305  			},
   306  		},
   307  		); err != nil {
   308  			t.Fatal(err)
   309  		}
   310  
   311  		res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   312  		if err == nil || err.Error() != expectedErr.Error() {
   313  			t.Fatal("expected:", expectedErr, "got:", err)
   314  		}
   315  
   316  		if res != nil {
   317  			t.Fatal("handshake returned non-nil res")
   318  		}
   319  	})
   320  
   321  	t.Run("Handshake - networkID mismatch", func(t *testing.T) {
   322  		var buffer1 bytes.Buffer
   323  		var buffer2 bytes.Buffer
   324  		stream1 := mock.NewStream(&buffer1, &buffer2)
   325  		stream2 := mock.NewStream(&buffer2, &buffer1)
   326  
   327  		w := protobuf.NewWriter(stream2)
   328  		if err := w.WriteMsg(&pb.SynAck{
   329  			Syn: &pb.Syn{
   330  				ObservedUnderlay: node1maBinary,
   331  			},
   332  			Ack: &pb.Ack{
   333  				Address: &pb.BzzAddress{
   334  					Underlay:  node2maBinary,
   335  					Overlay:   node2BzzAddress.Overlay.Bytes(),
   336  					Signature: node2BzzAddress.Signature,
   337  				},
   338  				NetworkID: 5,
   339  				FullNode:  true,
   340  			},
   341  		}); err != nil {
   342  			t.Fatal(err)
   343  		}
   344  
   345  		res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   346  		if res != nil {
   347  			t.Fatal("res should be nil")
   348  		}
   349  
   350  		if !errors.Is(err, handshake.ErrNetworkIDIncompatible) {
   351  			t.Fatalf("expected %s, got %s", handshake.ErrNetworkIDIncompatible, err)
   352  		}
   353  	})
   354  
   355  	t.Run("Handshake - invalid ack", func(t *testing.T) {
   356  		var buffer1 bytes.Buffer
   357  		var buffer2 bytes.Buffer
   358  		stream1 := mock.NewStream(&buffer1, &buffer2)
   359  		stream2 := mock.NewStream(&buffer2, &buffer1)
   360  
   361  		w := protobuf.NewWriter(stream2)
   362  		if err := w.WriteMsg(&pb.SynAck{
   363  			Syn: &pb.Syn{
   364  				ObservedUnderlay: node1maBinary,
   365  			},
   366  			Ack: &pb.Ack{
   367  				Address: &pb.BzzAddress{
   368  					Underlay:  node2maBinary,
   369  					Overlay:   node2BzzAddress.Overlay.Bytes(),
   370  					Signature: node1BzzAddress.Signature,
   371  				},
   372  				NetworkID: networkID,
   373  				FullNode:  true,
   374  			},
   375  		}); err != nil {
   376  			t.Fatal(err)
   377  		}
   378  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nonce, testWelcomeMessage, true, node1AddrInfo.ID, logger)
   379  		if err != nil {
   380  			t.Fatal(err)
   381  		}
   382  		res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   383  		if res != nil {
   384  			t.Fatal("res should be nil")
   385  		}
   386  
   387  		if !errors.Is(err, handshake.ErrInvalidAck) {
   388  			t.Fatalf("expected %s, got %s", handshake.ErrInvalidAck, err)
   389  		}
   390  	})
   391  
   392  	t.Run("Handshake - error advertisable address", func(t *testing.T) {
   393  		var buffer1 bytes.Buffer
   394  		var buffer2 bytes.Buffer
   395  		stream1 := mock.NewStream(&buffer1, &buffer2)
   396  		stream2 := mock.NewStream(&buffer2, &buffer1)
   397  
   398  		testError := errors.New("test error")
   399  		aaddresser.err = testError
   400  		defer func() {
   401  			aaddresser.err = nil
   402  		}()
   403  
   404  		w, _ := protobuf.NewWriterAndReader(stream2)
   405  		if err := w.WriteMsg(&pb.SynAck{
   406  			Syn: &pb.Syn{
   407  				ObservedUnderlay: node1maBinary,
   408  			},
   409  			Ack: &pb.Ack{
   410  				Address: &pb.BzzAddress{
   411  					Underlay:  node2maBinary,
   412  					Overlay:   node2BzzAddress.Overlay.Bytes(),
   413  					Signature: node2BzzAddress.Signature,
   414  				},
   415  				NetworkID: networkID,
   416  				FullNode:  true,
   417  			},
   418  		}); err != nil {
   419  			t.Fatal(err)
   420  		}
   421  
   422  		res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   423  		if !errors.Is(err, testError) {
   424  			t.Fatalf("expected error %v got %v", testError, err)
   425  
   426  		}
   427  
   428  		if res != nil {
   429  			t.Fatal("expected nil res")
   430  		}
   431  
   432  	})
   433  
   434  	t.Run("Handle - OK", func(t *testing.T) {
   435  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nonce, "", true, node1AddrInfo.ID, logger)
   436  		if err != nil {
   437  			t.Fatal(err)
   438  		}
   439  		var buffer1 bytes.Buffer
   440  		var buffer2 bytes.Buffer
   441  		stream1 := mock.NewStream(&buffer1, &buffer2)
   442  		stream2 := mock.NewStream(&buffer2, &buffer1)
   443  
   444  		w := protobuf.NewWriter(stream2)
   445  		if err := w.WriteMsg(&pb.Syn{
   446  			ObservedUnderlay: node1maBinary,
   447  		}); err != nil {
   448  			t.Fatal(err)
   449  		}
   450  
   451  		if err := w.WriteMsg(&pb.Ack{
   452  			Address: &pb.BzzAddress{
   453  				Underlay:  node2maBinary,
   454  				Overlay:   node2BzzAddress.Overlay.Bytes(),
   455  				Signature: node2BzzAddress.Signature,
   456  			},
   457  			NetworkID: networkID,
   458  			Nonce:     nonce,
   459  			FullNode:  true,
   460  		}); err != nil {
   461  			t.Fatal(err)
   462  		}
   463  
   464  		res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   465  		if err != nil {
   466  			t.Fatal(err)
   467  		}
   468  
   469  		testInfo(t, *res, node2Info)
   470  
   471  		_, r := protobuf.NewWriterAndReader(stream2)
   472  		var got pb.SynAck
   473  		if err := r.ReadMsg(&got); err != nil {
   474  			t.Fatal(err)
   475  		}
   476  
   477  		if !bytes.Equal(got.Syn.ObservedUnderlay, node2maBinary) {
   478  			t.Fatalf("got bad syn")
   479  		}
   480  
   481  		bzzAddress, err := bzz.ParseAddress(got.Ack.Address.Underlay, got.Ack.Address.Overlay, got.Ack.Address.Signature, got.Ack.Nonce, true, got.Ack.NetworkID)
   482  		if err != nil {
   483  			t.Fatal(err)
   484  		}
   485  
   486  		testInfo(t, node1Info, handshake.Info{
   487  			BzzAddress: bzzAddress,
   488  			FullNode:   got.Ack.FullNode,
   489  		})
   490  	})
   491  
   492  	t.Run("Handle - read error ", func(t *testing.T) {
   493  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger)
   494  		if err != nil {
   495  			t.Fatal(err)
   496  		}
   497  		testErr := errors.New("test error")
   498  		expectedErr := fmt.Errorf("read syn message: %w", testErr)
   499  		stream := &mock.Stream{}
   500  		stream.SetReadErr(testErr, 0)
   501  		res, err := handshakeService.Handle(context.Background(), stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   502  		if err == nil || err.Error() != expectedErr.Error() {
   503  			t.Fatal("expected:", expectedErr, "got:", err)
   504  		}
   505  
   506  		if res != nil {
   507  			t.Fatal("handle returned non-nil res")
   508  		}
   509  	})
   510  
   511  	t.Run("Handle - write error ", func(t *testing.T) {
   512  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger)
   513  		if err != nil {
   514  			t.Fatal(err)
   515  		}
   516  		testErr := errors.New("test error")
   517  		expectedErr := fmt.Errorf("write synack message: %w", testErr)
   518  		var buffer bytes.Buffer
   519  		stream := mock.NewStream(&buffer, &buffer)
   520  		stream.SetWriteErr(testErr, 1)
   521  		w := protobuf.NewWriter(stream)
   522  		if err := w.WriteMsg(&pb.Syn{
   523  			ObservedUnderlay: node1maBinary,
   524  		}); err != nil {
   525  			t.Fatal(err)
   526  		}
   527  
   528  		res, err := handshakeService.Handle(context.Background(), stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   529  		if err == nil || err.Error() != expectedErr.Error() {
   530  			t.Fatal("expected:", expectedErr, "got:", err)
   531  		}
   532  
   533  		if res != nil {
   534  			t.Fatal("handshake returned non-nil res")
   535  		}
   536  	})
   537  
   538  	t.Run("Handle - ack read error ", func(t *testing.T) {
   539  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger)
   540  		if err != nil {
   541  			t.Fatal(err)
   542  		}
   543  		testErr := errors.New("test error")
   544  		expectedErr := fmt.Errorf("read ack message: %w", testErr)
   545  		var buffer1 bytes.Buffer
   546  		var buffer2 bytes.Buffer
   547  		stream1 := mock.NewStream(&buffer1, &buffer2)
   548  		stream2 := mock.NewStream(&buffer2, &buffer1)
   549  		stream1.SetReadErr(testErr, 1)
   550  		w := protobuf.NewWriter(stream2)
   551  		if err := w.WriteMsg(&pb.Syn{
   552  			ObservedUnderlay: node1maBinary,
   553  		}); err != nil {
   554  			t.Fatal(err)
   555  		}
   556  
   557  		res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   558  		if err == nil || err.Error() != expectedErr.Error() {
   559  			t.Fatal("expected:", expectedErr, "got:", err)
   560  		}
   561  
   562  		if res != nil {
   563  			t.Fatal("handshake returned non-nil res")
   564  		}
   565  	})
   566  
   567  	t.Run("Handle - networkID mismatch ", func(t *testing.T) {
   568  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger)
   569  		if err != nil {
   570  			t.Fatal(err)
   571  		}
   572  		var buffer1 bytes.Buffer
   573  		var buffer2 bytes.Buffer
   574  		stream1 := mock.NewStream(&buffer1, &buffer2)
   575  		stream2 := mock.NewStream(&buffer2, &buffer1)
   576  
   577  		w := protobuf.NewWriter(stream2)
   578  		if err := w.WriteMsg(&pb.Syn{
   579  			ObservedUnderlay: node1maBinary,
   580  		}); err != nil {
   581  			t.Fatal(err)
   582  		}
   583  
   584  		if err := w.WriteMsg(&pb.Ack{
   585  			Address: &pb.BzzAddress{
   586  				Underlay:  node2maBinary,
   587  				Overlay:   node2BzzAddress.Overlay.Bytes(),
   588  				Signature: node2BzzAddress.Signature,
   589  			},
   590  			NetworkID: 5,
   591  			FullNode:  true,
   592  		}); err != nil {
   593  			t.Fatal(err)
   594  		}
   595  
   596  		res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   597  		if res != nil {
   598  			t.Fatal("res should be nil")
   599  		}
   600  
   601  		if !errors.Is(err, handshake.ErrNetworkIDIncompatible) {
   602  			t.Fatalf("expected %s, got %s", handshake.ErrNetworkIDIncompatible, err)
   603  		}
   604  	})
   605  
   606  	t.Run("Handle - invalid ack", func(t *testing.T) {
   607  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger)
   608  		if err != nil {
   609  			t.Fatal(err)
   610  		}
   611  		var buffer1 bytes.Buffer
   612  		var buffer2 bytes.Buffer
   613  		stream1 := mock.NewStream(&buffer1, &buffer2)
   614  		stream2 := mock.NewStream(&buffer2, &buffer1)
   615  
   616  		w := protobuf.NewWriter(stream2)
   617  		if err := w.WriteMsg(&pb.Syn{
   618  			ObservedUnderlay: node1maBinary,
   619  		}); err != nil {
   620  			t.Fatal(err)
   621  		}
   622  
   623  		if err := w.WriteMsg(&pb.Ack{
   624  			Address: &pb.BzzAddress{
   625  				Underlay:  node2maBinary,
   626  				Overlay:   node2BzzAddress.Overlay.Bytes(),
   627  				Signature: node1BzzAddress.Signature,
   628  			},
   629  			NetworkID: networkID,
   630  			FullNode:  true,
   631  		}); err != nil {
   632  			t.Fatal(err)
   633  		}
   634  
   635  		_, err = handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   636  		if !errors.Is(err, handshake.ErrInvalidAck) {
   637  			t.Fatalf("expected %s, got %v", handshake.ErrInvalidAck, err)
   638  		}
   639  	})
   640  
   641  	t.Run("Handle - advertisable error", func(t *testing.T) {
   642  		handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger)
   643  		if err != nil {
   644  			t.Fatal(err)
   645  		}
   646  		var buffer1 bytes.Buffer
   647  		var buffer2 bytes.Buffer
   648  		stream1 := mock.NewStream(&buffer1, &buffer2)
   649  		stream2 := mock.NewStream(&buffer2, &buffer1)
   650  
   651  		testError := errors.New("test error")
   652  		aaddresser.err = testError
   653  		defer func() {
   654  			aaddresser.err = nil
   655  		}()
   656  
   657  		w := protobuf.NewWriter(stream2)
   658  		if err := w.WriteMsg(&pb.Syn{
   659  			ObservedUnderlay: node1maBinary,
   660  		}); err != nil {
   661  			t.Fatal(err)
   662  		}
   663  
   664  		res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
   665  		if !errors.Is(err, testError) {
   666  			t.Fatal("expected error")
   667  		}
   668  
   669  		if res != nil {
   670  			t.Fatal("expected nil res")
   671  		}
   672  	})
   673  }
   674  
   675  func mockPicker(f func(p2p.Peer) bool) p2p.Picker {
   676  	return &picker{pickerFunc: f}
   677  }
   678  
   679  type picker struct {
   680  	pickerFunc func(p2p.Peer) bool
   681  }
   682  
   683  func (p *picker) Pick(peer p2p.Peer) bool {
   684  	return p.pickerFunc(peer)
   685  }
   686  
   687  // testInfo validates if two Info instances are equal.
   688  func testInfo(t *testing.T, got, want handshake.Info) {
   689  	t.Helper()
   690  	if !got.BzzAddress.Equal(want.BzzAddress) || got.FullNode != want.FullNode {
   691  		t.Fatalf("got info %+v, want %+v", got, want)
   692  	}
   693  }
   694  
   695  type AdvertisableAddresserMock struct {
   696  	advertisableAddress ma.Multiaddr
   697  	err                 error
   698  }
   699  
   700  func (a *AdvertisableAddresserMock) Resolve(observedAddress ma.Multiaddr) (ma.Multiaddr, error) {
   701  	if a.err != nil {
   702  		return nil, a.err
   703  	}
   704  
   705  	if a.advertisableAddress != nil {
   706  		return a.advertisableAddress, nil
   707  	}
   708  
   709  	return observedAddress, nil
   710  }