github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/p2p/channel_test.go (about)

     1  /*
     2   * Copyright (C) 2020 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package p2p
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"net"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  
    31  	"github.com/mysteriumnetwork/node/core/port"
    32  	"github.com/mysteriumnetwork/node/pb"
    33  )
    34  
    35  func TestChannelFullCommunicationFlow(t *testing.T) {
    36  	provider, consumer, err := createTestChannels()
    37  	require.NoError(t, err)
    38  	defer provider.Close()
    39  	defer consumer.Close()
    40  
    41  	t.Run("Test publish subscribe pattern", func(t *testing.T) {
    42  		consumerReceivedMsg := make(chan *pb.PingPong, 1)
    43  		providerReceivedMsg := make(chan *pb.PingPong, 1)
    44  
    45  		consumer.Handle("ping.pong", func(c Context) error {
    46  			var res pb.PingPong
    47  			err := c.Request().UnmarshalProto(&res)
    48  			assert.NoError(t, err)
    49  			consumerReceivedMsg <- &res
    50  			return c.OK()
    51  		})
    52  
    53  		provider.Handle("ping.pong", func(c Context) error {
    54  			var res pb.PingPong
    55  			err := c.Request().UnmarshalProto(&res)
    56  			assert.NoError(t, err)
    57  			providerReceivedMsg <- &res
    58  			return c.OK()
    59  		})
    60  
    61  		publishedConsumerMsg := &pb.PingPong{Value: "Consumer BigZ"}
    62  		msg := ProtoMessage(publishedConsumerMsg)
    63  		_, err := consumer.Send(context.Background(), "ping.pong", msg)
    64  		assert.NoError(t, err)
    65  
    66  		publishedProviderMsg := &pb.PingPong{Value: "Provider SmallZ"}
    67  		msg = ProtoMessage(publishedProviderMsg)
    68  		_, err = provider.Send(context.Background(), "ping.pong", msg)
    69  		assert.NoError(t, err)
    70  
    71  		select {
    72  		case v := <-consumerReceivedMsg:
    73  			assert.Equal(t, publishedProviderMsg.Value, v.Value)
    74  		case <-time.After(100 * time.Millisecond):
    75  			t.Fatal("did not received message from channel consumer subscription")
    76  		}
    77  
    78  		select {
    79  		case v := <-providerReceivedMsg:
    80  			assert.Equal(t, publishedConsumerMsg.Value, v.Value)
    81  		case <-time.After(100 * time.Millisecond):
    82  			t.Fatal("did not received message from channel provider subscription")
    83  		}
    84  	})
    85  
    86  	t.Run("Test request reply pattern", func(t *testing.T) {
    87  		provider.Handle("testreq", func(c Context) error {
    88  			var req pb.PingPong
    89  			err := c.Request().UnmarshalProto(&req)
    90  			assert.NoError(t, err)
    91  
    92  			msg := ProtoMessage(&pb.PingPong{Value: req.Value + "-pong"})
    93  			assert.NoError(t, err)
    94  			return c.OkWithReply(msg)
    95  		})
    96  
    97  		msg := ProtoMessage(&pb.PingPong{Value: "ping"})
    98  		res, err := consumer.Send(context.Background(), "testreq", msg)
    99  		assert.NoError(t, err)
   100  
   101  		var resMsg pb.PingPong
   102  		err = res.UnmarshalProto(&resMsg)
   103  		assert.NoError(t, err)
   104  		assert.Equal(t, "ping-pong", resMsg.Value)
   105  	})
   106  
   107  	t.Run("Test concurrent requests", func(t *testing.T) {
   108  		var wg sync.WaitGroup
   109  		provider.Handle("concurrent", func(c Context) error {
   110  			wg.Done()
   111  			return c.OK()
   112  		})
   113  
   114  		for i := 0; i < 10; i++ {
   115  			wg.Add(1)
   116  			go func() {
   117  				_, err := consumer.Send(context.Background(), "concurrent", &Message{Data: []byte{}})
   118  				assert.NoError(t, err)
   119  			}()
   120  		}
   121  
   122  		wg.Wait()
   123  	})
   124  
   125  	t.Run("Test slow topicHandlers are not blocking", func(t *testing.T) {
   126  		provider.Handle("slow", func(c Context) error {
   127  			time.Sleep(time.Hour)
   128  			return c.OK()
   129  		})
   130  
   131  		provider.Handle("fast", func(c Context) error {
   132  			return c.OK()
   133  		})
   134  
   135  		slowStarted := make(chan struct{})
   136  		go func() {
   137  			slowStarted <- struct{}{}
   138  			consumer.Send(context.Background(), "slow", &Message{})
   139  		}()
   140  
   141  		fastFinished := make(chan struct{})
   142  		go func() {
   143  			<-slowStarted
   144  			consumer.Send(context.Background(), "fast", &Message{})
   145  			fastFinished <- struct{}{}
   146  		}()
   147  
   148  		select {
   149  		case <-fastFinished:
   150  		case <-time.After(time.Second):
   151  			t.Fatal("slow handler blocks concurrent send")
   152  		}
   153  	})
   154  
   155  	t.Run("Test peer returns public error", func(t *testing.T) {
   156  		provider.Handle("get-error", func(c Context) error {
   157  			return c.Error(errors.New("I don't like you"))
   158  		})
   159  
   160  		_, err := consumer.Send(context.Background(), "get-error", &Message{Data: []byte("hello")})
   161  		assert.EqualError(t, err, "public peer error: I don't like you")
   162  	})
   163  
   164  	t.Run("Test peer returns internal error", func(t *testing.T) {
   165  		provider.Handle("get-error", func(c Context) error {
   166  			return errors.New("I don't like you")
   167  		})
   168  
   169  		_, err := consumer.Send(context.Background(), "get-error", &Message{Data: []byte("hello")})
   170  		assert.EqualError(t, err, "peer error: I don't like you")
   171  	})
   172  
   173  	t.Run("Test peer returns handler not found error", func(t *testing.T) {
   174  		_, err := consumer.Send(context.Background(), "ping", &Message{Data: []byte("hello")})
   175  		if !errors.Is(err, ErrHandlerNotFound) {
   176  			t.Fatalf("expect handler not found err, got %v", err)
   177  		}
   178  	})
   179  }
   180  
   181  func TestChannel_Send_Timeout(t *testing.T) {
   182  	provider, consumer, err := createTestChannels()
   183  	require.NoError(t, err)
   184  	defer provider.Close()
   185  	defer consumer.Close()
   186  
   187  	t.Run("Test timeout for long not responding peer", func(t *testing.T) {
   188  		provider.Handle("timeout", func(c Context) error {
   189  			time.Sleep(time.Hour)
   190  			return c.OK()
   191  		})
   192  
   193  		ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   194  		defer cancel()
   195  		_, err = consumer.Send(ctx, "timeout", &Message{Data: []byte("ping")})
   196  		if !errors.Is(err, ErrSendTimeout) {
   197  			t.Fatalf("expect timeout err, got: %v", err)
   198  		}
   199  	})
   200  
   201  	t.Run("Test timeout when peer is closed", func(t *testing.T) {
   202  		provider.Close()
   203  
   204  		ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   205  		defer cancel()
   206  		_, err = consumer.Send(ctx, "timeout", &Message{Data: []byte("ping")})
   207  		if !errors.Is(err, ErrSendTimeout) {
   208  			t.Fatalf("expect timeout err, got: %v", err)
   209  		}
   210  	})
   211  }
   212  
   213  func TestChannel_Send_To_When_Peer_Starts_Later(t *testing.T) {
   214  	provider, consumer, err := createTestChannels()
   215  	require.NoError(t, err)
   216  	defer consumer.Close()
   217  	defer provider.Close()
   218  
   219  	// Close provider channel to simulate unstable network.
   220  	// Consumer will try to send messages and during first 50 ms
   221  	// they will not reach provider peer, but since kcp will try
   222  	// keep resending packets they will finally reach opened
   223  	// provider peer.
   224  	addr := provider.(*channel).tr.remoteConn.LocalAddr().(*net.UDPAddr)
   225  	err = provider.Close()
   226  	require.NoError(t, err)
   227  	go func() {
   228  		time.Sleep(50 * time.Millisecond)
   229  		provider, err := reopenChannel(provider.(*channel), addr)
   230  
   231  		require.NoError(t, err)
   232  		provider.Handle("timeout", func(c Context) error {
   233  			return c.OK()
   234  		})
   235  	}()
   236  
   237  	ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   238  	defer cancel()
   239  	_, err = consumer.Send(ctx, "timeout", &Message{Data: []byte("ping")})
   240  	require.NoError(t, err)
   241  }
   242  
   243  func TestChannel_Detect_And_Update_Peer_Addr(t *testing.T) {
   244  	provider, consumer, err := createTestChannels()
   245  	require.NoError(t, err)
   246  	defer consumer.Close()
   247  	defer provider.Close()
   248  
   249  	provider.Handle("ping", func(c Context) error {
   250  		return c.OK()
   251  	})
   252  
   253  	// Close consumer peer and reopen channel with new local addr.
   254  	consumer.Close()
   255  	consumer, err = reopenChannelWithNewLocalAddr(consumer.(*channel))
   256  	require.NoError(t, err)
   257  
   258  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   259  	defer cancel()
   260  	_, err = consumer.Send(ctx, "ping", &Message{Data: []byte("pingasssas")})
   261  }
   262  
   263  func BenchmarkChannel_Send(b *testing.B) {
   264  	provider, consumer, err := createTestChannels()
   265  	require.NoError(b, err)
   266  	defer provider.Close()
   267  	defer consumer.Close()
   268  
   269  	provider.Handle("bench", func(c Context) error {
   270  		return c.OkWithReply(&Message{Data: []byte("I'm still OK")})
   271  	})
   272  
   273  	b.ResetTimer()
   274  
   275  	for i := 0; i < b.N; i++ {
   276  		res, err := consumer.Send(context.Background(), "bench", &Message{Data: []byte("Catch this!")})
   277  		require.NoError(b, err)
   278  		require.NotNil(b, res)
   279  	}
   280  }
   281  
   282  func reopenChannel(c *channel, addr *net.UDPAddr) (*channel, error) {
   283  	punchedConn, err := net.DialUDP("udp4", addr, c.peer.addr())
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  	ch, err := newChannel(punchedConn, c.privateKey, c.peer.publicKey, 1)
   288  	if err != nil {
   289  		return nil, err
   290  	}
   291  	ch.launchReadSendLoops()
   292  	return ch, err
   293  }
   294  
   295  func reopenChannelWithNewLocalAddr(c *channel) (*channel, error) {
   296  	punchedConn, err := net.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, c.peer.addr())
   297  	if err != nil {
   298  		return nil, err
   299  	}
   300  	ch, err := newChannel(punchedConn, c.privateKey, c.peer.publicKey, 1)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  	ch.launchReadSendLoops()
   305  	return ch, err
   306  }
   307  
   308  func createTestChannels() (Channel, Channel, error) {
   309  	ports, err := acquirePorts(2)
   310  	if err != nil {
   311  		return nil, nil, err
   312  	}
   313  	providerPort := ports[0]
   314  	consumerPort := ports[1]
   315  
   316  	providerConn, err := net.DialUDP("udp4", &net.UDPAddr{Port: providerPort}, &net.UDPAddr{Port: consumerPort})
   317  	if err != nil {
   318  		return nil, nil, err
   319  	}
   320  
   321  	consumerConn, err := net.DialUDP("udp4", &net.UDPAddr{Port: consumerPort}, &net.UDPAddr{Port: providerPort})
   322  	if err != nil {
   323  		return nil, nil, err
   324  	}
   325  
   326  	providerPublicKey, providerPrivateKey, err := GenerateKey()
   327  	if err != nil {
   328  		return nil, nil, err
   329  	}
   330  	consumerPublicKey, consumerPrivateKey, err := GenerateKey()
   331  	if err != nil {
   332  		return nil, nil, err
   333  	}
   334  
   335  	provider, err := newChannel(providerConn, providerPrivateKey, consumerPublicKey, 1)
   336  	if err != nil {
   337  		return nil, nil, err
   338  	}
   339  	provider.launchReadSendLoops()
   340  
   341  	consumer, err := newChannel(consumerConn, consumerPrivateKey, providerPublicKey, 1)
   342  	if err != nil {
   343  		return nil, nil, err
   344  	}
   345  	consumer.launchReadSendLoops()
   346  
   347  	return provider, consumer, nil
   348  }
   349  
   350  func acquirePorts(n int) ([]int, error) {
   351  	portPool := port.NewFixedRangePool(port.Range{Start: 10000, End: 60000})
   352  	ports, err := portPool.AcquireMultiple(n)
   353  	if err != nil {
   354  		return nil, err
   355  	}
   356  	var res []int
   357  	for _, v := range ports {
   358  		res = append(res, v.Num())
   359  	}
   360  	return res, nil
   361  }