github.com/e154/smart-home@v0.17.2-0.20240311175135-e530a6e5cd45/system/mqtt/mqtt.go (about)

     1  // This file is part of the Smart Home
     2  // Program complex distribution https://github.com/e154/smart-home
     3  // Copyright (C) 2016-2023, Filippov Alex
     4  //
     5  // This library is free software: you can redistribute it and/or
     6  // modify it under the terms of the GNU Lesser General Public
     7  // License as published by the Free Software Foundation; either
     8  // version 3 of the License, or (at your option) any later version.
     9  //
    10  // This library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    13  // Library General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public
    16  // License along with this library.  If not, see
    17  // <https://www.gnu.org/licenses/>.
    18  
    19  package mqtt
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"net"
    25  	"os"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/DrmagicE/gmqtt"
    30  	_ "github.com/DrmagicE/gmqtt/persistence"
    31  	"github.com/DrmagicE/gmqtt/pkg/codes"
    32  	"github.com/DrmagicE/gmqtt/pkg/packets"
    33  	"github.com/DrmagicE/gmqtt/server"
    34  	_ "github.com/DrmagicE/gmqtt/topicalias/fifo"
    35  	"go.uber.org/fx"
    36  	"go.uber.org/zap"
    37  	"go.uber.org/zap/zapcore"
    38  
    39  	"github.com/e154/smart-home/common"
    40  	"github.com/e154/smart-home/common/events"
    41  	"github.com/e154/smart-home/common/logger"
    42  	"github.com/e154/smart-home/system/bus"
    43  	"github.com/e154/smart-home/system/logging"
    44  	"github.com/e154/smart-home/system/mqtt/admin"
    45  	"github.com/e154/smart-home/system/mqtt_authenticator"
    46  	"github.com/e154/smart-home/system/scripts"
    47  )
    48  
    49  var (
    50  	log = logger.MustGetLogger("mqtt")
    51  )
    52  
    53  // Mqtt ...
    54  type Mqtt struct {
    55  	cfg           *Config
    56  	server        GMqttServer
    57  	authenticator mqtt_authenticator.MqttAuthenticator
    58  	isStarted     bool
    59  	clientsLock   *sync.Mutex
    60  	clients       map[string]MqttCli
    61  	admin         *admin.Admin
    62  	scriptService scripts.ScriptService
    63  	eventBus      bus.Bus
    64  }
    65  
    66  // NewMqtt ...
    67  func NewMqtt(lc fx.Lifecycle,
    68  	cfg *Config,
    69  	authenticator mqtt_authenticator.MqttAuthenticator,
    70  	scriptService scripts.ScriptService,
    71  	eventBus bus.Bus) (mqtt MqttServ) {
    72  
    73  	mqtt = &Mqtt{
    74  		cfg:           cfg,
    75  		authenticator: authenticator,
    76  		clientsLock:   &sync.Mutex{},
    77  		clients:       make(map[string]MqttCli),
    78  		admin:         admin.New(),
    79  		scriptService: scriptService,
    80  		eventBus:      eventBus,
    81  	}
    82  
    83  	lc.Append(fx.Hook{
    84  		OnStart: func(ctx context.Context) (err error) {
    85  			mqtt.Start()
    86  			return nil
    87  		},
    88  		OnStop: func(ctx context.Context) (err error) {
    89  			return mqtt.Shutdown()
    90  		},
    91  	})
    92  
    93  	return
    94  }
    95  
    96  // Shutdown ...
    97  func (m *Mqtt) Shutdown() (err error) {
    98  	if !m.isStarted {
    99  		return
   100  	}
   101  
   102  	log.Info("Server exiting")
   103  
   104  	m.scriptService.PopStruct("Mqtt")
   105  
   106  	m.clientsLock.Lock()
   107  	for name, cli := range m.clients {
   108  		cli.UnsubscribeAll()
   109  		delete(m.clients, name)
   110  	}
   111  	m.clientsLock.Unlock()
   112  
   113  	if m.server != nil {
   114  		ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
   115  		err = m.server.Stop(ctx)
   116  	}
   117  
   118  	m.eventBus.Publish("system/services/mqtt", events.EventServiceStopped{Service: "Mqtt"})
   119  	return
   120  }
   121  
   122  // Start ...
   123  func (m *Mqtt) Start() {
   124  
   125  	if m.isStarted {
   126  		return
   127  	}
   128  
   129  	ln, err := net.Listen("tcp", fmt.Sprintf(":%d", m.cfg.Port))
   130  	if err != nil {
   131  		log.Error(err.Error())
   132  	}
   133  
   134  	defer func() {
   135  		if err == nil {
   136  			m.isStarted = true
   137  		}
   138  	}()
   139  
   140  	options := []server.Options{
   141  		server.WithTCPListener(ln),
   142  		server.WithPlugin(m.admin),
   143  		server.WithHook(server.Hooks{
   144  			OnBasicAuth:  m.onBasicAuth,
   145  			OnMsgArrived: m.onMsgArrived,
   146  			OnConnected: func(ctx context.Context, client server.Client) {
   147  				m.eventBus.Publish("system/services/mqtt", events.EventMqttNewClient{
   148  					ClientId: client.ClientOptions().ClientID,
   149  				})
   150  			},
   151  		}),
   152  	}
   153  
   154  	if m.cfg.Logging {
   155  		options = append(options, server.WithLogger(m.logging()))
   156  	}
   157  
   158  	// Create a new server
   159  	m.server = server.New(options...)
   160  
   161  	log.Infof("Serving MQTT server at tcp://[::]:%d", m.cfg.Port)
   162  
   163  	m.scriptService.PushStruct("Mqtt", NewMqttBind(m))
   164  
   165  	go func() {
   166  		if err = m.server.Run(); err != nil {
   167  			log.Error(err.Error())
   168  		}
   169  	}()
   170  
   171  	m.eventBus.Publish("system/services/mqtt", events.EventServiceStarted{Service: "Mqtt"})
   172  }
   173  
   174  // OnMsgArrived ...
   175  func (m *Mqtt) onMsgArrived(ctx context.Context, client server.Client, msg *server.MsgArrivedRequest) (err error) {
   176  	m.clientsLock.Lock()
   177  	defer m.clientsLock.Unlock()
   178  
   179  	for _, cli := range m.clients {
   180  		cli.OnMsgArrived(ctx, client, msg)
   181  	}
   182  
   183  	return
   184  }
   185  
   186  // OnConnect ...
   187  func (m *Mqtt) onBasicAuth(ctx context.Context, client server.Client, req *server.ConnectRequest) (err error) {
   188  	log.Debugf("connect client version %v ...", client.Version())
   189  
   190  	username := string(req.Connect.Username)
   191  	password := string(req.Connect.Password)
   192  
   193  	//authentication
   194  	if err = m.authenticator.Authenticate(username, password); err == nil {
   195  		return
   196  	}
   197  
   198  	// check the client version, return a compatible reason code.
   199  	switch client.Version() {
   200  	case packets.Version5:
   201  		return codes.NewError(codes.BadUserNameOrPassword)
   202  	case packets.Version311:
   203  		return codes.NewError(codes.V3BadUsernameorPassword)
   204  	}
   205  	// return nil if pass authentication.
   206  	return nil
   207  }
   208  
   209  // Admin ...
   210  func (m *Mqtt) Admin() Admin {
   211  	return m.admin
   212  }
   213  
   214  // Publish ...
   215  func (m *Mqtt) Publish(topic string, payload []byte, qos uint8, retain bool) (err error) {
   216  	if qos < 0 || qos > 2 {
   217  		err = ErrInvalidQos
   218  		return
   219  	}
   220  	if !packets.ValidTopicFilter(true, []byte(topic)) {
   221  		err = ErrInvalidTopicFilter
   222  		return
   223  	}
   224  	if !packets.ValidUTF8(payload) {
   225  		err = ErrInvalidUtf8String
   226  		return
   227  	}
   228  
   229  	m.server.Publisher().Publish(&gmqtt.Message{
   230  		QoS:      qos,
   231  		Retained: retain,
   232  		Topic:    topic,
   233  		Payload:  payload,
   234  	})
   235  
   236  	// send to local subscribers
   237  	_ = m.onMsgArrived(context.TODO(), nil, &server.MsgArrivedRequest{
   238  		Message: &gmqtt.Message{
   239  			QoS:      qos,
   240  			Retained: retain,
   241  			Topic:    topic,
   242  			Payload:  payload,
   243  		},
   244  	})
   245  	return
   246  }
   247  
   248  // NewClient ...
   249  func (m *Mqtt) NewClient(name string) (client MqttCli) {
   250  	m.clientsLock.Lock()
   251  	defer m.clientsLock.Unlock()
   252  
   253  	var ok bool
   254  	if client, ok = m.clients[name]; ok {
   255  		return
   256  	}
   257  	client = NewClient(m, name)
   258  	m.clients[name] = client
   259  	log.Infof("new mqtt client '%s'", name)
   260  	return
   261  }
   262  
   263  // RemoveClient ...
   264  func (m *Mqtt) RemoveClient(name string) {
   265  	m.clientsLock.Lock()
   266  	defer m.clientsLock.Unlock()
   267  
   268  	var ok bool
   269  	if _, ok = m.clients[name]; !ok {
   270  		return
   271  	}
   272  	delete(m.clients, name)
   273  }
   274  
   275  func (m *Mqtt) logging() *zap.Logger {
   276  
   277  	// First, define our level-handling logic.
   278  	highPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
   279  		return lvl >= zapcore.ErrorLevel
   280  	})
   281  
   282  	lowLevel := zapcore.ErrorLevel
   283  	if m.cfg.DebugMode == common.ReleaseMode {
   284  		lowLevel = zapcore.DebugLevel
   285  	}
   286  	lowPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
   287  		return lvl < lowLevel
   288  	})
   289  
   290  	// High-priority output should also go to standard error, and low-priority
   291  	// output should also go to standard out.
   292  	consoleDebugging := zapcore.Lock(os.Stdout)
   293  	consoleErrors := zapcore.Lock(os.Stderr)
   294  
   295  	var encConfig zapcore.EncoderConfig
   296  	if m.cfg.DebugMode == common.ReleaseMode {
   297  		encConfig = zap.NewProductionEncoderConfig()
   298  	} else {
   299  		encConfig = zap.NewDevelopmentEncoderConfig()
   300  	}
   301  
   302  	encConfig.EncodeTime = nil
   303  	encConfig.EncodeName = logging.CustomNameEncoder
   304  	encConfig.EncodeCaller = logging.CustomCallerEncoder
   305  	consoleEncoder := zapcore.NewConsoleEncoder(encConfig)
   306  
   307  	// Join the outputs, encoders, and level-handling functions into
   308  	// zapcore.Cores, then tee the four cores together.
   309  	core := zapcore.NewTee(
   310  		zapcore.NewCore(consoleEncoder, consoleErrors, highPriority),
   311  		zapcore.NewCore(consoleEncoder, consoleDebugging, lowPriority),
   312  	)
   313  
   314  	// From a zapcore.Core, it's easy to construct a Logger.
   315  	return zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)).Named("mqtt")
   316  }
   317  
   318  // Authenticator ...
   319  func (m *Mqtt) Authenticator() mqtt_authenticator.MqttAuthenticator {
   320  	return m.authenticator
   321  }