github.com/anycable/anycable-go@v1.5.1/pubsub/redis.go (about)

     1  package pubsub
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/anycable/anycable-go/common"
    12  	"github.com/anycable/anycable-go/logger"
    13  	rconfig "github.com/anycable/anycable-go/redis"
    14  	"github.com/anycable/anycable-go/utils"
    15  	"github.com/redis/rueidis"
    16  )
    17  
    18  type subscriptionState = int
    19  
    20  const (
    21  	subscriptionPending subscriptionState = iota
    22  	subscriptionCreated
    23  	subscriptionPendingUnsubscribe
    24  )
    25  
    26  type subscriptionEntry struct {
    27  	id    string
    28  	state subscriptionState
    29  }
    30  
    31  type RedisSubscriber struct {
    32  	node   Handler
    33  	config *rconfig.RedisConfig
    34  
    35  	client           rueidis.Client
    36  	clientOptions    *rueidis.ClientOption
    37  	clientMu         sync.RWMutex
    38  	reconnectAttempt int
    39  
    40  	subscriptions map[string]*subscriptionEntry
    41  	subMu         sync.RWMutex
    42  
    43  	streamsCh  chan (*subscriptionEntry)
    44  	shutdownCh chan struct{}
    45  
    46  	log *slog.Logger
    47  }
    48  
    49  var _ Subscriber = (*RedisSubscriber)(nil)
    50  
    51  // NewRedisSubscriber creates a Redis subscriber using pub/sub
    52  func NewRedisSubscriber(node Handler, config *rconfig.RedisConfig, l *slog.Logger) (*RedisSubscriber, error) {
    53  	options, err := config.ToRueidisOptions()
    54  
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	return &RedisSubscriber{
    60  		node:          node,
    61  		config:        config,
    62  		clientOptions: options,
    63  		subscriptions: make(map[string]*subscriptionEntry),
    64  		log:           l.With("context", "pubsub"),
    65  		streamsCh:     make(chan *subscriptionEntry, 1024),
    66  		shutdownCh:    make(chan struct{}),
    67  	}, nil
    68  }
    69  
    70  func (s *RedisSubscriber) Start(done chan (error)) error {
    71  	if s.config.IsSentinel() { //nolint:gocritic
    72  		s.log.Info(fmt.Sprintf("Starting Redis pub/sub (sentinels): %v", s.config.Hostnames()))
    73  	} else if s.config.IsCluster() {
    74  		s.log.Info(fmt.Sprintf("Starting Redis pub/sub (cluster): %v", s.config.Hostnames()))
    75  	} else {
    76  		s.log.Info(fmt.Sprintf("Starting Redis pub/sub: %s", s.config.Hostname()))
    77  	}
    78  
    79  	go s.runPubSub(done)
    80  
    81  	s.Subscribe(s.config.InternalChannel)
    82  
    83  	return nil
    84  }
    85  
    86  func (s *RedisSubscriber) Shutdown(ctx context.Context) error {
    87  	s.clientMu.RLock()
    88  	defer s.clientMu.RUnlock()
    89  
    90  	if s.client == nil {
    91  		return nil
    92  	}
    93  
    94  	s.log.Debug("shutting down Redis pub/sub")
    95  
    96  	// First, shutdown the pub/sub routine
    97  	close(s.shutdownCh)
    98  	s.client.Close()
    99  
   100  	return nil
   101  }
   102  
   103  func (s *RedisSubscriber) IsMultiNode() bool {
   104  	return true
   105  }
   106  
   107  func (s *RedisSubscriber) Subscribe(stream string) {
   108  	s.subMu.Lock()
   109  	s.subscriptions[stream] = &subscriptionEntry{state: subscriptionPending, id: stream}
   110  	entry := s.subscriptions[stream]
   111  	s.subMu.Unlock()
   112  
   113  	s.streamsCh <- entry
   114  }
   115  
   116  func (s *RedisSubscriber) Unsubscribe(stream string) {
   117  	s.subMu.Lock()
   118  	if _, ok := s.subscriptions[stream]; !ok {
   119  		s.subMu.Unlock()
   120  		return
   121  	}
   122  
   123  	entry := s.subscriptions[stream]
   124  	entry.state = subscriptionPendingUnsubscribe
   125  
   126  	s.streamsCh <- entry
   127  	s.subMu.Unlock()
   128  }
   129  
   130  func (s *RedisSubscriber) Broadcast(msg *common.StreamMessage) {
   131  	s.Publish(msg.Stream, msg)
   132  }
   133  
   134  func (s *RedisSubscriber) BroadcastCommand(cmd *common.RemoteCommandMessage) {
   135  	s.Publish(s.config.InternalChannel, cmd)
   136  }
   137  
   138  func (s *RedisSubscriber) Publish(stream string, msg interface{}) {
   139  	s.clientMu.RLock()
   140  
   141  	if s.client == nil {
   142  		s.clientMu.RUnlock()
   143  		return
   144  	}
   145  
   146  	ctx := context.Background()
   147  	client := s.client
   148  
   149  	s.clientMu.RUnlock()
   150  
   151  	s.log.With("channel", stream).Debug("publish message", "data", msg)
   152  
   153  	client.Do(ctx, client.B().Publish().Channel(stream).Message(string(utils.ToJSON(msg))).Build())
   154  }
   155  
   156  func (s *RedisSubscriber) initClient() error {
   157  	s.clientMu.Lock()
   158  	defer s.clientMu.Unlock()
   159  
   160  	if s.client != nil {
   161  		return nil
   162  	}
   163  
   164  	c, err := rueidis.NewClient(*s.clientOptions)
   165  
   166  	if err != nil {
   167  		return err
   168  	}
   169  
   170  	s.client = c
   171  
   172  	return nil
   173  }
   174  
   175  func (s *RedisSubscriber) runPubSub(done chan (error)) {
   176  	err := s.initClient()
   177  
   178  	if err != nil {
   179  		s.log.Error("failed to connect to Redis", "error", err)
   180  		s.maybeReconnect(done)
   181  		return
   182  	}
   183  
   184  	client, cancel := s.client.Dedicate()
   185  	defer cancel()
   186  
   187  	wait := client.SetPubSubHooks(rueidis.PubSubHooks{
   188  		OnSubscription: func(m rueidis.PubSubSubscription) {
   189  			s.subMu.Lock()
   190  			defer s.subMu.Unlock()
   191  
   192  			if m.Kind == "subscribe" && m.Channel == s.config.InternalChannel {
   193  				if s.reconnectAttempt > 0 {
   194  					s.log.Info("reconnected to Redis")
   195  				}
   196  				s.reconnectAttempt = 0
   197  			}
   198  
   199  			if entry, ok := s.subscriptions[m.Channel]; ok {
   200  				if entry.state == subscriptionPending && m.Kind == "subscribe" {
   201  					s.log.With("channel", m.Channel).Debug("subscribed")
   202  					entry.state = subscriptionCreated
   203  				}
   204  
   205  				if entry.state == subscriptionPendingUnsubscribe && m.Kind == "unsubscribe" {
   206  					s.log.With("channel", m.Channel).Debug("unsubscribed")
   207  					delete(s.subscriptions, entry.id)
   208  				}
   209  			}
   210  		},
   211  		OnMessage: func(m rueidis.PubSubMessage) {
   212  			msg, err := common.PubSubMessageFromJSON([]byte(m.Message))
   213  
   214  			if err != nil {
   215  				s.log.Warn("failed to parse pubsub message", "data", logger.CompactValue(m.Message), "error", err)
   216  				return
   217  			}
   218  
   219  			switch v := msg.(type) {
   220  			case common.StreamMessage:
   221  				s.log.With("channel", m.Channel).Debug("received broadcast message")
   222  				s.node.Broadcast(&v)
   223  			case common.RemoteCommandMessage:
   224  				s.log.With("channel", m.Channel).Debug("received remote command")
   225  				s.node.ExecuteRemoteCommand(&v)
   226  			}
   227  		},
   228  	})
   229  
   230  	for {
   231  		select {
   232  		case err := <-wait:
   233  			if err != nil {
   234  				s.log.Error("Redis pub/sub disconnected", "error", err)
   235  			}
   236  
   237  			s.maybeReconnect(done)
   238  
   239  			return
   240  		case <-s.shutdownCh:
   241  			s.log.Debug("close pub/sub channel")
   242  			return
   243  		case entry := <-s.streamsCh:
   244  			ctx := context.Background()
   245  
   246  			switch entry.state {
   247  			case subscriptionPending:
   248  				s.log.With("channel", entry.id).Debug("subscribing")
   249  				client.Do(ctx, client.B().Subscribe().Channel(entry.id).Build())
   250  			case subscriptionPendingUnsubscribe:
   251  				s.log.With("channel", entry.id).Debug("unsubscribing")
   252  				client.Do(ctx, client.B().Unsubscribe().Channel(entry.id).Build())
   253  			}
   254  		}
   255  	}
   256  }
   257  
   258  func (s *RedisSubscriber) subscriptionEntry(stream string) *subscriptionEntry {
   259  	s.subMu.RLock()
   260  	defer s.subMu.RUnlock()
   261  
   262  	if entry, ok := s.subscriptions[stream]; ok {
   263  		return entry
   264  	}
   265  
   266  	return nil
   267  }
   268  
   269  func (s *RedisSubscriber) maybeReconnect(done chan (error)) {
   270  	if s.reconnectAttempt >= s.config.MaxReconnectAttempts {
   271  		done <- errors.New("failed to reconnect to Redis: attempts exceeded") //nolint:stylecheck
   272  		return
   273  	}
   274  
   275  	s.clientMu.RLock()
   276  	if s.client != nil {
   277  		// Make sure client knows about connection failure,
   278  		// so the next attempt to Publish won't fail
   279  		s.client.Do(context.Background(), s.client.B().Arbitrary("ping").Build())
   280  	}
   281  	s.clientMu.RUnlock()
   282  
   283  	s.subMu.Lock()
   284  	toRemove := []string{}
   285  
   286  	for key, sub := range s.subscriptions {
   287  		if sub.state == subscriptionCreated {
   288  			sub.state = subscriptionPending
   289  		}
   290  
   291  		if sub.state == subscriptionPendingUnsubscribe {
   292  			toRemove = append(toRemove, key)
   293  		}
   294  	}
   295  
   296  	for _, key := range toRemove {
   297  		delete(s.subscriptions, key)
   298  	}
   299  	s.subMu.Unlock()
   300  
   301  	s.reconnectAttempt++
   302  
   303  	delay := utils.NextRetry(s.reconnectAttempt - 1)
   304  
   305  	s.log.Info(fmt.Sprintf("next Redis reconnect attempt in %s", delay))
   306  	time.Sleep(delay)
   307  
   308  	s.log.Info("reconnecting to Redis...")
   309  
   310  	go s.runPubSub(done)
   311  
   312  	s.subMu.RLock()
   313  	defer s.subMu.RUnlock()
   314  
   315  	for _, sub := range s.subscriptions {
   316  		if sub.state == subscriptionPending {
   317  			s.log.Debug("resubscribing to stream", "stream", sub.id)
   318  			s.streamsCh <- sub
   319  		}
   320  	}
   321  }