github.com/anycable/anycable-go@v1.5.1/broadcast/redis.go (about)

     1  package broadcast
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	rconfig "github.com/anycable/anycable-go/redis"
    13  	"github.com/anycable/anycable-go/utils"
    14  
    15  	nanoid "github.com/matoous/go-nanoid"
    16  	"github.com/redis/rueidis"
    17  )
    18  
    19  // RedisBroadcaster represents Redis broadcaster using Redis streams
    20  type RedisBroadcaster struct {
    21  	node   Handler
    22  	config *rconfig.RedisConfig
    23  
    24  	// Unique consumer identifier
    25  	consumerName string
    26  
    27  	client        rueidis.Client
    28  	clientOptions *rueidis.ClientOption
    29  	clientMu      sync.RWMutex
    30  
    31  	reconnectAttempt int
    32  
    33  	shutdownCh chan struct{}
    34  	finishedCh chan struct{}
    35  
    36  	log *slog.Logger
    37  }
    38  
    39  var _ Broadcaster = (*RedisBroadcaster)(nil)
    40  
    41  // NewRedisBroadcaster builds a new RedisSubscriber struct
    42  func NewRedisBroadcaster(node Handler, config *rconfig.RedisConfig, l *slog.Logger) *RedisBroadcaster {
    43  	name, _ := nanoid.Nanoid(6)
    44  
    45  	return &RedisBroadcaster{
    46  		node:         node,
    47  		config:       config,
    48  		consumerName: name,
    49  		log:          l.With("context", "broadcast").With("provider", "redisx").With("id", name),
    50  		shutdownCh:   make(chan struct{}),
    51  		finishedCh:   make(chan struct{}),
    52  	}
    53  }
    54  
    55  func (s *RedisBroadcaster) IsFanout() bool {
    56  	return false
    57  }
    58  
    59  func (s *RedisBroadcaster) Start(done chan error) error {
    60  	options, err := s.config.ToRueidisOptions()
    61  
    62  	if err != nil {
    63  		return err
    64  	}
    65  
    66  	if s.config.IsSentinel() { //nolint:gocritic
    67  		s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (sentinels)", s.config.Hostnames()))
    68  	} else if s.config.IsCluster() {
    69  		s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (cluster)", s.config.Hostnames()))
    70  	} else {
    71  		s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %s", s.config.Hostname()))
    72  	}
    73  
    74  	s.clientOptions = options
    75  
    76  	go s.runReader(done)
    77  
    78  	return nil
    79  }
    80  
    81  func (s *RedisBroadcaster) Shutdown(ctx context.Context) error {
    82  	s.clientMu.RLock()
    83  	defer s.clientMu.RUnlock()
    84  
    85  	if s.client == nil {
    86  		return nil
    87  	}
    88  
    89  	s.log.Debug("shutting down Redis broadcaster")
    90  
    91  	close(s.shutdownCh)
    92  
    93  	<-s.finishedCh
    94  
    95  	res := s.client.Do(
    96  		context.Background(),
    97  		s.client.B().XgroupDelconsumer().Key(s.config.Channel).Group(s.config.Group).Consumername(s.consumerName).Build(),
    98  	)
    99  
   100  	err := res.Error()
   101  
   102  	if err != nil {
   103  		s.log.Error("failed to remove Redis stream consumer", "error", err)
   104  	}
   105  
   106  	s.client.Close()
   107  
   108  	return nil
   109  }
   110  
   111  func (s *RedisBroadcaster) initClient() error {
   112  	s.clientMu.Lock()
   113  	defer s.clientMu.Unlock()
   114  
   115  	if s.client != nil {
   116  		return nil
   117  	}
   118  
   119  	c, err := rueidis.NewClient(*s.clientOptions)
   120  
   121  	if err != nil {
   122  		return err
   123  	}
   124  
   125  	s.client = c
   126  
   127  	return nil
   128  }
   129  
   130  func (s *RedisBroadcaster) runReader(done chan (error)) {
   131  	err := s.initClient()
   132  
   133  	if err != nil {
   134  		s.log.Error("failed to connect to Redis", "error", err)
   135  		s.maybeReconnect(done)
   136  		return
   137  	}
   138  
   139  	// First, create a consumer group for the stream
   140  	err = s.client.Do(context.Background(),
   141  		s.client.B().XgroupCreate().Key(s.config.Channel).Group(s.config.Group).Id("$").Mkstream().Build(),
   142  	).Error()
   143  
   144  	if err != nil {
   145  		if redisErr, ok := rueidis.IsRedisErr(err); ok {
   146  			if strings.HasPrefix(redisErr.Error(), "BUSYGROUP") {
   147  				s.log.Debug("Redis consumer group already exists")
   148  			} else {
   149  				s.log.Error("failed to create consumer group", "error", err)
   150  				s.maybeReconnect(done)
   151  				return
   152  			}
   153  		}
   154  	}
   155  
   156  	s.reconnectAttempt = 0
   157  
   158  	readBlockMilliseconds := s.config.StreamReadBlockMilliseconds
   159  	var lastClaimedAt int64
   160  
   161  	for {
   162  		select {
   163  		case <-s.shutdownCh:
   164  			s.log.Debug("stop consuming stream")
   165  			close(s.finishedCh)
   166  			return
   167  		default:
   168  			if lastClaimedAt+readBlockMilliseconds < time.Now().UnixMilli() {
   169  				reclaimed, err := s.autoclaimMessages(readBlockMilliseconds)
   170  
   171  				if err != nil {
   172  					s.log.Error("failed to claim from Redis stream", "error", err)
   173  					s.maybeReconnect(done)
   174  					return
   175  				}
   176  
   177  				lastClaimedAt = time.Now().UnixMilli()
   178  
   179  				if len(reclaimed) > 0 {
   180  					s.log.Debug("reclaimed messages", "size", len(reclaimed))
   181  
   182  					s.broadcastXrange(reclaimed)
   183  				}
   184  			}
   185  
   186  			messages, err := s.readFromStream(readBlockMilliseconds)
   187  
   188  			if err != nil {
   189  				s.log.Error("failed to read from Redis stream", "error", err)
   190  				s.maybeReconnect(done)
   191  				return
   192  			}
   193  
   194  			if messages != nil {
   195  				s.broadcastXrange(messages)
   196  			}
   197  		}
   198  	}
   199  }
   200  
   201  func (s *RedisBroadcaster) readFromStream(blockTime int64) ([]rueidis.XRangeEntry, error) {
   202  	streamRes := s.client.Do(context.Background(),
   203  		s.client.B().Xreadgroup().Group(s.config.Group, s.consumerName).Block(blockTime).Streams().Key(s.config.Channel).Id(">").Build(),
   204  	)
   205  
   206  	res, _ := streamRes.AsXRead()
   207  	err := streamRes.Error()
   208  
   209  	if err != nil && !rueidis.IsRedisNil(err) {
   210  		return nil, err
   211  	}
   212  
   213  	if res == nil {
   214  		return nil, nil
   215  	}
   216  
   217  	if messages, ok := res[s.config.Channel]; ok {
   218  		return messages, nil
   219  	}
   220  
   221  	return nil, nil
   222  }
   223  
   224  func (s *RedisBroadcaster) autoclaimMessages(blockTime int64) ([]rueidis.XRangeEntry, error) {
   225  	claimRes := s.client.Do(context.Background(),
   226  		s.client.B().Xautoclaim().Key(s.config.Channel).Group(s.config.Group).Consumer(s.consumerName).MinIdleTime(fmt.Sprintf("%d", blockTime)).Start("0-0").Build(),
   227  	)
   228  
   229  	arr, err := claimRes.ToArray()
   230  
   231  	if err != nil && !rueidis.IsRedisNil(err) {
   232  		return nil, err
   233  	}
   234  
   235  	if arr == nil {
   236  		return nil, nil
   237  	}
   238  
   239  	if len(arr) < 2 {
   240  		return nil, fmt.Errorf("autoclaim failed: got %d elements, wanted 2", len(arr))
   241  	}
   242  
   243  	ranges, err := arr[1].AsXRange()
   244  
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	return ranges, nil
   250  }
   251  
   252  func (s *RedisBroadcaster) broadcastXrange(messages []rueidis.XRangeEntry) {
   253  	for _, message := range messages {
   254  		if payload, pok := message.FieldValues["payload"]; pok {
   255  			s.log.Debug("received broadcast")
   256  			s.node.HandleBroadcast([]byte(payload))
   257  
   258  			ackRes := s.client.DoMulti(context.Background(),
   259  				s.client.B().Xack().Key(s.config.Channel).Group(s.config.Group).Id(message.ID).Build(),
   260  				s.client.B().Xdel().Key(s.config.Channel).Id(message.ID).Build(),
   261  			)
   262  
   263  			err := ackRes[0].Error()
   264  
   265  			if err != nil {
   266  				s.log.Error("failed to ack message", "error", err)
   267  			}
   268  		}
   269  	}
   270  }
   271  
   272  func (s *RedisBroadcaster) maybeReconnect(done chan (error)) {
   273  	if s.reconnectAttempt >= s.config.MaxReconnectAttempts {
   274  		close(s.finishedCh)
   275  		done <- errors.New("failed to reconnect to Redis: attempts exceeded") //nolint:stylecheck
   276  		return
   277  	}
   278  
   279  	s.reconnectAttempt++
   280  
   281  	delay := utils.NextRetry(s.reconnectAttempt - 1)
   282  
   283  	s.log.Info(fmt.Sprintf("next Redis reconnect attempt in %s", delay))
   284  	time.Sleep(delay)
   285  
   286  	s.log.Info("reconnecting to Redis...")
   287  
   288  	s.clientMu.Lock()
   289  
   290  	if s.client != nil {
   291  		s.client.Close()
   292  		s.client = nil
   293  	}
   294  
   295  	s.clientMu.Unlock()
   296  
   297  	go s.runReader(done)
   298  }