github.com/anycable/anycable-go@v1.5.1/pubsub/nats_test.go (about)

     1  package pubsub
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/anycable/anycable-go/common"
    13  	"github.com/anycable/anycable-go/enats"
    14  	"github.com/anycable/anycable-go/nats"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  
    18  	nats_server "github.com/nats-io/nats.go"
    19  )
    20  
    21  func TestNATSCommon(t *testing.T) {
    22  	server := buildNATSServer()
    23  	err := server.Start()
    24  	require.NoError(t, err)
    25  	defer server.Shutdown(context.Background()) // nolint:errcheck
    26  
    27  	config := nats.NewNATSConfig()
    28  
    29  	SharedSubscriberTests(t, func(handler *TestHandler) Subscriber {
    30  		sub, err := NewNATSSubscriber(handler, &config, slog.Default())
    31  
    32  		if err != nil {
    33  			panic(err)
    34  		}
    35  
    36  		return sub
    37  	}, waitNATSSubscription)
    38  }
    39  
    40  func TestNATSReconnect(t *testing.T) {
    41  	server := buildNATSServer()
    42  	err := server.Start()
    43  	require.NoError(t, err)
    44  	defer server.Shutdown(context.Background()) // nolint:errcheck
    45  
    46  	handler := NewTestHandler()
    47  	config := nats.NewNATSConfig()
    48  
    49  	subscriber, err := NewNATSSubscriber(handler, &config, slog.Default())
    50  	require.NoError(t, err)
    51  
    52  	done := make(chan error)
    53  
    54  	err = subscriber.Start(done)
    55  	require.NoError(t, err)
    56  
    57  	defer subscriber.Shutdown(context.Background()) // nolint:errcheck
    58  
    59  	require.NoError(t, waitNATSSubscription(subscriber, "internal"))
    60  
    61  	subscriber.Subscribe("reconnectos")
    62  	require.NoError(t, waitNATSSubscription(subscriber, "reconnectos"))
    63  
    64  	subscriber.Broadcast(&common.StreamMessage{Stream: "reconnectos", Data: "2023"})
    65  
    66  	msg := handler.Receive()
    67  	require.NotNil(t, msg)
    68  	assert.Equal(t, "2023", msg.Data)
    69  
    70  	// Reload NATS server
    71  	err = server.Shutdown(context.Background())
    72  	require.NoError(t, err)
    73  	err = server.Start()
    74  	require.NoError(t, err)
    75  
    76  	err = waitNATSConnectionActive(subscriber)
    77  	require.NoError(t, err)
    78  
    79  	subscriber.Broadcast(&common.StreamMessage{Stream: "reconnectos", Data: "2023"})
    80  
    81  	msg = handler.Receive()
    82  	require.NotNil(t, msg)
    83  	assert.Equal(t, "2023", msg.Data)
    84  }
    85  
    86  func waitNATSSubscription(subscriber Subscriber, stream string) error {
    87  	s := subscriber.(*NATSSubscriber)
    88  
    89  	err := waitNATSConnectionActive(s)
    90  
    91  	if err != nil {
    92  		return err
    93  	}
    94  
    95  	if stream == "internal" {
    96  		stream = s.config.InternalChannel
    97  	}
    98  
    99  	unsubscribing := false
   100  
   101  	if strings.HasPrefix(stream, "-") {
   102  		unsubscribing = true
   103  		stream = strings.Replace(stream, "-", "", 1)
   104  	}
   105  
   106  	attempts := 0
   107  
   108  	for {
   109  		if attempts > 5 {
   110  			if unsubscribing {
   111  				return fmt.Errorf("Timeout exceeded to unsubscribe from stream: %s", stream)
   112  			} else {
   113  				return fmt.Errorf("Timeout exceeded to subscribe to stream: %s", stream)
   114  			}
   115  		}
   116  
   117  		s.subMu.RLock()
   118  		sub := s.subscriptions[stream]
   119  		s.subMu.RUnlock()
   120  
   121  		if unsubscribing {
   122  			if sub == nil {
   123  				return nil
   124  			}
   125  		} else {
   126  			if sub == nil {
   127  				return fmt.Errorf("No pending subscription: %s", stream)
   128  			}
   129  
   130  			// We cannot get the subscription's status, so let's add a bit of delay here
   131  			time.Sleep(100 * time.Millisecond)
   132  
   133  			return nil
   134  		}
   135  
   136  		time.Sleep(100 * time.Millisecond)
   137  		attempts++
   138  	}
   139  }
   140  
   141  func waitNATSConnectionActive(s *NATSSubscriber) error {
   142  	attempts := 0
   143  
   144  	for {
   145  		if attempts > 5 {
   146  			return errors.New("Connection wasn't restored")
   147  		}
   148  
   149  		if s.conn.Status() == nats_server.CONNECTED {
   150  			return nil
   151  		}
   152  
   153  		time.Sleep(500 * time.Millisecond)
   154  		attempts++
   155  	}
   156  }
   157  
   158  func buildNATSServer() *enats.Service {
   159  	conf := enats.NewConfig()
   160  	service := enats.NewService(&conf, slog.Default())
   161  
   162  	return service
   163  }