github.com/anycable/anycable-go@v1.5.1/hub/hub_test.go (about)

     1  package hub
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"math/rand"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/anycable/anycable-go/common"
    14  	"github.com/anycable/anycable-go/encoders"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  type MockSession struct {
    20  	sid      string
    21  	incoming chan ([]byte)
    22  	closed   bool
    23  	closeMu  sync.Mutex
    24  }
    25  
    26  func (s *MockSession) GetID() string {
    27  	return s.sid
    28  }
    29  
    30  func (s *MockSession) GetIdentifiers() string {
    31  	return s.sid
    32  }
    33  
    34  func (s *MockSession) Send(msg encoders.EncodedMessage) {
    35  	s.incoming <- toJSON(msg)
    36  }
    37  
    38  func (s *MockSession) DisconnectWithMessage(msg encoders.EncodedMessage, code string) {
    39  	s.closeMu.Lock()
    40  	defer s.closeMu.Unlock()
    41  
    42  	if s.closed {
    43  		return
    44  	}
    45  
    46  	s.incoming <- toJSON(msg)
    47  	s.closed = true
    48  }
    49  
    50  func (s *MockSession) Closed() bool {
    51  	s.closeMu.Lock()
    52  	defer s.closeMu.Unlock()
    53  
    54  	return s.closed
    55  }
    56  
    57  func (s *MockSession) Read() ([]byte, error) {
    58  	timer := time.After(100 * time.Millisecond)
    59  
    60  	select {
    61  	case <-timer:
    62  		return nil, errors.New("Session hasn't received any messages")
    63  	case msg := <-s.incoming:
    64  		return msg, nil
    65  	}
    66  }
    67  
    68  func (s *MockSession) ReadIndifinitely() ([]byte, error) {
    69  	return <-s.incoming, nil
    70  }
    71  
    72  func NewMockSession(sid string) *MockSession {
    73  	return &MockSession{sid: sid, incoming: make(chan []byte, 256)}
    74  }
    75  
    76  func TestUnsubscribeRaceConditions(t *testing.T) {
    77  	hub := NewHub(2, slog.Default())
    78  
    79  	go hub.Run()
    80  	defer hub.Shutdown()
    81  
    82  	session := NewMockSession("123")
    83  	session2 := NewMockSession("321")
    84  	session3 := NewMockSession("213")
    85  
    86  	hub.AddSession(session)
    87  	hub.SubscribeSession(session, "test", "test_channel")
    88  
    89  	hub.AddSession(session2)
    90  	hub.SubscribeSession(session2, "test", "test_channel")
    91  
    92  	hub.AddSession(session3)
    93  	hub.SubscribeSession(session3, "test", "test_channel")
    94  
    95  	hub.Broadcast("test", "hello")
    96  
    97  	_, err := session.Read()
    98  	assert.Nil(t, err)
    99  
   100  	_, err = session2.Read()
   101  	assert.Nil(t, err)
   102  
   103  	_, err = session3.Read()
   104  	assert.Nil(t, err)
   105  
   106  	assert.Equal(t, 3, hub.Size(), "Connections size must be equal 2")
   107  
   108  	go func() {
   109  		hub.Broadcast("test", "pong")
   110  		hub.RemoveSession(session)
   111  		hub.Broadcast("test", "ping")
   112  	}()
   113  
   114  	go func() {
   115  		hub.Broadcast("test", "bye-bye")
   116  		hub.RemoveSession(session3)
   117  		hub.Broadcast("test", "meow-meow")
   118  	}()
   119  
   120  	for i := 1; i < 5; i++ {
   121  		_, err = session2.Read()
   122  		assert.Nil(t, err)
   123  	}
   124  
   125  	_, err = session2.Read()
   126  	assert.NotNil(t, err)
   127  
   128  	assert.Equal(t, 1, hub.Size(), "Connections size must be equal 1")
   129  }
   130  
   131  func TestUnsubscribeSession(t *testing.T) {
   132  	hub := NewHub(2, slog.Default())
   133  
   134  	go hub.Run()
   135  	defer hub.Shutdown()
   136  
   137  	session := NewMockSession("123")
   138  	hub.AddSession(session)
   139  
   140  	hub.SubscribeSession(session, "test", "test_channel")
   141  	hub.SubscribeSession(session, "test2", "test_channel")
   142  
   143  	hub.Broadcast("test", "\"hello\"")
   144  
   145  	msg, err := session.Read()
   146  	assert.Nil(t, err)
   147  	assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"hello\"}", string(msg))
   148  
   149  	hub.UnsubscribeSession(session, "test", "test_channel")
   150  
   151  	hub.Broadcast("test", "\"goodbye\"")
   152  
   153  	_, err = session.Read()
   154  	assert.NotNil(t, err)
   155  
   156  	hub.Broadcast("test2", "\"bye\"")
   157  
   158  	msg, err = session.Read()
   159  	assert.Nil(t, err)
   160  	assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"bye\"}", string(msg))
   161  
   162  	hub.unsubscribeSessionFromAllChannels(session)
   163  
   164  	hub.Broadcast("test2", "\"goodbye\"")
   165  
   166  	_, err = session.Read()
   167  	assert.NotNil(t, err)
   168  }
   169  
   170  func TestSubscribeSession(t *testing.T) {
   171  	hub := NewHub(2, slog.Default())
   172  
   173  	go hub.Run()
   174  	defer hub.Shutdown()
   175  
   176  	session := NewMockSession("123")
   177  	hub.AddSession(session)
   178  
   179  	t.Run("Subscribe to a single channel", func(t *testing.T) {
   180  		hub.SubscribeSession(session, "test", "test_channel")
   181  
   182  		hub.Broadcast("test", "\"hello\"")
   183  
   184  		msg, err := session.Read()
   185  		assert.Nil(t, err)
   186  		assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"hello\"}", string(msg))
   187  	})
   188  
   189  	t.Run("Successful to the same stream from multiple channels", func(t *testing.T) {
   190  		hub.SubscribeSession(session, "test", "test_channel")
   191  		hub.SubscribeSession(session, "test", "test_channel2")
   192  
   193  		hub.Broadcast("test", "\"hello twice\"")
   194  
   195  		received := []string{}
   196  
   197  		msg, err := session.Read()
   198  		assert.Nil(t, err)
   199  		received = append(received, string(msg))
   200  
   201  		msg, err = session.Read()
   202  		assert.Nil(t, err)
   203  		received = append(received, string(msg))
   204  
   205  		assert.Contains(t, received, "{\"identifier\":\"test_channel\",\"message\":\"hello twice\"}")
   206  		assert.Contains(t, received, "{\"identifier\":\"test_channel2\",\"message\":\"hello twice\"}")
   207  	})
   208  }
   209  
   210  func TestRemoteDisconnect(t *testing.T) {
   211  	hub := NewHub(2, slog.Default())
   212  
   213  	go hub.Run()
   214  	defer hub.Shutdown()
   215  
   216  	session := NewMockSession("123")
   217  	hub.AddSession(session)
   218  
   219  	t.Run("Disconnect session", func(t *testing.T) {
   220  		hub.RemoteDisconnect(&common.RemoteDisconnectMessage{Identifier: "123", Reconnect: false})
   221  
   222  		msg, err := session.Read()
   223  		assert.Nil(t, err)
   224  		assert.Equal(t, "{\"type\":\"disconnect\",\"reason\":\"remote\",\"reconnect\":false}", string(msg))
   225  
   226  		assert.True(t, session.Closed())
   227  	})
   228  }
   229  
   230  func TestBroadcastMessage(t *testing.T) {
   231  	hub := NewHub(2, slog.Default())
   232  
   233  	go hub.Run()
   234  	defer hub.Shutdown()
   235  
   236  	session := NewMockSession("123")
   237  	hub.AddSession(session)
   238  	hub.SubscribeSession(session, "test", "test_channel")
   239  
   240  	t.Run("Broadcast without stream data", func(t *testing.T) {
   241  		hub.BroadcastMessage(&common.StreamMessage{Stream: "test", Data: "\"ciao\""})
   242  
   243  		msg, err := session.Read()
   244  		assert.Nil(t, err)
   245  		assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"ciao\"}", string(msg))
   246  	})
   247  
   248  	t.Run("Broadcast with stream data", func(t *testing.T) {
   249  		hub.BroadcastMessage(&common.StreamMessage{Stream: "test", Data: "\"ciao\"", Epoch: "xyz", Offset: 2022})
   250  
   251  		msg, err := session.Read()
   252  		assert.Nil(t, err)
   253  		assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"ciao\",\"stream_id\":\"test\",\"epoch\":\"xyz\",\"offset\":2022}", string(msg))
   254  	})
   255  
   256  	t.Run("Broadcast with exclude_socket", func(t *testing.T) {
   257  		session2 := NewMockSession("234")
   258  		hub.AddSession(session2)
   259  		hub.SubscribeSession(session2, "test", "test_channel")
   260  
   261  		hub.BroadcastMessage(&common.StreamMessage{Stream: "test", Data: "\"ciao\""})
   262  
   263  		msg, err := session.Read()
   264  		assert.Nil(t, err)
   265  		assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"ciao\"}", string(msg))
   266  
   267  		msg, err = session2.Read()
   268  		assert.Nil(t, err)
   269  		assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"ciao\"}", string(msg))
   270  
   271  		hub.BroadcastMessage(&common.StreamMessage{
   272  			Stream: "test",
   273  			Data:   "\"hoi!\"",
   274  			Meta: &common.StreamMessageMetadata{
   275  				ExcludeSocket: "234",
   276  			},
   277  		})
   278  
   279  		msg, err = session.Read()
   280  		assert.Nil(t, err)
   281  		assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"hoi!\"}", string(msg))
   282  
   283  		msg, err = session2.Read()
   284  		assert.Nil(t, msg)
   285  		require.Error(t, err)
   286  		assert.Contains(t, err.Error(), "hasn't received any messages")
   287  	})
   288  }
   289  
   290  func TestBroadcastOrder(t *testing.T) {
   291  	hub := NewHub(10, slog.Default())
   292  
   293  	go hub.Run()
   294  	defer hub.Shutdown()
   295  
   296  	session := NewMockSession("123")
   297  	hub.AddSession(session)
   298  	hub.SubscribeSession(session, "test", "test_channel")
   299  
   300  	session2 := NewMockSession("321")
   301  	hub.AddSession(session2)
   302  	hub.SubscribeSession(session2, "test2", "test2_channel")
   303  
   304  	var wg sync.WaitGroup
   305  
   306  	wg.Add(1)
   307  	go func() {
   308  		for i := 1; i <= 5; i++ {
   309  			hub.BroadcastMessage(&common.StreamMessage{Stream: "test", Data: fmt.Sprintf("%d", i)})
   310  		}
   311  		wg.Done()
   312  	}()
   313  
   314  	wg.Add(1)
   315  	go func() {
   316  		for i := 1; i <= 5; i++ {
   317  			hub.BroadcastMessage(&common.StreamMessage{Stream: "test2", Data: fmt.Sprintf("%d", i)})
   318  		}
   319  		wg.Done()
   320  	}()
   321  
   322  	wg.Wait()
   323  
   324  	for i := 1; i <= 5; i++ {
   325  		msg, err := session.Read()
   326  		assert.Nil(t, err)
   327  		assert.Equal(t, fmt.Sprintf("{\"identifier\":\"test_channel\",\"message\":%d}", i), string(msg))
   328  	}
   329  
   330  	for i := 1; i <= 5; i++ {
   331  		msg, err := session2.Read()
   332  		assert.Nil(t, err)
   333  		assert.Equal(t, fmt.Sprintf("{\"identifier\":\"test2_channel\",\"message\":%d}", i), string(msg))
   334  	}
   335  }
   336  
   337  func TestBuildMessageJSON(t *testing.T) {
   338  	expected := []byte("{\"identifier\":\"chat\",\"message\":{\"text\":\"hello!\"}}")
   339  	actual := toJSON(buildMessage(&common.StreamMessage{Data: "{\"text\":\"hello!\"}"}, "chat"))
   340  	assert.Equal(t, expected, actual)
   341  }
   342  
   343  func TestBuildMessageString(t *testing.T) {
   344  	expected := []byte("{\"identifier\":\"chat\",\"message\":\"plain string\"}")
   345  	actual := toJSON(buildMessage(&common.StreamMessage{Data: "\"plain string\""}, "chat"))
   346  	assert.Equal(t, expected, actual)
   347  }
   348  
   349  type benchmarkConfig struct {
   350  	hubPoolSize       int
   351  	totalStreams      int
   352  	totalSessions     int
   353  	streamsPerSession int
   354  	payload           string
   355  }
   356  
   357  func BenchmarkBroadcast(b *testing.B) {
   358  	configs := []benchmarkConfig{}
   359  
   360  	poolSizes := []int{128, 16, 2, 1}
   361  	streamNums := [][]int{
   362  		{1000, 10},
   363  		{100, 10},
   364  		{10, 3},
   365  	}
   366  	sessionsNum := 10000
   367  	payload := "\"A quick brow fox bla-bla-bla\""
   368  
   369  	for _, streamNum := range streamNums {
   370  		for _, poolSize := range poolSizes {
   371  			configs = append(configs, benchmarkConfig{poolSize, streamNum[0], sessionsNum, streamNum[1], payload})
   372  		}
   373  	}
   374  
   375  	for _, config := range configs {
   376  		b.Run(fmt.Sprintf("%v", config), func(b *testing.B) {
   377  			broadcastsPerStream := b.N / config.totalStreams
   378  			messagesPerSession := config.streamsPerSession * broadcastsPerStream
   379  
   380  			hub := NewHub(config.hubPoolSize, slog.Default())
   381  
   382  			go hub.Run()
   383  			defer hub.Shutdown()
   384  
   385  			var wg sync.WaitGroup
   386  			var streams []string
   387  
   388  			for i := 0; i < config.totalStreams; i++ {
   389  				stream := fmt.Sprintf("stream_%d", i)
   390  				streams = append(streams, stream)
   391  			}
   392  
   393  			for i := 0; i < config.totalSessions; i++ {
   394  				sid := fmt.Sprintf("%d", i)
   395  				session := NewMockSession(sid)
   396  
   397  				wg.Add(1)
   398  
   399  				go func() {
   400  					countdown := 0
   401  					for {
   402  						if countdown >= messagesPerSession {
   403  							wg.Done()
   404  							break
   405  						}
   406  
   407  						session.ReadIndifinitely() // nolint:errcheck
   408  						countdown++
   409  					}
   410  				}()
   411  
   412  				hub.AddSession(session)
   413  
   414  				for j := 0; j < config.streamsPerSession; j++ {
   415  					channel := fmt.Sprintf("test_channel%d", j)
   416  
   417  					stream := streams[rand.Intn(len(streams))] // nolint:gosec
   418  
   419  					hub.SubscribeSession(session, stream, channel)
   420  				}
   421  			}
   422  
   423  			b.ResetTimer()
   424  
   425  			for _, stream := range streams {
   426  				for i := 0; i < broadcastsPerStream; i++ {
   427  					hub.Broadcast(stream, config.payload)
   428  				}
   429  			}
   430  
   431  			wg.Wait()
   432  			b.StopTimer()
   433  		})
   434  	}
   435  }
   436  
   437  func toJSON(msg encoders.EncodedMessage) []byte {
   438  	b, err := json.Marshal(&msg)
   439  	if err != nil {
   440  		panic("Failed to build JSON 😲")
   441  	}
   442  
   443  	return b
   444  }