github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/cluster/pubsub_topic.go (about)

     1  package cluster
     2  
     3  import (
     4  	"context"
     5  	"log/slog"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/asynkron/protoactor-go/actor"
    10  	"github.com/asynkron/protoactor-go/eventstream"
    11  	"golang.org/x/exp/maps"
    12  )
    13  
    14  const TopicActorKind = "prototopic"
    15  
    16  type TopicActor struct {
    17  	topic                string
    18  	subscribers          map[subscribeIdentityStruct]*SubscriberIdentity
    19  	subscriptionStore    KeyValueStore[*Subscribers]
    20  	topologySubscription *eventstream.Subscription
    21  	shouldThrottle       actor.ShouldThrottle
    22  }
    23  
    24  func NewTopicActor(store KeyValueStore[*Subscribers], logger *slog.Logger) *TopicActor {
    25  	return &TopicActor{
    26  		subscriptionStore: store,
    27  		subscribers:       make(map[subscribeIdentityStruct]*SubscriberIdentity),
    28  		shouldThrottle: actor.NewThrottleWithLogger(logger, 10, time.Second, func(logger *slog.Logger, count int32) {
    29  			logger.Info("[TopicActor] Throttled logs", slog.Int("count", int(count)))
    30  		}),
    31  	}
    32  }
    33  
    34  func (t *TopicActor) Receive(c actor.Context) {
    35  	switch msg := c.Message().(type) {
    36  	case *actor.Started:
    37  		t.onStarted(c)
    38  	case *actor.Stopping:
    39  		t.onStopping(c)
    40  	case *actor.ReceiveTimeout:
    41  		t.onReceiveTimeout(c)
    42  	case *Initialize:
    43  		t.onInitialize(c, msg)
    44  	case *SubscribeRequest:
    45  		t.onSubscribe(c, msg)
    46  	case *UnsubscribeRequest:
    47  		t.onUnsubscribe(c, msg)
    48  	case *PubSubBatch:
    49  		t.onPubSubBatch(c, msg)
    50  	case *NotifyAboutFailingSubscribersRequest:
    51  		t.onNotifyAboutFailingSubscribers(c, msg)
    52  	case *ClusterTopology:
    53  		t.onClusterTopologyChanged(c, msg)
    54  	}
    55  }
    56  
    57  func (t *TopicActor) onStarted(c actor.Context) {
    58  	t.topic = GetClusterIdentity(c).Identity
    59  	t.topologySubscription = c.ActorSystem().EventStream.Subscribe(func(evt interface{}) {
    60  		if clusterTopology, ok := evt.(*ClusterTopology); ok {
    61  			c.Send(c.Self(), clusterTopology)
    62  		}
    63  	})
    64  
    65  	sub := t.loadSubscriptions(t.topic, c.Logger())
    66  	if sub.Subscribers != nil {
    67  		for _, subscriber := range sub.Subscribers {
    68  			t.subscribers[newSubscribeIdentityStruct(subscriber)] = subscriber
    69  		}
    70  	}
    71  	t.unsubscribeSubscribersOnMembersThatLeft(c)
    72  
    73  	c.Logger().Debug("Topic started", slog.String("topic", t.topic))
    74  }
    75  
    76  func (t *TopicActor) onStopping(c actor.Context) {
    77  	if t.topologySubscription != nil {
    78  		c.ActorSystem().EventStream.Unsubscribe(t.topologySubscription)
    79  		t.topologySubscription = nil
    80  	}
    81  }
    82  
    83  func (t *TopicActor) onReceiveTimeout(c actor.Context) {
    84  	c.Stop(c.Self())
    85  }
    86  
    87  func (t *TopicActor) onInitialize(c actor.Context, msg *Initialize) {
    88  	if msg.IdleTimeout != nil {
    89  		duration := msg.IdleTimeout.AsDuration()
    90  		if duration > 0 {
    91  			c.SetReceiveTimeout(duration)
    92  		}
    93  	}
    94  	c.Respond(&Acknowledge{})
    95  }
    96  
    97  type pidAndSubscriber struct {
    98  	pid        *actor.PID
    99  	subscriber *SubscriberIdentity
   100  }
   101  
   102  // onPubSubBatch handles a PubSubBatch message, sends the message to all subscribers
   103  func (t *TopicActor) onPubSubBatch(c actor.Context, batch *PubSubBatch) {
   104  	// map subscribers to map[address][](pid, subscriber)
   105  	members := make(map[string][]pidAndSubscriber)
   106  	for _, identity := range t.subscribers {
   107  		pid := t.getPID(c, identity)
   108  		if pid != nil {
   109  			members[pid.Address] = append(members[pid.Address], pidAndSubscriber{pid: pid, subscriber: identity})
   110  		}
   111  	}
   112  
   113  	// send message to each member
   114  	for address, member := range members {
   115  		subscribersOnMember := t.getSubscribersForAddress(member)
   116  		deliveryMessage := &DeliverBatchRequest{
   117  			Subscribers: subscribersOnMember,
   118  			PubSubBatch: batch,
   119  			Topic:       t.topic,
   120  		}
   121  		deliveryPid := actor.NewPID(address, PubSubDeliveryName)
   122  		c.Send(deliveryPid, deliveryMessage)
   123  	}
   124  	c.Respond(&PublishResponse{})
   125  }
   126  
   127  // getSubscribersForAddress returns the subscribers for the given member list
   128  func (t *TopicActor) getSubscribersForAddress(members []pidAndSubscriber) *Subscribers {
   129  	subscribers := make([]*SubscriberIdentity, len(members))
   130  	for i, member := range members {
   131  		subscribers[i] = member.subscriber
   132  	}
   133  	return &Subscribers{Subscribers: subscribers}
   134  }
   135  
   136  // getPID returns the PID of the subscriber
   137  func (t *TopicActor) getPID(c actor.Context, subscriber *SubscriberIdentity) *actor.PID {
   138  	if pid := subscriber.GetPid(); pid != nil {
   139  		return pid
   140  	}
   141  
   142  	return t.getClusterIdentityPid(c, subscriber.GetClusterIdentity())
   143  }
   144  
   145  // getClusterIdentityPid returns the PID of the clusterIdentity actor
   146  func (t *TopicActor) getClusterIdentityPid(c actor.Context, identity *ClusterIdentity) *actor.PID {
   147  	if identity == nil {
   148  		return nil
   149  	}
   150  
   151  	return GetCluster(c.ActorSystem()).Get(identity.Identity, identity.Kind)
   152  }
   153  
   154  // onNotifyAboutFailingSubscribers handles a NotifyAboutFailingSubscribersRequest message
   155  func (t *TopicActor) onNotifyAboutFailingSubscribers(c actor.Context, msg *NotifyAboutFailingSubscribersRequest) {
   156  	t.unsubscribeUnreachablePidSubscribers(c, msg.InvalidDeliveries)
   157  	t.logDeliveryErrors(msg.InvalidDeliveries, c.Logger())
   158  	c.Respond(&NotifyAboutFailingSubscribersResponse{})
   159  }
   160  
   161  // logDeliveryErrors logs the delivery errors in one log line
   162  func (t *TopicActor) logDeliveryErrors(reports []*SubscriberDeliveryReport, logger *slog.Logger) {
   163  	if len(reports) > 0 || t.shouldThrottle() == actor.Open {
   164  		subscribers := make([]string, len(reports))
   165  		for i, report := range reports {
   166  			subscribers[i] = report.Subscriber.String()
   167  		}
   168  		logger.Error("Topic following subscribers could not process the batch", slog.String("topic", t.topic), slog.String("subscribers", strings.Join(subscribers, ",")))
   169  	}
   170  }
   171  
   172  // unsubscribeUnreachablePidSubscribers deletes all subscribers that have a PID that is unreachable
   173  func (t *TopicActor) unsubscribeUnreachablePidSubscribers(_ actor.Context, allInvalidDeliveryReports []*SubscriberDeliveryReport) {
   174  	subscribers := make([]subscribeIdentityStruct, 0, len(allInvalidDeliveryReports))
   175  	for _, r := range allInvalidDeliveryReports {
   176  		if r.Subscriber.GetPid() != nil && r.Status == DeliveryStatus_SubscriberNoLongerReachable {
   177  			subscribers = append(subscribers, newSubscribeIdentityStruct(r.Subscriber))
   178  		}
   179  	}
   180  	t.removeSubscribers(subscribers, nil)
   181  }
   182  
   183  // onClusterTopologyChanged handles a ClusterTopology message
   184  func (t *TopicActor) onClusterTopologyChanged(ctx actor.Context, msg *ClusterTopology) {
   185  	if len(msg.Left) > 0 {
   186  		addressMap := make(map[string]struct{})
   187  		for _, member := range msg.Left {
   188  			addressMap[member.Address()] = struct{}{}
   189  		}
   190  
   191  		subscribersThatLeft := make([]subscribeIdentityStruct, 0, len(msg.Left))
   192  
   193  		for identityStruct, identity := range t.subscribers {
   194  			if pid := identity.GetPid(); pid != nil {
   195  				if _, ok := addressMap[pid.Address]; ok {
   196  					subscribersThatLeft = append(subscribersThatLeft, identityStruct)
   197  				}
   198  			}
   199  		}
   200  		t.removeSubscribers(subscribersThatLeft, ctx.Logger())
   201  	}
   202  }
   203  
   204  // unsubscribeSubscribersOnMembersThatLeft removes subscribers that are on members that left the clusterIdentity
   205  func (t *TopicActor) unsubscribeSubscribersOnMembersThatLeft(c actor.Context) {
   206  	members := GetCluster(c.ActorSystem()).MemberList.Members()
   207  	activeMemberAddresses := make(map[string]struct{})
   208  	for _, member := range members.Members() {
   209  		activeMemberAddresses[member.Address()] = struct{}{}
   210  	}
   211  
   212  	subscribersThatLeft := make([]subscribeIdentityStruct, 0)
   213  	for s := range t.subscribers {
   214  		if s.isPID {
   215  			if _, ok := activeMemberAddresses[s.pid.address]; !ok {
   216  				subscribersThatLeft = append(subscribersThatLeft, s)
   217  			}
   218  		}
   219  	}
   220  	t.removeSubscribers(subscribersThatLeft, nil)
   221  }
   222  
   223  // removeSubscribers remove subscribers from the topic
   224  func (t *TopicActor) removeSubscribers(subscribersThatLeft []subscribeIdentityStruct, logger *slog.Logger) {
   225  	if len(subscribersThatLeft) > 0 {
   226  		for _, subscriber := range subscribersThatLeft {
   227  			delete(t.subscribers, subscriber)
   228  		}
   229  		if t.shouldThrottle() == actor.Open {
   230  			logger.Warn("Topic removed subscribers, because they are dead or they are on members that left the clusterIdentity:", slog.String("topic", t.topic), slog.Any("subscribers", subscribersThatLeft))
   231  		}
   232  		t.saveSubscriptionsInTopicActor(logger)
   233  	}
   234  }
   235  
   236  // loadSubscriptions loads the subscriptions for the topic from the subscription store
   237  func (t *TopicActor) loadSubscriptions(topic string, logger *slog.Logger) *Subscribers {
   238  	// TODO: cancellation logic config?
   239  	state, err := t.subscriptionStore.Get(context.Background(), topic)
   240  	if err != nil {
   241  		if t.shouldThrottle() == actor.Open {
   242  			logger.Error("Error when loading subscriptions", slog.String("topic", topic), slog.Any("error", err))
   243  		}
   244  		return &Subscribers{}
   245  	}
   246  	if state == nil {
   247  		return &Subscribers{}
   248  	}
   249  	logger.Debug("Loaded subscriptions for topic", slog.String("topic", topic), slog.Any("subscriptions", state))
   250  	return state
   251  }
   252  
   253  // saveSubscriptionsInTopicActor saves the TopicActor.subscribers for the TopicActor.topic to the subscription store
   254  func (t *TopicActor) saveSubscriptionsInTopicActor(logger *slog.Logger) {
   255  	var subscribers *Subscribers = &Subscribers{Subscribers: maps.Values(t.subscribers)}
   256  
   257  	// TODO: cancellation logic config?
   258  	logger.Debug("Saving subscriptions for topic", slog.String("topic", t.topic), slog.Any("subscriptions", subscribers))
   259  	err := t.subscriptionStore.Set(context.Background(), t.topic, subscribers)
   260  	if err != nil && t.shouldThrottle() == actor.Open {
   261  		logger.Error("Error when saving subscriptions", slog.String("topic", t.topic), slog.Any("error", err))
   262  	}
   263  }
   264  
   265  func (t *TopicActor) onUnsubscribe(c actor.Context, msg *UnsubscribeRequest) {
   266  	delete(t.subscribers, newSubscribeIdentityStruct(msg.Subscriber))
   267  	t.saveSubscriptionsInTopicActor(c.Logger())
   268  	c.Respond(&UnsubscribeResponse{})
   269  }
   270  
   271  func (t *TopicActor) onSubscribe(c actor.Context, msg *SubscribeRequest) {
   272  	t.subscribers[newSubscribeIdentityStruct(msg.Subscriber)] = msg.Subscriber
   273  	c.Logger().Debug("Topic subscribed", slog.String("topic", t.topic), slog.Any("subscriber", msg.Subscriber))
   274  	t.saveSubscriptionsInTopicActor(c.Logger())
   275  	c.Respond(&SubscribeResponse{})
   276  }
   277  
   278  // pidStruct is a struct that represents a PID
   279  // It is used to implement the comparison interface
   280  type pidStruct struct {
   281  	address   string
   282  	id        string
   283  	requestId uint32
   284  }
   285  
   286  // newPIDStruct creates a new pidStruct from a *actor.PID
   287  func newPidStruct(pid *actor.PID) pidStruct {
   288  	return pidStruct{
   289  		address:   pid.Address,
   290  		id:        pid.Id,
   291  		requestId: pid.RequestId,
   292  	}
   293  }
   294  
   295  // toPID converts a pidStruct to a *actor.PID
   296  func (p pidStruct) toPID() *actor.PID {
   297  	return &actor.PID{
   298  		Address:   p.address,
   299  		Id:        p.id,
   300  		RequestId: p.requestId,
   301  	}
   302  }
   303  
   304  type clusterIdentityStruct struct {
   305  	identity string
   306  	kind     string
   307  }
   308  
   309  // newClusterIdentityStruct creates a new clusterIdentityStruct from a *ClusterIdentity
   310  func newClusterIdentityStruct(clusterIdentity *ClusterIdentity) clusterIdentityStruct {
   311  	return clusterIdentityStruct{
   312  		identity: clusterIdentity.Identity,
   313  		kind:     clusterIdentity.Kind,
   314  	}
   315  }
   316  
   317  // toClusterIdentity converts a clusterIdentityStruct to a *ClusterIdentity
   318  func (c clusterIdentityStruct) toClusterIdentity() *ClusterIdentity {
   319  	return &ClusterIdentity{
   320  		Identity: c.identity,
   321  		Kind:     c.kind,
   322  	}
   323  }
   324  
   325  // subscriberIdentityStruct is a struct that represents a SubscriberIdentity
   326  // It is used to implement the comparison interface
   327  type subscribeIdentityStruct struct {
   328  	isPID           bool
   329  	pid             pidStruct
   330  	clusterIdentity clusterIdentityStruct
   331  }
   332  
   333  // newSubscriberIdentityStruct creates a new subscriberIdentityStruct from a *SubscriberIdentity
   334  func newSubscribeIdentityStruct(subscriberIdentity *SubscriberIdentity) subscribeIdentityStruct {
   335  	if subscriberIdentity.GetPid() != nil {
   336  		return subscribeIdentityStruct{
   337  			isPID: true,
   338  			pid:   newPidStruct(subscriberIdentity.GetPid()),
   339  		}
   340  	}
   341  	return subscribeIdentityStruct{
   342  		isPID:           false,
   343  		clusterIdentity: newClusterIdentityStruct(subscriberIdentity.GetClusterIdentity()),
   344  	}
   345  }
   346  
   347  // toSubscriberIdentity converts a subscribeIdentityStruct to a *SubscriberIdentity
   348  func (s subscribeIdentityStruct) toSubscriberIdentity() *SubscriberIdentity {
   349  	if s.isPID {
   350  		return &SubscriberIdentity{
   351  			Identity: &SubscriberIdentity_Pid{Pid: s.pid.toPID()},
   352  		}
   353  	}
   354  	return &SubscriberIdentity{
   355  		Identity: &SubscriberIdentity_ClusterIdentity{ClusterIdentity: s.clusterIdentity.toClusterIdentity()},
   356  	}
   357  }