github.com/e154/smart-home@v0.17.2-0.20240311175135-e530a6e5cd45/system/mqtt/admin/admin.go (about)

     1  // This file is part of the Smart Home
     2  // Program complex distribution https://github.com/e154/smart-home
     3  // Copyright (C) 2016-2023, Filippov Alex
     4  //
     5  // This library is free software: you can redistribute it and/or
     6  // modify it under the terms of the GNU Lesser General Public
     7  // License as published by the Free Software Foundation; either
     8  // version 3 of the License, or (at your option) any later version.
     9  //
    10  // This library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    13  // Library General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public
    16  // License along with this library.  If not, see
    17  // <https://www.gnu.org/licenses/>.
    18  
    19  package admin
    20  
    21  import (
    22  	"context"
    23  
    24  	"github.com/e154/smart-home/common/logger"
    25  
    26  	"github.com/DrmagicE/gmqtt"
    27  	"github.com/DrmagicE/gmqtt/pkg/packets"
    28  	"github.com/DrmagicE/gmqtt/server"
    29  )
    30  
    31  var _ server.Plugin = (*Admin)(nil)
    32  
    33  var (
    34  	log = logger.MustGetLogger("mqtt.admin")
    35  )
    36  
    37  // Name ...
    38  const Name = "admin"
    39  
    40  // New ...
    41  func New() *Admin {
    42  	return &Admin{}
    43  }
    44  
    45  // Admin ...
    46  type Admin struct {
    47  	store               *store
    48  	statsReader         server.StatsReader
    49  	publisher           server.Publisher
    50  	clientService       server.ClientService
    51  	subscriptionService server.SubscriptionService
    52  }
    53  
    54  // HookWrapper ...
    55  func (a *Admin) HookWrapper() server.HookWrapper {
    56  	return server.HookWrapper{
    57  		OnSessionCreatedWrapper:    a.OnSessionCreatedWrapper,
    58  		OnSessionResumedWrapper:    a.OnSessionResumedWrapper,
    59  		OnClosedWrapper:            a.OnClosedWrapper,
    60  		OnSessionTerminatedWrapper: a.OnSessionTerminatedWrapper,
    61  		OnSubscribedWrapper:        a.OnSubscribedWrapper,
    62  		OnUnsubscribedWrapper:      a.OnUnsubscribedWrapper,
    63  	}
    64  }
    65  
    66  // Load ...
    67  func (a *Admin) Load(service server.Server) error {
    68  
    69  	a.store = newStore(service.StatsManager(), service.SubscriptionService(), service.ClientService())
    70  	a.statsReader = service.StatsManager()
    71  	a.publisher = service.Publisher()
    72  	a.clientService = service.ClientService()
    73  	a.subscriptionService = service.SubscriptionService()
    74  
    75  	log.Info("loaded ...")
    76  
    77  	return nil
    78  }
    79  
    80  // Unload ...
    81  func (a *Admin) Unload() error {
    82  	log.Info("unloaded ...")
    83  	return nil
    84  }
    85  
    86  // Name ...
    87  func (a *Admin) Name() string {
    88  	return Name
    89  }
    90  
    91  // OnSessionCreatedWrapper store the client when session created
    92  func (a *Admin) OnSessionCreatedWrapper(pre server.OnSessionCreated) server.OnSessionCreated {
    93  	return func(cs context.Context, client server.Client) {
    94  		pre(cs, client)
    95  		a.store.addClient(client)
    96  	}
    97  }
    98  
    99  // OnSessionResumedWrapper refresh the client when session resumed
   100  func (a *Admin) OnSessionResumedWrapper(pre server.OnSessionResumed) server.OnSessionResumed {
   101  	return func(cs context.Context, client server.Client) {
   102  		pre(cs, client)
   103  		a.store.addClient(client)
   104  	}
   105  }
   106  
   107  // OnClosedWrapper refresh the client when session resumed
   108  func (a *Admin) OnClosedWrapper(pre server.OnClosed) server.OnClosed {
   109  	return func(cs context.Context, client server.Client, err error) {
   110  		pre(cs, client, err)
   111  		a.store.setClientDisconnected(client.ClientOptions().ClientID)
   112  	}
   113  }
   114  
   115  // OnSessionTerminated remove the client when session terminated
   116  func (a *Admin) OnSessionTerminatedWrapper(pre server.OnSessionTerminated) server.OnSessionTerminated {
   117  	return func(cs context.Context, client string, reason server.SessionTerminatedReason) {
   118  		pre(cs, client, reason)
   119  		a.store.removeClient(client)
   120  	}
   121  }
   122  
   123  // OnSubscribedWrapper store the subscription
   124  func (a *Admin) OnSubscribedWrapper(pre server.OnSubscribed) server.OnSubscribed {
   125  	return func(cs context.Context, client server.Client, subscription *gmqtt.Subscription) {
   126  		pre(cs, client, subscription)
   127  		a.store.addSubscription(client.ClientOptions().ClientID, subscription)
   128  	}
   129  }
   130  
   131  // OnUnsubscribedWrapper remove the subscription
   132  func (a *Admin) OnUnsubscribedWrapper(pre server.OnUnsubscribed) server.OnUnsubscribed {
   133  	return func(cs context.Context, client server.Client, topicName string) {
   134  		pre(cs, client, topicName)
   135  		a.store.removeSubscription(client.ClientOptions().ClientID, topicName)
   136  	}
   137  }
   138  
   139  // GetClients ...
   140  func (a *Admin) GetClients(limit, offset uint) (list []*ClientInfo, total uint32, err error) {
   141  	list, total, err = a.store.GetClients(limit, offset)
   142  	return
   143  }
   144  
   145  // GetClient ...
   146  func (a *Admin) GetClient(clientId string) (client *ClientInfo, err error) {
   147  	client = a.store.GetClientByID(clientId)
   148  	return
   149  }
   150  
   151  // GetSessions ...
   152  func (a *Admin) GetSessions(limit, offset uint) (list []*SessionInfo, total int, err error) {
   153  	list, total, err = a.store.GetSessions(offset, limit)
   154  	return
   155  }
   156  
   157  // GetSession ...
   158  func (a *Admin) GetSession(clientId string) (session *SessionInfo, err error) {
   159  	session, err = a.store.GetSessionByID(clientId)
   160  	return
   161  }
   162  
   163  // GetSubscriptions ...
   164  func (a *Admin) GetSubscriptions(clientId string, limit, offset uint) (list []*SubscriptionInfo, total int, err error) {
   165  	list, total, err = a.store.GetClientSubscriptions(clientId, offset, limit)
   166  	return
   167  }
   168  
   169  // Subscribe ...
   170  func (a *Admin) Subscribe(clientId, topic string, qos int) (err error) {
   171  	if qos < 0 || qos > 2 {
   172  		err = ErrInvalidQos
   173  		return
   174  	}
   175  	if !packets.ValidTopicFilter(true, []byte(topic)) {
   176  		err = ErrInvalidTopicFilter
   177  		return
   178  	}
   179  	if clientId == "" {
   180  		err = ErrInvalidClientID
   181  		return
   182  	}
   183  	_, err = a.subscriptionService.Subscribe(clientId, &gmqtt.Subscription{
   184  		TopicFilter: topic,
   185  		QoS:         uint8(qos),
   186  	})
   187  	return
   188  }
   189  
   190  // Unsubscribe ...
   191  func (a *Admin) Unsubscribe(clientId, topic string) (err error) {
   192  	if !packets.ValidTopicFilter(true, []byte(topic)) {
   193  		err = ErrInvalidTopicFilter
   194  		return
   195  	}
   196  	if clientId == "" {
   197  		err = ErrInvalidClientID
   198  		return
   199  	}
   200  	_ = a.subscriptionService.Unsubscribe(clientId, topic)
   201  	return
   202  }
   203  
   204  // Publish ...
   205  func (a *Admin) Publish(topic string, qos int, payload []byte, retain bool) (err error) {
   206  	if qos < 0 || qos > 2 {
   207  		err = ErrInvalidQos
   208  		return
   209  	}
   210  	if !packets.ValidTopicFilter(true, []byte(topic)) {
   211  		err = ErrInvalidTopicFilter
   212  		return
   213  	}
   214  	if !packets.ValidUTF8(payload) {
   215  		err = ErrInvalidUtf8String
   216  		return
   217  	}
   218  	a.publisher.Publish(&gmqtt.Message{
   219  		QoS:      uint8(qos),
   220  		Retained: retain,
   221  		Topic:    topic,
   222  		Payload:  payload,
   223  	})
   224  	return
   225  }
   226  
   227  // CloseClient ...
   228  func (a *Admin) CloseClient(clientId string) (err error) {
   229  	if clientId == "" {
   230  		err = ErrInvalidClientID
   231  		return
   232  	}
   233  	client := a.clientService.GetClient(clientId)
   234  	if client != nil {
   235  		client.Close()
   236  	}
   237  
   238  	return
   239  }
   240  
   241  // SearchTopic ...
   242  func (a *Admin) SearchTopic(query string) (result []*SubscriptionInfo, err error) {
   243  	result, err = a.store.SearchTopic(query)
   244  	return
   245  }