github.com/lei006/gmqtt-broker@v0.0.1/broker/lib/topics/memtopics.go (about) 1 package topics 2 3 import ( 4 "fmt" 5 "reflect" 6 "sync" 7 8 "github.com/eclipse/paho.mqtt.golang/packets" 9 ) 10 11 const ( 12 QosAtMostOnce byte = iota 13 QosAtLeastOnce 14 QosExactlyOnce 15 QosFailure = 0x80 16 ) 17 18 var _ TopicsProvider = (*memTopics)(nil) 19 20 type memTopics struct { 21 // Sub/unsub mutex 22 smu sync.RWMutex 23 // Subscription tree 24 sroot *snode 25 26 // Retained message mutex 27 rmu sync.RWMutex 28 // Retained messages topic tree 29 rroot *rnode 30 } 31 32 func init() { 33 Register("mem", NewMemProvider()) 34 } 35 36 // NewMemProvider returns an new instance of the memTopics, which is implements the 37 // TopicsProvider interface. memProvider is a hidden struct that stores the topic 38 // subscriptions and retained messages in memory. The content is not persistend so 39 // when the server goes, everything will be gone. Use with care. 40 func NewMemProvider() *memTopics { 41 return &memTopics{ 42 sroot: newSNode(), 43 rroot: newRNode(), 44 } 45 } 46 47 func ValidQos(qos byte) bool { 48 return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce 49 } 50 51 func (this *memTopics) Subscribe(topic []byte, qos byte, sub interface{}) (byte, error) { 52 if !ValidQos(qos) { 53 return QosFailure, fmt.Errorf("Invalid QoS %d", qos) 54 } 55 56 if sub == nil { 57 return QosFailure, fmt.Errorf("Subscriber cannot be nil") 58 } 59 60 this.smu.Lock() 61 defer this.smu.Unlock() 62 63 if qos > QosExactlyOnce { 64 qos = QosExactlyOnce 65 } 66 67 if err := this.sroot.sinsert(topic, qos, sub); err != nil { 68 return QosFailure, err 69 } 70 71 return qos, nil 72 } 73 74 func (this *memTopics) Unsubscribe(topic []byte, sub interface{}) error { 75 this.smu.Lock() 76 defer this.smu.Unlock() 77 78 return this.sroot.sremove(topic, sub) 79 } 80 81 // Subscribers Returned values will be invalidated by the next Subscribers call 82 func (this *memTopics) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error { 83 if !ValidQos(qos) { 84 return fmt.Errorf("Invalid QoS %d", qos) 85 } 86 87 this.smu.RLock() 88 defer this.smu.RUnlock() 89 90 *subs = (*subs)[0:0] 91 *qoss = (*qoss)[0:0] 92 93 return this.sroot.smatch(topic, qos, subs, qoss) 94 } 95 96 func (this *memTopics) Retain(msg *packets.PublishPacket) error { 97 this.rmu.Lock() 98 defer this.rmu.Unlock() 99 100 // So apparently, at least according to the MQTT Conformance/Interoperability 101 // Testing, that a payload of 0 means delete the retain message. 102 // https://eclipse.org/paho/clients/testing/ 103 if len(msg.Payload) == 0 { 104 return this.rroot.rremove([]byte(msg.TopicName)) 105 } 106 107 return this.rroot.rinsertOrUpdate([]byte(msg.TopicName), msg) 108 } 109 110 func (this *memTopics) Retained(topic []byte, msgs *[]*packets.PublishPacket) error { 111 this.rmu.RLock() 112 defer this.rmu.RUnlock() 113 114 return this.rroot.rmatch(topic, msgs) 115 } 116 117 func (this *memTopics) Close() error { 118 this.sroot = nil 119 this.rroot = nil 120 return nil 121 } 122 123 // subscrition nodes 124 type snode struct { 125 // If this is the end of the topic string, then add subscribers here 126 subs []interface{} 127 qos []byte 128 129 // Otherwise add the next topic level here 130 snodes map[string]*snode 131 } 132 133 func newSNode() *snode { 134 return &snode{ 135 snodes: make(map[string]*snode), 136 } 137 } 138 139 func (this *snode) sinsert(topic []byte, qos byte, sub interface{}) error { 140 // If there's no more topic levels, that means we are at the matching snode 141 // to insert the subscriber. So let's see if there's such subscriber, 142 // if so, update it. Otherwise insert it. 143 if len(topic) == 0 { 144 // Let's see if the subscriber is already on the list. If yes, update 145 // QoS and then return. 146 for i := range this.subs { 147 if equal(this.subs[i], sub) { 148 this.qos[i] = qos 149 return nil 150 } 151 } 152 153 // Otherwise add. 154 this.subs = append(this.subs, sub) 155 this.qos = append(this.qos, qos) 156 157 return nil 158 } 159 160 // Not the last level, so let's find or create the next level snode, and 161 // recursively call it's insert(). 162 163 // ntl = next topic level 164 ntl, rem, err := nextTopicLevel(topic) 165 if err != nil { 166 return err 167 } 168 169 level := string(ntl) 170 171 // Add snode if it doesn't already exist 172 n, ok := this.snodes[level] 173 if !ok { 174 n = newSNode() 175 this.snodes[level] = n 176 } 177 178 return n.sinsert(rem, qos, sub) 179 } 180 181 // This remove implementation ignores the QoS, as long as the subscriber 182 // matches then it's removed 183 func (this *snode) sremove(topic []byte, sub interface{}) error { 184 // If the topic is empty, it means we are at the final matching snode. If so, 185 // let's find the matching subscribers and remove them. 186 if len(topic) == 0 { 187 // If subscriber == nil, then it's signal to remove ALL subscribers 188 if sub == nil { 189 this.subs = this.subs[0:0] 190 this.qos = this.qos[0:0] 191 return nil 192 } 193 194 // If we find the subscriber then remove it from the list. Technically 195 // we just overwrite the slot by shifting all other items up by one. 196 for i := range this.subs { 197 if equal(this.subs[i], sub) { 198 this.subs = append(this.subs[:i], this.subs[i+1:]...) 199 this.qos = append(this.qos[:i], this.qos[i+1:]...) 200 return nil 201 } 202 } 203 204 return fmt.Errorf("No topic found for subscriber") 205 } 206 207 // Not the last level, so let's find the next level snode, and recursively 208 // call it's remove(). 209 210 // ntl = next topic level 211 ntl, rem, err := nextTopicLevel(topic) 212 if err != nil { 213 return err 214 } 215 216 level := string(ntl) 217 218 // Find the snode that matches the topic level 219 n, ok := this.snodes[level] 220 if !ok { 221 return fmt.Errorf("No topic found") 222 } 223 224 // Remove the subscriber from the next level snode 225 if err := n.sremove(rem, sub); err != nil { 226 return err 227 } 228 229 // If there are no more subscribers and snodes to the next level we just visited 230 // let's remove it 231 if len(n.subs) == 0 && len(n.snodes) == 0 { 232 delete(this.snodes, level) 233 } 234 235 return nil 236 } 237 238 // smatch() returns all the subscribers that are subscribed to the topic. Given a topic 239 // with no wildcards (publish topic), it returns a list of subscribers that subscribes 240 // to the topic. For each of the level names, it's a match 241 // - if there are subscribers to '#', then all the subscribers are added to result set 242 func (this *snode) smatch(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error { 243 // If the topic is empty, it means we are at the final matching snode. If so, 244 // let's find the subscribers that match the qos and append them to the list. 245 if len(topic) == 0 { 246 this.matchQos(qos, subs, qoss) 247 if mwcn, _ := this.snodes[MWC]; mwcn != nil { 248 mwcn.matchQos(qos, subs, qoss) 249 } 250 return nil 251 } 252 253 // ntl = next topic level 254 ntl, rem, err := nextTopicLevel(topic) 255 if err != nil { 256 return err 257 } 258 259 level := string(ntl) 260 261 for k, n := range this.snodes { 262 // If the key is "#", then these subscribers are added to the result set 263 if k == MWC { 264 n.matchQos(qos, subs, qoss) 265 } else if k == SWC || k == level { 266 if err := n.smatch(rem, qos, subs, qoss); err != nil { 267 return err 268 } 269 } 270 } 271 272 return nil 273 } 274 275 // retained message nodes 276 type rnode struct { 277 // If this is the end of the topic string, then add retained messages here 278 msg *packets.PublishPacket 279 // Otherwise add the next topic level here 280 rnodes map[string]*rnode 281 } 282 283 func newRNode() *rnode { 284 return &rnode{ 285 rnodes: make(map[string]*rnode), 286 } 287 } 288 289 func (this *rnode) rinsertOrUpdate(topic []byte, msg *packets.PublishPacket) error { 290 // If there's no more topic levels, that means we are at the matching rnode. 291 if len(topic) == 0 { 292 // Reuse the message if possible 293 this.msg = msg 294 295 return nil 296 } 297 298 // Not the last level, so let's find or create the next level snode, and 299 // recursively call it's insert(). 300 301 // ntl = next topic level 302 ntl, rem, err := nextTopicLevel(topic) 303 if err != nil { 304 return err 305 } 306 307 level := string(ntl) 308 309 // Add snode if it doesn't already exist 310 n, ok := this.rnodes[level] 311 if !ok { 312 n = newRNode() 313 this.rnodes[level] = n 314 } 315 316 return n.rinsertOrUpdate(rem, msg) 317 } 318 319 // Remove the retained message for the supplied topic 320 func (this *rnode) rremove(topic []byte) error { 321 // If the topic is empty, it means we are at the final matching rnode. If so, 322 // let's remove the buffer and message. 323 if len(topic) == 0 { 324 this.msg = nil 325 return nil 326 } 327 328 // Not the last level, so let's find the next level rnode, and recursively 329 // call it's remove(). 330 331 // ntl = next topic level 332 ntl, rem, err := nextTopicLevel(topic) 333 if err != nil { 334 return err 335 } 336 337 level := string(ntl) 338 339 // Find the rnode that matches the topic level 340 n, ok := this.rnodes[level] 341 if !ok { 342 return fmt.Errorf("No topic found") 343 } 344 345 // Remove the subscriber from the next level rnode 346 if err := n.rremove(rem); err != nil { 347 return err 348 } 349 350 // If there are no more rnodes to the next level we just visited let's remove it 351 if len(n.rnodes) == 0 { 352 delete(this.rnodes, level) 353 } 354 355 return nil 356 } 357 358 // rmatch() finds the retained messages for the topic and qos provided. It's somewhat 359 // of a reverse match compare to match() since the supplied topic can contain 360 // wildcards, whereas the retained message topic is a full (no wildcard) topic. 361 func (this *rnode) rmatch(topic []byte, msgs *[]*packets.PublishPacket) error { 362 // If the topic is empty, it means we are at the final matching rnode. If so, 363 // add the retained msg to the list. 364 if len(topic) == 0 { 365 if this.msg != nil { 366 *msgs = append(*msgs, this.msg) 367 } 368 return nil 369 } 370 371 // ntl = next topic level 372 ntl, rem, err := nextTopicLevel(topic) 373 if err != nil { 374 return err 375 } 376 377 level := string(ntl) 378 379 if level == MWC { 380 // If '#', add all retained messages starting this node 381 this.allRetained(msgs) 382 } else if level == SWC { 383 // If '+', check all nodes at this level. Next levels must be matched. 384 for _, n := range this.rnodes { 385 if err := n.rmatch(rem, msgs); err != nil { 386 return err 387 } 388 } 389 } else { 390 // Otherwise, find the matching node, go to the next level 391 if n, ok := this.rnodes[level]; ok { 392 if err := n.rmatch(rem, msgs); err != nil { 393 return err 394 } 395 } 396 } 397 398 return nil 399 } 400 401 func (this *rnode) allRetained(msgs *[]*packets.PublishPacket) { 402 if this.msg != nil { 403 *msgs = append(*msgs, this.msg) 404 } 405 406 for _, n := range this.rnodes { 407 n.allRetained(msgs) 408 } 409 } 410 411 const ( 412 stateCHR byte = iota // Regular character 413 stateMWC // Multi-level wildcard 414 stateSWC // Single-level wildcard 415 stateSEP // Topic level separator 416 stateSYS // System level topic ($) 417 ) 418 419 // Returns topic level, remaining topic levels and any errors 420 func nextTopicLevel(topic []byte) ([]byte, []byte, error) { 421 s := stateCHR 422 423 for i, c := range topic { 424 switch c { 425 case '/': 426 if s == stateMWC { 427 return nil, nil, fmt.Errorf("Multi-level wildcard found in topic and it's not at the last level") 428 } 429 430 if i == 0 { 431 return []byte(SWC), topic[i+1:], nil 432 } 433 434 return topic[:i], topic[i+1:], nil 435 436 case '#': 437 if i != 0 { 438 return nil, nil, fmt.Errorf("Wildcard character '#' must occupy entire topic level") 439 } 440 441 s = stateMWC 442 443 case '+': 444 if i != 0 { 445 return nil, nil, fmt.Errorf("Wildcard character '+' must occupy entire topic level") 446 } 447 448 s = stateSWC 449 450 // case '$': 451 // if i == 0 { 452 // return nil, nil, fmt.Errorf("Cannot publish to $ topics") 453 // } 454 455 // s = stateSYS 456 457 default: 458 if s == stateMWC || s == stateSWC { 459 return nil, nil, fmt.Errorf("Wildcard characters '#' and '+' must occupy entire topic level") 460 } 461 462 s = stateCHR 463 } 464 } 465 466 // If we got here that means we didn't hit the separator along the way, so the 467 // topic is either empty, or does not contain a separator. Either way, we return 468 // the full topic 469 return topic, nil, nil 470 } 471 472 // The QoS of the payload messages sent in response to a subscription must be the 473 // minimum of the QoS of the originally published message (in this case, it's the 474 // qos parameter) and the maximum QoS granted by the server (in this case, it's 475 // the QoS in the topic tree). 476 // 477 // It's also possible that even if the topic matches, the subscriber is not included 478 // due to the QoS granted is lower than the published message QoS. For example, 479 // if the client is granted only QoS 0, and the publish message is QoS 1, then this 480 // client is not to be send the published message. 481 func (this *snode) matchQos(qos byte, subs *[]interface{}, qoss *[]byte) { 482 for _, sub := range this.subs { 483 // If the published QoS is higher than the subscriber QoS, then we skip the 484 // subscriber. Otherwise, add to the list. 485 // if qos >= this.qos[i] { 486 *subs = append(*subs, sub) 487 *qoss = append(*qoss, qos) 488 // } 489 } 490 } 491 492 func equal(k1, k2 interface{}) bool { 493 if reflect.TypeOf(k1) != reflect.TypeOf(k2) { 494 return false 495 } 496 497 if reflect.ValueOf(k1).Kind() == reflect.Func { 498 return &k1 == &k2 499 } 500 501 if k1 == k2 { 502 return true 503 } 504 505 switch k1 := k1.(type) { 506 case string: 507 return k1 == k2.(string) 508 509 case int64: 510 return k1 == k2.(int64) 511 512 case int32: 513 return k1 == k2.(int32) 514 515 case int16: 516 return k1 == k2.(int16) 517 518 case int8: 519 return k1 == k2.(int8) 520 521 case int: 522 return k1 == k2.(int) 523 524 case float32: 525 return k1 == k2.(float32) 526 527 case float64: 528 return k1 == k2.(float64) 529 530 case uint: 531 return k1 == k2.(uint) 532 533 case uint8: 534 return k1 == k2.(uint8) 535 536 case uint16: 537 return k1 == k2.(uint16) 538 539 case uint32: 540 return k1 == k2.(uint32) 541 542 case uint64: 543 return k1 == k2.(uint64) 544 545 case uintptr: 546 return k1 == k2.(uintptr) 547 } 548 549 return false 550 }