github.com/anycable/anycable-go@v1.5.1/broadcast/http_test.go (about) 1 package broadcast 2 3 import ( 4 "context" 5 "encoding/json" 6 "log/slog" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "testing" 11 12 "github.com/anycable/anycable-go/mocks" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/require" 15 ) 16 17 func TestHttpHandler(t *testing.T) { 18 handler := &mocks.Handler{} 19 config := NewHTTPConfig() 20 21 secretConfig := NewHTTPConfig() 22 secretConfig.SecretBase = "qwerty" 23 broadcastKey := "42923a28b760e667fc92f7c6123bb07a282822b329dd2ef48e7aee7830d98485" 24 25 broadcaster := NewHTTPBroadcaster(handler, &config, slog.Default()) 26 protectedBroadcaster := NewHTTPBroadcaster(handler, &secretConfig, slog.Default()) 27 28 done := make(chan (error)) 29 30 require.NoError(t, broadcaster.Start(done)) 31 defer broadcaster.Shutdown(context.Background()) // nolint:errcheck 32 33 require.NoError(t, protectedBroadcaster.Start(done)) 34 defer protectedBroadcaster.Shutdown(context.Background()) // nolint:errcheck 35 36 payload, err := json.Marshal(map[string]string{"stream": "any_test", "data": "123_test"}) 37 if err != nil { 38 t.Fatal(err) 39 } 40 41 handler.On( 42 "HandleBroadcast", 43 payload, 44 ) 45 46 t.Run("Handles broadcasts", func(t *testing.T) { 47 req, err := http.NewRequest("POST", "/", strings.NewReader(string(payload))) 48 require.NoError(t, err) 49 50 rr := httptest.NewRecorder() 51 handler := http.HandlerFunc(broadcaster.Handler) 52 handler.ServeHTTP(rr, req) 53 54 assert.Equal(t, http.StatusCreated, rr.Code) 55 }) 56 57 t.Run("Rejects non-POST requests", func(t *testing.T) { 58 req, err := http.NewRequest("GET", "/", strings.NewReader(string(payload))) 59 require.NoError(t, err) 60 61 rr := httptest.NewRecorder() 62 handler := http.HandlerFunc(broadcaster.Handler) 63 handler.ServeHTTP(rr, req) 64 65 assert.Equal(t, http.StatusUnprocessableEntity, rr.Code) 66 }) 67 68 t.Run("Rejects when authorization header is missing", func(t *testing.T) { 69 req, err := http.NewRequest("POST", "/", strings.NewReader(string(payload))) 70 require.NoError(t, err) 71 72 rr := httptest.NewRecorder() 73 handler := http.HandlerFunc(protectedBroadcaster.Handler) 74 handler.ServeHTTP(rr, req) 75 76 assert.Equal(t, http.StatusUnauthorized, rr.Code) 77 }) 78 79 t.Run("Accepts when authorization header is valid", func(t *testing.T) { 80 req, err := http.NewRequest("POST", "/", strings.NewReader(string(payload))) 81 req.Header.Set("Authorization", "Bearer "+broadcastKey) 82 83 require.NoError(t, err) 84 85 rr := httptest.NewRecorder() 86 handler := http.HandlerFunc(protectedBroadcaster.Handler) 87 handler.ServeHTTP(rr, req) 88 89 assert.Equal(t, http.StatusCreated, rr.Code) 90 }) 91 }