github.com/emitter-io/go/v2@v2.1.0/emitter.go (about)

     1  package emitter
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"log"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	mqtt "github.com/eclipse/paho.mqtt.golang"
    13  )
    14  
    15  // Various emitter errors
    16  var (
    17  	ErrTimeout   = errors.New("emitter: operation has timed out")
    18  	ErrUnmarshal = errors.New("emitter: unable to unmarshal the response")
    19  )
    20  
    21  // Message defines the externals that a message implementation must support
    22  // these are received messages that are passed to the callbacks, not internal
    23  // messages
    24  type Message interface {
    25  	Topic() string
    26  	Payload() []byte
    27  }
    28  
    29  // Client represents an emitter client which holds the connection.
    30  type Client struct {
    31  	sync.RWMutex
    32  	guid       string              // Emiter's client ID
    33  	conn       mqtt.Client         // MQTT client
    34  	opts       *mqtt.ClientOptions // MQTT options
    35  	store      *store              // In-flight requests store
    36  	handlers   *trie               // The registry for handlers
    37  	timeout    time.Duration       // Default timeout
    38  	message    MessageHandler      // User-defined message handler
    39  	connect    ConnectHandler      // User-defined connect handler
    40  	disconnect DisconnectHandler   // User-defined disconnect handler
    41  	presence   PresenceHandler     // User-defined presence handler
    42  	errors     ErrorHandler        // User-defined error handler
    43  }
    44  
    45  // Connect is a convenience function which sets a broker and connects to it.
    46  func Connect(host string, handler MessageHandler, options ...func(*Client)) (*Client, error) {
    47  	if len(host) > 0 {
    48  		options = append(options, WithBrokers(host))
    49  	}
    50  
    51  	// Create the client and handlers
    52  	client := NewClient(options...)
    53  	client.OnMessage(handler)
    54  
    55  	// Connect to the broker
    56  	err := client.Connect()
    57  	return client, err
    58  }
    59  
    60  // NewClient will create an MQTT v3.1.1 client with all of the options specified
    61  // in the provided ClientOptions. The client must have the Connect method called
    62  // on it before it may be used. This is to make sure resources (such as a net
    63  // connection) are created before the application is actually ready.
    64  func NewClient(options ...func(*Client)) *Client {
    65  	c := &Client{
    66  		opts:     mqtt.NewClientOptions(),
    67  		timeout:  60 * time.Second,
    68  		store:    new(store),
    69  		handlers: NewTrie(),
    70  	}
    71  
    72  	// Set handlers
    73  	c.opts.SetOnConnectHandler(c.onConnect)
    74  	c.opts.SetConnectionLostHandler(c.onConnectionLost)
    75  	c.opts.SetDefaultPublishHandler(c.onMessage)
    76  	c.opts.SetClientID(uuid())
    77  	c.opts.SetStore(c.store)
    78  
    79  	// Apply default configuration
    80  	WithBrokers("tcp://api.emitter.io:8080")(c)
    81  
    82  	// Apply options
    83  	for _, opt := range options {
    84  		opt(c)
    85  	}
    86  
    87  	// Create the underlying MQTT client and set the options
    88  	c.conn = mqtt.NewClient(c.opts)
    89  	return c
    90  }
    91  
    92  // OnMessage sets the MessageHandler that will be called when a message
    93  // is received that does not match any known subscriptions.
    94  func (c *Client) OnMessage(handler MessageHandler) {
    95  	c.message = handler
    96  }
    97  
    98  // OnConnect sets the function to be called when the client is connected. Both
    99  // at initial connection time and upon automatic reconnect.
   100  func (c *Client) OnConnect(handler ConnectHandler) {
   101  	c.connect = handler
   102  }
   103  
   104  // OnDisconnect will set the function callback to be executed
   105  // in the case where the client unexpectedly loses connection with the MQTT broker.
   106  func (c *Client) OnDisconnect(handler DisconnectHandler) {
   107  	c.disconnect = handler
   108  }
   109  
   110  // OnPresence sets the function that will be called when a presence event is received.
   111  func (c *Client) OnPresence(handler PresenceHandler) {
   112  	c.presence = handler
   113  }
   114  
   115  // onConnect occurs when MQTT client is connected
   116  func (c *Client) onConnect(_ mqtt.Client) {
   117  	if c.connect != nil {
   118  		c.connect(c)
   119  	}
   120  }
   121  
   122  // onConnectionLost occurs when MQTT client is disconnected
   123  func (c *Client) onConnectionLost(_ mqtt.Client, e error) {
   124  	if c.disconnect != nil {
   125  		c.disconnect(c, e)
   126  	} else {
   127  		log.Println("emitter: connection lost, due to", e.Error())
   128  	}
   129  }
   130  
   131  // OnError will set the function callback to be executed if an emitter-specific
   132  // error occurs.
   133  func (c *Client) OnError(handler ErrorHandler) {
   134  
   135  	c.errors = handler
   136  }
   137  
   138  // onMessage occurs when MQTT client receives a message
   139  func (c *Client) onMessage(_ mqtt.Client, m mqtt.Message) {
   140  	if !strings.HasPrefix(m.Topic(), "emitter/") {
   141  		handlers := c.handlers.Lookup(m.Topic())
   142  		if len(handlers) == 0 && c.message != nil { // Invoke the default message handler
   143  			c.message(c, m)
   144  		}
   145  
   146  		// Call each handler
   147  		for _, h := range handlers {
   148  			h(c, m)
   149  		}
   150  		return
   151  	}
   152  
   153  	// `onError` and `onResponse` read the callbacks store when calling
   154  	// the `NotifyResponse`. See the comments in the `request` function.
   155  	c.RLock()
   156  	defer c.RUnlock()
   157  
   158  	switch {
   159  
   160  	// Dispatch presence handler
   161  	case c.presence != nil && strings.HasPrefix(m.Topic(), "emitter/presence/"):
   162  		var msg presenceMessage
   163  		if err := json.Unmarshal(m.Payload(), &msg); err != nil {
   164  			log.Println("emitter:", err.Error())
   165  		}
   166  
   167  		r := PresenceEvent{msg, make([]PresenceInfo, 0)}
   168  		if msg.Event == "status" {
   169  			if err := json.Unmarshal([]byte(msg.Who), &r.Who); err != nil {
   170  				log.Println("emitter:", err.Error())
   171  			}
   172  		} else {
   173  			r.Who = append(r.Who, PresenceInfo{})
   174  			if err := json.Unmarshal([]byte(msg.Who), &r.Who[0]); err != nil {
   175  				log.Println("emitter:", err.Error())
   176  			}
   177  		}
   178  
   179  		c.presence(c, r)
   180  
   181  	// Dispatch errors handler
   182  	case strings.HasPrefix(m.Topic(), "emitter/error/"):
   183  		c.onError(m)
   184  
   185  	// Dispatch keygen handler
   186  	case strings.HasPrefix(m.Topic(), "emitter/keygen/"):
   187  		c.onResponse(m, new(keyGenResponse))
   188  
   189  	// Dispatch keyban handler
   190  	case strings.HasPrefix(m.Topic(), "emitter/keyban/"):
   191  		c.onResponse(m, new(keyBanResponse))
   192  
   193  	// Dispatch link handler
   194  	case strings.HasPrefix(m.Topic(), "emitter/link/"):
   195  		c.onResponse(m, new(Link))
   196  
   197  	// Dispatch me handler
   198  	case strings.HasPrefix(m.Topic(), "emitter/me/"):
   199  		c.onResponse(m, new(meResponse))
   200  
   201  	// Dispatch history handler
   202  	case strings.HasPrefix(m.Topic(), "emitter/history/"):
   203  		c.onResponse(m, new(historyResponse))
   204  
   205  	default:
   206  
   207  	}
   208  }
   209  
   210  // OnResponse handles the incoming response for emitter messages.
   211  func (c *Client) onResponse(m mqtt.Message, resp Response) bool {
   212  
   213  	// Check if we've got an error response
   214  	var errResponse Error
   215  	if err := json.Unmarshal(m.Payload(), &errResponse); err == nil && errResponse.Error() != "" {
   216  		return c.store.NotifyResponse(errResponse.RequestID(), &errResponse)
   217  	}
   218  
   219  	// If it's not an error, try to unmarshal the response
   220  	if err := json.Unmarshal(m.Payload(), &resp); err == nil && resp.RequestID() > 0 {
   221  		return c.store.NotifyResponse(resp.RequestID(), resp)
   222  	}
   223  	return false
   224  }
   225  
   226  // OnError handles the incoming error.
   227  func (c *Client) onError(m mqtt.Message) {
   228  	var resp Error
   229  	if err := json.Unmarshal(m.Payload(), &resp); err != nil {
   230  		return
   231  	}
   232  
   233  	if c.errors == nil {
   234  		log.Println("emitter:", resp.Error())
   235  	}
   236  
   237  	if c.errors != nil && !c.store.NotifyResponse(resp.RequestID(), &resp) {
   238  		c.errors(c, resp)
   239  	}
   240  }
   241  
   242  // IsConnected returns a bool signifying whether the client is connected or not.
   243  func (c *Client) IsConnected() bool {
   244  	return c.conn.IsConnected()
   245  }
   246  
   247  // Connect initiates a connection to the broker.
   248  func (c *Client) Connect() error {
   249  	return c.do(c.conn.Connect())
   250  }
   251  
   252  // ID retrieves information about the client.
   253  func (c *Client) ID() string {
   254  	if c.guid != "" {
   255  		return c.guid
   256  	}
   257  
   258  	// Query the remote GUID, cast the response and store it
   259  	if resp, err := c.request("me", nil); err == nil {
   260  		if result, ok := resp.(*meResponse); ok {
   261  			c.guid = result.ID
   262  		}
   263  	}
   264  
   265  	return c.guid
   266  }
   267  
   268  // Disconnect will end the connection with the server, but not before waiting
   269  // the specified number of milliseconds to wait for existing work to be
   270  // completed.
   271  func (c *Client) Disconnect(waitTime time.Duration) {
   272  	c.conn.Disconnect(uint(waitTime.Nanoseconds() / 1000000))
   273  }
   274  
   275  // Publish will publish a message with the specified QoS and content to the specified topic.
   276  // Returns a token to track delivery of the message to the broker
   277  func (c *Client) Publish(key string, channel string, payload interface{}, options ...Option) error {
   278  	qos, retain := getHeader(options)
   279  	token := c.conn.Publish(formatTopic(key, channel, options), qos, retain, payload)
   280  	return c.do(token)
   281  }
   282  
   283  // PublishWithTTL publishes a message with a specified Time-To-Live option
   284  func (c *Client) PublishWithTTL(key string, channel string, payload interface{}, ttl int) error {
   285  	return c.Publish(key, channel, payload, WithTTL(ttl))
   286  }
   287  
   288  // PublishWithRetain publishes a message with a retain flag set to true
   289  func (c *Client) PublishWithRetain(key string, channel string, payload interface{}, options ...Option) error {
   290  	options = append(options, WithRetain())
   291  	return c.Publish(key, channel, payload, options...)
   292  }
   293  
   294  // PublishWithLink publishes a message with a specified link name instead of a channel key.
   295  func (c *Client) PublishWithLink(name string, payload interface{}, options ...Option) error {
   296  	qos, retain := getHeader(options)
   297  	token := c.conn.Publish(name, qos, retain, payload)
   298  	return c.do(token)
   299  }
   300  
   301  // Subscribe starts a new subscription. Provide a MessageHandler to be executed when
   302  // a message is published on the topic provided.
   303  func (c *Client) Subscribe(key string, channel string, optionalHandler MessageHandler, options ...Option) error {
   304  	if optionalHandler != nil {
   305  		c.handlers.AddHandler(channel, optionalHandler)
   306  	}
   307  
   308  	// https://github.com/eclipse/paho.mqtt.golang/blob/master/topic.go#L78
   309  	topic := strings.ReplaceAll(formatTopic(key, channel, options), "#/", "#")
   310  
   311  	// Issue subscribe
   312  	token := c.conn.Subscribe(topic, 0, nil)
   313  	return c.do(token)
   314  }
   315  
   316  // SubscribeWithGroup creates a shared subscription to a share group.
   317  func (c *Client) SubscribeWithGroup(key, channel, shareGroup string, optionalHandler MessageHandler, options ...Option) error {
   318  	if optionalHandler != nil {
   319  		c.handlers.AddHandler(channel, optionalHandler)
   320  	}
   321  
   322  	// Issue subscribe
   323  	token := c.conn.Subscribe(formatShare(key, shareGroup, channel, options), 0, nil)
   324  	return c.do(token)
   325  }
   326  
   327  // SubscribeWithHistory performs a subscribe with an option to retrieve the specified number
   328  // of messages that were already published in the channel.
   329  func (c *Client) SubscribeWithHistory(key string, channel string, last int, optionalHandler MessageHandler) error {
   330  	return c.Subscribe(key, channel, optionalHandler, WithLast(last))
   331  }
   332  
   333  // Unsubscribe will end the subscription from each of the topics provided.
   334  // Messages published to those topics from other clients will no longer be
   335  // received.
   336  func (c *Client) Unsubscribe(key string, channel string) error {
   337  
   338  	// Remove the handler if we have one
   339  	c.handlers.RemoveHandler(channel)
   340  
   341  	// Issue the unsubscribe
   342  	token := c.conn.Unsubscribe(formatTopic(key, channel, nil))
   343  	return c.do(token)
   344  }
   345  
   346  // Presence sends a presence request to the broker.
   347  func (c *Client) Presence(key, channel string, status, changes bool) error {
   348  	req, err := json.Marshal(&presenceRequest{
   349  		Key:     key,
   350  		Channel: channel,
   351  		Status:  status,
   352  		Changes: changes,
   353  	})
   354  	if err != nil {
   355  		return err
   356  	}
   357  
   358  	return c.do(c.conn.Publish("emitter/presence/", 1, false, req))
   359  }
   360  
   361  // GenerateKey sends a key generation request to the broker
   362  func (c *Client) GenerateKey(key, channel, permissions string, ttl int) (string, string, error) {
   363  	resp, err := c.request("keygen", &keygenRequest{
   364  		Key:     key,
   365  		Channel: channel,
   366  		Type:    permissions,
   367  		TTL:     ttl,
   368  	})
   369  	if err != nil {
   370  		return "", "", err
   371  	}
   372  
   373  	// Cast the response and return it
   374  	if result, ok := resp.(*keyGenResponse); ok {
   375  		return result.Key, result.Channel, nil
   376  	}
   377  	return "", "", ErrUnmarshal
   378  }
   379  
   380  // BlockKey sends a request to block a key.
   381  func (c *Client) BlockKey(secretKey, targetKey string) (bool, error) {
   382  	resp, err := c.request("keyban", &keybanRequest{
   383  		Secret: secretKey,
   384  		Target: targetKey,
   385  		Banned: true,
   386  	})
   387  	if err != nil {
   388  		return false, err
   389  	}
   390  
   391  	// Cast the response and return it
   392  	if result, ok := resp.(*keyBanResponse); ok {
   393  		return result.Banned == true, nil
   394  	}
   395  	return false, ErrUnmarshal
   396  }
   397  
   398  // AllowKey sends a request to allow a previously blocked key.
   399  func (c *Client) AllowKey(secretKey, targetKey string) (bool, error) {
   400  	resp, err := c.request("keyban", &keybanRequest{
   401  		Secret: secretKey,
   402  		Target: targetKey,
   403  		Banned: false,
   404  	})
   405  	if err != nil {
   406  		return false, err
   407  	}
   408  
   409  	// Cast the response and return it
   410  	if result, ok := resp.(*keyBanResponse); ok {
   411  		return result.Banned == false, nil
   412  	}
   413  	return false, ErrUnmarshal
   414  }
   415  
   416  // CreateLink sends a request to create a default link.
   417  func (c *Client) CreateLink(key, channel, name string, optionalHandler MessageHandler, options ...Option) (*Link, error) {
   418  	resp, err := c.request("link", &linkRequest{
   419  		Name:      name,
   420  		Key:       key,
   421  		Channel:   formatTopic("", channel, options),
   422  		Subscribe: optionalHandler != nil,
   423  	})
   424  
   425  	if err != nil {
   426  		return nil, err
   427  	}
   428  
   429  	// Cast the response and return it
   430  	if result, ok := resp.(*Link); ok {
   431  		if optionalHandler != nil {
   432  			c.handlers.AddHandler(result.Channel, optionalHandler)
   433  		}
   434  
   435  		return result, nil
   436  	}
   437  	return nil, ErrUnmarshal
   438  }
   439  
   440  func (c *Client) History(key, channel string, from, until int64, limit int) func(func(m HistoryMessage, err error) bool) {
   441  	return func(yield func(m HistoryMessage, err error) bool) {
   442  		{
   443  			var startFromID MessageID = nil
   444  			nMsgRetrieved := 0
   445  
   446  			for {
   447  				req := &historyRequest{
   448  					// As an MQTT message size is limited, we need to paginate the history request.
   449  					// If we didn't receive all messages (limit) we request limit - nMsgRetrieved messages.
   450  					// We also set the startFromID to the message ID of the last message we received to
   451  					// continue retrieving them from this one.
   452  					Channel:     formatTopic(key, channel, []Option{WithLast(limit - nMsgRetrieved), WithFrom(time.Unix(from, 0)), WithUntil(time.Unix(until, 0))}),
   453  					StartFromID: startFromID,
   454  				}
   455  
   456  				resp, err := c.request("history", req)
   457  				if err != nil {
   458  					yield(HistoryMessage{}, err)
   459  				}
   460  
   461  				// Cast the response.
   462  				result, _ := resp.(*historyResponse)
   463  
   464  				// If no messages left in the history then return.
   465  				if len(result.Messages) == 0 {
   466  					return
   467  				}
   468  
   469  				// Yield each message returned by the history request.
   470  				for i := 0; i <= len(result.Messages)-1; i++ {
   471  					if !yield(result.Messages[i], nil) {
   472  						return
   473  					}
   474  				}
   475  
   476  				nMsgRetrieved += len(result.Messages)
   477  				// If we received the number of messages requested then return.
   478  				if nMsgRetrieved >= limit {
   479  					return
   480  				}
   481  
   482  				// As an MQTT message size is limited, we need to paginate the history request.
   483  				// In case we have more messages to retrieve, set the startFromID to the message ID
   484  				// of the last message we received to continue retrieving them from this one.
   485  				startFromID = result.Messages[0].ID
   486  			}
   487  		}
   488  	}
   489  }
   490  
   491  // Makes a request
   492  func (c *Client) request(operation string, req interface{}) (Response, error) {
   493  	request, err := json.Marshal(req)
   494  	if err != nil {
   495  		panic("unable to encode the request")
   496  	}
   497  
   498  	// Publish and wait for an error, response or puback
   499  	// The client is locked until the callback is stored, so the response
   500  	// cannot arrive before and be lost
   501  	c.Lock()
   502  	token := c.conn.Publish(fmt.Sprintf("emitter/%s/", operation), 1, false, request)
   503  	respChan := c.store.PutCallback(token.(*mqtt.PublishToken).MessageID())
   504  	c.Unlock()
   505  	if err := c.do(token); err != nil {
   506  		return nil, err
   507  	}
   508  	resp := <-respChan
   509  	if err, ok := resp.(error); ok {
   510  		return nil, err
   511  	}
   512  	return resp, nil
   513  }
   514  
   515  // do waits for the operation to complete
   516  func (c *Client) do(t mqtt.Token) error {
   517  	if !t.WaitTimeout(c.timeout) {
   518  		return ErrTimeout
   519  	}
   520  
   521  	return t.Error()
   522  }
   523  
   524  // Makes a topic name from the key/channel pair
   525  func formatTopic(key, channel string, options []Option) string {
   526  	key = trim(key)
   527  	channel = trim(channel)
   528  	opts := formatOptions(options)
   529  	if len(key) == 0 {
   530  		return fmt.Sprintf("%s/%s", channel, opts)
   531  	}
   532  
   533  	return fmt.Sprintf("%s/%s/%s", key, channel, opts)
   534  }
   535  
   536  // formatShare creates a shared topic subscription
   537  func formatShare(key, shareGroup, channel string, options []Option) string {
   538  	return fmt.Sprintf("%s/$share/%s/%s/%s", trim(key), trim(shareGroup), trim(channel), formatOptions(options))
   539  }
   540  
   541  // getHeader gets the header fields from options.
   542  func getHeader(options []Option) (qos byte, retain bool) {
   543  	for _, o := range options {
   544  		switch o {
   545  		case withRetain:
   546  			retain = true
   547  		case withQos0:
   548  			qos = 0
   549  		case withQos1:
   550  			qos = 1
   551  		}
   552  	}
   553  	return
   554  }
   555  
   556  // formatOptions formats a set of options, ignoring the reserved ones
   557  func formatOptions(options []Option) string {
   558  	opts, hasOpts := "", false
   559  	if options != nil && len(options) > 0 {
   560  		for _, option := range options {
   561  			opt := option.String()
   562  			if opt[0] == '+' {
   563  				continue
   564  			}
   565  
   566  			if !hasOpts {
   567  				hasOpts = true
   568  				opts += "?"
   569  			} else {
   570  				opts += "&"
   571  			}
   572  
   573  			opts += opt
   574  		}
   575  	}
   576  	return opts
   577  }
   578  
   579  // Trim removes both suffix and prefix
   580  func trim(v string) string {
   581  	return strings.TrimSuffix(strings.TrimPrefix(v, "/"), "/")
   582  }