github.com/anycable/anycable-go@v1.5.1/broadcast/http_test.go (about)

     1  package broadcast
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"log/slog"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/anycable/anycable-go/mocks"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func TestHttpHandler(t *testing.T) {
    18  	handler := &mocks.Handler{}
    19  	config := NewHTTPConfig()
    20  
    21  	secretConfig := NewHTTPConfig()
    22  	secretConfig.SecretBase = "qwerty"
    23  	broadcastKey := "42923a28b760e667fc92f7c6123bb07a282822b329dd2ef48e7aee7830d98485"
    24  
    25  	broadcaster := NewHTTPBroadcaster(handler, &config, slog.Default())
    26  	protectedBroadcaster := NewHTTPBroadcaster(handler, &secretConfig, slog.Default())
    27  
    28  	done := make(chan (error))
    29  
    30  	require.NoError(t, broadcaster.Start(done))
    31  	defer broadcaster.Shutdown(context.Background()) // nolint:errcheck
    32  
    33  	require.NoError(t, protectedBroadcaster.Start(done))
    34  	defer protectedBroadcaster.Shutdown(context.Background()) // nolint:errcheck
    35  
    36  	payload, err := json.Marshal(map[string]string{"stream": "any_test", "data": "123_test"})
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  
    41  	handler.On(
    42  		"HandleBroadcast",
    43  		payload,
    44  	)
    45  
    46  	t.Run("Handles broadcasts", func(t *testing.T) {
    47  		req, err := http.NewRequest("POST", "/", strings.NewReader(string(payload)))
    48  		require.NoError(t, err)
    49  
    50  		rr := httptest.NewRecorder()
    51  		handler := http.HandlerFunc(broadcaster.Handler)
    52  		handler.ServeHTTP(rr, req)
    53  
    54  		assert.Equal(t, http.StatusCreated, rr.Code)
    55  	})
    56  
    57  	t.Run("Rejects non-POST requests", func(t *testing.T) {
    58  		req, err := http.NewRequest("GET", "/", strings.NewReader(string(payload)))
    59  		require.NoError(t, err)
    60  
    61  		rr := httptest.NewRecorder()
    62  		handler := http.HandlerFunc(broadcaster.Handler)
    63  		handler.ServeHTTP(rr, req)
    64  
    65  		assert.Equal(t, http.StatusUnprocessableEntity, rr.Code)
    66  	})
    67  
    68  	t.Run("Rejects when authorization header is missing", func(t *testing.T) {
    69  		req, err := http.NewRequest("POST", "/", strings.NewReader(string(payload)))
    70  		require.NoError(t, err)
    71  
    72  		rr := httptest.NewRecorder()
    73  		handler := http.HandlerFunc(protectedBroadcaster.Handler)
    74  		handler.ServeHTTP(rr, req)
    75  
    76  		assert.Equal(t, http.StatusUnauthorized, rr.Code)
    77  	})
    78  
    79  	t.Run("Accepts when authorization header is valid", func(t *testing.T) {
    80  		req, err := http.NewRequest("POST", "/", strings.NewReader(string(payload)))
    81  		req.Header.Set("Authorization", "Bearer "+broadcastKey)
    82  
    83  		require.NoError(t, err)
    84  
    85  		rr := httptest.NewRecorder()
    86  		handler := http.HandlerFunc(protectedBroadcaster.Handler)
    87  		handler.ServeHTTP(rr, req)
    88  
    89  		assert.Equal(t, http.StatusCreated, rr.Code)
    90  	})
    91  }