github.com/anycable/anycable-go@v1.5.1/sse/handler_test.go (about)

     1  package sse
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io"
     7  	"log/slog"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/anycable/anycable-go/broker"
    15  	"github.com/anycable/anycable-go/common"
    16  	"github.com/anycable/anycable-go/metrics"
    17  	"github.com/anycable/anycable-go/mocks"
    18  	"github.com/anycable/anycable-go/node"
    19  	"github.com/anycable/anycable-go/pubsub"
    20  	"github.com/anycable/anycable-go/server"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/mock"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  type streamingWriter struct {
    27  	httptest.ResponseRecorder
    28  
    29  	stream chan []byte
    30  }
    31  
    32  func newStreamingWriter(w *httptest.ResponseRecorder) *streamingWriter {
    33  	return &streamingWriter{
    34  		ResponseRecorder: *w,
    35  		stream:           make(chan []byte, 100),
    36  	}
    37  }
    38  
    39  func (w *streamingWriter) Write(data []byte) (int, error) {
    40  	events := bytes.Split(data, []byte("\n\n"))
    41  
    42  	for _, event := range events {
    43  		if len(event) > 0 {
    44  			w.stream <- event
    45  		}
    46  	}
    47  
    48  	return w.ResponseRecorder.Write(data)
    49  }
    50  
    51  func (w *streamingWriter) ReadEvent(ctx context.Context) (string, error) {
    52  	for {
    53  		select {
    54  		case <-ctx.Done():
    55  			return "", ctx.Err()
    56  		case event := <-w.stream:
    57  			return string(event), nil
    58  		}
    59  	}
    60  }
    61  
    62  var _ http.ResponseWriter = (*streamingWriter)(nil)
    63  
    64  func TestSSEHandler(t *testing.T) {
    65  	appNode, controller := buildNode()
    66  	conf := NewConfig()
    67  
    68  	dconfig := node.NewDisconnectQueueConfig()
    69  	dconfig.Rate = 1
    70  	disconnector := node.NewDisconnectQueue(appNode, &dconfig, slog.Default())
    71  	appNode.SetDisconnector(disconnector)
    72  
    73  	go appNode.Start()                           // nolint: errcheck
    74  	defer appNode.Shutdown(context.Background()) // nolint: errcheck
    75  
    76  	headersExtractor := &server.DefaultHeadersExtractor{}
    77  
    78  	handler := SSEHandler(appNode, context.Background(), headersExtractor, &conf, slog.Default())
    79  
    80  	controller.
    81  		On("Shutdown").
    82  		Return(nil)
    83  
    84  	controller.
    85  		On("Disconnect", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
    86  		Return(nil)
    87  
    88  	t.Run("headers", func(t *testing.T) {
    89  		w := httptest.NewRecorder()
    90  		req, _ := http.NewRequest("GET", "/", nil)
    91  
    92  		handler.ServeHTTP(w, req)
    93  
    94  		assert.Equal(t, "text/event-stream; charset=utf-8", w.Header().Get("Content-Type"))
    95  		assert.Equal(t, "private, no-cache, no-store, must-revalidate, max-age=0", w.Header().Get("Cache-Control"))
    96  		assert.Equal(t, "no-cache", w.Header().Get("Pragma"))
    97  		assert.Equal(t, "keep-alive", w.Header().Get("Connection"))
    98  	})
    99  
   100  	t.Run("headers + CORS", func(t *testing.T) {
   101  		w := httptest.NewRecorder()
   102  		req, _ := http.NewRequest("GET", "/", nil)
   103  		req.Header.Set("Origin", "http://www.example.com")
   104  
   105  		corsConf := NewConfig()
   106  		corsConf.AllowedOrigins = "*.example.com"
   107  
   108  		corsHandler := SSEHandler(appNode, context.Background(), headersExtractor, &corsConf, slog.Default())
   109  
   110  		corsHandler.ServeHTTP(w, req)
   111  
   112  		assert.Equal(t, "http://www.example.com", w.Header().Get("Access-Control-Allow-Origin"))
   113  	})
   114  
   115  	t.Run("OPTIONS", func(t *testing.T) {
   116  		w := httptest.NewRecorder()
   117  		req, _ := http.NewRequest("OPTIONS", "/", nil)
   118  
   119  		handler.ServeHTTP(w, req)
   120  
   121  		assert.Equal(t, http.StatusOK, w.Code)
   122  	})
   123  
   124  	t.Run("non-GET/OPTIONS/POST", func(t *testing.T) {
   125  		w := httptest.NewRecorder()
   126  		req, _ := http.NewRequest("PUT", "/", nil)
   127  
   128  		handler.ServeHTTP(w, req)
   129  
   130  		assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
   131  	})
   132  
   133  	t.Run("when authentication fails", func(t *testing.T) {
   134  		defer assertNoSessions(t, appNode)
   135  
   136  		controller.
   137  			On("Authenticate", "sid-fail", mock.Anything).
   138  			Return(&common.ConnectResult{
   139  				Status:        common.FAILURE,
   140  				Transmissions: []string{`{"type":"disconnect"}`},
   141  			}, nil)
   142  
   143  		req, _ := http.NewRequest("GET", "/?channel=room_1", nil)
   144  		req.Header.Set("X-Request-ID", "sid-fail")
   145  
   146  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   147  		defer cancel()
   148  
   149  		req = req.WithContext(ctx)
   150  
   151  		w := httptest.NewRecorder()
   152  		handler.ServeHTTP(w, req)
   153  
   154  		require.Equal(t, http.StatusUnauthorized, w.Code)
   155  		assert.Empty(t, w.Body.String())
   156  	})
   157  
   158  	t.Run("GET request with identifier", func(t *testing.T) {
   159  		defer assertNoSessions(t, appNode)
   160  
   161  		controller.
   162  			On("Authenticate", "sid-gut", mock.Anything).
   163  			Return(&common.ConnectResult{
   164  				Identifier:    "se2023",
   165  				Status:        common.SUCCESS,
   166  				Transmissions: []string{`{"type":"welcome"}`},
   167  			}, nil)
   168  
   169  		controller.
   170  			On("Subscribe", "sid-gut", mock.Anything, "se2023", "chat_1").
   171  			Return(&common.CommandResult{
   172  				Status:        common.SUCCESS,
   173  				Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`},
   174  				Streams:       []string{"messages_1"},
   175  			}, nil)
   176  
   177  		req, _ := http.NewRequest("GET", "/?identifier=chat_1", nil)
   178  		req.Header.Set("X-Request-ID", "sid-gut")
   179  
   180  		ctx_, release := context.WithTimeout(context.Background(), 2*time.Second)
   181  		defer release()
   182  
   183  		ctx, cancel := context.WithCancel(ctx_)
   184  		defer cancel()
   185  
   186  		req = req.WithContext(ctx)
   187  
   188  		w := httptest.NewRecorder()
   189  		sw := newStreamingWriter(w)
   190  
   191  		go handler.ServeHTTP(sw, req)
   192  
   193  		msg, err := sw.ReadEvent(ctx)
   194  		require.NoError(t, err)
   195  		assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg)
   196  
   197  		msg, err = sw.ReadEvent(ctx)
   198  		require.NoError(t, err)
   199  		assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"chat_1"}`, msg)
   200  
   201  		appNode.Broadcast(&common.StreamMessage{Stream: "messages_1", Data: `{"content":"hello"}`})
   202  
   203  		msg, err = sw.ReadEvent(ctx)
   204  		require.NoError(t, err)
   205  		assert.Equal(t, `data: {"content":"hello"}`, msg)
   206  
   207  		require.Equal(t, http.StatusOK, w.Code)
   208  	})
   209  
   210  	t.Run("GET request with turbo_signed_stream_name", func(t *testing.T) {
   211  		defer assertNoSessions(t, appNode)
   212  
   213  		controller.
   214  			On("Authenticate", "sid-turbo", mock.Anything).
   215  			Return(&common.ConnectResult{
   216  				Identifier:    "se2023",
   217  				Status:        common.SUCCESS,
   218  				Transmissions: []string{`{"type":"welcome"}`},
   219  			}, nil)
   220  
   221  		turbo_identifier := `{"channel":"Turbo::StreamsChannel","signed_stream_name":"chat_1"}`
   222  
   223  		controller.
   224  			On("Subscribe", "sid-turbo", mock.Anything, "se2023", turbo_identifier).
   225  			Return(&common.CommandResult{
   226  				Status:        common.SUCCESS,
   227  				Transmissions: []string{`{"type":"confirm","identifier":"turbo_1"}`},
   228  				Streams:       []string{"chat_1"},
   229  			}, nil)
   230  
   231  		req, _ := http.NewRequest("GET", "/?turbo_signed_stream_name=chat_1", nil)
   232  		req.Header.Set("X-Request-ID", "sid-turbo")
   233  
   234  		ctx_, release := context.WithTimeout(context.Background(), 2*time.Second)
   235  		defer release()
   236  
   237  		ctx, cancel := context.WithCancel(ctx_)
   238  		defer cancel()
   239  
   240  		req = req.WithContext(ctx)
   241  
   242  		w := httptest.NewRecorder()
   243  		sw := newStreamingWriter(w)
   244  
   245  		go handler.ServeHTTP(sw, req)
   246  
   247  		msg, err := sw.ReadEvent(ctx)
   248  		require.NoError(t, err)
   249  		assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg)
   250  
   251  		msg, err = sw.ReadEvent(ctx)
   252  		require.NoError(t, err)
   253  		assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"turbo_1"}`, msg)
   254  
   255  		require.Equal(t, http.StatusOK, w.Code)
   256  	})
   257  
   258  	t.Run("GET request with stream", func(t *testing.T) {
   259  		defer assertNoSessions(t, appNode)
   260  
   261  		controller.
   262  			On("Authenticate", "sid-public-stream", mock.Anything).
   263  			Return(&common.ConnectResult{
   264  				Identifier:    "se2024",
   265  				Status:        common.SUCCESS,
   266  				Transmissions: []string{`{"type":"welcome"}`},
   267  			}, nil)
   268  
   269  		identifier := `{"channel":"$pubsub","stream_name":"chat_1"}`
   270  
   271  		controller.
   272  			On("Subscribe", "sid-public-stream", mock.Anything, "se2024", identifier).
   273  			Return(&common.CommandResult{
   274  				Status:        common.SUCCESS,
   275  				Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`},
   276  				Streams:       []string{"chat_1"},
   277  			}, nil)
   278  
   279  		req, _ := http.NewRequest("GET", "/?stream=chat_1", nil)
   280  		req.Header.Set("X-Request-ID", "sid-public-stream")
   281  
   282  		ctx_, release := context.WithTimeout(context.Background(), 2*time.Second)
   283  		defer release()
   284  
   285  		ctx, cancel := context.WithCancel(ctx_)
   286  		defer cancel()
   287  
   288  		req = req.WithContext(ctx)
   289  
   290  		w := httptest.NewRecorder()
   291  		sw := newStreamingWriter(w)
   292  
   293  		go handler.ServeHTTP(sw, req)
   294  
   295  		msg, err := sw.ReadEvent(ctx)
   296  		require.NoError(t, err)
   297  		assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg)
   298  
   299  		msg, err = sw.ReadEvent(ctx)
   300  		require.NoError(t, err)
   301  		assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"chat_1"}`, msg)
   302  
   303  		require.Equal(t, http.StatusOK, w.Code)
   304  	})
   305  
   306  	t.Run("GET request with signed_stream", func(t *testing.T) {
   307  		defer assertNoSessions(t, appNode)
   308  
   309  		controller.
   310  			On("Authenticate", "sid-signed-stream", mock.Anything).
   311  			Return(&common.ConnectResult{
   312  				Identifier:    "se2024",
   313  				Status:        common.SUCCESS,
   314  				Transmissions: []string{`{"type":"welcome"}`},
   315  			}, nil)
   316  
   317  		identifier := `{"channel":"$pubsub","signed_stream_name":"secretto"}`
   318  
   319  		controller.
   320  			On("Subscribe", "sid-signed-stream", mock.Anything, "se2024", identifier).
   321  			Return(&common.CommandResult{
   322  				Status:        common.SUCCESS,
   323  				Transmissions: []string{`{"type":"confirm","identifier":"secret_chat_1"}`},
   324  				Streams:       []string{"chat_1"},
   325  			}, nil)
   326  
   327  		req, _ := http.NewRequest("GET", "/?signed_stream=secretto", nil)
   328  		req.Header.Set("X-Request-ID", "sid-signed-stream")
   329  
   330  		ctx_, release := context.WithTimeout(context.Background(), 2*time.Second)
   331  		defer release()
   332  
   333  		ctx, cancel := context.WithCancel(ctx_)
   334  		defer cancel()
   335  
   336  		req = req.WithContext(ctx)
   337  
   338  		w := httptest.NewRecorder()
   339  		sw := newStreamingWriter(w)
   340  
   341  		go handler.ServeHTTP(sw, req)
   342  
   343  		msg, err := sw.ReadEvent(ctx)
   344  		require.NoError(t, err)
   345  		assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg)
   346  
   347  		msg, err = sw.ReadEvent(ctx)
   348  		require.NoError(t, err)
   349  		assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"secret_chat_1"}`, msg)
   350  
   351  		require.Equal(t, http.StatusOK, w.Code)
   352  	})
   353  
   354  	t.Run("GET request with channel + rejected", func(t *testing.T) {
   355  		defer assertNoSessions(t, appNode)
   356  
   357  		controller.
   358  			On("Authenticate", "sid-reject", mock.Anything).
   359  			Return(&common.ConnectResult{
   360  				Identifier:    "se2034",
   361  				Status:        common.SUCCESS,
   362  				Transmissions: []string{`{"type":"welcome"}`},
   363  			}, nil)
   364  
   365  		controller.
   366  			On("Subscribe", "sid-reject", mock.Anything, "se2034", `{"channel":"room_1"}`).
   367  			Return(&common.CommandResult{
   368  				Status:        common.FAILURE,
   369  				Transmissions: []string{`{"type":"reject","identifier":"room_1"}`},
   370  			}, nil)
   371  
   372  		req, _ := http.NewRequest("GET", "/?channel=room_1", nil)
   373  		req.Header.Set("X-Request-ID", "sid-reject")
   374  
   375  		ctx_, release := context.WithTimeout(context.Background(), 2*time.Second)
   376  		defer release()
   377  
   378  		ctx, cancel := context.WithCancel(ctx_)
   379  		defer cancel()
   380  
   381  		req = req.WithContext(ctx)
   382  
   383  		w := httptest.NewRecorder()
   384  
   385  		handler.ServeHTTP(w, req)
   386  
   387  		require.Equal(t, http.StatusBadRequest, w.Code)
   388  		assert.Empty(t, w.Body.String())
   389  
   390  		controller.AssertCalled(t, "Subscribe", "sid-reject", mock.Anything, "se2034", `{"channel":"room_1"}`)
   391  	})
   392  
   393  	t.Run("GET request without channel or identifier", func(t *testing.T) {
   394  		req, _ := http.NewRequest("GET", "/", nil)
   395  
   396  		w := httptest.NewRecorder()
   397  		handler.ServeHTTP(w, req)
   398  
   399  		require.Equal(t, http.StatusBadRequest, w.Code)
   400  		assert.Empty(t, w.Body.String())
   401  	})
   402  
   403  	t.Run("POST request without commands + server shutdown", func(t *testing.T) {
   404  		defer assertNoSessions(t, appNode)
   405  
   406  		controller.
   407  			On("Authenticate", "sid-post-no-op", mock.Anything).
   408  			Return(&common.ConnectResult{
   409  				Identifier:    "se2023-09-06",
   410  				Status:        common.SUCCESS,
   411  				Transmissions: []string{`{"type":"welcome"}`},
   412  			}, nil)
   413  
   414  		req, _ := http.NewRequest("POST", "/", nil)
   415  		req.Header.Set("X-Request-ID", "sid-post-no-op")
   416  
   417  		ctx_, release := context.WithTimeout(context.Background(), 2*time.Second)
   418  		defer release()
   419  
   420  		ctx, cancel := context.WithCancel(ctx_)
   421  		defer cancel()
   422  
   423  		req = req.WithContext(ctx)
   424  
   425  		w := httptest.NewRecorder()
   426  		sw := newStreamingWriter(w)
   427  
   428  		shutdownCtx, shutdownFn := context.WithCancel(context.Background())
   429  
   430  		shutdownHandler := SSEHandler(appNode, shutdownCtx, headersExtractor, &conf, slog.Default())
   431  
   432  		go shutdownHandler.ServeHTTP(sw, req)
   433  
   434  		msg, err := sw.ReadEvent(ctx)
   435  		require.NoError(t, err)
   436  		assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg)
   437  
   438  		shutdownFn()
   439  
   440  		msg, err = sw.ReadEvent(ctx)
   441  		require.NoError(t, err)
   442  		assert.Equal(t, "event: disconnect\n"+`data: {"type":"disconnect","reason":"server_restart","reconnect":true}`, msg)
   443  
   444  		require.Equal(t, http.StatusOK, w.Code)
   445  	})
   446  
   447  	t.Run("POST request with multiple subscriptions", func(t *testing.T) {
   448  		defer assertNoSessions(t, appNode)
   449  
   450  		controller.
   451  			On("Authenticate", "sid-post", mock.Anything).
   452  			Return(&common.ConnectResult{
   453  				Identifier:    "se2023-09-06",
   454  				Status:        common.SUCCESS,
   455  				Transmissions: []string{`{"type":"welcome"}`},
   456  			}, nil)
   457  
   458  		controller.
   459  			On("Subscribe", "sid-post", mock.Anything, "se2023-09-06", "chat_1").
   460  			Return(&common.CommandResult{
   461  				Status:        common.SUCCESS,
   462  				Transmissions: []string{`{"type":"confirm","identifier":"chat_1"}`},
   463  				Streams:       []string{"messages_1"},
   464  			}, nil)
   465  
   466  		controller.
   467  			On("Subscribe", "sid-post", mock.Anything, "se2023-09-06", "presence_1").
   468  			Return(&common.CommandResult{
   469  				Status:        common.SUCCESS,
   470  				Transmissions: []string{`{"type":"confirm","identifier":"presence_1"}`},
   471  				Streams:       []string{"presence_1"},
   472  			}, nil)
   473  
   474  		req, _ := http.NewRequest("POST", "/", nil)
   475  		req.Header.Set("X-Request-ID", "sid-post")
   476  		req.Body = io.NopCloser(
   477  			strings.NewReader("{\"command\":\"subscribe\",\"identifier\":\"chat_1\"}\n{\"command\":\"subscribe\",\"identifier\":\"presence_1\"}"),
   478  		)
   479  
   480  		ctx_, release := context.WithTimeout(context.Background(), 2*time.Second)
   481  		defer release()
   482  
   483  		ctx, cancel := context.WithCancel(ctx_)
   484  		defer cancel()
   485  
   486  		req = req.WithContext(ctx)
   487  
   488  		w := httptest.NewRecorder()
   489  		sw := newStreamingWriter(w)
   490  
   491  		go handler.ServeHTTP(sw, req)
   492  
   493  		msg, err := sw.ReadEvent(ctx)
   494  		require.NoError(t, err)
   495  		assert.Equal(t, "event: welcome\n"+`data: {"type":"welcome"}`, msg)
   496  
   497  		msg, err = sw.ReadEvent(ctx)
   498  		require.NoError(t, err)
   499  		assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"chat_1"}`, msg)
   500  
   501  		msg, err = sw.ReadEvent(ctx)
   502  		require.NoError(t, err)
   503  		assert.Equal(t, "event: confirm\n"+`data: {"type":"confirm","identifier":"presence_1"}`, msg)
   504  
   505  		appNode.Broadcast(&common.StreamMessage{Stream: "messages_1", Data: `{"content":"hello"}`})
   506  
   507  		msg, err = sw.ReadEvent(ctx)
   508  		require.NoError(t, err)
   509  		assert.Equal(t, `data: {"identifier":"chat_1","message":{"content":"hello"}}`, msg)
   510  
   511  		appNode.Broadcast(&common.StreamMessage{Stream: "presence_1", Data: `{"type":"join","user_id":1}`})
   512  
   513  		msg, err = sw.ReadEvent(ctx)
   514  		require.NoError(t, err)
   515  		assert.Equal(t, `data: {"identifier":"presence_1","message":{"type":"join","user_id":1}}`, msg)
   516  
   517  		require.Equal(t, http.StatusOK, w.Code)
   518  	})
   519  }
   520  
   521  // This a helper method to ensure no sessions left after test (so no global state is left).
   522  // Session may be removed from the hub asynchrounously, so we need to wait for it.
   523  func assertNoSessions(t *testing.T, n *node.Node) {
   524  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   525  	defer cancel()
   526  
   527  	done := make(chan struct{})
   528  
   529  	go func() {
   530  		for {
   531  			if n.Size() == 0 {
   532  				close(done)
   533  				return
   534  			}
   535  
   536  			time.Sleep(100 * time.Millisecond)
   537  		}
   538  	}()
   539  
   540  	select {
   541  	case <-ctx.Done():
   542  		require.Fail(t, "Timeout waiting for sessions to be removed")
   543  	case <-done:
   544  	}
   545  }
   546  
   547  type immediateDisconnector struct {
   548  	n *node.Node
   549  }
   550  
   551  func (d *immediateDisconnector) Enqueue(s *node.Session) error {
   552  	return d.n.DisconnectNow(s)
   553  }
   554  
   555  func (immediateDisconnector) Run() error                         { return nil }
   556  func (immediateDisconnector) Shutdown(ctx context.Context) error { return nil }
   557  func (immediateDisconnector) Size() int                          { return 0 }
   558  
   559  func buildNode() (*node.Node, *mocks.Controller) {
   560  	controller := &mocks.Controller{}
   561  	config := node.NewConfig()
   562  	config.HubGopoolSize = 2
   563  	n := node.NewNode(&config, node.WithController(controller), node.WithInstrumenter(metrics.NewMetrics(nil, 10, slog.Default())))
   564  	n.SetBroker(broker.NewLegacyBroker(pubsub.NewLegacySubscriber(n)))
   565  	n.SetDisconnector(&immediateDisconnector{n})
   566  	return n, controller
   567  }