github.com/Jeffail/benthos/v3@v3.65.0/lib/input/reader/mqtt.go (about)

     1  package reader
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/Jeffail/benthos/v3/internal/mqttconf"
    12  	"github.com/Jeffail/benthos/v3/lib/log"
    13  	"github.com/Jeffail/benthos/v3/lib/message"
    14  	"github.com/Jeffail/benthos/v3/lib/metrics"
    15  	"github.com/Jeffail/benthos/v3/lib/types"
    16  	"github.com/Jeffail/benthos/v3/lib/util/tls"
    17  	mqtt "github.com/eclipse/paho.mqtt.golang"
    18  	gonanoid "github.com/matoous/go-nanoid/v2"
    19  )
    20  
    21  //------------------------------------------------------------------------------
    22  
    23  // MQTTConfig contains configuration fields for the MQTT input type.
    24  type MQTTConfig struct {
    25  	URLs                   []string      `json:"urls" yaml:"urls"`
    26  	QoS                    uint8         `json:"qos" yaml:"qos"`
    27  	Topics                 []string      `json:"topics" yaml:"topics"`
    28  	ClientID               string        `json:"client_id" yaml:"client_id"`
    29  	DynamicClientIDSuffix  string        `json:"dynamic_client_id_suffix" yaml:"dynamic_client_id_suffix"`
    30  	Will                   mqttconf.Will `json:"will" yaml:"will"`
    31  	CleanSession           bool          `json:"clean_session" yaml:"clean_session"`
    32  	User                   string        `json:"user" yaml:"user"`
    33  	Password               string        `json:"password" yaml:"password"`
    34  	ConnectTimeout         string        `json:"connect_timeout" yaml:"connect_timeout"`
    35  	StaleConnectionTimeout string        `json:"stale_connection_timeout" yaml:"stale_connection_timeout"`
    36  	KeepAlive              int64         `json:"keepalive" yaml:"keepalive"`
    37  	TLS                    tls.Config    `json:"tls" yaml:"tls"`
    38  }
    39  
    40  // NewMQTTConfig creates a new MQTTConfig with default values.
    41  func NewMQTTConfig() MQTTConfig {
    42  	return MQTTConfig{
    43  		URLs:                   []string{"tcp://localhost:1883"},
    44  		QoS:                    1,
    45  		Topics:                 []string{"benthos_topic"},
    46  		ClientID:               "benthos_input",
    47  		Will:                   mqttconf.EmptyWill(),
    48  		CleanSession:           true,
    49  		User:                   "",
    50  		Password:               "",
    51  		ConnectTimeout:         "30s",
    52  		StaleConnectionTimeout: "",
    53  		KeepAlive:              30,
    54  		TLS:                    tls.NewConfig(),
    55  	}
    56  }
    57  
    58  //------------------------------------------------------------------------------
    59  
    60  // MQTT is an input type that reads MQTT Pub/Sub messages.
    61  type MQTT struct {
    62  	client  mqtt.Client
    63  	msgChan chan mqtt.Message
    64  	cMut    sync.Mutex
    65  
    66  	connectTimeout         time.Duration
    67  	staleConnectionTimeout time.Duration
    68  
    69  	conf MQTTConfig
    70  
    71  	interruptChan chan struct{}
    72  
    73  	urls []string
    74  
    75  	stats metrics.Type
    76  	log   log.Modular
    77  }
    78  
    79  // NewMQTT creates a new MQTT input type.
    80  func NewMQTT(
    81  	conf MQTTConfig, log log.Modular, stats metrics.Type,
    82  ) (*MQTT, error) {
    83  	m := &MQTT{
    84  		conf:          conf,
    85  		interruptChan: make(chan struct{}),
    86  		stats:         stats,
    87  		log:           log,
    88  	}
    89  
    90  	var err error
    91  	if m.connectTimeout, err = time.ParseDuration(conf.ConnectTimeout); err != nil {
    92  		return nil, fmt.Errorf("unable to parse connect timeout duration string: %w", err)
    93  	}
    94  	if len(conf.StaleConnectionTimeout) > 0 {
    95  		if m.staleConnectionTimeout, err = time.ParseDuration(conf.StaleConnectionTimeout); err != nil {
    96  			return nil, fmt.Errorf("unable to parse stale connection timeout duration string: %w", err)
    97  		}
    98  	}
    99  
   100  	switch m.conf.DynamicClientIDSuffix {
   101  	case "nanoid":
   102  		nid, err := gonanoid.New()
   103  		if err != nil {
   104  			return nil, fmt.Errorf("failed to generate nanoid: %w", err)
   105  		}
   106  		m.conf.ClientID += nid
   107  	case "":
   108  	default:
   109  		return nil, fmt.Errorf("unknown dynamic_client_id_suffix: %v", m.conf.DynamicClientIDSuffix)
   110  	}
   111  
   112  	if err := m.conf.Will.Validate(); err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	for _, u := range conf.URLs {
   117  		for _, splitURL := range strings.Split(u, ",") {
   118  			if len(splitURL) > 0 {
   119  				m.urls = append(m.urls, splitURL)
   120  			}
   121  		}
   122  	}
   123  
   124  	return m, nil
   125  }
   126  
   127  //------------------------------------------------------------------------------
   128  
   129  // Connect establishes a connection to an MQTT server.
   130  func (m *MQTT) Connect() error {
   131  	return m.ConnectWithContext(context.Background())
   132  }
   133  
   134  // ConnectWithContext establishes a connection to an MQTT server.
   135  func (m *MQTT) ConnectWithContext(ctx context.Context) error {
   136  	m.cMut.Lock()
   137  	defer m.cMut.Unlock()
   138  
   139  	if m.client != nil {
   140  		return nil
   141  	}
   142  
   143  	var msgMut sync.Mutex
   144  	msgChan := make(chan mqtt.Message)
   145  
   146  	closeMsgChan := func() bool {
   147  		msgMut.Lock()
   148  		chanOpen := msgChan != nil
   149  		if chanOpen {
   150  			close(msgChan)
   151  			msgChan = nil
   152  		}
   153  		msgMut.Unlock()
   154  		return chanOpen
   155  	}
   156  
   157  	conf := mqtt.NewClientOptions().
   158  		SetAutoReconnect(false).
   159  		SetClientID(m.conf.ClientID).
   160  		SetCleanSession(m.conf.CleanSession).
   161  		SetConnectTimeout(m.connectTimeout).
   162  		SetKeepAlive(time.Duration(m.conf.KeepAlive) * time.Second).
   163  		SetConnectionLostHandler(func(client mqtt.Client, reason error) {
   164  			client.Disconnect(0)
   165  			closeMsgChan()
   166  			m.log.Errorf("Connection lost due to: %v\n", reason)
   167  		}).
   168  		SetOnConnectHandler(func(c mqtt.Client) {
   169  			topics := make(map[string]byte)
   170  			for _, topic := range m.conf.Topics {
   171  				topics[topic] = m.conf.QoS
   172  			}
   173  
   174  			tok := c.SubscribeMultiple(topics, func(c mqtt.Client, msg mqtt.Message) {
   175  				msgMut.Lock()
   176  				if msgChan != nil {
   177  					select {
   178  					case msgChan <- msg:
   179  					case <-m.interruptChan:
   180  					}
   181  				}
   182  				msgMut.Unlock()
   183  			})
   184  			tok.Wait()
   185  			if err := tok.Error(); err != nil {
   186  				m.log.Errorf("Failed to subscribe to topics '%v': %v\n", m.conf.Topics, err)
   187  				m.log.Errorln("Shutting connection down.")
   188  				closeMsgChan()
   189  			}
   190  		})
   191  
   192  	if m.conf.Will.Enabled {
   193  		conf = conf.SetWill(m.conf.Will.Topic, m.conf.Will.Payload, m.conf.Will.QoS, m.conf.Will.Retained)
   194  	}
   195  
   196  	if m.conf.TLS.Enabled {
   197  		tlsConf, err := m.conf.TLS.Get()
   198  		if err != nil {
   199  			return err
   200  		}
   201  		conf.SetTLSConfig(tlsConf)
   202  	}
   203  
   204  	if m.conf.User != "" {
   205  		conf.SetUsername(m.conf.User)
   206  	}
   207  
   208  	if m.conf.Password != "" {
   209  		conf.SetPassword(m.conf.Password)
   210  	}
   211  
   212  	for _, u := range m.urls {
   213  		conf = conf.AddBroker(u)
   214  	}
   215  
   216  	client := mqtt.NewClient(conf)
   217  
   218  	tok := client.Connect()
   219  	tok.Wait()
   220  	if err := tok.Error(); err != nil {
   221  		return err
   222  	}
   223  
   224  	m.log.Infof("Receiving MQTT messages from topics: %v\n", m.conf.Topics)
   225  
   226  	if m.staleConnectionTimeout == 0 {
   227  		go func() {
   228  			for {
   229  				select {
   230  				case <-time.After(time.Second):
   231  					if !client.IsConnected() {
   232  						if closeMsgChan() {
   233  							m.log.Errorln("Connection lost for unknown reasons.")
   234  						}
   235  						return
   236  					}
   237  				case <-m.interruptChan:
   238  					return
   239  				}
   240  			}
   241  		}()
   242  	}
   243  
   244  	m.client = client
   245  	m.msgChan = msgChan
   246  	return nil
   247  }
   248  
   249  // ReadWithContext attempts to read a new message from an MQTT broker.
   250  func (m *MQTT) ReadWithContext(ctx context.Context) (types.Message, AsyncAckFn, error) {
   251  	m.cMut.Lock()
   252  	msgChan := m.msgChan
   253  	m.cMut.Unlock()
   254  
   255  	if msgChan == nil {
   256  		return nil, nil, types.ErrNotConnected
   257  	}
   258  
   259  	var staleTimer *time.Timer
   260  	var staleChan <-chan time.Time
   261  	if m.staleConnectionTimeout > 0 {
   262  		staleTimer = time.NewTimer(m.staleConnectionTimeout)
   263  		staleChan = staleTimer.C
   264  		defer staleTimer.Stop()
   265  	}
   266  
   267  	select {
   268  	case <-staleChan:
   269  		m.log.Errorln("Stale connection timeout triggered, re-establishing connection to broker.")
   270  		m.cMut.Lock()
   271  		m.client.Disconnect(0)
   272  		m.msgChan = nil
   273  		m.client = nil
   274  		m.cMut.Unlock()
   275  		return nil, nil, types.ErrNotConnected
   276  	case msg, open := <-msgChan:
   277  		if !open {
   278  			m.cMut.Lock()
   279  			m.msgChan = nil
   280  			m.client = nil
   281  			m.cMut.Unlock()
   282  			return nil, nil, types.ErrNotConnected
   283  		}
   284  
   285  		message := message.New([][]byte{msg.Payload()})
   286  
   287  		meta := message.Get(0).Metadata()
   288  		meta.Set("mqtt_duplicate", strconv.FormatBool(msg.Duplicate()))
   289  		meta.Set("mqtt_qos", strconv.Itoa(int(msg.Qos())))
   290  		meta.Set("mqtt_retained", strconv.FormatBool(msg.Retained()))
   291  		meta.Set("mqtt_topic", msg.Topic())
   292  		meta.Set("mqtt_message_id", strconv.Itoa(int(msg.MessageID())))
   293  
   294  		return message, func(ctx context.Context, res types.Response) error {
   295  			if res.Error() == nil {
   296  				msg.Ack()
   297  			}
   298  			return nil
   299  		}, nil
   300  	case <-ctx.Done():
   301  	case <-m.interruptChan:
   302  		return nil, nil, types.ErrTypeClosed
   303  	}
   304  	return nil, nil, types.ErrTimeout
   305  }
   306  
   307  // Read attempts to read a new message from an MQTT broker.
   308  func (m *MQTT) Read() (types.Message, error) {
   309  	msg, _, err := m.ReadWithContext(context.Background())
   310  	return msg, err
   311  }
   312  
   313  // Acknowledge instructs whether messages have been successfully propagated.
   314  func (m *MQTT) Acknowledge(err error) error {
   315  	return nil
   316  }
   317  
   318  // CloseAsync shuts down the MQTT input and stops processing requests.
   319  func (m *MQTT) CloseAsync() {
   320  	m.cMut.Lock()
   321  	if m.client != nil {
   322  		m.client.Disconnect(0)
   323  		m.client = nil
   324  		close(m.interruptChan)
   325  	}
   326  	m.cMut.Unlock()
   327  }
   328  
   329  // WaitForClose blocks until the MQTT input has closed down.
   330  func (m *MQTT) WaitForClose(timeout time.Duration) error {
   331  	return nil
   332  }
   333  
   334  //------------------------------------------------------------------------------