github.com/lei006/gmqtt-broker@v0.0.1/broker/lib/topics/memtopics.go (about)

     1  package topics
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  
     8  	"github.com/eclipse/paho.mqtt.golang/packets"
     9  )
    10  
    11  const (
    12  	QosAtMostOnce byte = iota
    13  	QosAtLeastOnce
    14  	QosExactlyOnce
    15  	QosFailure = 0x80
    16  )
    17  
    18  var _ TopicsProvider = (*memTopics)(nil)
    19  
    20  type memTopics struct {
    21  	// Sub/unsub mutex
    22  	smu sync.RWMutex
    23  	// Subscription tree
    24  	sroot *snode
    25  
    26  	// Retained message mutex
    27  	rmu sync.RWMutex
    28  	// Retained messages topic tree
    29  	rroot *rnode
    30  }
    31  
    32  func init() {
    33  	Register("mem", NewMemProvider())
    34  }
    35  
    36  // NewMemProvider returns an new instance of the memTopics, which is implements the
    37  // TopicsProvider interface. memProvider is a hidden struct that stores the topic
    38  // subscriptions and retained messages in memory. The content is not persistend so
    39  // when the server goes, everything will be gone. Use with care.
    40  func NewMemProvider() *memTopics {
    41  	return &memTopics{
    42  		sroot: newSNode(),
    43  		rroot: newRNode(),
    44  	}
    45  }
    46  
    47  func ValidQos(qos byte) bool {
    48  	return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce
    49  }
    50  
    51  func (this *memTopics) Subscribe(topic []byte, qos byte, sub interface{}) (byte, error) {
    52  	if !ValidQos(qos) {
    53  		return QosFailure, fmt.Errorf("Invalid QoS %d", qos)
    54  	}
    55  
    56  	if sub == nil {
    57  		return QosFailure, fmt.Errorf("Subscriber cannot be nil")
    58  	}
    59  
    60  	this.smu.Lock()
    61  	defer this.smu.Unlock()
    62  
    63  	if qos > QosExactlyOnce {
    64  		qos = QosExactlyOnce
    65  	}
    66  
    67  	if err := this.sroot.sinsert(topic, qos, sub); err != nil {
    68  		return QosFailure, err
    69  	}
    70  
    71  	return qos, nil
    72  }
    73  
    74  func (this *memTopics) Unsubscribe(topic []byte, sub interface{}) error {
    75  	this.smu.Lock()
    76  	defer this.smu.Unlock()
    77  
    78  	return this.sroot.sremove(topic, sub)
    79  }
    80  
    81  // Subscribers Returned values will be invalidated by the next Subscribers call
    82  func (this *memTopics) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
    83  	if !ValidQos(qos) {
    84  		return fmt.Errorf("Invalid QoS %d", qos)
    85  	}
    86  
    87  	this.smu.RLock()
    88  	defer this.smu.RUnlock()
    89  
    90  	*subs = (*subs)[0:0]
    91  	*qoss = (*qoss)[0:0]
    92  
    93  	return this.sroot.smatch(topic, qos, subs, qoss)
    94  }
    95  
    96  func (this *memTopics) Retain(msg *packets.PublishPacket) error {
    97  	this.rmu.Lock()
    98  	defer this.rmu.Unlock()
    99  
   100  	// So apparently, at least according to the MQTT Conformance/Interoperability
   101  	// Testing, that a payload of 0 means delete the retain message.
   102  	// https://eclipse.org/paho/clients/testing/
   103  	if len(msg.Payload) == 0 {
   104  		return this.rroot.rremove([]byte(msg.TopicName))
   105  	}
   106  
   107  	return this.rroot.rinsertOrUpdate([]byte(msg.TopicName), msg)
   108  }
   109  
   110  func (this *memTopics) Retained(topic []byte, msgs *[]*packets.PublishPacket) error {
   111  	this.rmu.RLock()
   112  	defer this.rmu.RUnlock()
   113  
   114  	return this.rroot.rmatch(topic, msgs)
   115  }
   116  
   117  func (this *memTopics) Close() error {
   118  	this.sroot = nil
   119  	this.rroot = nil
   120  	return nil
   121  }
   122  
   123  // subscrition nodes
   124  type snode struct {
   125  	// If this is the end of the topic string, then add subscribers here
   126  	subs []interface{}
   127  	qos  []byte
   128  
   129  	// Otherwise add the next topic level here
   130  	snodes map[string]*snode
   131  }
   132  
   133  func newSNode() *snode {
   134  	return &snode{
   135  		snodes: make(map[string]*snode),
   136  	}
   137  }
   138  
   139  func (this *snode) sinsert(topic []byte, qos byte, sub interface{}) error {
   140  	// If there's no more topic levels, that means we are at the matching snode
   141  	// to insert the subscriber. So let's see if there's such subscriber,
   142  	// if so, update it. Otherwise insert it.
   143  	if len(topic) == 0 {
   144  		// Let's see if the subscriber is already on the list. If yes, update
   145  		// QoS and then return.
   146  		for i := range this.subs {
   147  			if equal(this.subs[i], sub) {
   148  				this.qos[i] = qos
   149  				return nil
   150  			}
   151  		}
   152  
   153  		// Otherwise add.
   154  		this.subs = append(this.subs, sub)
   155  		this.qos = append(this.qos, qos)
   156  
   157  		return nil
   158  	}
   159  
   160  	// Not the last level, so let's find or create the next level snode, and
   161  	// recursively call it's insert().
   162  
   163  	// ntl = next topic level
   164  	ntl, rem, err := nextTopicLevel(topic)
   165  	if err != nil {
   166  		return err
   167  	}
   168  
   169  	level := string(ntl)
   170  
   171  	// Add snode if it doesn't already exist
   172  	n, ok := this.snodes[level]
   173  	if !ok {
   174  		n = newSNode()
   175  		this.snodes[level] = n
   176  	}
   177  
   178  	return n.sinsert(rem, qos, sub)
   179  }
   180  
   181  // This remove implementation ignores the QoS, as long as the subscriber
   182  // matches then it's removed
   183  func (this *snode) sremove(topic []byte, sub interface{}) error {
   184  	// If the topic is empty, it means we are at the final matching snode. If so,
   185  	// let's find the matching subscribers and remove them.
   186  	if len(topic) == 0 {
   187  		// If subscriber == nil, then it's signal to remove ALL subscribers
   188  		if sub == nil {
   189  			this.subs = this.subs[0:0]
   190  			this.qos = this.qos[0:0]
   191  			return nil
   192  		}
   193  
   194  		// If we find the subscriber then remove it from the list. Technically
   195  		// we just overwrite the slot by shifting all other items up by one.
   196  		for i := range this.subs {
   197  			if equal(this.subs[i], sub) {
   198  				this.subs = append(this.subs[:i], this.subs[i+1:]...)
   199  				this.qos = append(this.qos[:i], this.qos[i+1:]...)
   200  				return nil
   201  			}
   202  		}
   203  
   204  		return fmt.Errorf("No topic found for subscriber")
   205  	}
   206  
   207  	// Not the last level, so let's find the next level snode, and recursively
   208  	// call it's remove().
   209  
   210  	// ntl = next topic level
   211  	ntl, rem, err := nextTopicLevel(topic)
   212  	if err != nil {
   213  		return err
   214  	}
   215  
   216  	level := string(ntl)
   217  
   218  	// Find the snode that matches the topic level
   219  	n, ok := this.snodes[level]
   220  	if !ok {
   221  		return fmt.Errorf("No topic found")
   222  	}
   223  
   224  	// Remove the subscriber from the next level snode
   225  	if err := n.sremove(rem, sub); err != nil {
   226  		return err
   227  	}
   228  
   229  	// If there are no more subscribers and snodes to the next level we just visited
   230  	// let's remove it
   231  	if len(n.subs) == 0 && len(n.snodes) == 0 {
   232  		delete(this.snodes, level)
   233  	}
   234  
   235  	return nil
   236  }
   237  
   238  // smatch() returns all the subscribers that are subscribed to the topic. Given a topic
   239  // with no wildcards (publish topic), it returns a list of subscribers that subscribes
   240  // to the topic. For each of the level names, it's a match
   241  // - if there are subscribers to '#', then all the subscribers are added to result set
   242  func (this *snode) smatch(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
   243  	// If the topic is empty, it means we are at the final matching snode. If so,
   244  	// let's find the subscribers that match the qos and append them to the list.
   245  	if len(topic) == 0 {
   246  		this.matchQos(qos, subs, qoss)
   247  		if mwcn, _ := this.snodes[MWC]; mwcn != nil {
   248  			mwcn.matchQos(qos, subs, qoss)
   249  		}
   250  		return nil
   251  	}
   252  
   253  	// ntl = next topic level
   254  	ntl, rem, err := nextTopicLevel(topic)
   255  	if err != nil {
   256  		return err
   257  	}
   258  
   259  	level := string(ntl)
   260  
   261  	for k, n := range this.snodes {
   262  		// If the key is "#", then these subscribers are added to the result set
   263  		if k == MWC {
   264  			n.matchQos(qos, subs, qoss)
   265  		} else if k == SWC || k == level {
   266  			if err := n.smatch(rem, qos, subs, qoss); err != nil {
   267  				return err
   268  			}
   269  		}
   270  	}
   271  
   272  	return nil
   273  }
   274  
   275  // retained message nodes
   276  type rnode struct {
   277  	// If this is the end of the topic string, then add retained messages here
   278  	msg *packets.PublishPacket
   279  	// Otherwise add the next topic level here
   280  	rnodes map[string]*rnode
   281  }
   282  
   283  func newRNode() *rnode {
   284  	return &rnode{
   285  		rnodes: make(map[string]*rnode),
   286  	}
   287  }
   288  
   289  func (this *rnode) rinsertOrUpdate(topic []byte, msg *packets.PublishPacket) error {
   290  	// If there's no more topic levels, that means we are at the matching rnode.
   291  	if len(topic) == 0 {
   292  		// Reuse the message if possible
   293  		this.msg = msg
   294  
   295  		return nil
   296  	}
   297  
   298  	// Not the last level, so let's find or create the next level snode, and
   299  	// recursively call it's insert().
   300  
   301  	// ntl = next topic level
   302  	ntl, rem, err := nextTopicLevel(topic)
   303  	if err != nil {
   304  		return err
   305  	}
   306  
   307  	level := string(ntl)
   308  
   309  	// Add snode if it doesn't already exist
   310  	n, ok := this.rnodes[level]
   311  	if !ok {
   312  		n = newRNode()
   313  		this.rnodes[level] = n
   314  	}
   315  
   316  	return n.rinsertOrUpdate(rem, msg)
   317  }
   318  
   319  // Remove the retained message for the supplied topic
   320  func (this *rnode) rremove(topic []byte) error {
   321  	// If the topic is empty, it means we are at the final matching rnode. If so,
   322  	// let's remove the buffer and message.
   323  	if len(topic) == 0 {
   324  		this.msg = nil
   325  		return nil
   326  	}
   327  
   328  	// Not the last level, so let's find the next level rnode, and recursively
   329  	// call it's remove().
   330  
   331  	// ntl = next topic level
   332  	ntl, rem, err := nextTopicLevel(topic)
   333  	if err != nil {
   334  		return err
   335  	}
   336  
   337  	level := string(ntl)
   338  
   339  	// Find the rnode that matches the topic level
   340  	n, ok := this.rnodes[level]
   341  	if !ok {
   342  		return fmt.Errorf("No topic found")
   343  	}
   344  
   345  	// Remove the subscriber from the next level rnode
   346  	if err := n.rremove(rem); err != nil {
   347  		return err
   348  	}
   349  
   350  	// If there are no more rnodes to the next level we just visited let's remove it
   351  	if len(n.rnodes) == 0 {
   352  		delete(this.rnodes, level)
   353  	}
   354  
   355  	return nil
   356  }
   357  
   358  // rmatch() finds the retained messages for the topic and qos provided. It's somewhat
   359  // of a reverse match compare to match() since the supplied topic can contain
   360  // wildcards, whereas the retained message topic is a full (no wildcard) topic.
   361  func (this *rnode) rmatch(topic []byte, msgs *[]*packets.PublishPacket) error {
   362  	// If the topic is empty, it means we are at the final matching rnode. If so,
   363  	// add the retained msg to the list.
   364  	if len(topic) == 0 {
   365  		if this.msg != nil {
   366  			*msgs = append(*msgs, this.msg)
   367  		}
   368  		return nil
   369  	}
   370  
   371  	// ntl = next topic level
   372  	ntl, rem, err := nextTopicLevel(topic)
   373  	if err != nil {
   374  		return err
   375  	}
   376  
   377  	level := string(ntl)
   378  
   379  	if level == MWC {
   380  		// If '#', add all retained messages starting this node
   381  		this.allRetained(msgs)
   382  	} else if level == SWC {
   383  		// If '+', check all nodes at this level. Next levels must be matched.
   384  		for _, n := range this.rnodes {
   385  			if err := n.rmatch(rem, msgs); err != nil {
   386  				return err
   387  			}
   388  		}
   389  	} else {
   390  		// Otherwise, find the matching node, go to the next level
   391  		if n, ok := this.rnodes[level]; ok {
   392  			if err := n.rmatch(rem, msgs); err != nil {
   393  				return err
   394  			}
   395  		}
   396  	}
   397  
   398  	return nil
   399  }
   400  
   401  func (this *rnode) allRetained(msgs *[]*packets.PublishPacket) {
   402  	if this.msg != nil {
   403  		*msgs = append(*msgs, this.msg)
   404  	}
   405  
   406  	for _, n := range this.rnodes {
   407  		n.allRetained(msgs)
   408  	}
   409  }
   410  
   411  const (
   412  	stateCHR byte = iota // Regular character
   413  	stateMWC             // Multi-level wildcard
   414  	stateSWC             // Single-level wildcard
   415  	stateSEP             // Topic level separator
   416  	stateSYS             // System level topic ($)
   417  )
   418  
   419  // Returns topic level, remaining topic levels and any errors
   420  func nextTopicLevel(topic []byte) ([]byte, []byte, error) {
   421  	s := stateCHR
   422  
   423  	for i, c := range topic {
   424  		switch c {
   425  		case '/':
   426  			if s == stateMWC {
   427  				return nil, nil, fmt.Errorf("Multi-level wildcard found in topic and it's not at the last level")
   428  			}
   429  
   430  			if i == 0 {
   431  				return []byte(SWC), topic[i+1:], nil
   432  			}
   433  
   434  			return topic[:i], topic[i+1:], nil
   435  
   436  		case '#':
   437  			if i != 0 {
   438  				return nil, nil, fmt.Errorf("Wildcard character '#' must occupy entire topic level")
   439  			}
   440  
   441  			s = stateMWC
   442  
   443  		case '+':
   444  			if i != 0 {
   445  				return nil, nil, fmt.Errorf("Wildcard character '+' must occupy entire topic level")
   446  			}
   447  
   448  			s = stateSWC
   449  
   450  		// case '$':
   451  		// 	if i == 0 {
   452  		// 		return nil, nil, fmt.Errorf("Cannot publish to $ topics")
   453  		// 	}
   454  
   455  		// 	s = stateSYS
   456  
   457  		default:
   458  			if s == stateMWC || s == stateSWC {
   459  				return nil, nil, fmt.Errorf("Wildcard characters '#' and '+' must occupy entire topic level")
   460  			}
   461  
   462  			s = stateCHR
   463  		}
   464  	}
   465  
   466  	// If we got here that means we didn't hit the separator along the way, so the
   467  	// topic is either empty, or does not contain a separator. Either way, we return
   468  	// the full topic
   469  	return topic, nil, nil
   470  }
   471  
   472  // The QoS of the payload messages sent in response to a subscription must be the
   473  // minimum of the QoS of the originally published message (in this case, it's the
   474  // qos parameter) and the maximum QoS granted by the server (in this case, it's
   475  // the QoS in the topic tree).
   476  //
   477  // It's also possible that even if the topic matches, the subscriber is not included
   478  // due to the QoS granted is lower than the published message QoS. For example,
   479  // if the client is granted only QoS 0, and the publish message is QoS 1, then this
   480  // client is not to be send the published message.
   481  func (this *snode) matchQos(qos byte, subs *[]interface{}, qoss *[]byte) {
   482  	for _, sub := range this.subs {
   483  		// If the published QoS is higher than the subscriber QoS, then we skip the
   484  		// subscriber. Otherwise, add to the list.
   485  		// if qos >= this.qos[i] {
   486  		*subs = append(*subs, sub)
   487  		*qoss = append(*qoss, qos)
   488  		// }
   489  	}
   490  }
   491  
   492  func equal(k1, k2 interface{}) bool {
   493  	if reflect.TypeOf(k1) != reflect.TypeOf(k2) {
   494  		return false
   495  	}
   496  
   497  	if reflect.ValueOf(k1).Kind() == reflect.Func {
   498  		return &k1 == &k2
   499  	}
   500  
   501  	if k1 == k2 {
   502  		return true
   503  	}
   504  
   505  	switch k1 := k1.(type) {
   506  	case string:
   507  		return k1 == k2.(string)
   508  
   509  	case int64:
   510  		return k1 == k2.(int64)
   511  
   512  	case int32:
   513  		return k1 == k2.(int32)
   514  
   515  	case int16:
   516  		return k1 == k2.(int16)
   517  
   518  	case int8:
   519  		return k1 == k2.(int8)
   520  
   521  	case int:
   522  		return k1 == k2.(int)
   523  
   524  	case float32:
   525  		return k1 == k2.(float32)
   526  
   527  	case float64:
   528  		return k1 == k2.(float64)
   529  
   530  	case uint:
   531  		return k1 == k2.(uint)
   532  
   533  	case uint8:
   534  		return k1 == k2.(uint8)
   535  
   536  	case uint16:
   537  		return k1 == k2.(uint16)
   538  
   539  	case uint32:
   540  		return k1 == k2.(uint32)
   541  
   542  	case uint64:
   543  		return k1 == k2.(uint64)
   544  
   545  	case uintptr:
   546  		return k1 == k2.(uintptr)
   547  	}
   548  
   549  	return false
   550  }