github.com/anycable/anycable-go@v1.5.1/broadcast/redis_test.go (about)

     1  package broadcast
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"log/slog"
     9  	"os"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/anycable/anycable-go/mocks"
    15  	rconfig "github.com/anycable/anycable-go/redis"
    16  	"github.com/anycable/anycable-go/utils"
    17  	"github.com/redis/rueidis"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/mock"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  var (
    24  	redisAvailable = false
    25  	redisURL       = os.Getenv("REDIS_URL")
    26  )
    27  
    28  // Check if Redis is available and skip tests otherwise
    29  func init() {
    30  	config := rconfig.NewRedisConfig()
    31  
    32  	if redisURL != "" {
    33  		config.URL = redisURL
    34  	}
    35  
    36  	options, err := config.ToRueidisOptions()
    37  
    38  	if err != nil {
    39  		fmt.Printf("Failed to parse Redis URL: %v", err)
    40  		return
    41  	}
    42  
    43  	c, err := rueidis.NewClient(*options)
    44  
    45  	if err != nil {
    46  		fmt.Printf("Failed to connect to Redis: %v", err)
    47  		return
    48  	}
    49  
    50  	err = c.Do(context.Background(), c.B().Arbitrary("PING").Build()).Error()
    51  
    52  	redisAvailable = err == nil
    53  
    54  	if !redisAvailable {
    55  		return
    56  	}
    57  
    58  	c.Do(context.Background(), c.B().XgroupDestroy().Key("__anycable__").Group("bx").Build())
    59  }
    60  
    61  func TestRedisBroadcaster(t *testing.T) {
    62  	if !redisAvailable {
    63  		t.Skip("Skipping Redis tests: no Redis available")
    64  		return
    65  	}
    66  
    67  	config := rconfig.NewRedisConfig()
    68  
    69  	if redisURL != "" {
    70  		config.URL = redisURL
    71  	}
    72  
    73  	config.StreamReadBlockMilliseconds = 500
    74  
    75  	handler := &mocks.Handler{}
    76  	errchan := make(chan error)
    77  	broadcasts := make(chan map[string]string, 10)
    78  
    79  	payload := utils.ToJSON(map[string]string{"stream": "any_test", "data": "123_test"})
    80  
    81  	handler.On(
    82  		"HandleBroadcast",
    83  		mock.Anything,
    84  	).Run(func(args mock.Arguments) {
    85  		data := args.Get(0).([]byte)
    86  		var msg map[string]string
    87  		json.Unmarshal(data, &msg) // nolint: errcheck
    88  
    89  		broadcasts <- msg
    90  	})
    91  
    92  	t.Run("Handles broadcasts", func(t *testing.T) {
    93  		broadcaster := NewRedisBroadcaster(handler, &config, slog.Default())
    94  
    95  		err := broadcaster.Start(errchan)
    96  		require.NoError(t, err)
    97  
    98  		defer broadcaster.Shutdown(context.Background()) // nolint:errcheck
    99  
   100  		require.NoError(t, broadcaster.initClient())
   101  
   102  		require.NoError(t, waitRedisStreamConsumers(broadcaster.client, 1))
   103  
   104  		require.NoError(t, publishToRedisStream(broadcaster.client, "__anycable__", string(payload)))
   105  
   106  		messages := drainBroadcasts(broadcasts)
   107  		require.Equalf(t, 1, len(messages), "Expected 1 message, got %d", len(messages))
   108  
   109  		msg := messages[0]
   110  
   111  		assert.Equal(t, "any_test", msg["stream"])
   112  		assert.Equal(t, "123_test", msg["data"])
   113  	})
   114  
   115  	t.Run("With multiple subscribers", func(t *testing.T) {
   116  		broadcaster := NewRedisBroadcaster(handler, &config, slog.Default())
   117  
   118  		err := broadcaster.Start(errchan)
   119  		require.NoError(t, err)
   120  
   121  		defer broadcaster.Shutdown(context.Background()) // nolint:errcheck
   122  
   123  		require.NoError(t, broadcaster.initClient())
   124  
   125  		broadcaster2 := NewRedisBroadcaster(handler, &config, slog.Default())
   126  		err = broadcaster2.Start(errchan)
   127  		require.NoError(t, err)
   128  
   129  		defer broadcaster2.Shutdown(context.Background()) // nolint:errcheck
   130  
   131  		require.NoError(t, waitRedisStreamConsumers(broadcaster.client, 2))
   132  
   133  		require.NoError(t, publishToRedisStream(broadcaster.client, "__anycable__",
   134  			string(utils.ToJSON(map[string]string{"stream": "any_test", "data": "123_test"})),
   135  		))
   136  
   137  		require.NoError(t, publishToRedisStream(broadcaster.client, "__anycable__",
   138  			string(utils.ToJSON(map[string]string{"stream": "any_test", "data": "124_test"})),
   139  		))
   140  
   141  		require.NoError(t, publishToRedisStream(broadcaster.client, "__anycable__",
   142  			string(utils.ToJSON(map[string]string{"stream": "any_test", "data": "125_test"})),
   143  		))
   144  
   145  		messages := drainBroadcasts(broadcasts)
   146  
   147  		require.Equalf(t, 3, len(messages), "Expected 3 messages, got %d", len(messages))
   148  	})
   149  }
   150  
   151  func TestRedisBroadcasterAcksClaims(t *testing.T) {
   152  	if !redisAvailable {
   153  		t.Skip("Skipping Redis tests: no Redis available")
   154  		return
   155  	}
   156  
   157  	config := rconfig.NewRedisConfig()
   158  	// Make it short to avoid sleeping for too long in tests
   159  	config.StreamReadBlockMilliseconds = 100
   160  
   161  	if redisURL != "" {
   162  		config.URL = redisURL
   163  	}
   164  
   165  	handler := &mocks.Handler{}
   166  	broadcaster := NewRedisBroadcaster(handler, &config, slog.Default())
   167  
   168  	errchan := make(chan error)
   169  	broadcasts := make(chan string, 10)
   170  
   171  	closed := false
   172  
   173  	handler.On(
   174  		"HandleBroadcast",
   175  		mock.Anything,
   176  	).Run(func(args mock.Arguments) {
   177  		msg := string(args.Get(0).([]byte))
   178  		broadcasts <- msg
   179  
   180  		if msg == "2" && !closed {
   181  			closed = true
   182  			// Close the connection to prevent consumer from ack-ing the message
   183  			broadcaster.client.Close()
   184  			broadcaster.reconnectAttempt = config.MaxReconnectAttempts + 1
   185  		}
   186  	})
   187  
   188  	err := broadcaster.Start(errchan)
   189  	require.NoError(t, err)
   190  	defer broadcaster.Shutdown(context.Background()) // nolint:errcheck
   191  
   192  	require.NoError(t, broadcaster.initClient())
   193  	require.NoError(t, waitRedisStreamConsumers(broadcaster.client, 1))
   194  
   195  	require.NoError(t, publishToRedisStream(broadcaster.client, "__anycable__", "1"))
   196  	require.NoError(t, publishToRedisStream(broadcaster.client, "__anycable__", "2"))
   197  
   198  	broadcaster2 := NewRedisBroadcaster(handler, &config, slog.Default())
   199  	err = broadcaster2.Start(errchan)
   200  	require.NoError(t, err)
   201  	defer broadcaster2.Shutdown(context.Background()) // nolint:errcheck
   202  
   203  	require.NoError(t, broadcaster2.initClient())
   204  	require.NoError(t, waitRedisStreamConsumers(broadcaster2.client, 1))
   205  
   206  	// We should wait for at least 2*blockTime to mark older consumer as stale
   207  	// and claim its messages
   208  	time.Sleep(300 * time.Millisecond)
   209  
   210  	messages := drainBroadcasts(broadcasts)
   211  	require.Equalf(t, 3, len(messages), "Expected 3 messages, got %d", len(messages))
   212  
   213  	assert.Equal(t, "1", messages[0])
   214  	assert.Equal(t, "2", messages[1])
   215  	// We haven't acked the last message within the first broadcaster,
   216  	// so the second one must have picked it up
   217  	assert.Equal(t, "2", messages[1])
   218  }
   219  
   220  func drainBroadcasts[T any](ch chan T) []T {
   221  	buffer := make([]T, 0)
   222  
   223  out:
   224  	for {
   225  		select {
   226  		case msg := <-ch:
   227  			buffer = append(buffer, msg)
   228  		case <-time.After(time.Second):
   229  			break out
   230  		}
   231  	}
   232  
   233  	return buffer
   234  }
   235  
   236  func publishToRedisStream(client rueidis.Client, stream string, payload string) error {
   237  	if client == nil {
   238  		return errors.New("No Redis client configured")
   239  	}
   240  
   241  	res := client.Do(context.Background(),
   242  		client.B().Xadd().Key(stream).Id("*").FieldValue().FieldValue("payload", payload).Build(),
   243  	)
   244  
   245  	return res.Error()
   246  }
   247  
   248  func waitRedisStreamConsumers(client rueidis.Client, count int) error {
   249  	if client == nil {
   250  		return errors.New("No Redis client configured")
   251  	}
   252  
   253  	attempts := 0
   254  
   255  	for {
   256  		if attempts > 5 {
   257  			return errors.New("No stream consumer were created")
   258  		}
   259  
   260  		res := client.Do(context.Background(), client.B().Arbitrary("client", "list").Build())
   261  		clientsStr, err := res.ToString()
   262  
   263  		if err == nil {
   264  			clients := strings.Split(clientsStr, "\n")
   265  
   266  			readers := 0
   267  			for _, clientMsg := range clients {
   268  				if clientMsg == "" {
   269  					continue
   270  				}
   271  
   272  				clientCmd := strings.Split(strings.Split(clientMsg, "cmd=")[1], " ")[0]
   273  
   274  				if clientCmd == "xreadgroup" {
   275  					readers++
   276  				}
   277  			}
   278  
   279  			if readers >= count {
   280  				return nil
   281  			}
   282  		}
   283  
   284  		time.Sleep(500 * time.Millisecond)
   285  		attempts++
   286  	}
   287  }