github.com/anycable/anycable-go@v1.5.1/hub/hub_test.go (about) 1 package hub 2 3 import ( 4 "encoding/json" 5 "errors" 6 "fmt" 7 "log/slog" 8 "math/rand" 9 "sync" 10 "testing" 11 "time" 12 13 "github.com/anycable/anycable-go/common" 14 "github.com/anycable/anycable-go/encoders" 15 "github.com/stretchr/testify/assert" 16 "github.com/stretchr/testify/require" 17 ) 18 19 type MockSession struct { 20 sid string 21 incoming chan ([]byte) 22 closed bool 23 closeMu sync.Mutex 24 } 25 26 func (s *MockSession) GetID() string { 27 return s.sid 28 } 29 30 func (s *MockSession) GetIdentifiers() string { 31 return s.sid 32 } 33 34 func (s *MockSession) Send(msg encoders.EncodedMessage) { 35 s.incoming <- toJSON(msg) 36 } 37 38 func (s *MockSession) DisconnectWithMessage(msg encoders.EncodedMessage, code string) { 39 s.closeMu.Lock() 40 defer s.closeMu.Unlock() 41 42 if s.closed { 43 return 44 } 45 46 s.incoming <- toJSON(msg) 47 s.closed = true 48 } 49 50 func (s *MockSession) Closed() bool { 51 s.closeMu.Lock() 52 defer s.closeMu.Unlock() 53 54 return s.closed 55 } 56 57 func (s *MockSession) Read() ([]byte, error) { 58 timer := time.After(100 * time.Millisecond) 59 60 select { 61 case <-timer: 62 return nil, errors.New("Session hasn't received any messages") 63 case msg := <-s.incoming: 64 return msg, nil 65 } 66 } 67 68 func (s *MockSession) ReadIndifinitely() ([]byte, error) { 69 return <-s.incoming, nil 70 } 71 72 func NewMockSession(sid string) *MockSession { 73 return &MockSession{sid: sid, incoming: make(chan []byte, 256)} 74 } 75 76 func TestUnsubscribeRaceConditions(t *testing.T) { 77 hub := NewHub(2, slog.Default()) 78 79 go hub.Run() 80 defer hub.Shutdown() 81 82 session := NewMockSession("123") 83 session2 := NewMockSession("321") 84 session3 := NewMockSession("213") 85 86 hub.AddSession(session) 87 hub.SubscribeSession(session, "test", "test_channel") 88 89 hub.AddSession(session2) 90 hub.SubscribeSession(session2, "test", "test_channel") 91 92 hub.AddSession(session3) 93 hub.SubscribeSession(session3, "test", "test_channel") 94 95 hub.Broadcast("test", "hello") 96 97 _, err := session.Read() 98 assert.Nil(t, err) 99 100 _, err = session2.Read() 101 assert.Nil(t, err) 102 103 _, err = session3.Read() 104 assert.Nil(t, err) 105 106 assert.Equal(t, 3, hub.Size(), "Connections size must be equal 2") 107 108 go func() { 109 hub.Broadcast("test", "pong") 110 hub.RemoveSession(session) 111 hub.Broadcast("test", "ping") 112 }() 113 114 go func() { 115 hub.Broadcast("test", "bye-bye") 116 hub.RemoveSession(session3) 117 hub.Broadcast("test", "meow-meow") 118 }() 119 120 for i := 1; i < 5; i++ { 121 _, err = session2.Read() 122 assert.Nil(t, err) 123 } 124 125 _, err = session2.Read() 126 assert.NotNil(t, err) 127 128 assert.Equal(t, 1, hub.Size(), "Connections size must be equal 1") 129 } 130 131 func TestUnsubscribeSession(t *testing.T) { 132 hub := NewHub(2, slog.Default()) 133 134 go hub.Run() 135 defer hub.Shutdown() 136 137 session := NewMockSession("123") 138 hub.AddSession(session) 139 140 hub.SubscribeSession(session, "test", "test_channel") 141 hub.SubscribeSession(session, "test2", "test_channel") 142 143 hub.Broadcast("test", "\"hello\"") 144 145 msg, err := session.Read() 146 assert.Nil(t, err) 147 assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"hello\"}", string(msg)) 148 149 hub.UnsubscribeSession(session, "test", "test_channel") 150 151 hub.Broadcast("test", "\"goodbye\"") 152 153 _, err = session.Read() 154 assert.NotNil(t, err) 155 156 hub.Broadcast("test2", "\"bye\"") 157 158 msg, err = session.Read() 159 assert.Nil(t, err) 160 assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"bye\"}", string(msg)) 161 162 hub.unsubscribeSessionFromAllChannels(session) 163 164 hub.Broadcast("test2", "\"goodbye\"") 165 166 _, err = session.Read() 167 assert.NotNil(t, err) 168 } 169 170 func TestSubscribeSession(t *testing.T) { 171 hub := NewHub(2, slog.Default()) 172 173 go hub.Run() 174 defer hub.Shutdown() 175 176 session := NewMockSession("123") 177 hub.AddSession(session) 178 179 t.Run("Subscribe to a single channel", func(t *testing.T) { 180 hub.SubscribeSession(session, "test", "test_channel") 181 182 hub.Broadcast("test", "\"hello\"") 183 184 msg, err := session.Read() 185 assert.Nil(t, err) 186 assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"hello\"}", string(msg)) 187 }) 188 189 t.Run("Successful to the same stream from multiple channels", func(t *testing.T) { 190 hub.SubscribeSession(session, "test", "test_channel") 191 hub.SubscribeSession(session, "test", "test_channel2") 192 193 hub.Broadcast("test", "\"hello twice\"") 194 195 received := []string{} 196 197 msg, err := session.Read() 198 assert.Nil(t, err) 199 received = append(received, string(msg)) 200 201 msg, err = session.Read() 202 assert.Nil(t, err) 203 received = append(received, string(msg)) 204 205 assert.Contains(t, received, "{\"identifier\":\"test_channel\",\"message\":\"hello twice\"}") 206 assert.Contains(t, received, "{\"identifier\":\"test_channel2\",\"message\":\"hello twice\"}") 207 }) 208 } 209 210 func TestRemoteDisconnect(t *testing.T) { 211 hub := NewHub(2, slog.Default()) 212 213 go hub.Run() 214 defer hub.Shutdown() 215 216 session := NewMockSession("123") 217 hub.AddSession(session) 218 219 t.Run("Disconnect session", func(t *testing.T) { 220 hub.RemoteDisconnect(&common.RemoteDisconnectMessage{Identifier: "123", Reconnect: false}) 221 222 msg, err := session.Read() 223 assert.Nil(t, err) 224 assert.Equal(t, "{\"type\":\"disconnect\",\"reason\":\"remote\",\"reconnect\":false}", string(msg)) 225 226 assert.True(t, session.Closed()) 227 }) 228 } 229 230 func TestBroadcastMessage(t *testing.T) { 231 hub := NewHub(2, slog.Default()) 232 233 go hub.Run() 234 defer hub.Shutdown() 235 236 session := NewMockSession("123") 237 hub.AddSession(session) 238 hub.SubscribeSession(session, "test", "test_channel") 239 240 t.Run("Broadcast without stream data", func(t *testing.T) { 241 hub.BroadcastMessage(&common.StreamMessage{Stream: "test", Data: "\"ciao\""}) 242 243 msg, err := session.Read() 244 assert.Nil(t, err) 245 assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"ciao\"}", string(msg)) 246 }) 247 248 t.Run("Broadcast with stream data", func(t *testing.T) { 249 hub.BroadcastMessage(&common.StreamMessage{Stream: "test", Data: "\"ciao\"", Epoch: "xyz", Offset: 2022}) 250 251 msg, err := session.Read() 252 assert.Nil(t, err) 253 assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"ciao\",\"stream_id\":\"test\",\"epoch\":\"xyz\",\"offset\":2022}", string(msg)) 254 }) 255 256 t.Run("Broadcast with exclude_socket", func(t *testing.T) { 257 session2 := NewMockSession("234") 258 hub.AddSession(session2) 259 hub.SubscribeSession(session2, "test", "test_channel") 260 261 hub.BroadcastMessage(&common.StreamMessage{Stream: "test", Data: "\"ciao\""}) 262 263 msg, err := session.Read() 264 assert.Nil(t, err) 265 assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"ciao\"}", string(msg)) 266 267 msg, err = session2.Read() 268 assert.Nil(t, err) 269 assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"ciao\"}", string(msg)) 270 271 hub.BroadcastMessage(&common.StreamMessage{ 272 Stream: "test", 273 Data: "\"hoi!\"", 274 Meta: &common.StreamMessageMetadata{ 275 ExcludeSocket: "234", 276 }, 277 }) 278 279 msg, err = session.Read() 280 assert.Nil(t, err) 281 assert.Equal(t, "{\"identifier\":\"test_channel\",\"message\":\"hoi!\"}", string(msg)) 282 283 msg, err = session2.Read() 284 assert.Nil(t, msg) 285 require.Error(t, err) 286 assert.Contains(t, err.Error(), "hasn't received any messages") 287 }) 288 } 289 290 func TestBroadcastOrder(t *testing.T) { 291 hub := NewHub(10, slog.Default()) 292 293 go hub.Run() 294 defer hub.Shutdown() 295 296 session := NewMockSession("123") 297 hub.AddSession(session) 298 hub.SubscribeSession(session, "test", "test_channel") 299 300 session2 := NewMockSession("321") 301 hub.AddSession(session2) 302 hub.SubscribeSession(session2, "test2", "test2_channel") 303 304 var wg sync.WaitGroup 305 306 wg.Add(1) 307 go func() { 308 for i := 1; i <= 5; i++ { 309 hub.BroadcastMessage(&common.StreamMessage{Stream: "test", Data: fmt.Sprintf("%d", i)}) 310 } 311 wg.Done() 312 }() 313 314 wg.Add(1) 315 go func() { 316 for i := 1; i <= 5; i++ { 317 hub.BroadcastMessage(&common.StreamMessage{Stream: "test2", Data: fmt.Sprintf("%d", i)}) 318 } 319 wg.Done() 320 }() 321 322 wg.Wait() 323 324 for i := 1; i <= 5; i++ { 325 msg, err := session.Read() 326 assert.Nil(t, err) 327 assert.Equal(t, fmt.Sprintf("{\"identifier\":\"test_channel\",\"message\":%d}", i), string(msg)) 328 } 329 330 for i := 1; i <= 5; i++ { 331 msg, err := session2.Read() 332 assert.Nil(t, err) 333 assert.Equal(t, fmt.Sprintf("{\"identifier\":\"test2_channel\",\"message\":%d}", i), string(msg)) 334 } 335 } 336 337 func TestBuildMessageJSON(t *testing.T) { 338 expected := []byte("{\"identifier\":\"chat\",\"message\":{\"text\":\"hello!\"}}") 339 actual := toJSON(buildMessage(&common.StreamMessage{Data: "{\"text\":\"hello!\"}"}, "chat")) 340 assert.Equal(t, expected, actual) 341 } 342 343 func TestBuildMessageString(t *testing.T) { 344 expected := []byte("{\"identifier\":\"chat\",\"message\":\"plain string\"}") 345 actual := toJSON(buildMessage(&common.StreamMessage{Data: "\"plain string\""}, "chat")) 346 assert.Equal(t, expected, actual) 347 } 348 349 type benchmarkConfig struct { 350 hubPoolSize int 351 totalStreams int 352 totalSessions int 353 streamsPerSession int 354 payload string 355 } 356 357 func BenchmarkBroadcast(b *testing.B) { 358 configs := []benchmarkConfig{} 359 360 poolSizes := []int{128, 16, 2, 1} 361 streamNums := [][]int{ 362 {1000, 10}, 363 {100, 10}, 364 {10, 3}, 365 } 366 sessionsNum := 10000 367 payload := "\"A quick brow fox bla-bla-bla\"" 368 369 for _, streamNum := range streamNums { 370 for _, poolSize := range poolSizes { 371 configs = append(configs, benchmarkConfig{poolSize, streamNum[0], sessionsNum, streamNum[1], payload}) 372 } 373 } 374 375 for _, config := range configs { 376 b.Run(fmt.Sprintf("%v", config), func(b *testing.B) { 377 broadcastsPerStream := b.N / config.totalStreams 378 messagesPerSession := config.streamsPerSession * broadcastsPerStream 379 380 hub := NewHub(config.hubPoolSize, slog.Default()) 381 382 go hub.Run() 383 defer hub.Shutdown() 384 385 var wg sync.WaitGroup 386 var streams []string 387 388 for i := 0; i < config.totalStreams; i++ { 389 stream := fmt.Sprintf("stream_%d", i) 390 streams = append(streams, stream) 391 } 392 393 for i := 0; i < config.totalSessions; i++ { 394 sid := fmt.Sprintf("%d", i) 395 session := NewMockSession(sid) 396 397 wg.Add(1) 398 399 go func() { 400 countdown := 0 401 for { 402 if countdown >= messagesPerSession { 403 wg.Done() 404 break 405 } 406 407 session.ReadIndifinitely() // nolint:errcheck 408 countdown++ 409 } 410 }() 411 412 hub.AddSession(session) 413 414 for j := 0; j < config.streamsPerSession; j++ { 415 channel := fmt.Sprintf("test_channel%d", j) 416 417 stream := streams[rand.Intn(len(streams))] // nolint:gosec 418 419 hub.SubscribeSession(session, stream, channel) 420 } 421 } 422 423 b.ResetTimer() 424 425 for _, stream := range streams { 426 for i := 0; i < broadcastsPerStream; i++ { 427 hub.Broadcast(stream, config.payload) 428 } 429 } 430 431 wg.Wait() 432 b.StopTimer() 433 }) 434 } 435 } 436 437 func toJSON(msg encoders.EncodedMessage) []byte { 438 b, err := json.Marshal(&msg) 439 if err != nil { 440 panic("Failed to build JSON 😲") 441 } 442 443 return b 444 }