github.com/anycable/anycable-go@v1.5.1/pubsub/nats.go (about)

     1  package pubsub
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log/slog"
     7  	"sync"
     8  
     9  	"github.com/anycable/anycable-go/common"
    10  	"github.com/anycable/anycable-go/logger"
    11  	nconfig "github.com/anycable/anycable-go/nats"
    12  	"github.com/anycable/anycable-go/utils"
    13  
    14  	"github.com/nats-io/nats.go"
    15  )
    16  
    17  type NATSSubscriber struct {
    18  	node   Handler
    19  	config *nconfig.NATSConfig
    20  
    21  	conn *nats.Conn
    22  
    23  	subscriptions map[string]*nats.Subscription
    24  	subMu         sync.RWMutex
    25  
    26  	log *slog.Logger
    27  }
    28  
    29  var _ Subscriber = (*NATSSubscriber)(nil)
    30  
    31  // NewNATSSubscriber creates a NATS subscriber using pub/sub
    32  func NewNATSSubscriber(node Handler, config *nconfig.NATSConfig, l *slog.Logger) (*NATSSubscriber, error) {
    33  	return &NATSSubscriber{
    34  		node:          node,
    35  		config:        config,
    36  		subscriptions: make(map[string]*nats.Subscription),
    37  		log:           l.With("context", "pubsub"),
    38  	}, nil
    39  }
    40  
    41  func (s *NATSSubscriber) Start(done chan (error)) error {
    42  	connectOptions := []nats.Option{
    43  		nats.RetryOnFailedConnect(true),
    44  		nats.MaxReconnects(s.config.MaxReconnectAttempts),
    45  		nats.DisconnectErrHandler(func(nc *nats.Conn, err error) {
    46  			if err != nil {
    47  				s.log.Warn("connection failed", "error", err)
    48  			}
    49  		}),
    50  		nats.ReconnectHandler(func(nc *nats.Conn) {
    51  			s.log.Info("connection restored", "url", nc.ConnectedUrl())
    52  		}),
    53  	}
    54  
    55  	if s.config.DontRandomizeServers {
    56  		connectOptions = append(connectOptions, nats.DontRandomize())
    57  	}
    58  
    59  	nc, err := nats.Connect(s.config.Servers, connectOptions...)
    60  
    61  	if err != nil {
    62  		return err
    63  	}
    64  
    65  	s.log.Info(fmt.Sprintf("Starting NATS pub/sub: %s", s.config.Servers))
    66  
    67  	s.conn = nc
    68  
    69  	s.Subscribe(s.config.InternalChannel)
    70  
    71  	return nil
    72  }
    73  
    74  func (s *NATSSubscriber) Shutdown(ctx context.Context) error {
    75  	if s.conn != nil {
    76  		s.conn.Close()
    77  	}
    78  
    79  	return nil
    80  }
    81  
    82  func (s *NATSSubscriber) IsMultiNode() bool {
    83  	return true
    84  }
    85  
    86  func (s *NATSSubscriber) Subscribe(stream string) {
    87  	s.subMu.RLock()
    88  	if _, ok := s.subscriptions[stream]; ok {
    89  		s.subMu.RUnlock()
    90  		return
    91  	}
    92  
    93  	s.subMu.RUnlock()
    94  
    95  	s.subMu.Lock()
    96  	defer s.subMu.Unlock()
    97  
    98  	sub, err := s.conn.Subscribe(stream, s.handleMessage)
    99  
   100  	if err != nil {
   101  		s.log.Error("failed to subscribe", "stream", stream, "error", err)
   102  		return
   103  	}
   104  
   105  	s.subscriptions[stream] = sub
   106  }
   107  
   108  func (s *NATSSubscriber) Unsubscribe(stream string) {
   109  	s.subMu.Lock()
   110  	defer s.subMu.Unlock()
   111  
   112  	if sub, ok := s.subscriptions[stream]; ok {
   113  		delete(s.subscriptions, stream)
   114  		sub.Unsubscribe() // nolint:errcheck
   115  	}
   116  }
   117  
   118  func (s *NATSSubscriber) Broadcast(msg *common.StreamMessage) {
   119  	s.Publish(msg.Stream, msg)
   120  }
   121  
   122  func (s *NATSSubscriber) BroadcastCommand(cmd *common.RemoteCommandMessage) {
   123  	s.Publish(s.config.InternalChannel, cmd)
   124  }
   125  
   126  func (s *NATSSubscriber) Publish(stream string, msg interface{}) {
   127  	s.log.With("channel", stream).Debug("publish message", "data", msg)
   128  
   129  	if err := s.conn.Publish(stream, utils.ToJSON(msg)); err != nil {
   130  		s.log.Error("failed to publish message", "error", err)
   131  	}
   132  }
   133  
   134  func (s *NATSSubscriber) handleMessage(m *nats.Msg) {
   135  	msg, err := common.PubSubMessageFromJSON(m.Data)
   136  
   137  	if err != nil {
   138  		s.log.Warn("failed to parse pubsub message", "data", logger.CompactValue(m.Data), "error", err)
   139  		return
   140  	}
   141  
   142  	switch v := msg.(type) {
   143  	case common.StreamMessage:
   144  		s.log.With("channel", m.Subject).Debug("received broadcast message")
   145  		s.node.Broadcast(&v)
   146  	case common.RemoteCommandMessage:
   147  		s.log.With("channel", m.Subject).Debug("received remote command")
   148  		s.node.ExecuteRemoteCommand(&v)
   149  	default:
   150  		s.log.With("channel", m.Subject).Warn("received unknown message", "data", logger.CompactValue(m.Data))
   151  	}
   152  }