github.com/vipernet-xyz/tendermint-core@v0.32.0/libs/pubsub/pubsub.go (about)

     1  // Package pubsub implements a pub-sub model with a single publisher (Server)
     2  // and multiple subscribers (clients).
     3  //
     4  // Though you can have multiple publishers by sharing a pointer to a server or
     5  // by giving the same channel to each publisher and publishing messages from
     6  // that channel (fan-in).
     7  //
     8  // Clients subscribe for messages, which could be of any type, using a query.
     9  // When some message is published, we match it with all queries. If there is a
    10  // match, this message will be pushed to all clients, subscribed to that query.
    11  // See query subpackage for our implementation.
    12  //
    13  // Example:
    14  //
    15  //     q, err := query.New("account.name='John'")
    16  //     if err != nil {
    17  //         return err
    18  //     }
    19  //     ctx, cancel := context.WithTimeout(context.Background(), 1 * time.Second)
    20  //     defer cancel()
    21  //     subscription, err := pubsub.Subscribe(ctx, "johns-transactions", q)
    22  //     if err != nil {
    23  //         return err
    24  //     }
    25  //
    26  //     for {
    27  //         select {
    28  //         case msg <- subscription.Out():
    29  //             // handle msg.Data() and msg.Events()
    30  //         case <-subscription.Cancelled():
    31  //             return subscription.Err()
    32  //         }
    33  //     }
    34  //
    35  package pubsub
    36  
    37  import (
    38  	"context"
    39  	"sync"
    40  
    41  	"github.com/pkg/errors"
    42  
    43  	"github.com/tendermint/tendermint/libs/service"
    44  )
    45  
    46  type operation int
    47  
    48  const (
    49  	sub operation = iota
    50  	pub
    51  	unsub
    52  	shutdown
    53  )
    54  
    55  var (
    56  	// ErrSubscriptionNotFound is returned when a client tries to unsubscribe
    57  	// from not existing subscription.
    58  	ErrSubscriptionNotFound = errors.New("subscription not found")
    59  
    60  	// ErrAlreadySubscribed is returned when a client tries to subscribe twice or
    61  	// more using the same query.
    62  	ErrAlreadySubscribed = errors.New("already subscribed")
    63  )
    64  
    65  // Query defines an interface for a query to be used for subscribing. A query
    66  // matches against a map of events. Each key in this map is a composite of the
    67  // even type and an attribute key (e.g. "{eventType}.{eventAttrKey}") and the
    68  // values are the event values that are contained under that relationship. This
    69  // allows event types to repeat themselves with the same set of keys and
    70  // different values.
    71  type Query interface {
    72  	Matches(events map[string][]string) (bool, error)
    73  	String() string
    74  }
    75  
    76  type cmd struct {
    77  	op operation
    78  
    79  	// subscribe, unsubscribe
    80  	query        Query
    81  	subscription *Subscription
    82  	clientID     string
    83  
    84  	// publish
    85  	msg    interface{}
    86  	events map[string][]string
    87  }
    88  
    89  // Server allows clients to subscribe/unsubscribe for messages, publishing
    90  // messages with or without events, and manages internal state.
    91  type Server struct {
    92  	service.BaseService
    93  
    94  	cmds    chan cmd
    95  	cmdsCap int
    96  
    97  	// check if we have subscription before
    98  	// subscribing or unsubscribing
    99  	mtx           sync.RWMutex
   100  	subscriptions map[string]map[string]struct{} // subscriber -> query (string) -> empty struct
   101  }
   102  
   103  // Option sets a parameter for the server.
   104  type Option func(*Server)
   105  
   106  // NewServer returns a new server. See the commentary on the Option functions
   107  // for a detailed description of how to configure buffering. If no options are
   108  // provided, the resulting server's queue is unbuffered.
   109  func NewServer(options ...Option) *Server {
   110  	s := &Server{
   111  		subscriptions: make(map[string]map[string]struct{}),
   112  	}
   113  	s.BaseService = *service.NewBaseService(nil, "PubSub", s)
   114  
   115  	for _, option := range options {
   116  		option(s)
   117  	}
   118  
   119  	// if BufferCapacity option was not set, the channel is unbuffered
   120  	s.cmds = make(chan cmd, s.cmdsCap)
   121  
   122  	return s
   123  }
   124  
   125  // BufferCapacity allows you to specify capacity for the internal server's
   126  // queue. Since the server, given Y subscribers, could only process X messages,
   127  // this option could be used to survive spikes (e.g. high amount of
   128  // transactions during peak hours).
   129  func BufferCapacity(cap int) Option {
   130  	return func(s *Server) {
   131  		if cap > 0 {
   132  			s.cmdsCap = cap
   133  		}
   134  	}
   135  }
   136  
   137  // BufferCapacity returns capacity of the internal server's queue.
   138  func (s *Server) BufferCapacity() int {
   139  	return s.cmdsCap
   140  }
   141  
   142  // Subscribe creates a subscription for the given client.
   143  //
   144  // An error will be returned to the caller if the context is canceled or if
   145  // subscription already exist for pair clientID and query.
   146  //
   147  // outCapacity can be used to set a capacity for Subscription#Out channel (1 by
   148  // default). Panics if outCapacity is less than or equal to zero. If you want
   149  // an unbuffered channel, use SubscribeUnbuffered.
   150  func (s *Server) Subscribe(
   151  	ctx context.Context,
   152  	clientID string,
   153  	query Query,
   154  	outCapacity ...int) (*Subscription, error) {
   155  	outCap := 1
   156  	if len(outCapacity) > 0 {
   157  		if outCapacity[0] <= 0 {
   158  			panic("Negative or zero capacity. Use SubscribeUnbuffered if you want an unbuffered channel")
   159  		}
   160  		outCap = outCapacity[0]
   161  	}
   162  
   163  	return s.subscribe(ctx, clientID, query, outCap)
   164  }
   165  
   166  // SubscribeUnbuffered does the same as Subscribe, except it returns a
   167  // subscription with unbuffered channel. Use with caution as it can freeze the
   168  // server.
   169  func (s *Server) SubscribeUnbuffered(ctx context.Context, clientID string, query Query) (*Subscription, error) {
   170  	return s.subscribe(ctx, clientID, query, 0)
   171  }
   172  
   173  func (s *Server) subscribe(ctx context.Context, clientID string, query Query, outCapacity int) (*Subscription, error) {
   174  	s.mtx.RLock()
   175  	clientSubscriptions, ok := s.subscriptions[clientID]
   176  	if ok {
   177  		_, ok = clientSubscriptions[query.String()]
   178  	}
   179  	s.mtx.RUnlock()
   180  	if ok {
   181  		return nil, ErrAlreadySubscribed
   182  	}
   183  
   184  	subscription := NewSubscription(outCapacity)
   185  	select {
   186  	case s.cmds <- cmd{op: sub, clientID: clientID, query: query, subscription: subscription}:
   187  		s.mtx.Lock()
   188  		if _, ok = s.subscriptions[clientID]; !ok {
   189  			s.subscriptions[clientID] = make(map[string]struct{})
   190  		}
   191  		s.subscriptions[clientID][query.String()] = struct{}{}
   192  		s.mtx.Unlock()
   193  		return subscription, nil
   194  	case <-ctx.Done():
   195  		return nil, ctx.Err()
   196  	case <-s.Quit():
   197  		return nil, nil
   198  	}
   199  }
   200  
   201  // Unsubscribe removes the subscription on the given query. An error will be
   202  // returned to the caller if the context is canceled or if subscription does
   203  // not exist.
   204  func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error {
   205  	s.mtx.RLock()
   206  	clientSubscriptions, ok := s.subscriptions[clientID]
   207  	if ok {
   208  		_, ok = clientSubscriptions[query.String()]
   209  	}
   210  	s.mtx.RUnlock()
   211  	if !ok {
   212  		return ErrSubscriptionNotFound
   213  	}
   214  
   215  	select {
   216  	case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}:
   217  		s.mtx.Lock()
   218  		delete(clientSubscriptions, query.String())
   219  		if len(clientSubscriptions) == 0 {
   220  			delete(s.subscriptions, clientID)
   221  		}
   222  		s.mtx.Unlock()
   223  		return nil
   224  	case <-ctx.Done():
   225  		return ctx.Err()
   226  	case <-s.Quit():
   227  		return nil
   228  	}
   229  }
   230  
   231  // UnsubscribeAll removes all client subscriptions. An error will be returned
   232  // to the caller if the context is canceled or if subscription does not exist.
   233  func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error {
   234  	s.mtx.RLock()
   235  	_, ok := s.subscriptions[clientID]
   236  	s.mtx.RUnlock()
   237  	if !ok {
   238  		return ErrSubscriptionNotFound
   239  	}
   240  
   241  	select {
   242  	case s.cmds <- cmd{op: unsub, clientID: clientID}:
   243  		s.mtx.Lock()
   244  		delete(s.subscriptions, clientID)
   245  		s.mtx.Unlock()
   246  		return nil
   247  	case <-ctx.Done():
   248  		return ctx.Err()
   249  	case <-s.Quit():
   250  		return nil
   251  	}
   252  }
   253  
   254  // NumClients returns the number of clients.
   255  func (s *Server) NumClients() int {
   256  	s.mtx.RLock()
   257  	defer s.mtx.RUnlock()
   258  	return len(s.subscriptions)
   259  }
   260  
   261  // NumClientSubscriptions returns the number of subscriptions the client has.
   262  func (s *Server) NumClientSubscriptions(clientID string) int {
   263  	s.mtx.RLock()
   264  	defer s.mtx.RUnlock()
   265  	return len(s.subscriptions[clientID])
   266  }
   267  
   268  // Publish publishes the given message. An error will be returned to the caller
   269  // if the context is canceled.
   270  func (s *Server) Publish(ctx context.Context, msg interface{}) error {
   271  	return s.PublishWithEvents(ctx, msg, make(map[string][]string))
   272  }
   273  
   274  // PublishWithEvents publishes the given message with the set of events. The set
   275  // is matched with clients queries. If there is a match, the message is sent to
   276  // the client.
   277  func (s *Server) PublishWithEvents(ctx context.Context, msg interface{}, events map[string][]string) error {
   278  	select {
   279  	case s.cmds <- cmd{op: pub, msg: msg, events: events}:
   280  		return nil
   281  	case <-ctx.Done():
   282  		return ctx.Err()
   283  	case <-s.Quit():
   284  		return nil
   285  	}
   286  }
   287  
   288  // OnStop implements Service.OnStop by shutting down the server.
   289  func (s *Server) OnStop() {
   290  	s.cmds <- cmd{op: shutdown}
   291  }
   292  
   293  // NOTE: not goroutine safe
   294  type state struct {
   295  	// query string -> client -> subscription
   296  	subscriptions map[string]map[string]*Subscription
   297  	// query string -> queryPlusRefCount
   298  	queries map[string]*queryPlusRefCount
   299  }
   300  
   301  // queryPlusRefCount holds a pointer to a query and reference counter. When
   302  // refCount is zero, query will be removed.
   303  type queryPlusRefCount struct {
   304  	q        Query
   305  	refCount int
   306  }
   307  
   308  // OnStart implements Service.OnStart by starting the server.
   309  func (s *Server) OnStart() error {
   310  	go s.loop(state{
   311  		subscriptions: make(map[string]map[string]*Subscription),
   312  		queries:       make(map[string]*queryPlusRefCount),
   313  	})
   314  	return nil
   315  }
   316  
   317  // OnReset implements Service.OnReset
   318  func (s *Server) OnReset() error {
   319  	return nil
   320  }
   321  
   322  func (s *Server) loop(state state) {
   323  loop:
   324  	for cmd := range s.cmds {
   325  		switch cmd.op {
   326  		case unsub:
   327  			if cmd.query != nil {
   328  				state.remove(cmd.clientID, cmd.query.String(), ErrUnsubscribed)
   329  			} else {
   330  				state.removeClient(cmd.clientID, ErrUnsubscribed)
   331  			}
   332  		case shutdown:
   333  			state.removeAll(nil)
   334  			break loop
   335  		case sub:
   336  			state.add(cmd.clientID, cmd.query, cmd.subscription)
   337  		case pub:
   338  			if err := state.send(cmd.msg, cmd.events); err != nil {
   339  				s.Logger.Error("Error querying for events", "err", err)
   340  			}
   341  		}
   342  	}
   343  }
   344  
   345  func (state *state) add(clientID string, q Query, subscription *Subscription) {
   346  	qStr := q.String()
   347  
   348  	// initialize subscription for this client per query if needed
   349  	if _, ok := state.subscriptions[qStr]; !ok {
   350  		state.subscriptions[qStr] = make(map[string]*Subscription)
   351  	}
   352  	// create subscription
   353  	state.subscriptions[qStr][clientID] = subscription
   354  
   355  	// initialize query if needed
   356  	if _, ok := state.queries[qStr]; !ok {
   357  		state.queries[qStr] = &queryPlusRefCount{q: q, refCount: 0}
   358  	}
   359  	// increment reference counter
   360  	state.queries[qStr].refCount++
   361  }
   362  
   363  func (state *state) remove(clientID string, qStr string, reason error) {
   364  	clientSubscriptions, ok := state.subscriptions[qStr]
   365  	if !ok {
   366  		return
   367  	}
   368  
   369  	subscription, ok := clientSubscriptions[clientID]
   370  	if !ok {
   371  		return
   372  	}
   373  
   374  	subscription.cancel(reason)
   375  
   376  	// remove client from query map.
   377  	// if query has no other clients subscribed, remove it.
   378  	delete(state.subscriptions[qStr], clientID)
   379  	if len(state.subscriptions[qStr]) == 0 {
   380  		delete(state.subscriptions, qStr)
   381  	}
   382  
   383  	// decrease ref counter in queries
   384  	state.queries[qStr].refCount--
   385  	// remove the query if nobody else is using it
   386  	if state.queries[qStr].refCount == 0 {
   387  		delete(state.queries, qStr)
   388  	}
   389  }
   390  
   391  func (state *state) removeClient(clientID string, reason error) {
   392  	for qStr, clientSubscriptions := range state.subscriptions {
   393  		if _, ok := clientSubscriptions[clientID]; ok {
   394  			state.remove(clientID, qStr, reason)
   395  		}
   396  	}
   397  }
   398  
   399  func (state *state) removeAll(reason error) {
   400  	for qStr, clientSubscriptions := range state.subscriptions {
   401  		for clientID := range clientSubscriptions {
   402  			state.remove(clientID, qStr, reason)
   403  		}
   404  	}
   405  }
   406  
   407  func (state *state) send(msg interface{}, events map[string][]string) error {
   408  	for qStr, clientSubscriptions := range state.subscriptions {
   409  		q := state.queries[qStr].q
   410  
   411  		match, err := q.Matches(events)
   412  		if err != nil {
   413  			return errors.Wrapf(err, "failed to match against query %s", q.String())
   414  		}
   415  
   416  		if match {
   417  			for clientID, subscription := range clientSubscriptions {
   418  				if cap(subscription.out) == 0 {
   419  					// block on unbuffered channel
   420  					subscription.out <- NewMessage(msg, events)
   421  				} else {
   422  					// don't block on buffered channels
   423  					select {
   424  					case subscription.out <- NewMessage(msg, events):
   425  					default:
   426  						state.remove(clientID, qStr, ErrOutOfCapacity)
   427  					}
   428  				}
   429  			}
   430  		}
   431  	}
   432  
   433  	return nil
   434  }