github.com/ethersphere/bee/v2@v2.2.0/pkg/p2p/libp2p/headers_test.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package libp2p_test
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/ethersphere/bee/v2/pkg/p2p"
    14  	"github.com/ethersphere/bee/v2/pkg/p2p/libp2p"
    15  	"github.com/ethersphere/bee/v2/pkg/swarm"
    16  )
    17  
    18  func TestHeaders(t *testing.T) {
    19  	t.Parallel()
    20  
    21  	headers := p2p.Headers{
    22  		"test-header-key": []byte("header-value"),
    23  		"other-key":       []byte("other-value"),
    24  	}
    25  
    26  	ctx, cancel := context.WithCancel(context.Background())
    27  	defer cancel()
    28  
    29  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
    30  		FullNode: true,
    31  	}})
    32  
    33  	s2, overlay2 := newService(t, 1, libp2pServiceOpts{})
    34  
    35  	var gotHeaders p2p.Headers
    36  	handled := make(chan struct{})
    37  	if err := s1.AddProtocol(newTestProtocol(func(ctx context.Context, p p2p.Peer, stream p2p.Stream) error {
    38  		if ctx == nil {
    39  			t.Fatal("missing context")
    40  		}
    41  		if !p.Address.Equal(overlay2) {
    42  			t.Fatalf("got peer %v, want %v", p.Address, overlay2)
    43  		}
    44  		gotHeaders = stream.Headers()
    45  		close(handled)
    46  		return nil
    47  	})); err != nil {
    48  		t.Fatal(err)
    49  	}
    50  
    51  	addr := serviceUnderlayAddress(t, s1)
    52  
    53  	if _, err := s2.Connect(ctx, addr); err != nil {
    54  		t.Fatal(err)
    55  	}
    56  
    57  	stream, err := s2.NewStream(ctx, overlay1, headers, testProtocolName, testProtocolVersion, testStreamName)
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	defer stream.Close()
    62  
    63  	select {
    64  	case <-handled:
    65  	case <-time.After(30 * time.Second):
    66  		t.Fatal("timeout waiting for handler")
    67  	}
    68  
    69  	if fmt.Sprint(gotHeaders) != fmt.Sprint(headers) {
    70  		t.Errorf("got headers %+v, want %+v", gotHeaders, headers)
    71  	}
    72  }
    73  
    74  func TestHeaders_empty(t *testing.T) {
    75  	t.Parallel()
    76  
    77  	ctx, cancel := context.WithCancel(context.Background())
    78  	defer cancel()
    79  
    80  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
    81  		FullNode: true,
    82  	}})
    83  
    84  	s2, overlay2 := newService(t, 1, libp2pServiceOpts{})
    85  
    86  	var gotHeaders p2p.Headers
    87  	handled := make(chan struct{})
    88  	if err := s1.AddProtocol(newTestProtocol(func(ctx context.Context, p p2p.Peer, stream p2p.Stream) error {
    89  		if ctx == nil {
    90  			t.Fatal("missing context")
    91  		}
    92  		if !p.Address.Equal(overlay2) {
    93  			t.Fatalf("got peer %v, want %v", p.Address, overlay2)
    94  		}
    95  		gotHeaders = stream.Headers()
    96  		close(handled)
    97  		return nil
    98  	})); err != nil {
    99  		t.Fatal(err)
   100  	}
   101  
   102  	addr := serviceUnderlayAddress(t, s1)
   103  
   104  	if _, err := s2.Connect(ctx, addr); err != nil {
   105  		t.Fatal(err)
   106  	}
   107  
   108  	stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
   109  	if err != nil {
   110  		t.Fatal(err)
   111  	}
   112  	defer stream.Close()
   113  
   114  	select {
   115  	case <-handled:
   116  	case <-time.After(30 * time.Second):
   117  		t.Fatal("timeout waiting for handler")
   118  	}
   119  
   120  	if len(gotHeaders) != 0 {
   121  		t.Errorf("got headers %+v, want none", gotHeaders)
   122  	}
   123  }
   124  
   125  func TestHeadler(t *testing.T) {
   126  	t.Parallel()
   127  
   128  	receivedHeaders := p2p.Headers{
   129  		"test-header-key": []byte("header-value"),
   130  		"other-key":       []byte("other-value"),
   131  	}
   132  	sentHeaders := p2p.Headers{
   133  		"sent-header-key": []byte("sent-value"),
   134  		"other-sent-key":  []byte("other-sent-value"),
   135  	}
   136  
   137  	ctx, cancel := context.WithCancel(context.Background())
   138  	defer cancel()
   139  
   140  	s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
   141  		FullNode: true,
   142  	}})
   143  
   144  	s2, _ := newService(t, 1, libp2pServiceOpts{})
   145  
   146  	var gotReceivedHeaders p2p.Headers
   147  	handled := make(chan struct{})
   148  	if err := s1.AddProtocol(p2p.ProtocolSpec{
   149  		Name:    testProtocolName,
   150  		Version: testProtocolVersion,
   151  		StreamSpecs: []p2p.StreamSpec{
   152  			{
   153  				Name: testStreamName,
   154  				Handler: func(_ context.Context, _ p2p.Peer, stream p2p.Stream) error {
   155  					return nil
   156  				},
   157  				Headler: func(headers p2p.Headers, address swarm.Address) p2p.Headers {
   158  					defer close(handled)
   159  					gotReceivedHeaders = headers
   160  					return sentHeaders
   161  				},
   162  			},
   163  		},
   164  	}); err != nil {
   165  		t.Fatal(err)
   166  	}
   167  
   168  	addr := serviceUnderlayAddress(t, s1)
   169  
   170  	if _, err := s2.Connect(ctx, addr); err != nil {
   171  		t.Fatal(err)
   172  	}
   173  
   174  	stream, err := s2.NewStream(ctx, overlay1, receivedHeaders, testProtocolName, testProtocolVersion, testStreamName)
   175  	if err != nil {
   176  		t.Fatal(err)
   177  	}
   178  	defer stream.Close()
   179  
   180  	select {
   181  	case <-handled:
   182  	case <-time.After(30 * time.Second):
   183  		t.Fatal("timeout waiting for handler")
   184  	}
   185  
   186  	if fmt.Sprint(gotReceivedHeaders) != fmt.Sprint(receivedHeaders) {
   187  		t.Errorf("got received headers %+v, want %+v", gotReceivedHeaders, receivedHeaders)
   188  	}
   189  
   190  	gotSentHeaders := stream.Headers()
   191  	if fmt.Sprint(gotSentHeaders) != fmt.Sprint(sentHeaders) {
   192  		t.Errorf("got sent headers %+v, want %+v", gotSentHeaders, sentHeaders)
   193  	}
   194  }