github.com/wfusion/gofusion@v1.1.14/common/infra/watermill/pubsub/gochannel/pubsub.go (about)

     1  package gochannel
     2  
     3  import (
     4  	"context"
     5  	"math/rand"
     6  	"sync"
     7  
     8  	"github.com/pkg/errors"
     9  
    10  	"github.com/wfusion/gofusion/common/infra/watermill"
    11  	"github.com/wfusion/gofusion/common/infra/watermill/message"
    12  	"github.com/wfusion/gofusion/common/utils"
    13  )
    14  
    15  // Config holds the GoChannel Pub/Sub's configuration options.
    16  type Config struct {
    17  	// Output channel buffer size.
    18  	OutputChannelBuffer int64
    19  
    20  	// If persistent is set to true, when subscriber subscribes to the topic,
    21  	// it will receive all previously produced messages.
    22  	//
    23  	// All messages are persisted to the memory (simple slice),
    24  	// so be aware that with large amount of messages you can go out of the memory.
    25  	Persistent bool
    26  
    27  	// When true, Publish will block until subscriber Ack's the message.
    28  	// If there are no subscribers, Publish will not block (also when Persistent is true).
    29  	BlockPublishUntilSubscriberAck bool
    30  
    31  	ConsumerGroup string
    32  
    33  	AppID string
    34  }
    35  
    36  // GoChannel is the simplest Pub/Sub implementation.
    37  // It is based on Golang's channels which are sent within the process.
    38  //
    39  // GoChannel has no global state,
    40  // that means that you need to use the same instance for Publishing and Subscribing!
    41  //
    42  // When GoChannel is persistent, messages order is not guaranteed.
    43  type GoChannel struct {
    44  	config Config
    45  	logger watermill.LoggerAdapter
    46  
    47  	subscribersWg          sync.WaitGroup
    48  	subscribers            map[string]map[string][]*subscriber
    49  	subscribersLock        sync.RWMutex
    50  	subscribersByTopicLock sync.Map // map of *sync.Mutex
    51  
    52  	closed     bool
    53  	closedLock sync.Mutex
    54  	closing    chan struct{}
    55  
    56  	persistedMessages     map[string][]*message.Message
    57  	persistedMessagesLock sync.RWMutex
    58  }
    59  
    60  // NewGoChannel creates new GoChannel Pub/Sub.
    61  //
    62  // This GoChannel is not persistent.
    63  // That means if you send a message to a topic to which no subscriber is subscribed, that message will be discarded.
    64  func NewGoChannel(config Config, logger watermill.LoggerAdapter) *GoChannel {
    65  	if logger == nil {
    66  		logger = watermill.NopLogger{}
    67  	}
    68  
    69  	return &GoChannel{
    70  		config: config,
    71  
    72  		subscribers:            make(map[string]map[string][]*subscriber),
    73  		subscribersByTopicLock: sync.Map{},
    74  		logger: logger.With(watermill.LogFields{
    75  			"pubsub_uuid": utils.ShortUUID(),
    76  		}),
    77  
    78  		closing: make(chan struct{}),
    79  
    80  		persistedMessages: map[string][]*message.Message{},
    81  	}
    82  }
    83  
    84  // Publish in GoChannel is NOT blocking until all consumers consume.
    85  // Messages will be sent in background.
    86  //
    87  // Messages may be persisted or not, depending on persistent attribute.
    88  func (g *GoChannel) Publish(ctx context.Context, topic string, messages ...*message.Message) error {
    89  	if g.isClosed() {
    90  		return errors.New("Pub/Sub closed")
    91  	}
    92  
    93  	messagesToPublish := make(message.Messages, len(messages))
    94  	for i, msg := range messages {
    95  		messagesToPublish[i] = msg.Copy()
    96  	}
    97  
    98  	g.subscribersLock.RLock()
    99  	defer g.subscribersLock.RUnlock()
   100  
   101  	subLock, _ := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{})
   102  	subLock.(*sync.Mutex).Lock()
   103  	defer subLock.(*sync.Mutex).Unlock()
   104  
   105  	if g.config.Persistent {
   106  		g.persistedMessagesLock.Lock()
   107  		if _, ok := g.persistedMessages[topic]; !ok {
   108  			g.persistedMessages[topic] = make([]*message.Message, 0)
   109  		}
   110  		g.persistedMessages[topic] = append(g.persistedMessages[topic], messagesToPublish...)
   111  		g.persistedMessagesLock.Unlock()
   112  	}
   113  
   114  	for i := range messagesToPublish {
   115  		msg := messagesToPublish[i]
   116  
   117  		ackedBySubscribers, err := g.sendMessage(ctx, topic, msg)
   118  		if err != nil {
   119  			return err
   120  		}
   121  
   122  		if g.config.BlockPublishUntilSubscriberAck {
   123  			g.waitForAckFromSubscribers(msg, ackedBySubscribers)
   124  		}
   125  	}
   126  
   127  	return nil
   128  }
   129  
   130  func (g *GoChannel) waitForAckFromSubscribers(msg *message.Message, ackedByConsumer <-chan struct{}) {
   131  	logFields := watermill.LogFields{"message_uuid": msg.UUID}
   132  	g.logger.Debug("[Common] watermill gochannel waiting for subscribers ack", logFields)
   133  
   134  	select {
   135  	case <-ackedByConsumer:
   136  		g.logger.Trace("[Common] watermill gochannel message acked by subscribers", logFields)
   137  	case <-g.closing:
   138  		g.logger.Trace("[Common] watermill gochannel closing pub/sub before ack from subscribers", logFields)
   139  	}
   140  }
   141  
   142  func (g *GoChannel) sendMessage(ctx context.Context, topic string, message *message.Message) (<-chan struct{}, error) {
   143  	subscribers := g.topicSubscribers(topic)
   144  	ackedBySubscribers := make(chan struct{})
   145  
   146  	logFields := watermill.LogFields{"message_uuid": message.UUID, "topic": topic}
   147  
   148  	if len(subscribers) == 0 {
   149  		close(ackedBySubscribers)
   150  		g.logger.Info("[Common] watermill gochannel none subscribers to send message", logFields)
   151  		return ackedBySubscribers, nil
   152  	}
   153  
   154  	go func(subscribers map[string][]*subscriber) {
   155  		wg := &sync.WaitGroup{}
   156  
   157  		if noneGroupSubs, ok := subscribers[""]; ok {
   158  			for i := range noneGroupSubs {
   159  				subscriber := noneGroupSubs[i]
   160  
   161  				wg.Add(1)
   162  				go func() {
   163  					subscriber.sendMessageToSubscriber(message, logFields, g.config)
   164  					wg.Done()
   165  				}()
   166  			}
   167  			delete(subscribers, "")
   168  		}
   169  		for _, subs := range subscribers {
   170  			rand.Shuffle(len(subs), func(i, j int) { subs[i], subs[j] = subs[j], subs[i] })
   171  			subscriber := subs[0]
   172  			wg.Add(1)
   173  			go func() {
   174  				subscriber.sendMessageToSubscriber(message, logFields, g.config)
   175  				wg.Done()
   176  			}()
   177  		}
   178  
   179  		wg.Wait()
   180  		close(ackedBySubscribers)
   181  	}(subscribers)
   182  
   183  	return ackedBySubscribers, nil
   184  }
   185  
   186  // Subscribe returns channel to which all published messages are sent.
   187  // Messages are not persisted. If there are no subscribers and message is produced it will be gone.
   188  //
   189  // There are no consumer groups support etc. Every consumer will receive every produced message.
   190  func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) {
   191  	g.closedLock.Lock()
   192  
   193  	if g.closed {
   194  		g.closedLock.Unlock()
   195  		return nil, errors.New("pub/sub closed")
   196  	}
   197  
   198  	g.subscribersWg.Add(1)
   199  	g.closedLock.Unlock()
   200  
   201  	g.subscribersLock.Lock()
   202  
   203  	subLock, _ := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{})
   204  	subLock.(*sync.Mutex).Lock()
   205  
   206  	s := &subscriber{
   207  		ctx:           ctx,
   208  		uuid:          utils.UUID(),
   209  		outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer),
   210  		logger:        g.logger,
   211  		closing:       make(chan struct{}),
   212  		g:             g,
   213  	}
   214  
   215  	go func(s *subscriber, g *GoChannel) {
   216  		select {
   217  		case <-ctx.Done():
   218  			// unblock
   219  		case <-g.closing:
   220  			// unblock
   221  		}
   222  
   223  		s.Close()
   224  
   225  		g.subscribersLock.Lock()
   226  		defer g.subscribersLock.Unlock()
   227  
   228  		subLock, _ := g.subscribersByTopicLock.Load(topic)
   229  		subLock.(*sync.Mutex).Lock()
   230  		defer subLock.(*sync.Mutex).Unlock()
   231  
   232  		g.removeSubscriber(topic, g.config.ConsumerGroup, s)
   233  		g.subscribersWg.Done()
   234  	}(s, g)
   235  
   236  	if !g.config.Persistent {
   237  		defer g.subscribersLock.Unlock()
   238  		defer subLock.(*sync.Mutex).Unlock()
   239  
   240  		g.addSubscriber(topic, g.config.ConsumerGroup, s)
   241  
   242  		return s.outputChannel, nil
   243  	}
   244  
   245  	go func(s *subscriber) {
   246  		defer g.subscribersLock.Unlock()
   247  		defer subLock.(*sync.Mutex).Unlock()
   248  
   249  		g.persistedMessagesLock.RLock()
   250  		messages, ok := g.persistedMessages[topic]
   251  		g.persistedMessagesLock.RUnlock()
   252  
   253  		if ok && g.config.ConsumerGroup == "" {
   254  			for i := 0; i < len(messages); i++ {
   255  				msg := g.persistedMessages[topic][i]
   256  				logFields := watermill.LogFields{"message_uuid": msg.UUID, "topic": topic}
   257  
   258  				go s.sendMessageToSubscriber(msg, logFields, g.config)
   259  			}
   260  		}
   261  
   262  		g.addSubscriber(topic, g.config.ConsumerGroup, s)
   263  	}(s)
   264  
   265  	return s.outputChannel, nil
   266  }
   267  
   268  func (g *GoChannel) addSubscriber(topic, group string, s *subscriber) {
   269  	if _, ok := g.subscribers[topic]; !ok {
   270  		g.subscribers[topic] = make(map[string][]*subscriber)
   271  	}
   272  	g.subscribers[topic][group] = append(g.subscribers[topic][group], s)
   273  }
   274  
   275  func (g *GoChannel) removeSubscriber(topic, group string, toRemove *subscriber) {
   276  	removed := false
   277  	for _, groupSub := range g.subscribers[topic] {
   278  		for i, sub := range groupSub {
   279  			if sub == toRemove {
   280  				g.subscribers[topic][group] = append(g.subscribers[topic][group][:i],
   281  					g.subscribers[topic][group][i+1:]...)
   282  				removed = true
   283  				break
   284  			}
   285  		}
   286  
   287  	}
   288  	if !removed {
   289  		panic("cannot remove subscriber, not found " + toRemove.uuid)
   290  	}
   291  }
   292  
   293  func (g *GoChannel) topicSubscribers(topic string) map[string][]*subscriber {
   294  	subscribers, ok := g.subscribers[topic]
   295  	if !ok {
   296  		return nil
   297  	}
   298  
   299  	// let's do a copy to avoid race conditions and deadlocks due to lock
   300  	subscribersCopy := make(map[string][]*subscriber, len(subscribers))
   301  	for group, subs := range subscribers {
   302  		subscribersCopy[group] = make([]*subscriber, len(subs))
   303  		copy(subscribersCopy[group], subs)
   304  	}
   305  
   306  	return subscribersCopy
   307  }
   308  
   309  func (g *GoChannel) isClosed() bool {
   310  	g.closedLock.Lock()
   311  	defer g.closedLock.Unlock()
   312  
   313  	return g.closed
   314  }
   315  
   316  // Close closes the GoChannel Pub/Sub.
   317  func (g *GoChannel) Close() error {
   318  	g.closedLock.Lock()
   319  	defer g.closedLock.Unlock()
   320  
   321  	if g.closed {
   322  		return nil
   323  	}
   324  
   325  	g.closed = true
   326  	close(g.closing)
   327  
   328  	g.logger.Debug("[Common] watermill gochannel closing pub/sub, waiting for subscribers", nil)
   329  	g.subscribersWg.Wait()
   330  
   331  	g.logger.Info("[Common] watermill gochannel pub/sub closed", nil)
   332  	g.persistedMessages = nil
   333  
   334  	return nil
   335  }
   336  
   337  type subscriber struct {
   338  	ctx context.Context
   339  
   340  	uuid string
   341  
   342  	sending       sync.Mutex
   343  	outputChannel chan *message.Message
   344  
   345  	logger  watermill.LoggerAdapter
   346  	closed  bool
   347  	closing chan struct{}
   348  
   349  	g *GoChannel
   350  }
   351  
   352  func (s *subscriber) Close() {
   353  	if s.closed {
   354  		return
   355  	}
   356  	close(s.closing)
   357  
   358  	s.logger.Debug("[Common] watermill gochannel closing subscriber, waiting for sending lock", nil)
   359  
   360  	// ensuring that we are not sending to closed channel
   361  	s.sending.Lock()
   362  	defer s.sending.Unlock()
   363  
   364  	s.logger.Debug("[Common] watermill gochannel pub/sub subscriber closed", nil)
   365  	s.closed = true
   366  
   367  	close(s.outputChannel)
   368  }
   369  
   370  func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields, conf Config) {
   371  	s.sending.Lock()
   372  	defer s.sending.Unlock()
   373  
   374  	rawMessageID := utils.NginxID()
   375  	ctx := context.WithValue(s.ctx, watermill.ContextKeyMessageUUID, msg.UUID)
   376  	ctx = context.WithValue(ctx, watermill.ContextKeyRawMessageID, rawMessageID)
   377  	ctx, cancelCtx := context.WithCancel(ctx)
   378  	defer cancelCtx()
   379  	msg.Metadata[watermill.ContextKeyMessageUUID] = msg.UUID
   380  	msg.Metadata[watermill.ContextKeyRawMessageID] = rawMessageID
   381  	msg.Metadata[watermill.MessageHeaderAppID] = conf.AppID
   382  
   383  SendToSubscriber:
   384  	for {
   385  		// copy the message to prevent ack/nack propagation to other consumers
   386  		// also allows to make retries on a fresh copy of the original message
   387  		msgToSend := msg.Copy()
   388  		msgToSend.SetContext(ctx)
   389  
   390  		s.logger.Trace("[Common] watermill gochannel sending msg to subscriber", logFields)
   391  
   392  		if s.closed {
   393  			s.logger.Info("[Common] watermill gochannel pub/sub closed, discarding msg", logFields)
   394  			return
   395  		}
   396  
   397  		select {
   398  		case s.outputChannel <- msgToSend:
   399  			s.logger.Trace("[Common] watermill gochannel sent message to subscriber", logFields)
   400  		case <-s.closing:
   401  			s.logger.Trace("[Common] watermill gochannel closing, message discarded", logFields)
   402  			return
   403  		}
   404  
   405  		select {
   406  		case <-msgToSend.Acked():
   407  			s.logger.Trace("[Common] watermill gochannel message acked", logFields)
   408  			return
   409  		case <-msgToSend.Nacked():
   410  			s.logger.Trace("[Common] watermill gochannel nack received, resending message", logFields)
   411  			continue SendToSubscriber
   412  		case <-s.closing:
   413  			s.logger.Trace("[Common] watermill gochannel closing, message discarded", logFields)
   414  			return
   415  		}
   416  	}
   417  }