github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/cluster/pubsub_topic.go (about) 1 package cluster 2 3 import ( 4 "context" 5 "log/slog" 6 "strings" 7 "time" 8 9 "github.com/asynkron/protoactor-go/actor" 10 "github.com/asynkron/protoactor-go/eventstream" 11 "golang.org/x/exp/maps" 12 ) 13 14 const TopicActorKind = "prototopic" 15 16 type TopicActor struct { 17 topic string 18 subscribers map[subscribeIdentityStruct]*SubscriberIdentity 19 subscriptionStore KeyValueStore[*Subscribers] 20 topologySubscription *eventstream.Subscription 21 shouldThrottle actor.ShouldThrottle 22 } 23 24 func NewTopicActor(store KeyValueStore[*Subscribers], logger *slog.Logger) *TopicActor { 25 return &TopicActor{ 26 subscriptionStore: store, 27 subscribers: make(map[subscribeIdentityStruct]*SubscriberIdentity), 28 shouldThrottle: actor.NewThrottleWithLogger(logger, 10, time.Second, func(logger *slog.Logger, count int32) { 29 logger.Info("[TopicActor] Throttled logs", slog.Int("count", int(count))) 30 }), 31 } 32 } 33 34 func (t *TopicActor) Receive(c actor.Context) { 35 switch msg := c.Message().(type) { 36 case *actor.Started: 37 t.onStarted(c) 38 case *actor.Stopping: 39 t.onStopping(c) 40 case *actor.ReceiveTimeout: 41 t.onReceiveTimeout(c) 42 case *Initialize: 43 t.onInitialize(c, msg) 44 case *SubscribeRequest: 45 t.onSubscribe(c, msg) 46 case *UnsubscribeRequest: 47 t.onUnsubscribe(c, msg) 48 case *PubSubBatch: 49 t.onPubSubBatch(c, msg) 50 case *NotifyAboutFailingSubscribersRequest: 51 t.onNotifyAboutFailingSubscribers(c, msg) 52 case *ClusterTopology: 53 t.onClusterTopologyChanged(c, msg) 54 } 55 } 56 57 func (t *TopicActor) onStarted(c actor.Context) { 58 t.topic = GetClusterIdentity(c).Identity 59 t.topologySubscription = c.ActorSystem().EventStream.Subscribe(func(evt interface{}) { 60 if clusterTopology, ok := evt.(*ClusterTopology); ok { 61 c.Send(c.Self(), clusterTopology) 62 } 63 }) 64 65 sub := t.loadSubscriptions(t.topic, c.Logger()) 66 if sub.Subscribers != nil { 67 for _, subscriber := range sub.Subscribers { 68 t.subscribers[newSubscribeIdentityStruct(subscriber)] = subscriber 69 } 70 } 71 t.unsubscribeSubscribersOnMembersThatLeft(c) 72 73 c.Logger().Debug("Topic started", slog.String("topic", t.topic)) 74 } 75 76 func (t *TopicActor) onStopping(c actor.Context) { 77 if t.topologySubscription != nil { 78 c.ActorSystem().EventStream.Unsubscribe(t.topologySubscription) 79 t.topologySubscription = nil 80 } 81 } 82 83 func (t *TopicActor) onReceiveTimeout(c actor.Context) { 84 c.Stop(c.Self()) 85 } 86 87 func (t *TopicActor) onInitialize(c actor.Context, msg *Initialize) { 88 if msg.IdleTimeout != nil { 89 duration := msg.IdleTimeout.AsDuration() 90 if duration > 0 { 91 c.SetReceiveTimeout(duration) 92 } 93 } 94 c.Respond(&Acknowledge{}) 95 } 96 97 type pidAndSubscriber struct { 98 pid *actor.PID 99 subscriber *SubscriberIdentity 100 } 101 102 // onPubSubBatch handles a PubSubBatch message, sends the message to all subscribers 103 func (t *TopicActor) onPubSubBatch(c actor.Context, batch *PubSubBatch) { 104 // map subscribers to map[address][](pid, subscriber) 105 members := make(map[string][]pidAndSubscriber) 106 for _, identity := range t.subscribers { 107 pid := t.getPID(c, identity) 108 if pid != nil { 109 members[pid.Address] = append(members[pid.Address], pidAndSubscriber{pid: pid, subscriber: identity}) 110 } 111 } 112 113 // send message to each member 114 for address, member := range members { 115 subscribersOnMember := t.getSubscribersForAddress(member) 116 deliveryMessage := &DeliverBatchRequest{ 117 Subscribers: subscribersOnMember, 118 PubSubBatch: batch, 119 Topic: t.topic, 120 } 121 deliveryPid := actor.NewPID(address, PubSubDeliveryName) 122 c.Send(deliveryPid, deliveryMessage) 123 } 124 c.Respond(&PublishResponse{}) 125 } 126 127 // getSubscribersForAddress returns the subscribers for the given member list 128 func (t *TopicActor) getSubscribersForAddress(members []pidAndSubscriber) *Subscribers { 129 subscribers := make([]*SubscriberIdentity, len(members)) 130 for i, member := range members { 131 subscribers[i] = member.subscriber 132 } 133 return &Subscribers{Subscribers: subscribers} 134 } 135 136 // getPID returns the PID of the subscriber 137 func (t *TopicActor) getPID(c actor.Context, subscriber *SubscriberIdentity) *actor.PID { 138 if pid := subscriber.GetPid(); pid != nil { 139 return pid 140 } 141 142 return t.getClusterIdentityPid(c, subscriber.GetClusterIdentity()) 143 } 144 145 // getClusterIdentityPid returns the PID of the clusterIdentity actor 146 func (t *TopicActor) getClusterIdentityPid(c actor.Context, identity *ClusterIdentity) *actor.PID { 147 if identity == nil { 148 return nil 149 } 150 151 return GetCluster(c.ActorSystem()).Get(identity.Identity, identity.Kind) 152 } 153 154 // onNotifyAboutFailingSubscribers handles a NotifyAboutFailingSubscribersRequest message 155 func (t *TopicActor) onNotifyAboutFailingSubscribers(c actor.Context, msg *NotifyAboutFailingSubscribersRequest) { 156 t.unsubscribeUnreachablePidSubscribers(c, msg.InvalidDeliveries) 157 t.logDeliveryErrors(msg.InvalidDeliveries, c.Logger()) 158 c.Respond(&NotifyAboutFailingSubscribersResponse{}) 159 } 160 161 // logDeliveryErrors logs the delivery errors in one log line 162 func (t *TopicActor) logDeliveryErrors(reports []*SubscriberDeliveryReport, logger *slog.Logger) { 163 if len(reports) > 0 || t.shouldThrottle() == actor.Open { 164 subscribers := make([]string, len(reports)) 165 for i, report := range reports { 166 subscribers[i] = report.Subscriber.String() 167 } 168 logger.Error("Topic following subscribers could not process the batch", slog.String("topic", t.topic), slog.String("subscribers", strings.Join(subscribers, ","))) 169 } 170 } 171 172 // unsubscribeUnreachablePidSubscribers deletes all subscribers that have a PID that is unreachable 173 func (t *TopicActor) unsubscribeUnreachablePidSubscribers(_ actor.Context, allInvalidDeliveryReports []*SubscriberDeliveryReport) { 174 subscribers := make([]subscribeIdentityStruct, 0, len(allInvalidDeliveryReports)) 175 for _, r := range allInvalidDeliveryReports { 176 if r.Subscriber.GetPid() != nil && r.Status == DeliveryStatus_SubscriberNoLongerReachable { 177 subscribers = append(subscribers, newSubscribeIdentityStruct(r.Subscriber)) 178 } 179 } 180 t.removeSubscribers(subscribers, nil) 181 } 182 183 // onClusterTopologyChanged handles a ClusterTopology message 184 func (t *TopicActor) onClusterTopologyChanged(ctx actor.Context, msg *ClusterTopology) { 185 if len(msg.Left) > 0 { 186 addressMap := make(map[string]struct{}) 187 for _, member := range msg.Left { 188 addressMap[member.Address()] = struct{}{} 189 } 190 191 subscribersThatLeft := make([]subscribeIdentityStruct, 0, len(msg.Left)) 192 193 for identityStruct, identity := range t.subscribers { 194 if pid := identity.GetPid(); pid != nil { 195 if _, ok := addressMap[pid.Address]; ok { 196 subscribersThatLeft = append(subscribersThatLeft, identityStruct) 197 } 198 } 199 } 200 t.removeSubscribers(subscribersThatLeft, ctx.Logger()) 201 } 202 } 203 204 // unsubscribeSubscribersOnMembersThatLeft removes subscribers that are on members that left the clusterIdentity 205 func (t *TopicActor) unsubscribeSubscribersOnMembersThatLeft(c actor.Context) { 206 members := GetCluster(c.ActorSystem()).MemberList.Members() 207 activeMemberAddresses := make(map[string]struct{}) 208 for _, member := range members.Members() { 209 activeMemberAddresses[member.Address()] = struct{}{} 210 } 211 212 subscribersThatLeft := make([]subscribeIdentityStruct, 0) 213 for s := range t.subscribers { 214 if s.isPID { 215 if _, ok := activeMemberAddresses[s.pid.address]; !ok { 216 subscribersThatLeft = append(subscribersThatLeft, s) 217 } 218 } 219 } 220 t.removeSubscribers(subscribersThatLeft, nil) 221 } 222 223 // removeSubscribers remove subscribers from the topic 224 func (t *TopicActor) removeSubscribers(subscribersThatLeft []subscribeIdentityStruct, logger *slog.Logger) { 225 if len(subscribersThatLeft) > 0 { 226 for _, subscriber := range subscribersThatLeft { 227 delete(t.subscribers, subscriber) 228 } 229 if t.shouldThrottle() == actor.Open { 230 logger.Warn("Topic removed subscribers, because they are dead or they are on members that left the clusterIdentity:", slog.String("topic", t.topic), slog.Any("subscribers", subscribersThatLeft)) 231 } 232 t.saveSubscriptionsInTopicActor(logger) 233 } 234 } 235 236 // loadSubscriptions loads the subscriptions for the topic from the subscription store 237 func (t *TopicActor) loadSubscriptions(topic string, logger *slog.Logger) *Subscribers { 238 // TODO: cancellation logic config? 239 state, err := t.subscriptionStore.Get(context.Background(), topic) 240 if err != nil { 241 if t.shouldThrottle() == actor.Open { 242 logger.Error("Error when loading subscriptions", slog.String("topic", topic), slog.Any("error", err)) 243 } 244 return &Subscribers{} 245 } 246 if state == nil { 247 return &Subscribers{} 248 } 249 logger.Debug("Loaded subscriptions for topic", slog.String("topic", topic), slog.Any("subscriptions", state)) 250 return state 251 } 252 253 // saveSubscriptionsInTopicActor saves the TopicActor.subscribers for the TopicActor.topic to the subscription store 254 func (t *TopicActor) saveSubscriptionsInTopicActor(logger *slog.Logger) { 255 var subscribers *Subscribers = &Subscribers{Subscribers: maps.Values(t.subscribers)} 256 257 // TODO: cancellation logic config? 258 logger.Debug("Saving subscriptions for topic", slog.String("topic", t.topic), slog.Any("subscriptions", subscribers)) 259 err := t.subscriptionStore.Set(context.Background(), t.topic, subscribers) 260 if err != nil && t.shouldThrottle() == actor.Open { 261 logger.Error("Error when saving subscriptions", slog.String("topic", t.topic), slog.Any("error", err)) 262 } 263 } 264 265 func (t *TopicActor) onUnsubscribe(c actor.Context, msg *UnsubscribeRequest) { 266 delete(t.subscribers, newSubscribeIdentityStruct(msg.Subscriber)) 267 t.saveSubscriptionsInTopicActor(c.Logger()) 268 c.Respond(&UnsubscribeResponse{}) 269 } 270 271 func (t *TopicActor) onSubscribe(c actor.Context, msg *SubscribeRequest) { 272 t.subscribers[newSubscribeIdentityStruct(msg.Subscriber)] = msg.Subscriber 273 c.Logger().Debug("Topic subscribed", slog.String("topic", t.topic), slog.Any("subscriber", msg.Subscriber)) 274 t.saveSubscriptionsInTopicActor(c.Logger()) 275 c.Respond(&SubscribeResponse{}) 276 } 277 278 // pidStruct is a struct that represents a PID 279 // It is used to implement the comparison interface 280 type pidStruct struct { 281 address string 282 id string 283 requestId uint32 284 } 285 286 // newPIDStruct creates a new pidStruct from a *actor.PID 287 func newPidStruct(pid *actor.PID) pidStruct { 288 return pidStruct{ 289 address: pid.Address, 290 id: pid.Id, 291 requestId: pid.RequestId, 292 } 293 } 294 295 // toPID converts a pidStruct to a *actor.PID 296 func (p pidStruct) toPID() *actor.PID { 297 return &actor.PID{ 298 Address: p.address, 299 Id: p.id, 300 RequestId: p.requestId, 301 } 302 } 303 304 type clusterIdentityStruct struct { 305 identity string 306 kind string 307 } 308 309 // newClusterIdentityStruct creates a new clusterIdentityStruct from a *ClusterIdentity 310 func newClusterIdentityStruct(clusterIdentity *ClusterIdentity) clusterIdentityStruct { 311 return clusterIdentityStruct{ 312 identity: clusterIdentity.Identity, 313 kind: clusterIdentity.Kind, 314 } 315 } 316 317 // toClusterIdentity converts a clusterIdentityStruct to a *ClusterIdentity 318 func (c clusterIdentityStruct) toClusterIdentity() *ClusterIdentity { 319 return &ClusterIdentity{ 320 Identity: c.identity, 321 Kind: c.kind, 322 } 323 } 324 325 // subscriberIdentityStruct is a struct that represents a SubscriberIdentity 326 // It is used to implement the comparison interface 327 type subscribeIdentityStruct struct { 328 isPID bool 329 pid pidStruct 330 clusterIdentity clusterIdentityStruct 331 } 332 333 // newSubscriberIdentityStruct creates a new subscriberIdentityStruct from a *SubscriberIdentity 334 func newSubscribeIdentityStruct(subscriberIdentity *SubscriberIdentity) subscribeIdentityStruct { 335 if subscriberIdentity.GetPid() != nil { 336 return subscribeIdentityStruct{ 337 isPID: true, 338 pid: newPidStruct(subscriberIdentity.GetPid()), 339 } 340 } 341 return subscribeIdentityStruct{ 342 isPID: false, 343 clusterIdentity: newClusterIdentityStruct(subscriberIdentity.GetClusterIdentity()), 344 } 345 } 346 347 // toSubscriberIdentity converts a subscribeIdentityStruct to a *SubscriberIdentity 348 func (s subscribeIdentityStruct) toSubscriberIdentity() *SubscriberIdentity { 349 if s.isPID { 350 return &SubscriberIdentity{ 351 Identity: &SubscriberIdentity_Pid{Pid: s.pid.toPID()}, 352 } 353 } 354 return &SubscriberIdentity{ 355 Identity: &SubscriberIdentity_ClusterIdentity{ClusterIdentity: s.clusterIdentity.toClusterIdentity()}, 356 } 357 }