github.com/evdatsion/aphelion-dpos-bft@v0.32.1/libs/pubsub/pubsub.go (about)

     1  // Package pubsub implements a pub-sub model with a single publisher (Server)
     2  // and multiple subscribers (clients).
     3  //
     4  // Though you can have multiple publishers by sharing a pointer to a server or
     5  // by giving the same channel to each publisher and publishing messages from
     6  // that channel (fan-in).
     7  //
     8  // Clients subscribe for messages, which could be of any type, using a query.
     9  // When some message is published, we match it with all queries. If there is a
    10  // match, this message will be pushed to all clients, subscribed to that query.
    11  // See query subpackage for our implementation.
    12  //
    13  // Example:
    14  //
    15  //     q, err := query.New("account.name='John'")
    16  //     if err != nil {
    17  //         return err
    18  //     }
    19  //     ctx, cancel := context.WithTimeout(context.Background(), 1 * time.Second)
    20  //     defer cancel()
    21  //     subscription, err := pubsub.Subscribe(ctx, "johns-transactions", q)
    22  //     if err != nil {
    23  //         return err
    24  //     }
    25  //
    26  //     for {
    27  //         select {
    28  //         case msg <- subscription.Out():
    29  //             // handle msg.Data() and msg.Events()
    30  //         case <-subscription.Cancelled():
    31  //             return subscription.Err()
    32  //         }
    33  //     }
    34  //
    35  package pubsub
    36  
    37  import (
    38  	"context"
    39  	"errors"
    40  	"sync"
    41  
    42  	cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common"
    43  )
    44  
    45  type operation int
    46  
    47  const (
    48  	sub operation = iota
    49  	pub
    50  	unsub
    51  	shutdown
    52  )
    53  
    54  var (
    55  	// ErrSubscriptionNotFound is returned when a client tries to unsubscribe
    56  	// from not existing subscription.
    57  	ErrSubscriptionNotFound = errors.New("subscription not found")
    58  
    59  	// ErrAlreadySubscribed is returned when a client tries to subscribe twice or
    60  	// more using the same query.
    61  	ErrAlreadySubscribed = errors.New("already subscribed")
    62  )
    63  
    64  // Query defines an interface for a query to be used for subscribing. A query
    65  // matches against a map of events. Each key in this map is a composite of the
    66  // even type and an attribute key (e.g. "{eventType}.{eventAttrKey}") and the
    67  // values are the event values that are contained under that relationship. This
    68  // allows event types to repeat themselves with the same set of keys and
    69  // different values.
    70  type Query interface {
    71  	Matches(events map[string][]string) bool
    72  	String() string
    73  }
    74  
    75  type cmd struct {
    76  	op operation
    77  
    78  	// subscribe, unsubscribe
    79  	query        Query
    80  	subscription *Subscription
    81  	clientID     string
    82  
    83  	// publish
    84  	msg    interface{}
    85  	events map[string][]string
    86  }
    87  
    88  // Server allows clients to subscribe/unsubscribe for messages, publishing
    89  // messages with or without events, and manages internal state.
    90  type Server struct {
    91  	cmn.BaseService
    92  
    93  	cmds    chan cmd
    94  	cmdsCap int
    95  
    96  	// check if we have subscription before
    97  	// subscribing or unsubscribing
    98  	mtx           sync.RWMutex
    99  	subscriptions map[string]map[string]struct{} // subscriber -> query (string) -> empty struct
   100  }
   101  
   102  // Option sets a parameter for the server.
   103  type Option func(*Server)
   104  
   105  // NewServer returns a new server. See the commentary on the Option functions
   106  // for a detailed description of how to configure buffering. If no options are
   107  // provided, the resulting server's queue is unbuffered.
   108  func NewServer(options ...Option) *Server {
   109  	s := &Server{
   110  		subscriptions: make(map[string]map[string]struct{}),
   111  	}
   112  	s.BaseService = *cmn.NewBaseService(nil, "PubSub", s)
   113  
   114  	for _, option := range options {
   115  		option(s)
   116  	}
   117  
   118  	// if BufferCapacity option was not set, the channel is unbuffered
   119  	s.cmds = make(chan cmd, s.cmdsCap)
   120  
   121  	return s
   122  }
   123  
   124  // BufferCapacity allows you to specify capacity for the internal server's
   125  // queue. Since the server, given Y subscribers, could only process X messages,
   126  // this option could be used to survive spikes (e.g. high amount of
   127  // transactions during peak hours).
   128  func BufferCapacity(cap int) Option {
   129  	return func(s *Server) {
   130  		if cap > 0 {
   131  			s.cmdsCap = cap
   132  		}
   133  	}
   134  }
   135  
   136  // BufferCapacity returns capacity of the internal server's queue.
   137  func (s *Server) BufferCapacity() int {
   138  	return s.cmdsCap
   139  }
   140  
   141  // Subscribe creates a subscription for the given client.
   142  //
   143  // An error will be returned to the caller if the context is canceled or if
   144  // subscription already exist for pair clientID and query.
   145  //
   146  // outCapacity can be used to set a capacity for Subscription#Out channel (1 by
   147  // default). Panics if outCapacity is less than or equal to zero. If you want
   148  // an unbuffered channel, use SubscribeUnbuffered.
   149  func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, outCapacity ...int) (*Subscription, error) {
   150  	outCap := 1
   151  	if len(outCapacity) > 0 {
   152  		if outCapacity[0] <= 0 {
   153  			panic("Negative or zero capacity. Use SubscribeUnbuffered if you want an unbuffered channel")
   154  		}
   155  		outCap = outCapacity[0]
   156  	}
   157  
   158  	return s.subscribe(ctx, clientID, query, outCap)
   159  }
   160  
   161  // SubscribeUnbuffered does the same as Subscribe, except it returns a
   162  // subscription with unbuffered channel. Use with caution as it can freeze the
   163  // server.
   164  func (s *Server) SubscribeUnbuffered(ctx context.Context, clientID string, query Query) (*Subscription, error) {
   165  	return s.subscribe(ctx, clientID, query, 0)
   166  }
   167  
   168  func (s *Server) subscribe(ctx context.Context, clientID string, query Query, outCapacity int) (*Subscription, error) {
   169  	s.mtx.RLock()
   170  	clientSubscriptions, ok := s.subscriptions[clientID]
   171  	if ok {
   172  		_, ok = clientSubscriptions[query.String()]
   173  	}
   174  	s.mtx.RUnlock()
   175  	if ok {
   176  		return nil, ErrAlreadySubscribed
   177  	}
   178  
   179  	subscription := NewSubscription(outCapacity)
   180  	select {
   181  	case s.cmds <- cmd{op: sub, clientID: clientID, query: query, subscription: subscription}:
   182  		s.mtx.Lock()
   183  		if _, ok = s.subscriptions[clientID]; !ok {
   184  			s.subscriptions[clientID] = make(map[string]struct{})
   185  		}
   186  		s.subscriptions[clientID][query.String()] = struct{}{}
   187  		s.mtx.Unlock()
   188  		return subscription, nil
   189  	case <-ctx.Done():
   190  		return nil, ctx.Err()
   191  	case <-s.Quit():
   192  		return nil, nil
   193  	}
   194  }
   195  
   196  // Unsubscribe removes the subscription on the given query. An error will be
   197  // returned to the caller if the context is canceled or if subscription does
   198  // not exist.
   199  func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error {
   200  	s.mtx.RLock()
   201  	clientSubscriptions, ok := s.subscriptions[clientID]
   202  	if ok {
   203  		_, ok = clientSubscriptions[query.String()]
   204  	}
   205  	s.mtx.RUnlock()
   206  	if !ok {
   207  		return ErrSubscriptionNotFound
   208  	}
   209  
   210  	select {
   211  	case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}:
   212  		s.mtx.Lock()
   213  		delete(clientSubscriptions, query.String())
   214  		if len(clientSubscriptions) == 0 {
   215  			delete(s.subscriptions, clientID)
   216  		}
   217  		s.mtx.Unlock()
   218  		return nil
   219  	case <-ctx.Done():
   220  		return ctx.Err()
   221  	case <-s.Quit():
   222  		return nil
   223  	}
   224  }
   225  
   226  // UnsubscribeAll removes all client subscriptions. An error will be returned
   227  // to the caller if the context is canceled or if subscription does not exist.
   228  func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error {
   229  	s.mtx.RLock()
   230  	_, ok := s.subscriptions[clientID]
   231  	s.mtx.RUnlock()
   232  	if !ok {
   233  		return ErrSubscriptionNotFound
   234  	}
   235  
   236  	select {
   237  	case s.cmds <- cmd{op: unsub, clientID: clientID}:
   238  		s.mtx.Lock()
   239  		delete(s.subscriptions, clientID)
   240  		s.mtx.Unlock()
   241  		return nil
   242  	case <-ctx.Done():
   243  		return ctx.Err()
   244  	case <-s.Quit():
   245  		return nil
   246  	}
   247  }
   248  
   249  // NumClients returns the number of clients.
   250  func (s *Server) NumClients() int {
   251  	s.mtx.RLock()
   252  	defer s.mtx.RUnlock()
   253  	return len(s.subscriptions)
   254  }
   255  
   256  // NumClientSubscriptions returns the number of subscriptions the client has.
   257  func (s *Server) NumClientSubscriptions(clientID string) int {
   258  	s.mtx.RLock()
   259  	defer s.mtx.RUnlock()
   260  	return len(s.subscriptions[clientID])
   261  }
   262  
   263  // Publish publishes the given message. An error will be returned to the caller
   264  // if the context is canceled.
   265  func (s *Server) Publish(ctx context.Context, msg interface{}) error {
   266  	return s.PublishWithEvents(ctx, msg, make(map[string][]string))
   267  }
   268  
   269  // PublishWithEvents publishes the given message with the set of events. The set
   270  // is matched with clients queries. If there is a match, the message is sent to
   271  // the client.
   272  func (s *Server) PublishWithEvents(ctx context.Context, msg interface{}, events map[string][]string) error {
   273  	select {
   274  	case s.cmds <- cmd{op: pub, msg: msg, events: events}:
   275  		return nil
   276  	case <-ctx.Done():
   277  		return ctx.Err()
   278  	case <-s.Quit():
   279  		return nil
   280  	}
   281  }
   282  
   283  // OnStop implements Service.OnStop by shutting down the server.
   284  func (s *Server) OnStop() {
   285  	s.cmds <- cmd{op: shutdown}
   286  }
   287  
   288  // NOTE: not goroutine safe
   289  type state struct {
   290  	// query string -> client -> subscription
   291  	subscriptions map[string]map[string]*Subscription
   292  	// query string -> queryPlusRefCount
   293  	queries map[string]*queryPlusRefCount
   294  }
   295  
   296  // queryPlusRefCount holds a pointer to a query and reference counter. When
   297  // refCount is zero, query will be removed.
   298  type queryPlusRefCount struct {
   299  	q        Query
   300  	refCount int
   301  }
   302  
   303  // OnStart implements Service.OnStart by starting the server.
   304  func (s *Server) OnStart() error {
   305  	go s.loop(state{
   306  		subscriptions: make(map[string]map[string]*Subscription),
   307  		queries:       make(map[string]*queryPlusRefCount),
   308  	})
   309  	return nil
   310  }
   311  
   312  // OnReset implements Service.OnReset
   313  func (s *Server) OnReset() error {
   314  	return nil
   315  }
   316  
   317  func (s *Server) loop(state state) {
   318  loop:
   319  	for cmd := range s.cmds {
   320  		switch cmd.op {
   321  		case unsub:
   322  			if cmd.query != nil {
   323  				state.remove(cmd.clientID, cmd.query.String(), ErrUnsubscribed)
   324  			} else {
   325  				state.removeClient(cmd.clientID, ErrUnsubscribed)
   326  			}
   327  		case shutdown:
   328  			state.removeAll(nil)
   329  			break loop
   330  		case sub:
   331  			state.add(cmd.clientID, cmd.query, cmd.subscription)
   332  		case pub:
   333  			state.send(cmd.msg, cmd.events)
   334  		}
   335  	}
   336  }
   337  
   338  func (state *state) add(clientID string, q Query, subscription *Subscription) {
   339  	qStr := q.String()
   340  
   341  	// initialize subscription for this client per query if needed
   342  	if _, ok := state.subscriptions[qStr]; !ok {
   343  		state.subscriptions[qStr] = make(map[string]*Subscription)
   344  	}
   345  	// create subscription
   346  	state.subscriptions[qStr][clientID] = subscription
   347  
   348  	// initialize query if needed
   349  	if _, ok := state.queries[qStr]; !ok {
   350  		state.queries[qStr] = &queryPlusRefCount{q: q, refCount: 0}
   351  	}
   352  	// increment reference counter
   353  	state.queries[qStr].refCount++
   354  }
   355  
   356  func (state *state) remove(clientID string, qStr string, reason error) {
   357  	clientSubscriptions, ok := state.subscriptions[qStr]
   358  	if !ok {
   359  		return
   360  	}
   361  
   362  	subscription, ok := clientSubscriptions[clientID]
   363  	if !ok {
   364  		return
   365  	}
   366  
   367  	subscription.cancel(reason)
   368  
   369  	// remove client from query map.
   370  	// if query has no other clients subscribed, remove it.
   371  	delete(state.subscriptions[qStr], clientID)
   372  	if len(state.subscriptions[qStr]) == 0 {
   373  		delete(state.subscriptions, qStr)
   374  	}
   375  
   376  	// decrease ref counter in queries
   377  	state.queries[qStr].refCount--
   378  	// remove the query if nobody else is using it
   379  	if state.queries[qStr].refCount == 0 {
   380  		delete(state.queries, qStr)
   381  	}
   382  }
   383  
   384  func (state *state) removeClient(clientID string, reason error) {
   385  	for qStr, clientSubscriptions := range state.subscriptions {
   386  		if _, ok := clientSubscriptions[clientID]; ok {
   387  			state.remove(clientID, qStr, reason)
   388  		}
   389  	}
   390  }
   391  
   392  func (state *state) removeAll(reason error) {
   393  	for qStr, clientSubscriptions := range state.subscriptions {
   394  		for clientID := range clientSubscriptions {
   395  			state.remove(clientID, qStr, reason)
   396  		}
   397  	}
   398  }
   399  
   400  func (state *state) send(msg interface{}, events map[string][]string) {
   401  	for qStr, clientSubscriptions := range state.subscriptions {
   402  		q := state.queries[qStr].q
   403  		if q.Matches(events) {
   404  			for clientID, subscription := range clientSubscriptions {
   405  				if cap(subscription.out) == 0 {
   406  					// block on unbuffered channel
   407  					subscription.out <- NewMessage(msg, events)
   408  				} else {
   409  					// don't block on buffered channels
   410  					select {
   411  					case subscription.out <- NewMessage(msg, events):
   412  					default:
   413  						state.remove(clientID, qStr, ErrOutOfCapacity)
   414  					}
   415  				}
   416  			}
   417  		}
   418  	}
   419  }