github.com/evdatsion/aphelion-dpos-bft@v0.32.1/libs/pubsub/pubsub.go (about) 1 // Package pubsub implements a pub-sub model with a single publisher (Server) 2 // and multiple subscribers (clients). 3 // 4 // Though you can have multiple publishers by sharing a pointer to a server or 5 // by giving the same channel to each publisher and publishing messages from 6 // that channel (fan-in). 7 // 8 // Clients subscribe for messages, which could be of any type, using a query. 9 // When some message is published, we match it with all queries. If there is a 10 // match, this message will be pushed to all clients, subscribed to that query. 11 // See query subpackage for our implementation. 12 // 13 // Example: 14 // 15 // q, err := query.New("account.name='John'") 16 // if err != nil { 17 // return err 18 // } 19 // ctx, cancel := context.WithTimeout(context.Background(), 1 * time.Second) 20 // defer cancel() 21 // subscription, err := pubsub.Subscribe(ctx, "johns-transactions", q) 22 // if err != nil { 23 // return err 24 // } 25 // 26 // for { 27 // select { 28 // case msg <- subscription.Out(): 29 // // handle msg.Data() and msg.Events() 30 // case <-subscription.Cancelled(): 31 // return subscription.Err() 32 // } 33 // } 34 // 35 package pubsub 36 37 import ( 38 "context" 39 "errors" 40 "sync" 41 42 cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common" 43 ) 44 45 type operation int 46 47 const ( 48 sub operation = iota 49 pub 50 unsub 51 shutdown 52 ) 53 54 var ( 55 // ErrSubscriptionNotFound is returned when a client tries to unsubscribe 56 // from not existing subscription. 57 ErrSubscriptionNotFound = errors.New("subscription not found") 58 59 // ErrAlreadySubscribed is returned when a client tries to subscribe twice or 60 // more using the same query. 61 ErrAlreadySubscribed = errors.New("already subscribed") 62 ) 63 64 // Query defines an interface for a query to be used for subscribing. A query 65 // matches against a map of events. Each key in this map is a composite of the 66 // even type and an attribute key (e.g. "{eventType}.{eventAttrKey}") and the 67 // values are the event values that are contained under that relationship. This 68 // allows event types to repeat themselves with the same set of keys and 69 // different values. 70 type Query interface { 71 Matches(events map[string][]string) bool 72 String() string 73 } 74 75 type cmd struct { 76 op operation 77 78 // subscribe, unsubscribe 79 query Query 80 subscription *Subscription 81 clientID string 82 83 // publish 84 msg interface{} 85 events map[string][]string 86 } 87 88 // Server allows clients to subscribe/unsubscribe for messages, publishing 89 // messages with or without events, and manages internal state. 90 type Server struct { 91 cmn.BaseService 92 93 cmds chan cmd 94 cmdsCap int 95 96 // check if we have subscription before 97 // subscribing or unsubscribing 98 mtx sync.RWMutex 99 subscriptions map[string]map[string]struct{} // subscriber -> query (string) -> empty struct 100 } 101 102 // Option sets a parameter for the server. 103 type Option func(*Server) 104 105 // NewServer returns a new server. See the commentary on the Option functions 106 // for a detailed description of how to configure buffering. If no options are 107 // provided, the resulting server's queue is unbuffered. 108 func NewServer(options ...Option) *Server { 109 s := &Server{ 110 subscriptions: make(map[string]map[string]struct{}), 111 } 112 s.BaseService = *cmn.NewBaseService(nil, "PubSub", s) 113 114 for _, option := range options { 115 option(s) 116 } 117 118 // if BufferCapacity option was not set, the channel is unbuffered 119 s.cmds = make(chan cmd, s.cmdsCap) 120 121 return s 122 } 123 124 // BufferCapacity allows you to specify capacity for the internal server's 125 // queue. Since the server, given Y subscribers, could only process X messages, 126 // this option could be used to survive spikes (e.g. high amount of 127 // transactions during peak hours). 128 func BufferCapacity(cap int) Option { 129 return func(s *Server) { 130 if cap > 0 { 131 s.cmdsCap = cap 132 } 133 } 134 } 135 136 // BufferCapacity returns capacity of the internal server's queue. 137 func (s *Server) BufferCapacity() int { 138 return s.cmdsCap 139 } 140 141 // Subscribe creates a subscription for the given client. 142 // 143 // An error will be returned to the caller if the context is canceled or if 144 // subscription already exist for pair clientID and query. 145 // 146 // outCapacity can be used to set a capacity for Subscription#Out channel (1 by 147 // default). Panics if outCapacity is less than or equal to zero. If you want 148 // an unbuffered channel, use SubscribeUnbuffered. 149 func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, outCapacity ...int) (*Subscription, error) { 150 outCap := 1 151 if len(outCapacity) > 0 { 152 if outCapacity[0] <= 0 { 153 panic("Negative or zero capacity. Use SubscribeUnbuffered if you want an unbuffered channel") 154 } 155 outCap = outCapacity[0] 156 } 157 158 return s.subscribe(ctx, clientID, query, outCap) 159 } 160 161 // SubscribeUnbuffered does the same as Subscribe, except it returns a 162 // subscription with unbuffered channel. Use with caution as it can freeze the 163 // server. 164 func (s *Server) SubscribeUnbuffered(ctx context.Context, clientID string, query Query) (*Subscription, error) { 165 return s.subscribe(ctx, clientID, query, 0) 166 } 167 168 func (s *Server) subscribe(ctx context.Context, clientID string, query Query, outCapacity int) (*Subscription, error) { 169 s.mtx.RLock() 170 clientSubscriptions, ok := s.subscriptions[clientID] 171 if ok { 172 _, ok = clientSubscriptions[query.String()] 173 } 174 s.mtx.RUnlock() 175 if ok { 176 return nil, ErrAlreadySubscribed 177 } 178 179 subscription := NewSubscription(outCapacity) 180 select { 181 case s.cmds <- cmd{op: sub, clientID: clientID, query: query, subscription: subscription}: 182 s.mtx.Lock() 183 if _, ok = s.subscriptions[clientID]; !ok { 184 s.subscriptions[clientID] = make(map[string]struct{}) 185 } 186 s.subscriptions[clientID][query.String()] = struct{}{} 187 s.mtx.Unlock() 188 return subscription, nil 189 case <-ctx.Done(): 190 return nil, ctx.Err() 191 case <-s.Quit(): 192 return nil, nil 193 } 194 } 195 196 // Unsubscribe removes the subscription on the given query. An error will be 197 // returned to the caller if the context is canceled or if subscription does 198 // not exist. 199 func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { 200 s.mtx.RLock() 201 clientSubscriptions, ok := s.subscriptions[clientID] 202 if ok { 203 _, ok = clientSubscriptions[query.String()] 204 } 205 s.mtx.RUnlock() 206 if !ok { 207 return ErrSubscriptionNotFound 208 } 209 210 select { 211 case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: 212 s.mtx.Lock() 213 delete(clientSubscriptions, query.String()) 214 if len(clientSubscriptions) == 0 { 215 delete(s.subscriptions, clientID) 216 } 217 s.mtx.Unlock() 218 return nil 219 case <-ctx.Done(): 220 return ctx.Err() 221 case <-s.Quit(): 222 return nil 223 } 224 } 225 226 // UnsubscribeAll removes all client subscriptions. An error will be returned 227 // to the caller if the context is canceled or if subscription does not exist. 228 func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error { 229 s.mtx.RLock() 230 _, ok := s.subscriptions[clientID] 231 s.mtx.RUnlock() 232 if !ok { 233 return ErrSubscriptionNotFound 234 } 235 236 select { 237 case s.cmds <- cmd{op: unsub, clientID: clientID}: 238 s.mtx.Lock() 239 delete(s.subscriptions, clientID) 240 s.mtx.Unlock() 241 return nil 242 case <-ctx.Done(): 243 return ctx.Err() 244 case <-s.Quit(): 245 return nil 246 } 247 } 248 249 // NumClients returns the number of clients. 250 func (s *Server) NumClients() int { 251 s.mtx.RLock() 252 defer s.mtx.RUnlock() 253 return len(s.subscriptions) 254 } 255 256 // NumClientSubscriptions returns the number of subscriptions the client has. 257 func (s *Server) NumClientSubscriptions(clientID string) int { 258 s.mtx.RLock() 259 defer s.mtx.RUnlock() 260 return len(s.subscriptions[clientID]) 261 } 262 263 // Publish publishes the given message. An error will be returned to the caller 264 // if the context is canceled. 265 func (s *Server) Publish(ctx context.Context, msg interface{}) error { 266 return s.PublishWithEvents(ctx, msg, make(map[string][]string)) 267 } 268 269 // PublishWithEvents publishes the given message with the set of events. The set 270 // is matched with clients queries. If there is a match, the message is sent to 271 // the client. 272 func (s *Server) PublishWithEvents(ctx context.Context, msg interface{}, events map[string][]string) error { 273 select { 274 case s.cmds <- cmd{op: pub, msg: msg, events: events}: 275 return nil 276 case <-ctx.Done(): 277 return ctx.Err() 278 case <-s.Quit(): 279 return nil 280 } 281 } 282 283 // OnStop implements Service.OnStop by shutting down the server. 284 func (s *Server) OnStop() { 285 s.cmds <- cmd{op: shutdown} 286 } 287 288 // NOTE: not goroutine safe 289 type state struct { 290 // query string -> client -> subscription 291 subscriptions map[string]map[string]*Subscription 292 // query string -> queryPlusRefCount 293 queries map[string]*queryPlusRefCount 294 } 295 296 // queryPlusRefCount holds a pointer to a query and reference counter. When 297 // refCount is zero, query will be removed. 298 type queryPlusRefCount struct { 299 q Query 300 refCount int 301 } 302 303 // OnStart implements Service.OnStart by starting the server. 304 func (s *Server) OnStart() error { 305 go s.loop(state{ 306 subscriptions: make(map[string]map[string]*Subscription), 307 queries: make(map[string]*queryPlusRefCount), 308 }) 309 return nil 310 } 311 312 // OnReset implements Service.OnReset 313 func (s *Server) OnReset() error { 314 return nil 315 } 316 317 func (s *Server) loop(state state) { 318 loop: 319 for cmd := range s.cmds { 320 switch cmd.op { 321 case unsub: 322 if cmd.query != nil { 323 state.remove(cmd.clientID, cmd.query.String(), ErrUnsubscribed) 324 } else { 325 state.removeClient(cmd.clientID, ErrUnsubscribed) 326 } 327 case shutdown: 328 state.removeAll(nil) 329 break loop 330 case sub: 331 state.add(cmd.clientID, cmd.query, cmd.subscription) 332 case pub: 333 state.send(cmd.msg, cmd.events) 334 } 335 } 336 } 337 338 func (state *state) add(clientID string, q Query, subscription *Subscription) { 339 qStr := q.String() 340 341 // initialize subscription for this client per query if needed 342 if _, ok := state.subscriptions[qStr]; !ok { 343 state.subscriptions[qStr] = make(map[string]*Subscription) 344 } 345 // create subscription 346 state.subscriptions[qStr][clientID] = subscription 347 348 // initialize query if needed 349 if _, ok := state.queries[qStr]; !ok { 350 state.queries[qStr] = &queryPlusRefCount{q: q, refCount: 0} 351 } 352 // increment reference counter 353 state.queries[qStr].refCount++ 354 } 355 356 func (state *state) remove(clientID string, qStr string, reason error) { 357 clientSubscriptions, ok := state.subscriptions[qStr] 358 if !ok { 359 return 360 } 361 362 subscription, ok := clientSubscriptions[clientID] 363 if !ok { 364 return 365 } 366 367 subscription.cancel(reason) 368 369 // remove client from query map. 370 // if query has no other clients subscribed, remove it. 371 delete(state.subscriptions[qStr], clientID) 372 if len(state.subscriptions[qStr]) == 0 { 373 delete(state.subscriptions, qStr) 374 } 375 376 // decrease ref counter in queries 377 state.queries[qStr].refCount-- 378 // remove the query if nobody else is using it 379 if state.queries[qStr].refCount == 0 { 380 delete(state.queries, qStr) 381 } 382 } 383 384 func (state *state) removeClient(clientID string, reason error) { 385 for qStr, clientSubscriptions := range state.subscriptions { 386 if _, ok := clientSubscriptions[clientID]; ok { 387 state.remove(clientID, qStr, reason) 388 } 389 } 390 } 391 392 func (state *state) removeAll(reason error) { 393 for qStr, clientSubscriptions := range state.subscriptions { 394 for clientID := range clientSubscriptions { 395 state.remove(clientID, qStr, reason) 396 } 397 } 398 } 399 400 func (state *state) send(msg interface{}, events map[string][]string) { 401 for qStr, clientSubscriptions := range state.subscriptions { 402 q := state.queries[qStr].q 403 if q.Matches(events) { 404 for clientID, subscription := range clientSubscriptions { 405 if cap(subscription.out) == 0 { 406 // block on unbuffered channel 407 subscription.out <- NewMessage(msg, events) 408 } else { 409 // don't block on buffered channels 410 select { 411 case subscription.out <- NewMessage(msg, events): 412 default: 413 state.remove(clientID, qStr, ErrOutOfCapacity) 414 } 415 } 416 } 417 } 418 } 419 }