github.com/bluenviron/gomavlib/v2@v2.2.1-0.20240308101627-2c07e3da629c/node.go (about)

     1  /*
     2  Package gomavlib is a library that implements Mavlink 2.0 and 1.0 in the Go
     3  programming language. It can power UGVs, UAVs, ground stations, monitoring
     4  systems or routers acting in a Mavlink network.
     5  
     6  Mavlink is a lighweight and transport-independent protocol that is mostly used
     7  to communicate with unmanned ground vehicles (UGV) and unmanned aerial vehicles
     8  (UAV, drones, quadcopters, multirotors). It is supported by the most common
     9  open-source flight controllers (Ardupilot and PX4).
    10  
    11  Examples are available at https://github.com/bluenviron/gomavlib/tree/main/examples
    12  */
    13  package gomavlib
    14  
    15  import (
    16  	"fmt"
    17  	"sync"
    18  	"time"
    19  
    20  	"github.com/bluenviron/gomavlib/v2/pkg/dialect"
    21  	"github.com/bluenviron/gomavlib/v2/pkg/frame"
    22  	"github.com/bluenviron/gomavlib/v2/pkg/message"
    23  )
    24  
    25  var errTerminated = fmt.Errorf("terminated")
    26  
    27  type writeToReq struct {
    28  	ch   *Channel
    29  	what interface{}
    30  }
    31  
    32  type writeExceptReq struct {
    33  	except *Channel
    34  	what   interface{}
    35  }
    36  
    37  // NodeConf allows to configure a Node.
    38  type NodeConf struct {
    39  	// the endpoints with which this node will
    40  	// communicate. Each endpoint contains zero or more channels
    41  	Endpoints []EndpointConf
    42  
    43  	// (optional) the dialect which contains the messages that will be encoded and decoded.
    44  	// If not provided, messages are decoded in the MessageRaw struct.
    45  	Dialect *dialect.Dialect
    46  
    47  	// (optional) the secret key used to validate incoming frames.
    48  	// Non signed frames are discarded, as well as frames with a version < 2.0.
    49  	InKey *frame.V2Key
    50  
    51  	// Mavlink version used to encode messages. See Version
    52  	// for the available options.
    53  	OutVersion Version
    54  	// the system id, added to every outgoing frame and used to identify this
    55  	// node in the network.
    56  	OutSystemID byte
    57  	// (optional) the component id, added to every outgoing frame, defaults to 1.
    58  	OutComponentID byte
    59  	// (optional) the secret key used to sign outgoing frames.
    60  	// This feature requires a version >= 2.0.
    61  	OutKey *frame.V2Key
    62  
    63  	// (optional) disables the periodic sending of heartbeats to open channels.
    64  	HeartbeatDisable bool
    65  	// (optional) the period between heartbeats. It defaults to 5 seconds.
    66  	HeartbeatPeriod time.Duration
    67  	// (optional) the system type advertised by heartbeats.
    68  	// It defaults to MAV_TYPE_GCS
    69  	HeartbeatSystemType int
    70  	// (optional) the autopilot type advertised by heartbeats.
    71  	// It defaults to MAV_AUTOPILOT_GENERIC
    72  	HeartbeatAutopilotType int
    73  
    74  	// (optional) automatically request streams to detected Ardupilot devices,
    75  	// that need an explicit request in order to emit telemetry stream.
    76  	StreamRequestEnable bool
    77  	// (optional) the requested stream frequency in Hz. It defaults to 4.
    78  	StreamRequestFrequency int
    79  
    80  	// (optional) read timeout.
    81  	// It defaults to 10 seconds.
    82  	ReadTimeout time.Duration
    83  	// (optional) write timeout.
    84  	// It defaults to 10 seconds.
    85  	WriteTimeout time.Duration
    86  	// (optional) timeout before closing idle connections.
    87  	// It defaults to 60 seconds.
    88  	IdleTimeout time.Duration
    89  }
    90  
    91  // Node is a high-level Mavlink encoder and decoder that works with endpoints.
    92  type Node struct {
    93  	conf              NodeConf
    94  	dialectRW         *dialect.ReadWriter
    95  	wg                sync.WaitGroup
    96  	channelProviders  map[*channelProvider]struct{}
    97  	channels          map[*Channel]struct{}
    98  	nodeHeartbeat     *nodeHeartbeat
    99  	nodeStreamRequest *nodeStreamRequest
   100  
   101  	// in
   102  	chNewChannel   chan *Channel
   103  	chCloseChannel chan *Channel
   104  	chWriteTo      chan writeToReq
   105  	chWriteAll     chan interface{}
   106  	chWriteExcept  chan writeExceptReq
   107  	terminate      chan struct{}
   108  
   109  	// out
   110  	chEvent chan Event
   111  	done    chan struct{}
   112  }
   113  
   114  // NewNode allocates a Node. See NodeConf for the options.
   115  func NewNode(conf NodeConf) (*Node, error) {
   116  	if len(conf.Endpoints) == 0 {
   117  		return nil, fmt.Errorf("at least one endpoint must be provided")
   118  	}
   119  	if conf.HeartbeatPeriod == 0 {
   120  		conf.HeartbeatPeriod = 5 * time.Second
   121  	}
   122  	if conf.HeartbeatSystemType == 0 {
   123  		conf.HeartbeatSystemType = 6 // MAV_TYPE_GCS
   124  	}
   125  	if conf.HeartbeatAutopilotType == 0 {
   126  		conf.HeartbeatAutopilotType = 0 // MAV_AUTOPILOT_GENERIC
   127  	}
   128  	if conf.StreamRequestFrequency == 0 {
   129  		conf.StreamRequestFrequency = 4
   130  	}
   131  
   132  	// check Transceiver configuration here, since Transceiver is created dynamically
   133  	if conf.OutVersion == 0 {
   134  		return nil, fmt.Errorf("OutVersion not provided")
   135  	}
   136  	if conf.OutSystemID < 1 {
   137  		return nil, fmt.Errorf("OutSystemID must be greater than one")
   138  	}
   139  	if conf.OutComponentID < 1 {
   140  		conf.OutComponentID = 1
   141  	}
   142  	if conf.OutKey != nil && conf.OutVersion != V2 {
   143  		return nil, fmt.Errorf("OutKey requires V2 frames")
   144  	}
   145  
   146  	if conf.ReadTimeout == 0 {
   147  		conf.ReadTimeout = 10 * time.Second
   148  	}
   149  	if conf.WriteTimeout == 0 {
   150  		conf.WriteTimeout = 10 * time.Second
   151  	}
   152  	if conf.IdleTimeout == 0 {
   153  		conf.IdleTimeout = 60 * time.Second
   154  	}
   155  
   156  	dialectRW, err := func() (*dialect.ReadWriter, error) {
   157  		if conf.Dialect == nil {
   158  			return nil, nil
   159  		}
   160  		return dialect.NewReadWriter(conf.Dialect)
   161  	}()
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	n := &Node{
   167  		conf:             conf,
   168  		dialectRW:        dialectRW,
   169  		channelProviders: make(map[*channelProvider]struct{}),
   170  		channels:         make(map[*Channel]struct{}),
   171  		chNewChannel:     make(chan *Channel),
   172  		chCloseChannel:   make(chan *Channel),
   173  		chWriteTo:        make(chan writeToReq),
   174  		chWriteAll:       make(chan interface{}),
   175  		chWriteExcept:    make(chan writeExceptReq),
   176  		terminate:        make(chan struct{}),
   177  		chEvent:          make(chan Event),
   178  		done:             make(chan struct{}),
   179  	}
   180  
   181  	closeExisting := func() {
   182  		for ch := range n.channels {
   183  			ch.close()
   184  		}
   185  		for ca := range n.channelProviders {
   186  			ca.close()
   187  		}
   188  	}
   189  
   190  	// endpoints
   191  	for _, tconf := range conf.Endpoints {
   192  		tp, err := tconf.init(n)
   193  		if err != nil {
   194  			closeExisting()
   195  			return nil, err
   196  		}
   197  
   198  		switch ttp := tp.(type) {
   199  		case endpointChannelProvider:
   200  			ca, err := newChannelProvider(n, ttp)
   201  			if err != nil {
   202  				closeExisting()
   203  				return nil, err
   204  			}
   205  
   206  			n.channelProviders[ca] = struct{}{}
   207  
   208  		case endpointChannelSingle:
   209  			ch, err := newChannel(n, ttp, ttp.label(), ttp)
   210  			if err != nil {
   211  				closeExisting()
   212  				return nil, err
   213  			}
   214  
   215  			n.channels[ch] = struct{}{}
   216  
   217  		default:
   218  			panic(fmt.Errorf("endpoint %T does not implement any interface", tp))
   219  		}
   220  	}
   221  
   222  	n.nodeHeartbeat = newNodeHeartbeat(n)
   223  	n.nodeStreamRequest = newNodeStreamRequest(n)
   224  
   225  	if n.nodeHeartbeat != nil {
   226  		go n.nodeHeartbeat.run()
   227  	}
   228  
   229  	if n.nodeStreamRequest != nil {
   230  		go n.nodeStreamRequest.run()
   231  	}
   232  
   233  	for ch := range n.channels {
   234  		ch.start()
   235  	}
   236  
   237  	for ca := range n.channelProviders {
   238  		ca.start()
   239  	}
   240  
   241  	go n.run()
   242  
   243  	return n, nil
   244  }
   245  
   246  // Close halts node operations and waits for all routines to return.
   247  func (n *Node) Close() {
   248  	close(n.terminate)
   249  	<-n.done
   250  }
   251  
   252  func (n *Node) run() {
   253  	defer close(n.done)
   254  
   255  outer:
   256  	for {
   257  		select {
   258  		case ch := <-n.chNewChannel:
   259  			n.channels[ch] = struct{}{}
   260  			ch.start()
   261  
   262  		case ch := <-n.chCloseChannel:
   263  			delete(n.channels, ch)
   264  
   265  		case req := <-n.chWriteTo:
   266  			if _, ok := n.channels[req.ch]; !ok {
   267  				continue
   268  			}
   269  			req.ch.write(req.what)
   270  
   271  		case what := <-n.chWriteAll:
   272  			for ch := range n.channels {
   273  				ch.write(what)
   274  			}
   275  
   276  		case req := <-n.chWriteExcept:
   277  			for ch := range n.channels {
   278  				if ch != req.except {
   279  					ch.write(req.what)
   280  				}
   281  			}
   282  
   283  		case <-n.terminate:
   284  			break outer
   285  		}
   286  	}
   287  
   288  	if n.nodeHeartbeat != nil {
   289  		n.nodeHeartbeat.close()
   290  	}
   291  
   292  	if n.nodeStreamRequest != nil {
   293  		n.nodeStreamRequest.close()
   294  	}
   295  
   296  	for ca := range n.channelProviders {
   297  		ca.close()
   298  	}
   299  
   300  	for ch := range n.channels {
   301  		ch.close()
   302  	}
   303  
   304  	n.wg.Wait()
   305  
   306  	close(n.chEvent)
   307  }
   308  
   309  // FixFrame recomputes the Frame checksum and signature.
   310  // This can be called on Frames whose content has been edited.
   311  func (n *Node) FixFrame(fr frame.Frame) error {
   312  	err := n.encodeFrame(fr)
   313  	if err != nil {
   314  		return err
   315  	}
   316  
   317  	if n.dialectRW == nil {
   318  		return fmt.Errorf("dialect is nil")
   319  	}
   320  
   321  	mp := n.dialectRW.GetMessage(fr.GetMessage().GetID())
   322  	if mp == nil {
   323  		return fmt.Errorf("message is not in the dialect")
   324  	}
   325  
   326  	// fill checksum
   327  	switch ff := fr.(type) {
   328  	case *frame.V1Frame:
   329  		ff.Checksum = ff.GenerateChecksum(mp.CRCExtra())
   330  	case *frame.V2Frame:
   331  		ff.Checksum = ff.GenerateChecksum(mp.CRCExtra())
   332  	}
   333  
   334  	// fill Signature if v2
   335  	if ff, ok := fr.(*frame.V2Frame); ok && n.conf.OutKey != nil {
   336  		ff.Signature = ff.GenerateSignature(n.conf.OutKey)
   337  	}
   338  
   339  	return nil
   340  }
   341  
   342  func (n *Node) encodeFrame(fr frame.Frame) error {
   343  	if _, ok := fr.GetMessage().(*message.MessageRaw); !ok {
   344  		if n.dialectRW == nil {
   345  			return fmt.Errorf("dialect is nil")
   346  		}
   347  
   348  		mp := n.dialectRW.GetMessage(fr.GetMessage().GetID())
   349  		if mp == nil {
   350  			return fmt.Errorf("message is not in the dialect")
   351  		}
   352  
   353  		_, isV2 := fr.(*frame.V2Frame)
   354  		msgRaw := mp.Write(fr.GetMessage(), isV2)
   355  
   356  		switch fr := fr.(type) {
   357  		case *frame.V1Frame:
   358  			fr.Message = msgRaw
   359  		case *frame.V2Frame:
   360  			fr.Message = msgRaw
   361  		}
   362  	}
   363  
   364  	return nil
   365  }
   366  
   367  func (n *Node) encodeMessage(msg message.Message) (message.Message, error) {
   368  	if _, ok := msg.(*message.MessageRaw); !ok {
   369  		if n.dialectRW == nil {
   370  			return nil, fmt.Errorf("dialect is nil")
   371  		}
   372  
   373  		mp := n.dialectRW.GetMessage(msg.GetID())
   374  		if mp == nil {
   375  			return nil, fmt.Errorf("message is not in the dialect")
   376  		}
   377  
   378  		msgRaw := mp.Write(msg, n.conf.OutVersion == V2)
   379  		return msgRaw, nil
   380  	}
   381  
   382  	return msg, nil
   383  }
   384  
   385  // Events returns a channel from which receiving events. Possible events are:
   386  //
   387  // * EventChannelOpen
   388  // * EventChannelClose
   389  // * EventFrame
   390  // * EventParseError
   391  // * EventStreamRequested
   392  //
   393  // See individual events for details.
   394  func (n *Node) Events() chan Event {
   395  	return n.chEvent
   396  }
   397  
   398  // WriteMessageTo writes a message to given channel.
   399  func (n *Node) WriteMessageTo(channel *Channel, m message.Message) error {
   400  	m, err := n.encodeMessage(m)
   401  	if err != nil {
   402  		return err
   403  	}
   404  
   405  	select {
   406  	case n.chWriteTo <- writeToReq{channel, m}:
   407  	case <-n.terminate:
   408  	}
   409  
   410  	return nil
   411  }
   412  
   413  // WriteMessageAll writes a message to all channels.
   414  func (n *Node) WriteMessageAll(m message.Message) error {
   415  	m, err := n.encodeMessage(m)
   416  	if err != nil {
   417  		return err
   418  	}
   419  
   420  	select {
   421  	case n.chWriteAll <- m:
   422  	case <-n.terminate:
   423  	}
   424  
   425  	return nil
   426  }
   427  
   428  // WriteMessageExcept writes a message to all channels except specified channel.
   429  func (n *Node) WriteMessageExcept(exceptChannel *Channel, m message.Message) error {
   430  	m, err := n.encodeMessage(m)
   431  	if err != nil {
   432  		return err
   433  	}
   434  
   435  	select {
   436  	case n.chWriteExcept <- writeExceptReq{exceptChannel, m}:
   437  	case <-n.terminate:
   438  	}
   439  
   440  	return nil
   441  }
   442  
   443  // WriteFrameTo writes a frame to given channel.
   444  // This function is intended only for routing pre-existing frames to other nodes,
   445  // since all frame fields must be filled manually.
   446  func (n *Node) WriteFrameTo(channel *Channel, fr frame.Frame) error {
   447  	err := n.encodeFrame(fr)
   448  	if err != nil {
   449  		return err
   450  	}
   451  
   452  	select {
   453  	case n.chWriteTo <- writeToReq{channel, fr}:
   454  	case <-n.terminate:
   455  	}
   456  
   457  	return nil
   458  }
   459  
   460  // WriteFrameAll writes a frame to all channels.
   461  // This function is intended only for routing pre-existing frames to other nodes,
   462  // since all frame fields must be filled manually.
   463  func (n *Node) WriteFrameAll(fr frame.Frame) error {
   464  	err := n.encodeFrame(fr)
   465  	if err != nil {
   466  		return err
   467  	}
   468  
   469  	select {
   470  	case n.chWriteAll <- fr:
   471  	case <-n.terminate:
   472  	}
   473  
   474  	return nil
   475  }
   476  
   477  // WriteFrameExcept writes a frame to all channels except specified channel.
   478  // This function is intended only for routing pre-existing frames to other nodes,
   479  // since all frame fields must be filled manually.
   480  func (n *Node) WriteFrameExcept(exceptChannel *Channel, fr frame.Frame) error {
   481  	err := n.encodeFrame(fr)
   482  	if err != nil {
   483  		return err
   484  	}
   485  
   486  	select {
   487  	case n.chWriteExcept <- writeExceptReq{exceptChannel, fr}:
   488  	case <-n.terminate:
   489  	}
   490  
   491  	return nil
   492  }
   493  
   494  func (n *Node) pushEvent(evt Event) {
   495  	select {
   496  	case n.chEvent <- evt:
   497  	case <-n.terminate:
   498  	}
   499  }
   500  
   501  func (n *Node) newChannel(ch *Channel) {
   502  	select {
   503  	case n.chNewChannel <- ch:
   504  	case <-n.terminate:
   505  		ch.close()
   506  	}
   507  }
   508  
   509  func (n *Node) closeChannel(ch *Channel) {
   510  	select {
   511  	case n.chCloseChannel <- ch:
   512  	case <-n.terminate:
   513  	}
   514  }