github.com/pure-x-eth/consensus_tm@v0.0.0-20230502163723-e3c2ff987250/libs/pubsub/pubsub.go (about) 1 // Package pubsub implements a pub-sub model with a single publisher (Server) 2 // and multiple subscribers (clients). 3 // 4 // Though you can have multiple publishers by sharing a pointer to a server or 5 // by giving the same channel to each publisher and publishing messages from 6 // that channel (fan-in). 7 // 8 // Clients subscribe for messages, which could be of any type, using a query. 9 // When some message is published, we match it with all queries. If there is a 10 // match, this message will be pushed to all clients, subscribed to that query. 11 // See query subpackage for our implementation. 12 // 13 // Example: 14 // 15 // q, err := query.New("account.name='John'") 16 // if err != nil { 17 // return err 18 // } 19 // ctx, cancel := context.WithTimeout(context.Background(), 1 * time.Second) 20 // defer cancel() 21 // subscription, err := pubsub.Subscribe(ctx, "johns-transactions", q) 22 // if err != nil { 23 // return err 24 // } 25 // 26 // for { 27 // select { 28 // case msg <- subscription.Out(): 29 // // handle msg.Data() and msg.Events() 30 // case <-subscription.Cancelled(): 31 // return subscription.Err() 32 // } 33 // } 34 package pubsub 35 36 import ( 37 "context" 38 "errors" 39 "fmt" 40 41 "github.com/pure-x-eth/consensus_tm/libs/service" 42 tmsync "github.com/pure-x-eth/consensus_tm/libs/sync" 43 ) 44 45 type operation int 46 47 const ( 48 sub operation = iota 49 pub 50 unsub 51 shutdown 52 ) 53 54 var ( 55 // ErrSubscriptionNotFound is returned when a client tries to unsubscribe 56 // from not existing subscription. 57 ErrSubscriptionNotFound = errors.New("subscription not found") 58 59 // ErrAlreadySubscribed is returned when a client tries to subscribe twice or 60 // more using the same query. 61 ErrAlreadySubscribed = errors.New("already subscribed") 62 ) 63 64 // Query defines an interface for a query to be used for subscribing. A query 65 // matches against a map of events. Each key in this map is a composite of the 66 // even type and an attribute key (e.g. "{eventType}.{eventAttrKey}") and the 67 // values are the event values that are contained under that relationship. This 68 // allows event types to repeat themselves with the same set of keys and 69 // different values. 70 type Query interface { 71 Matches(events map[string][]string) (bool, error) 72 String() string 73 } 74 75 type cmd struct { 76 op operation 77 78 // subscribe, unsubscribe 79 query Query 80 subscription *Subscription 81 clientID string 82 83 // publish 84 msg interface{} 85 events map[string][]string 86 } 87 88 // Server allows clients to subscribe/unsubscribe for messages, publishing 89 // messages with or without events, and manages internal state. 90 type Server struct { 91 service.BaseService 92 93 cmds chan cmd 94 cmdsCap int 95 96 // check if we have subscription before 97 // subscribing or unsubscribing 98 mtx tmsync.RWMutex 99 subscriptions map[string]map[string]struct{} // subscriber -> query (string) -> empty struct 100 } 101 102 // Option sets a parameter for the server. 103 type Option func(*Server) 104 105 // NewServer returns a new server. See the commentary on the Option functions 106 // for a detailed description of how to configure buffering. If no options are 107 // provided, the resulting server's queue is unbuffered. 108 func NewServer(options ...Option) *Server { 109 s := &Server{ 110 subscriptions: make(map[string]map[string]struct{}), 111 } 112 s.BaseService = *service.NewBaseService(nil, "PubSub", s) 113 114 for _, option := range options { 115 option(s) 116 } 117 118 // if BufferCapacity option was not set, the channel is unbuffered 119 s.cmds = make(chan cmd, s.cmdsCap) 120 121 return s 122 } 123 124 // BufferCapacity allows you to specify capacity for the internal server's 125 // queue. Since the server, given Y subscribers, could only process X messages, 126 // this option could be used to survive spikes (e.g. high amount of 127 // transactions during peak hours). 128 func BufferCapacity(cap int) Option { 129 return func(s *Server) { 130 if cap > 0 { 131 s.cmdsCap = cap 132 } 133 } 134 } 135 136 // BufferCapacity returns capacity of the internal server's queue. 137 func (s *Server) BufferCapacity() int { 138 return s.cmdsCap 139 } 140 141 // Subscribe creates a subscription for the given client. 142 // 143 // An error will be returned to the caller if the context is canceled or if 144 // subscription already exist for pair clientID and query. 145 // 146 // outCapacity can be used to set a capacity for Subscription#Out channel (1 by 147 // default). Panics if outCapacity is less than or equal to zero. If you want 148 // an unbuffered channel, use SubscribeUnbuffered. 149 func (s *Server) Subscribe( 150 ctx context.Context, 151 clientID string, 152 query Query, 153 outCapacity ...int) (*Subscription, error) { 154 outCap := 1 155 if len(outCapacity) > 0 { 156 if outCapacity[0] <= 0 { 157 panic("Negative or zero capacity. Use SubscribeUnbuffered if you want an unbuffered channel") 158 } 159 outCap = outCapacity[0] 160 } 161 162 return s.subscribe(ctx, clientID, query, outCap) 163 } 164 165 // SubscribeUnbuffered does the same as Subscribe, except it returns a 166 // subscription with unbuffered channel. Use with caution as it can freeze the 167 // server. 168 func (s *Server) SubscribeUnbuffered(ctx context.Context, clientID string, query Query) (*Subscription, error) { 169 return s.subscribe(ctx, clientID, query, 0) 170 } 171 172 func (s *Server) subscribe(ctx context.Context, clientID string, query Query, outCapacity int) (*Subscription, error) { 173 s.mtx.RLock() 174 clientSubscriptions, ok := s.subscriptions[clientID] 175 if ok { 176 _, ok = clientSubscriptions[query.String()] 177 } 178 s.mtx.RUnlock() 179 if ok { 180 return nil, ErrAlreadySubscribed 181 } 182 183 subscription := NewSubscription(outCapacity) 184 select { 185 case s.cmds <- cmd{op: sub, clientID: clientID, query: query, subscription: subscription}: 186 s.mtx.Lock() 187 if _, ok = s.subscriptions[clientID]; !ok { 188 s.subscriptions[clientID] = make(map[string]struct{}) 189 } 190 s.subscriptions[clientID][query.String()] = struct{}{} 191 s.mtx.Unlock() 192 return subscription, nil 193 case <-ctx.Done(): 194 return nil, ctx.Err() 195 case <-s.Quit(): 196 return nil, errors.New("service is shutting down") 197 } 198 } 199 200 // Unsubscribe removes the subscription on the given query. An error will be 201 // returned to the caller if the context is canceled or if subscription does 202 // not exist. 203 func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { 204 s.mtx.RLock() 205 clientSubscriptions, ok := s.subscriptions[clientID] 206 if ok { 207 _, ok = clientSubscriptions[query.String()] 208 } 209 s.mtx.RUnlock() 210 if !ok { 211 return ErrSubscriptionNotFound 212 } 213 214 select { 215 case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: 216 s.mtx.Lock() 217 delete(clientSubscriptions, query.String()) 218 if len(clientSubscriptions) == 0 { 219 delete(s.subscriptions, clientID) 220 } 221 s.mtx.Unlock() 222 return nil 223 case <-ctx.Done(): 224 return ctx.Err() 225 case <-s.Quit(): 226 return nil 227 } 228 } 229 230 // UnsubscribeAll removes all client subscriptions. An error will be returned 231 // to the caller if the context is canceled or if subscription does not exist. 232 func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error { 233 s.mtx.RLock() 234 _, ok := s.subscriptions[clientID] 235 s.mtx.RUnlock() 236 if !ok { 237 return ErrSubscriptionNotFound 238 } 239 240 select { 241 case s.cmds <- cmd{op: unsub, clientID: clientID}: 242 s.mtx.Lock() 243 delete(s.subscriptions, clientID) 244 s.mtx.Unlock() 245 return nil 246 case <-ctx.Done(): 247 return ctx.Err() 248 case <-s.Quit(): 249 return nil 250 } 251 } 252 253 // NumClients returns the number of clients. 254 func (s *Server) NumClients() int { 255 s.mtx.RLock() 256 defer s.mtx.RUnlock() 257 return len(s.subscriptions) 258 } 259 260 // NumClientSubscriptions returns the number of subscriptions the client has. 261 func (s *Server) NumClientSubscriptions(clientID string) int { 262 s.mtx.RLock() 263 defer s.mtx.RUnlock() 264 return len(s.subscriptions[clientID]) 265 } 266 267 // Publish publishes the given message. An error will be returned to the caller 268 // if the context is canceled. 269 func (s *Server) Publish(ctx context.Context, msg interface{}) error { 270 return s.PublishWithEvents(ctx, msg, make(map[string][]string)) 271 } 272 273 // PublishWithEvents publishes the given message with the set of events. The set 274 // is matched with clients queries. If there is a match, the message is sent to 275 // the client. 276 func (s *Server) PublishWithEvents(ctx context.Context, msg interface{}, events map[string][]string) error { 277 select { 278 case s.cmds <- cmd{op: pub, msg: msg, events: events}: 279 return nil 280 case <-ctx.Done(): 281 return ctx.Err() 282 case <-s.Quit(): 283 return nil 284 } 285 } 286 287 // OnStop implements Service.OnStop by shutting down the server. 288 func (s *Server) OnStop() { 289 s.cmds <- cmd{op: shutdown} 290 } 291 292 // NOTE: not goroutine safe 293 type state struct { 294 // query string -> client -> subscription 295 subscriptions map[string]map[string]*Subscription 296 // query string -> queryPlusRefCount 297 queries map[string]*queryPlusRefCount 298 } 299 300 // queryPlusRefCount holds a pointer to a query and reference counter. When 301 // refCount is zero, query will be removed. 302 type queryPlusRefCount struct { 303 q Query 304 refCount int 305 } 306 307 // OnStart implements Service.OnStart by starting the server. 308 func (s *Server) OnStart() error { 309 go s.loop(state{ 310 subscriptions: make(map[string]map[string]*Subscription), 311 queries: make(map[string]*queryPlusRefCount), 312 }) 313 return nil 314 } 315 316 // OnReset implements Service.OnReset 317 func (s *Server) OnReset() error { 318 return nil 319 } 320 321 func (s *Server) loop(state state) { 322 loop: 323 for cmd := range s.cmds { 324 switch cmd.op { 325 case unsub: 326 if cmd.query != nil { 327 state.remove(cmd.clientID, cmd.query.String(), ErrUnsubscribed) 328 } else { 329 state.removeClient(cmd.clientID, ErrUnsubscribed) 330 } 331 case shutdown: 332 state.removeAll(nil) 333 break loop 334 case sub: 335 state.add(cmd.clientID, cmd.query, cmd.subscription) 336 case pub: 337 if err := state.send(cmd.msg, cmd.events); err != nil { 338 s.Logger.Error("Error querying for events", "err", err) 339 } 340 } 341 } 342 } 343 344 func (state *state) add(clientID string, q Query, subscription *Subscription) { 345 qStr := q.String() 346 347 // initialize subscription for this client per query if needed 348 if _, ok := state.subscriptions[qStr]; !ok { 349 state.subscriptions[qStr] = make(map[string]*Subscription) 350 } 351 // create subscription 352 state.subscriptions[qStr][clientID] = subscription 353 354 // initialize query if needed 355 if _, ok := state.queries[qStr]; !ok { 356 state.queries[qStr] = &queryPlusRefCount{q: q, refCount: 0} 357 } 358 // increment reference counter 359 state.queries[qStr].refCount++ 360 } 361 362 func (state *state) remove(clientID string, qStr string, reason error) { 363 clientSubscriptions, ok := state.subscriptions[qStr] 364 if !ok { 365 return 366 } 367 368 subscription, ok := clientSubscriptions[clientID] 369 if !ok { 370 return 371 } 372 373 subscription.cancel(reason) 374 375 // remove client from query map. 376 // if query has no other clients subscribed, remove it. 377 delete(state.subscriptions[qStr], clientID) 378 if len(state.subscriptions[qStr]) == 0 { 379 delete(state.subscriptions, qStr) 380 } 381 382 // decrease ref counter in queries 383 state.queries[qStr].refCount-- 384 // remove the query if nobody else is using it 385 if state.queries[qStr].refCount == 0 { 386 delete(state.queries, qStr) 387 } 388 } 389 390 func (state *state) removeClient(clientID string, reason error) { 391 for qStr, clientSubscriptions := range state.subscriptions { 392 if _, ok := clientSubscriptions[clientID]; ok { 393 state.remove(clientID, qStr, reason) 394 } 395 } 396 } 397 398 func (state *state) removeAll(reason error) { 399 for qStr, clientSubscriptions := range state.subscriptions { 400 for clientID := range clientSubscriptions { 401 state.remove(clientID, qStr, reason) 402 } 403 } 404 } 405 406 func (state *state) send(msg interface{}, events map[string][]string) error { 407 for qStr, clientSubscriptions := range state.subscriptions { 408 q := state.queries[qStr].q 409 410 match, err := q.Matches(events) 411 if err != nil { 412 return fmt.Errorf("failed to match against query %s: %w", q.String(), err) 413 } 414 415 if match { 416 for clientID, subscription := range clientSubscriptions { 417 if cap(subscription.out) == 0 { 418 // block on unbuffered channel 419 subscription.out <- NewMessage(msg, events) 420 } else { 421 // don't block on buffered channels 422 select { 423 case subscription.out <- NewMessage(msg, events): 424 default: 425 state.remove(clientID, qStr, ErrOutOfCapacity) 426 } 427 } 428 } 429 } 430 } 431 432 return nil 433 }