github.com/anycable/anycable-go@v1.5.1/pubsub/redis_test.go (about)

     1  package pubsub
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"os"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/anycable/anycable-go/common"
    14  	rconfig "github.com/anycable/anycable-go/redis"
    15  	"github.com/redis/rueidis"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  )
    19  
    20  var (
    21  	redisAvailable = false
    22  	redisURL       = os.Getenv("REDIS_URL")
    23  )
    24  
    25  // Check if Redis is available and skip tests otherwise
    26  func init() {
    27  	config := rconfig.NewRedisConfig()
    28  
    29  	if redisURL != "" {
    30  		config.URL = redisURL
    31  	}
    32  
    33  	subscriber, err := NewRedisSubscriber(nil, &config, slog.Default())
    34  	if err != nil {
    35  		fmt.Printf("Failed to create redis subscriber: %v", err)
    36  		return
    37  	}
    38  
    39  	err = subscriber.Start(make(chan error))
    40  
    41  	if err != nil {
    42  		fmt.Printf("Failed to start Redis subscriber: %v", err)
    43  		return
    44  	}
    45  
    46  	err = subscriber.initClient()
    47  	if err != nil {
    48  		fmt.Printf("No Redis detected at %s: %v", config.URL, err)
    49  		return
    50  	}
    51  
    52  	defer subscriber.Shutdown(context.Background()) // nolint:errcheck
    53  
    54  	c := subscriber.client
    55  
    56  	err = c.Do(context.Background(), c.B().Arbitrary("PING").Build()).Error()
    57  
    58  	redisAvailable = err == nil
    59  }
    60  
    61  func TestRedisCommon(t *testing.T) {
    62  	if !redisAvailable {
    63  		t.Skip("Skipping Redis tests: no Redis available")
    64  		return
    65  	}
    66  
    67  	config := rconfig.NewRedisConfig()
    68  
    69  	if redisURL != "" {
    70  		config.URL = redisURL
    71  	}
    72  
    73  	SharedSubscriberTests(t, func(handler *TestHandler) Subscriber {
    74  		sub, err := NewRedisSubscriber(handler, &config, slog.Default())
    75  
    76  		if err != nil {
    77  			panic(err)
    78  		}
    79  
    80  		return sub
    81  	}, waitRedisSubscription)
    82  }
    83  
    84  func TestRedisReconnect(t *testing.T) {
    85  	if !redisAvailable {
    86  		t.Skip("Skipping Redis tests: no Redis available")
    87  		return
    88  	}
    89  
    90  	handler := NewTestHandler()
    91  	config := rconfig.NewRedisConfig()
    92  
    93  	if redisURL != "" {
    94  		config.URL = redisURL
    95  	}
    96  
    97  	subscriber, err := NewRedisSubscriber(handler, &config, slog.Default())
    98  	require.NoError(t, err)
    99  
   100  	done := make(chan error)
   101  
   102  	err = subscriber.Start(done)
   103  	require.NoError(t, err)
   104  
   105  	defer subscriber.Shutdown(context.Background()) // nolint:errcheck
   106  
   107  	require.NoError(t, waitRedisSubscription(subscriber, "internal"))
   108  
   109  	subscriber.Subscribe("reconnectos")
   110  	require.NoError(t, waitRedisSubscription(subscriber, "reconnectos"))
   111  
   112  	subscriber.Broadcast(&common.StreamMessage{Stream: "reconnectos", Data: "2022"})
   113  
   114  	msg := handler.Receive()
   115  	require.NotNil(t, msg)
   116  	assert.Equal(t, "2022", msg.Data)
   117  
   118  	// Drop Redis pus/sub connections
   119  	require.NoError(t, dropRedisPubSubConnections(subscriber.client))
   120  	require.NoError(t, waitRedisPubSubConnections(subscriber.client))
   121  
   122  	require.NoError(t, waitRedisSubscription(subscriber, "reconnectos"))
   123  
   124  	subscriber.Broadcast(&common.StreamMessage{Stream: "reconnectos", Data: "2023"})
   125  
   126  	msg = handler.Receive()
   127  	require.NotNil(t, msg)
   128  	assert.Equal(t, "2023", msg.Data)
   129  }
   130  
   131  func waitRedisSubscription(subscriber Subscriber, stream string) error {
   132  	s := subscriber.(*RedisSubscriber)
   133  
   134  	if stream == "internal" {
   135  		stream = s.config.InternalChannel
   136  	}
   137  
   138  	unsubscribing := false
   139  
   140  	if strings.HasPrefix(stream, "-") {
   141  		unsubscribing = true
   142  		stream = strings.Replace(stream, "-", "", 1)
   143  	}
   144  
   145  	attempts := 0
   146  
   147  	for {
   148  		if attempts > 5 {
   149  			if unsubscribing {
   150  				return fmt.Errorf("Timeout exceeded to unsubscribe from stream: %s", stream)
   151  			} else {
   152  				return fmt.Errorf("Timeout exceeded to subscribe to stream: %s", stream)
   153  			}
   154  		}
   155  
   156  		s.subMu.RLock()
   157  		entry := s.subscriptionEntry(stream)
   158  		state := subscriptionPending
   159  		if entry != nil {
   160  			state = entry.state
   161  		}
   162  		s.subMu.RUnlock()
   163  
   164  		if unsubscribing {
   165  			if entry == nil {
   166  				return nil
   167  			}
   168  		} else {
   169  			if entry == nil {
   170  				return fmt.Errorf("No pending subscription: %s", stream)
   171  			}
   172  
   173  			if state == subscriptionCreated {
   174  				return nil
   175  			}
   176  		}
   177  
   178  		time.Sleep(100 * time.Millisecond)
   179  		attempts++
   180  	}
   181  }
   182  
   183  // Mimics Rails implementation: https://github.com/rails/rails/blob/6d581c43a77b8945df3d427273d357b67c303077/actioncable/test/subscription_adapter/redis_test.rb#L51-L67
   184  func dropRedisPubSubConnections(client rueidis.Client) error {
   185  	res := client.Do(context.Background(), client.B().Arbitrary("client", "kill", "type", "pubsub").Build())
   186  
   187  	_, err := res.AsInt64()
   188  
   189  	return err
   190  }
   191  
   192  func waitRedisPubSubConnections(client rueidis.Client) error {
   193  	attempts := 0
   194  
   195  	for {
   196  		if attempts > 5 {
   197  			return errors.New("No pub/sub connection were created")
   198  		}
   199  
   200  		res := client.Do(context.Background(), client.B().Arbitrary("pubsub", "channels").Build())
   201  		channels, err := res.ToArray()
   202  
   203  		if err == nil && len(channels) > 0 {
   204  			return nil
   205  		}
   206  
   207  		time.Sleep(500 * time.Millisecond)
   208  		attempts++
   209  	}
   210  }