github.com/anycable/anycable-go@v1.5.1/pubsub/subscriber_test.go (about)

     1  package pubsub
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/anycable/anycable-go/common"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  type TestHandler struct {
    14  	messages chan (*common.StreamMessage)
    15  	commands chan (*common.RemoteCommandMessage)
    16  }
    17  
    18  var _ Handler = (*TestHandler)(nil)
    19  
    20  func NewTestHandler() *TestHandler {
    21  	return &TestHandler{
    22  		messages: make(chan *common.StreamMessage, 10),
    23  		commands: make(chan *common.RemoteCommandMessage, 10),
    24  	}
    25  }
    26  
    27  func (h *TestHandler) Broadcast(msg *common.StreamMessage) {
    28  	h.messages <- msg
    29  }
    30  
    31  func (h *TestHandler) ExecuteRemoteCommand(cmd *common.RemoteCommandMessage) {
    32  	h.commands <- cmd
    33  }
    34  
    35  func (h *TestHandler) Receive() *common.StreamMessage {
    36  	timer := time.After(100 * time.Millisecond)
    37  
    38  	select {
    39  	case <-timer:
    40  		return nil
    41  	case msg := <-h.messages:
    42  		return msg
    43  	}
    44  }
    45  
    46  func (h *TestHandler) ReceiveCommand() *common.RemoteCommandMessage {
    47  	timer := time.After(100 * time.Millisecond)
    48  
    49  	select {
    50  	case <-timer:
    51  		return nil
    52  	case msg := <-h.commands:
    53  		return msg
    54  	}
    55  }
    56  
    57  type subscriberFactory = func(handler *TestHandler) Subscriber
    58  type subscriptionWaiter = func(subscriber Subscriber, stream string) error
    59  
    60  func SharedSubscriberTests(t *testing.T, factory subscriberFactory, wait subscriptionWaiter) {
    61  	handler := NewTestHandler()
    62  	subscriber := factory(handler)
    63  	done := make(chan error)
    64  
    65  	err := subscriber.Start(done)
    66  	require.NoError(t, err)
    67  
    68  	require.NoError(t, wait(subscriber, "internal"))
    69  
    70  	defer subscriber.Shutdown(context.Background()) // nolint:errcheck
    71  
    72  	t.Run("Broadcast", func(t *testing.T) {
    73  		// Sbscribers may rely on known subscriptions
    74  		subscriber.Subscribe("test")
    75  		require.NoError(t, wait(subscriber, "test"))
    76  
    77  		subscriber.Broadcast(&common.StreamMessage{Stream: "test", Data: "boo"})
    78  
    79  		msg := handler.Receive()
    80  		require.NotNil(t, msg)
    81  		assert.Equal(t, "boo", msg.Data)
    82  	})
    83  
    84  	t.Run("Broadcast commands", func(t *testing.T) {
    85  		subscriber.BroadcastCommand(&common.RemoteCommandMessage{Command: "test", Payload: []byte(`{"foo":"bar"}`)})
    86  
    87  		cmd := handler.ReceiveCommand()
    88  		require.NotNil(t, cmd)
    89  		assert.Equal(t, "test", cmd.Command)
    90  	})
    91  
    92  	if !subscriber.IsMultiNode() {
    93  		return
    94  	}
    95  
    96  	// Tests for multi-node subscribers require at least two handler and subscribers to
    97  	// test re-transmission
    98  
    99  	otherHandler := NewTestHandler()
   100  	otherSubscriber := factory(otherHandler)
   101  
   102  	err = otherSubscriber.Start(done)
   103  	require.NoError(t, err)
   104  
   105  	require.NoError(t, wait(otherSubscriber, "internal"))
   106  
   107  	defer otherSubscriber.Shutdown(context.Background()) // nolint:errcheck
   108  
   109  	t.Run("Subscribe - Broadcast", func(t *testing.T) {
   110  		subscriber.Subscribe("a")
   111  		otherSubscriber.Subscribe("b")
   112  		otherSubscriber.Subscribe("a")
   113  
   114  		require.NoError(t, wait(subscriber, "a"))
   115  		require.NoError(t, wait(otherSubscriber, "a"))
   116  		require.NoError(t, wait(otherSubscriber, "b"))
   117  
   118  		subscriber.Broadcast(&common.StreamMessage{Stream: "a", Data: "1"})
   119  
   120  		msg := handler.Receive()
   121  		require.NotNil(t, msg)
   122  		assert.Equal(t, "1", msg.Data)
   123  		assert.Equal(t, "a", msg.Stream)
   124  
   125  		nextMsg := handler.Receive()
   126  		assert.Nilf(t, nextMsg, "Must broadcast message once")
   127  
   128  		msg = otherHandler.Receive()
   129  		require.NotNil(t, msg)
   130  		assert.Equal(t, "1", msg.Data)
   131  		assert.Equal(t, "a", msg.Stream)
   132  
   133  		nextMsg = otherHandler.Receive()
   134  		assert.Nilf(t, nextMsg, "Must broadcast message once")
   135  
   136  		subscriber.Broadcast(&common.StreamMessage{Stream: "b", Data: "2"})
   137  
   138  		msg = handler.Receive()
   139  		assert.Nilf(t, msg, "Should not broadcast message for unknown stream")
   140  
   141  		msg = otherHandler.Receive()
   142  		require.NotNil(t, msg)
   143  		assert.Equal(t, "2", msg.Data)
   144  		assert.Equal(t, "b", msg.Stream)
   145  	})
   146  
   147  	t.Run("Re-transmit commands", func(t *testing.T) {
   148  		subscriber.BroadcastCommand(&common.RemoteCommandMessage{Command: "test"})
   149  
   150  		cmd := handler.ReceiveCommand()
   151  		require.NotNil(t, cmd)
   152  		assert.Equal(t, "test", cmd.Command)
   153  
   154  		cmd = otherHandler.ReceiveCommand()
   155  		require.NotNil(t, cmd)
   156  		assert.Equal(t, "test", cmd.Command)
   157  	})
   158  
   159  	t.Run("Subscribe - Broadcast - Unsubscribe - Broadcast", func(t *testing.T) {
   160  		subscriber.Subscribe("a")
   161  		otherSubscriber.Subscribe("a")
   162  
   163  		require.NoError(t, wait(subscriber, "a"))
   164  		require.NoError(t, wait(otherSubscriber, "a"))
   165  
   166  		subscriber.Broadcast(&common.StreamMessage{Stream: "a", Data: "1"})
   167  
   168  		msg := handler.Receive()
   169  		require.NotNil(t, msg)
   170  		assert.Equal(t, "1", msg.Data)
   171  		assert.Equal(t, "a", msg.Stream)
   172  
   173  		msg = otherHandler.Receive()
   174  		require.NotNil(t, msg)
   175  		assert.Equal(t, "1", msg.Data)
   176  		assert.Equal(t, "a", msg.Stream)
   177  
   178  		subscriber.Unsubscribe("a")
   179  		require.NoError(t, wait(subscriber, "-a"))
   180  
   181  		subscriber.Broadcast(&common.StreamMessage{Stream: "a", Data: "2"})
   182  
   183  		msg = handler.Receive()
   184  		assert.Nilf(t, msg, "Should not broadcast message for unsubscribed stream")
   185  
   186  		msg = otherHandler.Receive()
   187  		require.NotNil(t, msg)
   188  		assert.Equal(t, "2", msg.Data)
   189  		assert.Equal(t, "a", msg.Stream)
   190  	})
   191  }
   192  
   193  func TestLegacySubscriber(t *testing.T) {
   194  	SharedSubscriberTests(t, func(handler *TestHandler) Subscriber {
   195  		return NewLegacySubscriber(handler)
   196  	}, func(subscriber Subscriber, stream string) error { return nil })
   197  }