github.com/anycable/anycable-go@v1.5.1/sse/handler_test.go (about) 1 package sse 2 3 import ( 4 "bytes" 5 "context" 6 "io" 7 "log/slog" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/anycable/anycable-go/broker" 15 "github.com/anycable/anycable-go/common" 16 "github.com/anycable/anycable-go/metrics" 17 "github.com/anycable/anycable-go/mocks" 18 "github.com/anycable/anycable-go/node" 19 "github.com/anycable/anycable-go/pubsub" 20 "github.com/anycable/anycable-go/server" 21 "github.com/stretchr/testify/assert" 22 "github.com/stretchr/testify/mock" 23 "github.com/stretchr/testify/require" 24 ) 25 26 type streamingWriter struct { 27 httptest.ResponseRecorder 28 29 stream chan []byte 30 } 31 32 func newStreamingWriter(w *httptest.ResponseRecorder) *streamingWriter { 33 return &streamingWriter{ 34 ResponseRecorder: *w, 35 stream: make(chan []byte, 100), 36 } 37 } 38 39 func (w *streamingWriter) Write(data []byte) (int, error) { 40 events := bytes.Split(data, []byte("\n\n")) 41 42 for _, event := range events { 43 if len(event) > 0 { 44 w.stream <- event 45 } 46 } 47 48 return w.ResponseRecorder.Write(data) 49 } 50 51 func (w *streamingWriter) ReadEvent(ctx context.Context) (string, error) { 52 for { 53 select { 54 case <-ctx.Done(): 55 return "", ctx.Err() 56 case event := <-w.stream: 57 return string(event), nil 58 } 59 } 60 } 61 62 var _ http.ResponseWriter = (*streamingWriter)(nil) 63 64 func TestSSEHandler(t *testing.T) { 65 appNode, controller := buildNode() 66 conf := NewConfig() 67 68 dconfig := node.NewDisconnectQueueConfig() 69 dconfig.Rate = 1 70 disconnector := node.NewDisconnectQueue(appNode, &dconfig, slog.Default()) 71 appNode.SetDisconnector(disconnector) 72 73 go appNode.Start() // nolint: errcheck 74 defer appNode.Shutdown(context.Background()) // nolint: errcheck 75 76 headersExtractor := &server.DefaultHeadersExtractor{} 77 78 handler := SSEHandler(appNode, context.Background(), headersExtractor, &conf, slog.Default()) 79 80 controller. 81 On("Shutdown"). 82 Return(nil) 83 84 controller. 85 On("Disconnect", mock.Anything, mock.Anything, mock.Anything, mock.Anything). 86 Return(nil) 87 88 t.Run("headers", func(t *testing.T) { 89 w := httptest.NewRecorder() 90 req, _ := http.NewRequest("GET", "/", nil) 91 92 handler.ServeHTTP(w, req) 93 94 assert.Equal(t, "text/event-stream; charset=utf-8", w.Header().Get("Content-Type")) 95 assert.Equal(t, "private, no-cache, no-store, must-revalidate, max-age=0", w.Header().Get("Cache-Control")) 96 assert.Equal(t, "no-cache", w.Header().Get("Pragma")) 97 assert.Equal(t, "keep-alive", w.Header().Get("Connection")) 98 }) 99 100 t.Run("headers + CORS", func(t *testing.T) { 101 w := httptest.NewRecorder() 102 req, _ := http.NewRequest("GET", "/", nil) 103 req.Header.Set("Origin", "http://www.example.com") 104 105 corsConf := NewConfig() 106 corsConf.AllowedOrigins = "*.example.com" 107 108 corsHandler := SSEHandler(appNode, context.Background(), headersExtractor, &corsConf, slog.Default()) 109 110 corsHandler.ServeHTTP(w, req) 111 112 assert.Equal(t, "http://www.example.com", w.Header().Get("Access-Control-Allow-Origin")) 113 }) 114 115 t.Run("OPTIONS", func(t *testing.T) { 116 w := httptest.NewRecorder() 117 req, _ := http.NewRequest("OPTIONS", "/", nil) 118 119 handler.ServeHTTP(w, req) 120 121 assert.Equal(t, http.StatusOK, w.Code) 122 }) 123 124 t.Run("non-GET/OPTIONS/POST", func(t *testing.T) { 125 w := httptest.NewRecorder() 126 req, _ := http.NewRequest("PUT", "/", nil) 127 128 handler.ServeHTTP(w, req) 129 130 assert.Equal(t, http.StatusMethodNotAllowed, w.Code) 131 }) 132 133 t.Run("when authentication fails", func(t *testing.T) { 134 defer assertNoSessions(t, appNode) 135 136 controller. 137 On("Authenticate", "sid-fail", mock.Anything). 138 Return(&common.ConnectResult{ 139 Status: common.FAILURE, 140 Transmissions: []string{`{"type":"disconnect"}`}, 141 }, nil) 142 143 req, _ := http.NewRequest("GET", "/?channel=room_1", nil) 144 req.Header.Set("X-Request-ID", "sid-fail") 145 146 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 147 defer cancel() 148 149 req = req.WithContext(ctx) 150 151 w := httptest.NewRecorder() 152 handler.ServeHTTP(w, req) 153 154 require.Equal(t, http.StatusUnauthorized, w.Code) 155 assert.Empty(t, w.Body.String()) 156 }) 157 158 t.Run("GET request with identifier", func(t *testing.T) { 159 defer assertNoSessions(t, appNode) 160 161 controller. 162 On("Authenticate", "sid-gut", mock.Anything). 163 Return(&common.ConnectResult{ 164 Identifier: "se2023", 165 Status: common.SUCCESS, 166 Transmissions: []string{`{"type":"welcome"}`}, 167 }, nil) 168 169 controller. 170 On("Subscribe", "sid-gut", mock.Anything, "se2023", "chat_1"). 171 Return(&common.CommandResult{ 172 Status: common.SUCCESS, 173 Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`}, 174 Streams: []string{"messages_1"}, 175 }, nil) 176 177 req, _ := http.NewRequest("GET", "/?identifier=chat_1", nil) 178 req.Header.Set("X-Request-ID", "sid-gut") 179 180 ctx_, release := context.WithTimeout(context.Background(), 2*time.Second) 181 defer release() 182 183 ctx, cancel := context.WithCancel(ctx_) 184 defer cancel() 185 186 req = req.WithContext(ctx) 187 188 w := httptest.NewRecorder() 189 sw := newStreamingWriter(w) 190 191 go handler.ServeHTTP(sw, req) 192 193 msg, err := sw.ReadEvent(ctx) 194 require.NoError(t, err) 195 assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg) 196 197 msg, err = sw.ReadEvent(ctx) 198 require.NoError(t, err) 199 assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"chat_1"}`, msg) 200 201 appNode.Broadcast(&common.StreamMessage{Stream: "messages_1", Data: `{"content":"hello"}`}) 202 203 msg, err = sw.ReadEvent(ctx) 204 require.NoError(t, err) 205 assert.Equal(t, `data: {"content":"hello"}`, msg) 206 207 require.Equal(t, http.StatusOK, w.Code) 208 }) 209 210 t.Run("GET request with turbo_signed_stream_name", func(t *testing.T) { 211 defer assertNoSessions(t, appNode) 212 213 controller. 214 On("Authenticate", "sid-turbo", mock.Anything). 215 Return(&common.ConnectResult{ 216 Identifier: "se2023", 217 Status: common.SUCCESS, 218 Transmissions: []string{`{"type":"welcome"}`}, 219 }, nil) 220 221 turbo_identifier := `{"channel":"Turbo::StreamsChannel","signed_stream_name":"chat_1"}` 222 223 controller. 224 On("Subscribe", "sid-turbo", mock.Anything, "se2023", turbo_identifier). 225 Return(&common.CommandResult{ 226 Status: common.SUCCESS, 227 Transmissions: []string{`{"type":"confirm","identifier":"turbo_1"}`}, 228 Streams: []string{"chat_1"}, 229 }, nil) 230 231 req, _ := http.NewRequest("GET", "/?turbo_signed_stream_name=chat_1", nil) 232 req.Header.Set("X-Request-ID", "sid-turbo") 233 234 ctx_, release := context.WithTimeout(context.Background(), 2*time.Second) 235 defer release() 236 237 ctx, cancel := context.WithCancel(ctx_) 238 defer cancel() 239 240 req = req.WithContext(ctx) 241 242 w := httptest.NewRecorder() 243 sw := newStreamingWriter(w) 244 245 go handler.ServeHTTP(sw, req) 246 247 msg, err := sw.ReadEvent(ctx) 248 require.NoError(t, err) 249 assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg) 250 251 msg, err = sw.ReadEvent(ctx) 252 require.NoError(t, err) 253 assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"turbo_1"}`, msg) 254 255 require.Equal(t, http.StatusOK, w.Code) 256 }) 257 258 t.Run("GET request with stream", func(t *testing.T) { 259 defer assertNoSessions(t, appNode) 260 261 controller. 262 On("Authenticate", "sid-public-stream", mock.Anything). 263 Return(&common.ConnectResult{ 264 Identifier: "se2024", 265 Status: common.SUCCESS, 266 Transmissions: []string{`{"type":"welcome"}`}, 267 }, nil) 268 269 identifier := `{"channel":"$pubsub","stream_name":"chat_1"}` 270 271 controller. 272 On("Subscribe", "sid-public-stream", mock.Anything, "se2024", identifier). 273 Return(&common.CommandResult{ 274 Status: common.SUCCESS, 275 Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`}, 276 Streams: []string{"chat_1"}, 277 }, nil) 278 279 req, _ := http.NewRequest("GET", "/?stream=chat_1", nil) 280 req.Header.Set("X-Request-ID", "sid-public-stream") 281 282 ctx_, release := context.WithTimeout(context.Background(), 2*time.Second) 283 defer release() 284 285 ctx, cancel := context.WithCancel(ctx_) 286 defer cancel() 287 288 req = req.WithContext(ctx) 289 290 w := httptest.NewRecorder() 291 sw := newStreamingWriter(w) 292 293 go handler.ServeHTTP(sw, req) 294 295 msg, err := sw.ReadEvent(ctx) 296 require.NoError(t, err) 297 assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg) 298 299 msg, err = sw.ReadEvent(ctx) 300 require.NoError(t, err) 301 assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"chat_1"}`, msg) 302 303 require.Equal(t, http.StatusOK, w.Code) 304 }) 305 306 t.Run("GET request with signed_stream", func(t *testing.T) { 307 defer assertNoSessions(t, appNode) 308 309 controller. 310 On("Authenticate", "sid-signed-stream", mock.Anything). 311 Return(&common.ConnectResult{ 312 Identifier: "se2024", 313 Status: common.SUCCESS, 314 Transmissions: []string{`{"type":"welcome"}`}, 315 }, nil) 316 317 identifier := `{"channel":"$pubsub","signed_stream_name":"secretto"}` 318 319 controller. 320 On("Subscribe", "sid-signed-stream", mock.Anything, "se2024", identifier). 321 Return(&common.CommandResult{ 322 Status: common.SUCCESS, 323 Transmissions: []string{`{"type":"confirm","identifier":"secret_chat_1"}`}, 324 Streams: []string{"chat_1"}, 325 }, nil) 326 327 req, _ := http.NewRequest("GET", "/?signed_stream=secretto", nil) 328 req.Header.Set("X-Request-ID", "sid-signed-stream") 329 330 ctx_, release := context.WithTimeout(context.Background(), 2*time.Second) 331 defer release() 332 333 ctx, cancel := context.WithCancel(ctx_) 334 defer cancel() 335 336 req = req.WithContext(ctx) 337 338 w := httptest.NewRecorder() 339 sw := newStreamingWriter(w) 340 341 go handler.ServeHTTP(sw, req) 342 343 msg, err := sw.ReadEvent(ctx) 344 require.NoError(t, err) 345 assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg) 346 347 msg, err = sw.ReadEvent(ctx) 348 require.NoError(t, err) 349 assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"secret_chat_1"}`, msg) 350 351 require.Equal(t, http.StatusOK, w.Code) 352 }) 353 354 t.Run("GET request with channel + rejected", func(t *testing.T) { 355 defer assertNoSessions(t, appNode) 356 357 controller. 358 On("Authenticate", "sid-reject", mock.Anything). 359 Return(&common.ConnectResult{ 360 Identifier: "se2034", 361 Status: common.SUCCESS, 362 Transmissions: []string{`{"type":"welcome"}`}, 363 }, nil) 364 365 controller. 366 On("Subscribe", "sid-reject", mock.Anything, "se2034", `{"channel":"room_1"}`). 367 Return(&common.CommandResult{ 368 Status: common.FAILURE, 369 Transmissions: []string{`{"type":"reject","identifier":"room_1"}`}, 370 }, nil) 371 372 req, _ := http.NewRequest("GET", "/?channel=room_1", nil) 373 req.Header.Set("X-Request-ID", "sid-reject") 374 375 ctx_, release := context.WithTimeout(context.Background(), 2*time.Second) 376 defer release() 377 378 ctx, cancel := context.WithCancel(ctx_) 379 defer cancel() 380 381 req = req.WithContext(ctx) 382 383 w := httptest.NewRecorder() 384 385 handler.ServeHTTP(w, req) 386 387 require.Equal(t, http.StatusBadRequest, w.Code) 388 assert.Empty(t, w.Body.String()) 389 390 controller.AssertCalled(t, "Subscribe", "sid-reject", mock.Anything, "se2034", `{"channel":"room_1"}`) 391 }) 392 393 t.Run("GET request without channel or identifier", func(t *testing.T) { 394 req, _ := http.NewRequest("GET", "/", nil) 395 396 w := httptest.NewRecorder() 397 handler.ServeHTTP(w, req) 398 399 require.Equal(t, http.StatusBadRequest, w.Code) 400 assert.Empty(t, w.Body.String()) 401 }) 402 403 t.Run("POST request without commands + server shutdown", func(t *testing.T) { 404 defer assertNoSessions(t, appNode) 405 406 controller. 407 On("Authenticate", "sid-post-no-op", mock.Anything). 408 Return(&common.ConnectResult{ 409 Identifier: "se2023-09-06", 410 Status: common.SUCCESS, 411 Transmissions: []string{`{"type":"welcome"}`}, 412 }, nil) 413 414 req, _ := http.NewRequest("POST", "/", nil) 415 req.Header.Set("X-Request-ID", "sid-post-no-op") 416 417 ctx_, release := context.WithTimeout(context.Background(), 2*time.Second) 418 defer release() 419 420 ctx, cancel := context.WithCancel(ctx_) 421 defer cancel() 422 423 req = req.WithContext(ctx) 424 425 w := httptest.NewRecorder() 426 sw := newStreamingWriter(w) 427 428 shutdownCtx, shutdownFn := context.WithCancel(context.Background()) 429 430 shutdownHandler := SSEHandler(appNode, shutdownCtx, headersExtractor, &conf, slog.Default()) 431 432 go shutdownHandler.ServeHTTP(sw, req) 433 434 msg, err := sw.ReadEvent(ctx) 435 require.NoError(t, err) 436 assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg) 437 438 shutdownFn() 439 440 msg, err = sw.ReadEvent(ctx) 441 require.NoError(t, err) 442 assert.Equal(t, "event: disconnect\n"+`data: {"type":"disconnect","reason":"server_restart","reconnect":true}`, msg) 443 444 require.Equal(t, http.StatusOK, w.Code) 445 }) 446 447 t.Run("POST request with multiple subscriptions", func(t *testing.T) { 448 defer assertNoSessions(t, appNode) 449 450 controller. 451 On("Authenticate", "sid-post", mock.Anything). 452 Return(&common.ConnectResult{ 453 Identifier: "se2023-09-06", 454 Status: common.SUCCESS, 455 Transmissions: []string{`{"type":"welcome"}`}, 456 }, nil) 457 458 controller. 459 On("Subscribe", "sid-post", mock.Anything, "se2023-09-06", "chat_1"). 460 Return(&common.CommandResult{ 461 Status: common.SUCCESS, 462 Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`}, 463 Streams: []string{"messages_1"}, 464 }, nil) 465 466 controller. 467 On("Subscribe", "sid-post", mock.Anything, "se2023-09-06", "presence_1"). 468 Return(&common.CommandResult{ 469 Status: common.SUCCESS, 470 Transmissions: []string{`{"type":"confirm","identifier":"presence_1"}`}, 471 Streams: []string{"presence_1"}, 472 }, nil) 473 474 req, _ := http.NewRequest("POST", "/", nil) 475 req.Header.Set("X-Request-ID", "sid-post") 476 req.Body = io.NopCloser( 477 strings.NewReader("{\"command\":\"subscribe\",\"identifier\":\"chat_1\"}\n{\"command\":\"subscribe\",\"identifier\":\"presence_1\"}"), 478 ) 479 480 ctx_, release := context.WithTimeout(context.Background(), 2*time.Second) 481 defer release() 482 483 ctx, cancel := context.WithCancel(ctx_) 484 defer cancel() 485 486 req = req.WithContext(ctx) 487 488 w := httptest.NewRecorder() 489 sw := newStreamingWriter(w) 490 491 go handler.ServeHTTP(sw, req) 492 493 msg, err := sw.ReadEvent(ctx) 494 require.NoError(t, err) 495 assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg) 496 497 msg, err = sw.ReadEvent(ctx) 498 require.NoError(t, err) 499 assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"chat_1"}`, msg) 500 501 msg, err = sw.ReadEvent(ctx) 502 require.NoError(t, err) 503 assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"presence_1"}`, msg) 504 505 appNode.Broadcast(&common.StreamMessage{Stream: "messages_1", Data: `{"content":"hello"}`}) 506 507 msg, err = sw.ReadEvent(ctx) 508 require.NoError(t, err) 509 assert.Equal(t, `data: {"identifier":"chat_1","message":{"content":"hello"}}`, msg) 510 511 appNode.Broadcast(&common.StreamMessage{Stream: "presence_1", Data: `{"type":"join","user_id":1}`}) 512 513 msg, err = sw.ReadEvent(ctx) 514 require.NoError(t, err) 515 assert.Equal(t, `data: {"identifier":"presence_1","message":{"type":"join","user_id":1}}`, msg) 516 517 require.Equal(t, http.StatusOK, w.Code) 518 }) 519 } 520 521 // This a helper method to ensure no sessions left after test (so no global state is left). 522 // Session may be removed from the hub asynchrounously, so we need to wait for it. 523 func assertNoSessions(t *testing.T, n *node.Node) { 524 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 525 defer cancel() 526 527 done := make(chan struct{}) 528 529 go func() { 530 for { 531 if n.Size() == 0 { 532 close(done) 533 return 534 } 535 536 time.Sleep(100 * time.Millisecond) 537 } 538 }() 539 540 select { 541 case <-ctx.Done(): 542 require.Fail(t, "Timeout waiting for sessions to be removed") 543 case <-done: 544 } 545 } 546 547 type immediateDisconnector struct { 548 n *node.Node 549 } 550 551 func (d *immediateDisconnector) Enqueue(s *node.Session) error { 552 return d.n.DisconnectNow(s) 553 } 554 555 func (immediateDisconnector) Run() error { return nil } 556 func (immediateDisconnector) Shutdown(ctx context.Context) error { return nil } 557 func (immediateDisconnector) Size() int { return 0 } 558 559 func buildNode() (*node.Node, *mocks.Controller) { 560 controller := &mocks.Controller{} 561 config := node.NewConfig() 562 config.HubGopoolSize = 2 563 n := node.NewNode(&config, node.WithController(controller), node.WithInstrumenter(metrics.NewMetrics(nil, 10, slog.Default()))) 564 n.SetBroker(broker.NewLegacyBroker(pubsub.NewLegacySubscriber(n))) 565 n.SetDisconnector(&immediateDisconnector{n}) 566 return n, controller 567 }