github.com/e154/smart-home@v0.17.2-0.20240311175135-e530a6e5cd45/system/mqtt/mqtt.go (about) 1 // This file is part of the Smart Home 2 // Program complex distribution https://github.com/e154/smart-home 3 // Copyright (C) 2016-2023, Filippov Alex 4 // 5 // This library is free software: you can redistribute it and/or 6 // modify it under the terms of the GNU Lesser General Public 7 // License as published by the Free Software Foundation; either 8 // version 3 of the License, or (at your option) any later version. 9 // 10 // This library is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 13 // Library General Public License for more details. 14 // 15 // You should have received a copy of the GNU Lesser General Public 16 // License along with this library. If not, see 17 // <https://www.gnu.org/licenses/>. 18 19 package mqtt 20 21 import ( 22 "context" 23 "fmt" 24 "net" 25 "os" 26 "sync" 27 "time" 28 29 "github.com/DrmagicE/gmqtt" 30 _ "github.com/DrmagicE/gmqtt/persistence" 31 "github.com/DrmagicE/gmqtt/pkg/codes" 32 "github.com/DrmagicE/gmqtt/pkg/packets" 33 "github.com/DrmagicE/gmqtt/server" 34 _ "github.com/DrmagicE/gmqtt/topicalias/fifo" 35 "go.uber.org/fx" 36 "go.uber.org/zap" 37 "go.uber.org/zap/zapcore" 38 39 "github.com/e154/smart-home/common" 40 "github.com/e154/smart-home/common/events" 41 "github.com/e154/smart-home/common/logger" 42 "github.com/e154/smart-home/system/bus" 43 "github.com/e154/smart-home/system/logging" 44 "github.com/e154/smart-home/system/mqtt/admin" 45 "github.com/e154/smart-home/system/mqtt_authenticator" 46 "github.com/e154/smart-home/system/scripts" 47 ) 48 49 var ( 50 log = logger.MustGetLogger("mqtt") 51 ) 52 53 // Mqtt ... 54 type Mqtt struct { 55 cfg *Config 56 server GMqttServer 57 authenticator mqtt_authenticator.MqttAuthenticator 58 isStarted bool 59 clientsLock *sync.Mutex 60 clients map[string]MqttCli 61 admin *admin.Admin 62 scriptService scripts.ScriptService 63 eventBus bus.Bus 64 } 65 66 // NewMqtt ... 67 func NewMqtt(lc fx.Lifecycle, 68 cfg *Config, 69 authenticator mqtt_authenticator.MqttAuthenticator, 70 scriptService scripts.ScriptService, 71 eventBus bus.Bus) (mqtt MqttServ) { 72 73 mqtt = &Mqtt{ 74 cfg: cfg, 75 authenticator: authenticator, 76 clientsLock: &sync.Mutex{}, 77 clients: make(map[string]MqttCli), 78 admin: admin.New(), 79 scriptService: scriptService, 80 eventBus: eventBus, 81 } 82 83 lc.Append(fx.Hook{ 84 OnStart: func(ctx context.Context) (err error) { 85 mqtt.Start() 86 return nil 87 }, 88 OnStop: func(ctx context.Context) (err error) { 89 return mqtt.Shutdown() 90 }, 91 }) 92 93 return 94 } 95 96 // Shutdown ... 97 func (m *Mqtt) Shutdown() (err error) { 98 if !m.isStarted { 99 return 100 } 101 102 log.Info("Server exiting") 103 104 m.scriptService.PopStruct("Mqtt") 105 106 m.clientsLock.Lock() 107 for name, cli := range m.clients { 108 cli.UnsubscribeAll() 109 delete(m.clients, name) 110 } 111 m.clientsLock.Unlock() 112 113 if m.server != nil { 114 ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) 115 err = m.server.Stop(ctx) 116 } 117 118 m.eventBus.Publish("system/services/mqtt", events.EventServiceStopped{Service: "Mqtt"}) 119 return 120 } 121 122 // Start ... 123 func (m *Mqtt) Start() { 124 125 if m.isStarted { 126 return 127 } 128 129 ln, err := net.Listen("tcp", fmt.Sprintf(":%d", m.cfg.Port)) 130 if err != nil { 131 log.Error(err.Error()) 132 } 133 134 defer func() { 135 if err == nil { 136 m.isStarted = true 137 } 138 }() 139 140 options := []server.Options{ 141 server.WithTCPListener(ln), 142 server.WithPlugin(m.admin), 143 server.WithHook(server.Hooks{ 144 OnBasicAuth: m.onBasicAuth, 145 OnMsgArrived: m.onMsgArrived, 146 OnConnected: func(ctx context.Context, client server.Client) { 147 m.eventBus.Publish("system/services/mqtt", events.EventMqttNewClient{ 148 ClientId: client.ClientOptions().ClientID, 149 }) 150 }, 151 }), 152 } 153 154 if m.cfg.Logging { 155 options = append(options, server.WithLogger(m.logging())) 156 } 157 158 // Create a new server 159 m.server = server.New(options...) 160 161 log.Infof("Serving MQTT server at tcp://[::]:%d", m.cfg.Port) 162 163 m.scriptService.PushStruct("Mqtt", NewMqttBind(m)) 164 165 go func() { 166 if err = m.server.Run(); err != nil { 167 log.Error(err.Error()) 168 } 169 }() 170 171 m.eventBus.Publish("system/services/mqtt", events.EventServiceStarted{Service: "Mqtt"}) 172 } 173 174 // OnMsgArrived ... 175 func (m *Mqtt) onMsgArrived(ctx context.Context, client server.Client, msg *server.MsgArrivedRequest) (err error) { 176 m.clientsLock.Lock() 177 defer m.clientsLock.Unlock() 178 179 for _, cli := range m.clients { 180 cli.OnMsgArrived(ctx, client, msg) 181 } 182 183 return 184 } 185 186 // OnConnect ... 187 func (m *Mqtt) onBasicAuth(ctx context.Context, client server.Client, req *server.ConnectRequest) (err error) { 188 log.Debugf("connect client version %v ...", client.Version()) 189 190 username := string(req.Connect.Username) 191 password := string(req.Connect.Password) 192 193 //authentication 194 if err = m.authenticator.Authenticate(username, password); err == nil { 195 return 196 } 197 198 // check the client version, return a compatible reason code. 199 switch client.Version() { 200 case packets.Version5: 201 return codes.NewError(codes.BadUserNameOrPassword) 202 case packets.Version311: 203 return codes.NewError(codes.V3BadUsernameorPassword) 204 } 205 // return nil if pass authentication. 206 return nil 207 } 208 209 // Admin ... 210 func (m *Mqtt) Admin() Admin { 211 return m.admin 212 } 213 214 // Publish ... 215 func (m *Mqtt) Publish(topic string, payload []byte, qos uint8, retain bool) (err error) { 216 if qos < 0 || qos > 2 { 217 err = ErrInvalidQos 218 return 219 } 220 if !packets.ValidTopicFilter(true, []byte(topic)) { 221 err = ErrInvalidTopicFilter 222 return 223 } 224 if !packets.ValidUTF8(payload) { 225 err = ErrInvalidUtf8String 226 return 227 } 228 229 m.server.Publisher().Publish(&gmqtt.Message{ 230 QoS: qos, 231 Retained: retain, 232 Topic: topic, 233 Payload: payload, 234 }) 235 236 // send to local subscribers 237 _ = m.onMsgArrived(context.TODO(), nil, &server.MsgArrivedRequest{ 238 Message: &gmqtt.Message{ 239 QoS: qos, 240 Retained: retain, 241 Topic: topic, 242 Payload: payload, 243 }, 244 }) 245 return 246 } 247 248 // NewClient ... 249 func (m *Mqtt) NewClient(name string) (client MqttCli) { 250 m.clientsLock.Lock() 251 defer m.clientsLock.Unlock() 252 253 var ok bool 254 if client, ok = m.clients[name]; ok { 255 return 256 } 257 client = NewClient(m, name) 258 m.clients[name] = client 259 log.Infof("new mqtt client '%s'", name) 260 return 261 } 262 263 // RemoveClient ... 264 func (m *Mqtt) RemoveClient(name string) { 265 m.clientsLock.Lock() 266 defer m.clientsLock.Unlock() 267 268 var ok bool 269 if _, ok = m.clients[name]; !ok { 270 return 271 } 272 delete(m.clients, name) 273 } 274 275 func (m *Mqtt) logging() *zap.Logger { 276 277 // First, define our level-handling logic. 278 highPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 279 return lvl >= zapcore.ErrorLevel 280 }) 281 282 lowLevel := zapcore.ErrorLevel 283 if m.cfg.DebugMode == common.ReleaseMode { 284 lowLevel = zapcore.DebugLevel 285 } 286 lowPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 287 return lvl < lowLevel 288 }) 289 290 // High-priority output should also go to standard error, and low-priority 291 // output should also go to standard out. 292 consoleDebugging := zapcore.Lock(os.Stdout) 293 consoleErrors := zapcore.Lock(os.Stderr) 294 295 var encConfig zapcore.EncoderConfig 296 if m.cfg.DebugMode == common.ReleaseMode { 297 encConfig = zap.NewProductionEncoderConfig() 298 } else { 299 encConfig = zap.NewDevelopmentEncoderConfig() 300 } 301 302 encConfig.EncodeTime = nil 303 encConfig.EncodeName = logging.CustomNameEncoder 304 encConfig.EncodeCaller = logging.CustomCallerEncoder 305 consoleEncoder := zapcore.NewConsoleEncoder(encConfig) 306 307 // Join the outputs, encoders, and level-handling functions into 308 // zapcore.Cores, then tee the four cores together. 309 core := zapcore.NewTee( 310 zapcore.NewCore(consoleEncoder, consoleErrors, highPriority), 311 zapcore.NewCore(consoleEncoder, consoleDebugging, lowPriority), 312 ) 313 314 // From a zapcore.Core, it's easy to construct a Logger. 315 return zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)).Named("mqtt") 316 } 317 318 // Authenticator ... 319 func (m *Mqtt) Authenticator() mqtt_authenticator.MqttAuthenticator { 320 return m.authenticator 321 }