github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/internal/mdns/client.go (about)

     1  package mdns
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"net"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/miekg/dns"
    13  	"golang.org/x/net/ipv4"
    14  	"golang.org/x/net/ipv6"
    15  )
    16  
    17  // ServiceEntry is returned after we query for a service
    18  type ServiceEntry struct {
    19  	Name       string
    20  	Host       string
    21  	AddrV4     net.IP
    22  	AddrV6     net.IP
    23  	Port       int
    24  	Info       string
    25  	InfoFields []string
    26  	TTL        int
    27  	Type       uint16
    28  
    29  	Addr net.IP // @Deprecated
    30  
    31  	hasTXT bool
    32  	sent   bool
    33  }
    34  
    35  // complete is used to check if we have all the info we need
    36  func (s *ServiceEntry) complete() bool {
    37  
    38  	return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
    39  }
    40  
    41  // QueryParam is used to customize how a Lookup is performed
    42  type QueryParam struct {
    43  	Service             string               // Service to lookup
    44  	Domain              string               // Lookup domain, default "local"
    45  	Type                uint16               // Lookup type, defaults to dns.TypePTR
    46  	Context             context.Context      // Context
    47  	Timeout             time.Duration        // Lookup timeout, default 1 second. Ignored if Context is provided
    48  	Interface           *net.Interface       // Multicast interface to use
    49  	Entries             chan<- *ServiceEntry // Entries Channel
    50  	WantUnicastResponse bool                 // Unicast response desired, as per 5.4 in RFC
    51  }
    52  
    53  // DefaultParams is used to return a default set of QueryParam's
    54  func DefaultParams(service string) *QueryParam {
    55  	return &QueryParam{
    56  		Service:             service,
    57  		Domain:              "local",
    58  		Timeout:             time.Second,
    59  		Entries:             make(chan *ServiceEntry),
    60  		WantUnicastResponse: false, // TODO(reddaly): Change this default.
    61  	}
    62  }
    63  
    64  // Query looks up a given service, in a domain, waiting at most
    65  // for a timeout before finishing the query. The results are streamed
    66  // to a channel. Sends will not block, so clients should make sure to
    67  // either read or buffer.
    68  func Query(params *QueryParam) error {
    69  	// Create a new client
    70  	client, err := newClient()
    71  	if err != nil {
    72  		return err
    73  	}
    74  	defer client.Close()
    75  
    76  	// Set the multicast interface
    77  	if params.Interface != nil {
    78  		if err := client.setInterface(params.Interface, false); err != nil {
    79  			return err
    80  		}
    81  	}
    82  
    83  	// Ensure defaults are set
    84  	if params.Domain == "" {
    85  		params.Domain = "local"
    86  	}
    87  
    88  	if params.Context == nil {
    89  		if params.Timeout == 0 {
    90  			params.Timeout = time.Second
    91  		}
    92  		var cancel context.CancelFunc
    93  		params.Context, cancel = context.WithTimeout(context.Background(), params.Timeout)
    94  		defer cancel()
    95  	}
    96  
    97  	// Run the query
    98  	return client.query(params)
    99  }
   100  
   101  // Listen listens indefinitely for multicast updates
   102  func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error {
   103  	// Create a new client
   104  	client, err := newClient()
   105  	if err != nil {
   106  		return err
   107  	}
   108  	defer client.Close()
   109  
   110  	client.setInterface(nil, true)
   111  
   112  	// Start listening for response packets
   113  	msgCh := make(chan *dns.Msg, 32)
   114  
   115  	go client.recv(client.ipv4UnicastConn, msgCh)
   116  	go client.recv(client.ipv6UnicastConn, msgCh)
   117  	go client.recv(client.ipv4MulticastConn, msgCh)
   118  	go client.recv(client.ipv6MulticastConn, msgCh)
   119  
   120  	ip := make(map[string]*ServiceEntry)
   121  
   122  	for {
   123  		select {
   124  		case <-exit:
   125  			return nil
   126  		case <-client.closedCh:
   127  			return nil
   128  		case m := <-msgCh:
   129  			e := messageToEntry(m, ip)
   130  			if e == nil {
   131  				continue
   132  			}
   133  
   134  			// Check if this entry is complete
   135  			if e.complete() {
   136  				if e.sent {
   137  					continue
   138  				}
   139  				e.sent = true
   140  				entries <- e
   141  				ip = make(map[string]*ServiceEntry)
   142  			} else {
   143  				// Fire off a node specific query
   144  				m := new(dns.Msg)
   145  				m.SetQuestion(e.Name, dns.TypePTR)
   146  				m.RecursionDesired = false
   147  				if err := client.sendQuery(m); err != nil {
   148  					log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err)
   149  				}
   150  			}
   151  		}
   152  	}
   153  }
   154  
   155  // Lookup is the same as Query, however it uses all the default parameters
   156  func Lookup(service string, entries chan<- *ServiceEntry) error {
   157  	params := DefaultParams(service)
   158  	params.Entries = entries
   159  	return Query(params)
   160  }
   161  
   162  // Client provides a query interface that can be used to
   163  // search for service providers using mDNS
   164  type client struct {
   165  	ipv4UnicastConn *net.UDPConn
   166  	ipv6UnicastConn *net.UDPConn
   167  
   168  	ipv4MulticastConn *net.UDPConn
   169  	ipv6MulticastConn *net.UDPConn
   170  
   171  	closed    bool
   172  	closedCh  chan struct{} // TODO(reddaly): This doesn't appear to be used.
   173  	closeLock sync.Mutex
   174  }
   175  
   176  // NewClient creates a new mdns Client that can be used to query
   177  // for records
   178  func newClient() (*client, error) {
   179  	// TODO(reddaly): At least attempt to bind to the port required in the spec.
   180  	// Create a IPv4 listener
   181  	uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
   182  	uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
   183  	if err4 != nil && err6 != nil {
   184  		log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
   185  	}
   186  
   187  	if uconn4 == nil && uconn6 == nil {
   188  		return nil, fmt.Errorf("failed to bind to any unicast udp port")
   189  	}
   190  
   191  	if uconn4 == nil {
   192  		uconn4 = &net.UDPConn{}
   193  	}
   194  
   195  	if uconn6 == nil {
   196  		uconn6 = &net.UDPConn{}
   197  	}
   198  
   199  	mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
   200  	mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
   201  	if err4 != nil && err6 != nil {
   202  		log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
   203  	}
   204  
   205  	if mconn4 == nil && mconn6 == nil {
   206  		return nil, fmt.Errorf("failed to bind to any multicast udp port")
   207  	}
   208  
   209  	if mconn4 == nil {
   210  		mconn4 = &net.UDPConn{}
   211  	}
   212  
   213  	if mconn6 == nil {
   214  		mconn6 = &net.UDPConn{}
   215  	}
   216  
   217  	p1 := ipv4.NewPacketConn(mconn4)
   218  	p2 := ipv6.NewPacketConn(mconn6)
   219  	p1.SetMulticastLoopback(true)
   220  	p2.SetMulticastLoopback(true)
   221  
   222  	ifaces, err := net.Interfaces()
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	var errCount1, errCount2 int
   228  
   229  	for _, iface := range ifaces {
   230  		if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   231  			errCount1++
   232  		}
   233  		if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   234  			errCount2++
   235  		}
   236  	}
   237  
   238  	if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
   239  		return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
   240  	}
   241  
   242  	c := &client{
   243  		ipv4MulticastConn: mconn4,
   244  		ipv6MulticastConn: mconn6,
   245  		ipv4UnicastConn:   uconn4,
   246  		ipv6UnicastConn:   uconn6,
   247  		closedCh:          make(chan struct{}),
   248  	}
   249  	return c, nil
   250  }
   251  
   252  // Close is used to cleanup the client
   253  func (c *client) Close() error {
   254  	c.closeLock.Lock()
   255  	defer c.closeLock.Unlock()
   256  
   257  	if c.closed {
   258  		return nil
   259  	}
   260  	c.closed = true
   261  
   262  	close(c.closedCh)
   263  
   264  	if c.ipv4UnicastConn != nil {
   265  		c.ipv4UnicastConn.Close()
   266  	}
   267  	if c.ipv6UnicastConn != nil {
   268  		c.ipv6UnicastConn.Close()
   269  	}
   270  	if c.ipv4MulticastConn != nil {
   271  		c.ipv4MulticastConn.Close()
   272  	}
   273  	if c.ipv6MulticastConn != nil {
   274  		c.ipv6MulticastConn.Close()
   275  	}
   276  
   277  	return nil
   278  }
   279  
   280  // setInterface is used to set the query interface, uses system
   281  // default if not provided
   282  func (c *client) setInterface(iface *net.Interface, loopback bool) error {
   283  	p := ipv4.NewPacketConn(c.ipv4UnicastConn)
   284  	if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   285  		return err
   286  	}
   287  	p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
   288  	if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   289  		return err
   290  	}
   291  	p = ipv4.NewPacketConn(c.ipv4MulticastConn)
   292  	if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   293  		return err
   294  	}
   295  	p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
   296  	if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   297  		return err
   298  	}
   299  
   300  	if loopback {
   301  		p.SetMulticastLoopback(true)
   302  		p2.SetMulticastLoopback(true)
   303  	}
   304  
   305  	return nil
   306  }
   307  
   308  // query is used to perform a lookup and stream results
   309  func (c *client) query(params *QueryParam) error {
   310  	// Create the service name
   311  	serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
   312  
   313  	// Start listening for response packets
   314  	msgCh := make(chan *dns.Msg, 32)
   315  	go c.recv(c.ipv4UnicastConn, msgCh)
   316  	go c.recv(c.ipv6UnicastConn, msgCh)
   317  	go c.recv(c.ipv4MulticastConn, msgCh)
   318  	go c.recv(c.ipv6MulticastConn, msgCh)
   319  
   320  	// Send the query
   321  	m := new(dns.Msg)
   322  	if params.Type == dns.TypeNone {
   323  		m.SetQuestion(serviceAddr, dns.TypePTR)
   324  	} else {
   325  		m.SetQuestion(serviceAddr, params.Type)
   326  	}
   327  	// RFC 6762, section 18.12.  Repurposing of Top Bit of qclass in Question
   328  	// Section
   329  	//
   330  	// In the Question Section of a Multicast DNS query, the top bit of the qclass
   331  	// field is used to indicate that unicast responses are preferred for this
   332  	// particular question.  (See Section 5.4.)
   333  	if params.WantUnicastResponse {
   334  		m.Question[0].Qclass |= 1 << 15
   335  	}
   336  	m.RecursionDesired = false
   337  	if err := c.sendQuery(m); err != nil {
   338  		return err
   339  	}
   340  
   341  	// Map the in-progress responses
   342  	inprogress := make(map[string]*ServiceEntry)
   343  
   344  	for {
   345  		select {
   346  		case resp := <-msgCh:
   347  			inp := messageToEntry(resp, inprogress)
   348  
   349  			if inp == nil {
   350  				continue
   351  			}
   352  			if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
   353  				// discard anything which we've not asked for
   354  				continue
   355  			}
   356  
   357  			// Check if this entry is complete
   358  			if inp.complete() {
   359  				if inp.sent {
   360  					continue
   361  				}
   362  
   363  				inp.sent = true
   364  				select {
   365  				case params.Entries <- inp:
   366  				case <-params.Context.Done():
   367  					return nil
   368  				}
   369  			} else {
   370  				// Fire off a node specific query
   371  				m := new(dns.Msg)
   372  				m.SetQuestion(inp.Name, inp.Type)
   373  				m.RecursionDesired = false
   374  				if err := c.sendQuery(m); err != nil {
   375  					log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
   376  				}
   377  			}
   378  		case <-params.Context.Done():
   379  			return nil
   380  		}
   381  	}
   382  }
   383  
   384  // sendQuery is used to multicast a query out
   385  func (c *client) sendQuery(q *dns.Msg) error {
   386  	buf, err := q.Pack()
   387  	if err != nil {
   388  		return err
   389  	}
   390  	if c.ipv4UnicastConn != nil {
   391  		c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
   392  	}
   393  	if c.ipv6UnicastConn != nil {
   394  		c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
   395  	}
   396  	return nil
   397  }
   398  
   399  // recv is used to receive until we get a shutdown
   400  func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
   401  	if l == nil {
   402  		return
   403  	}
   404  	buf := make([]byte, 65536)
   405  	for {
   406  		c.closeLock.Lock()
   407  		if c.closed {
   408  			c.closeLock.Unlock()
   409  			return
   410  		}
   411  		c.closeLock.Unlock()
   412  		n, err := l.Read(buf)
   413  		if err != nil {
   414  			continue
   415  		}
   416  		msg := new(dns.Msg)
   417  		if err := msg.Unpack(buf[:n]); err != nil {
   418  			continue
   419  		}
   420  		select {
   421  		case msgCh <- msg:
   422  		case <-c.closedCh:
   423  			return
   424  		}
   425  	}
   426  }
   427  
   428  // ensureName is used to ensure the named node is in progress
   429  func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry {
   430  	if inp, ok := inprogress[name]; ok {
   431  		return inp
   432  	}
   433  	inp := &ServiceEntry{
   434  		Name: name,
   435  		Type: typ,
   436  	}
   437  	inprogress[name] = inp
   438  	return inp
   439  }
   440  
   441  // alias is used to setup an alias between two entries
   442  func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) {
   443  	srcEntry := ensureName(inprogress, src, typ)
   444  	inprogress[dst] = srcEntry
   445  }
   446  
   447  func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry {
   448  	var inp *ServiceEntry
   449  
   450  	for _, answer := range append(m.Answer, m.Extra...) {
   451  		// TODO(reddaly): Check that response corresponds to serviceAddr?
   452  		switch rr := answer.(type) {
   453  		case *dns.PTR:
   454  			// Create new entry for this
   455  			inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype)
   456  			if inp.complete() {
   457  				continue
   458  			}
   459  		case *dns.SRV:
   460  			// Check for a target mismatch
   461  			if rr.Target != rr.Hdr.Name {
   462  				alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype)
   463  			}
   464  
   465  			// Get the port
   466  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   467  			if inp.complete() {
   468  				continue
   469  			}
   470  			inp.Host = rr.Target
   471  			inp.Port = int(rr.Port)
   472  		case *dns.TXT:
   473  			// Pull out the txt
   474  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   475  			if inp.complete() {
   476  				continue
   477  			}
   478  			inp.Info = strings.Join(rr.Txt, "|")
   479  			inp.InfoFields = rr.Txt
   480  			inp.hasTXT = true
   481  		case *dns.A:
   482  			// Pull out the IP
   483  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   484  			if inp.complete() {
   485  				continue
   486  			}
   487  			inp.Addr = rr.A // @Deprecated
   488  			inp.AddrV4 = rr.A
   489  		case *dns.AAAA:
   490  			// Pull out the IP
   491  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   492  			if inp.complete() {
   493  				continue
   494  			}
   495  			inp.Addr = rr.AAAA // @Deprecated
   496  			inp.AddrV6 = rr.AAAA
   497  		}
   498  
   499  		if inp != nil {
   500  			inp.TTL = int(answer.Header().Ttl)
   501  		}
   502  	}
   503  
   504  	return inp
   505  }