go-micro.dev/v5@v5.12.0/util/mdns/client.go (about)

     1  package mdns
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/miekg/dns"
    12  	"go-micro.dev/v5/logger"
    13  	"golang.org/x/net/ipv4"
    14  	"golang.org/x/net/ipv6"
    15  )
    16  
    17  // ServiceEntry is returned after we query for a service.
    18  type ServiceEntry struct {
    19  	Name       string
    20  	Host       string
    21  	Info       string
    22  	AddrV4     net.IP
    23  	AddrV6     net.IP
    24  	InfoFields []string
    25  
    26  	Addr net.IP // @Deprecated
    27  
    28  	Port int
    29  	TTL  int
    30  	Type uint16
    31  
    32  	hasTXT bool
    33  	sent   bool
    34  }
    35  
    36  // complete is used to check if we have all the info we need.
    37  func (s *ServiceEntry) complete() bool {
    38  	return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
    39  }
    40  
    41  // QueryParam is used to customize how a Lookup is performed.
    42  type QueryParam struct {
    43  	Context             context.Context      // Context
    44  	Interface           *net.Interface       // Multicast interface to use
    45  	Entries             chan<- *ServiceEntry // Entries Channel
    46  	Service             string               // Service to lookup
    47  	Domain              string               // Lookup domain, default "local"
    48  	Timeout             time.Duration        // Lookup timeout, default 1 second. Ignored if Context is provided
    49  	Type                uint16               // Lookup type, defaults to dns.TypePTR
    50  	WantUnicastResponse bool                 // Unicast response desired, as per 5.4 in RFC
    51  }
    52  
    53  // DefaultParams is used to return a default set of QueryParam's.
    54  func DefaultParams(service string) *QueryParam {
    55  	return &QueryParam{
    56  		Service:             service,
    57  		Domain:              "local",
    58  		Timeout:             time.Second,
    59  		Entries:             make(chan *ServiceEntry),
    60  		WantUnicastResponse: false, // TODO(reddaly): Change this default.
    61  	}
    62  }
    63  
    64  // Query looks up a given service, in a domain, waiting at most
    65  // for a timeout before finishing the query. The results are streamed
    66  // to a channel. Sends will not block, so clients should make sure to
    67  // either read or buffer.
    68  func Query(params *QueryParam) error {
    69  	// Create a new client
    70  	client, err := newClient()
    71  	if err != nil {
    72  		return err
    73  	}
    74  	defer client.Close()
    75  
    76  	// Set the multicast interface
    77  	if params.Interface != nil {
    78  		if err := client.setInterface(params.Interface, false); err != nil {
    79  			return err
    80  		}
    81  	}
    82  
    83  	// Ensure defaults are set
    84  	if params.Domain == "" {
    85  		params.Domain = "local"
    86  	}
    87  
    88  	if params.Context == nil {
    89  		if params.Timeout == 0 {
    90  			params.Timeout = time.Second
    91  		}
    92  		params.Context, _ = context.WithTimeout(context.Background(), params.Timeout)
    93  		if err != nil {
    94  			return err
    95  		}
    96  	}
    97  
    98  	// Run the query
    99  	return client.query(params)
   100  }
   101  
   102  // Listen listens indefinitely for multicast updates.
   103  func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error {
   104  	// Create a new client
   105  	client, err := newClient()
   106  	if err != nil {
   107  		return err
   108  	}
   109  	defer client.Close()
   110  
   111  	client.setInterface(nil, true)
   112  
   113  	// Start listening for response packets
   114  	msgCh := make(chan *dns.Msg, 32)
   115  
   116  	go client.recv(client.ipv4UnicastConn, msgCh)
   117  	go client.recv(client.ipv6UnicastConn, msgCh)
   118  	go client.recv(client.ipv4MulticastConn, msgCh)
   119  	go client.recv(client.ipv6MulticastConn, msgCh)
   120  
   121  	ip := make(map[string]*ServiceEntry)
   122  
   123  	for {
   124  		select {
   125  		case <-exit:
   126  			return nil
   127  		case <-client.closedCh:
   128  			return nil
   129  		case m := <-msgCh:
   130  			e := messageToEntry(m, ip)
   131  			if e == nil {
   132  				continue
   133  			}
   134  
   135  			// Check if this entry is complete
   136  			if e.complete() {
   137  				if e.sent {
   138  					continue
   139  				}
   140  				e.sent = true
   141  				entries <- e
   142  				ip = make(map[string]*ServiceEntry)
   143  			} else {
   144  				// Fire off a node specific query
   145  				m := new(dns.Msg)
   146  				m.SetQuestion(e.Name, dns.TypePTR)
   147  				m.RecursionDesired = false
   148  				if err := client.sendQuery(m); err != nil {
   149  					logger.Logf(logger.ErrorLevel, "[mdns] failed to query instance %s: %v", e.Name, err)
   150  				}
   151  			}
   152  		}
   153  	}
   154  
   155  	return nil
   156  }
   157  
   158  // Lookup is the same as Query, however it uses all the default parameters.
   159  func Lookup(service string, entries chan<- *ServiceEntry) error {
   160  	params := DefaultParams(service)
   161  	params.Entries = entries
   162  	return Query(params)
   163  }
   164  
   165  // Client provides a query interface that can be used to
   166  // search for service providers using mDNS.
   167  type client struct {
   168  	ipv4UnicastConn *net.UDPConn
   169  	ipv6UnicastConn *net.UDPConn
   170  
   171  	ipv4MulticastConn *net.UDPConn
   172  	ipv6MulticastConn *net.UDPConn
   173  
   174  	closedCh  chan struct{} // TODO(reddaly): This doesn't appear to be used.
   175  	closeLock sync.Mutex
   176  
   177  	closed bool
   178  }
   179  
   180  // NewClient creates a new mdns Client that can be used to query
   181  // for records.
   182  func newClient() (*client, error) {
   183  	// TODO(reddaly): At least attempt to bind to the port required in the spec.
   184  	// Create a IPv4 listener
   185  	uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
   186  	uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
   187  	if err4 != nil && err6 != nil {
   188  		logger.Logf(logger.ErrorLevel, "[mdns] failed to bind to udp port: %v %v", err4, err6)
   189  	}
   190  
   191  	if uconn4 == nil && uconn6 == nil {
   192  		return nil, fmt.Errorf("failed to bind to any unicast udp port")
   193  	}
   194  
   195  	if uconn4 == nil {
   196  		uconn4 = &net.UDPConn{}
   197  	}
   198  
   199  	if uconn6 == nil {
   200  		uconn6 = &net.UDPConn{}
   201  	}
   202  
   203  	mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
   204  	mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
   205  	if err4 != nil && err6 != nil {
   206  		logger.Logf(logger.ErrorLevel, "[mdns] failed to bind to udp port: %v %v", err4, err6)
   207  	}
   208  
   209  	if mconn4 == nil && mconn6 == nil {
   210  		return nil, fmt.Errorf("failed to bind to any multicast udp port")
   211  	}
   212  
   213  	if mconn4 == nil {
   214  		mconn4 = &net.UDPConn{}
   215  	}
   216  
   217  	if mconn6 == nil {
   218  		mconn6 = &net.UDPConn{}
   219  	}
   220  
   221  	p1 := ipv4.NewPacketConn(mconn4)
   222  	p2 := ipv6.NewPacketConn(mconn6)
   223  	p1.SetMulticastLoopback(true)
   224  	p2.SetMulticastLoopback(true)
   225  
   226  	ifaces, err := net.Interfaces()
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  
   231  	var errCount1, errCount2 int
   232  
   233  	for _, iface := range ifaces {
   234  		if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   235  			errCount1++
   236  		}
   237  		if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   238  			errCount2++
   239  		}
   240  	}
   241  
   242  	if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
   243  		return nil, fmt.Errorf("failed to join multicast group on all interfaces")
   244  	}
   245  
   246  	c := &client{
   247  		ipv4MulticastConn: mconn4,
   248  		ipv6MulticastConn: mconn6,
   249  		ipv4UnicastConn:   uconn4,
   250  		ipv6UnicastConn:   uconn6,
   251  		closedCh:          make(chan struct{}),
   252  	}
   253  	return c, nil
   254  }
   255  
   256  // Close is used to cleanup the client.
   257  func (c *client) Close() error {
   258  	c.closeLock.Lock()
   259  	defer c.closeLock.Unlock()
   260  
   261  	if c.closed {
   262  		return nil
   263  	}
   264  	c.closed = true
   265  
   266  	close(c.closedCh)
   267  
   268  	if c.ipv4UnicastConn != nil {
   269  		c.ipv4UnicastConn.Close()
   270  	}
   271  	if c.ipv6UnicastConn != nil {
   272  		c.ipv6UnicastConn.Close()
   273  	}
   274  	if c.ipv4MulticastConn != nil {
   275  		c.ipv4MulticastConn.Close()
   276  	}
   277  	if c.ipv6MulticastConn != nil {
   278  		c.ipv6MulticastConn.Close()
   279  	}
   280  
   281  	return nil
   282  }
   283  
   284  // setInterface is used to set the query interface, uses system
   285  // default if not provided.
   286  func (c *client) setInterface(iface *net.Interface, loopback bool) error {
   287  	p := ipv4.NewPacketConn(c.ipv4UnicastConn)
   288  	if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   289  		return err
   290  	}
   291  	p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
   292  	if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   293  		return err
   294  	}
   295  	p = ipv4.NewPacketConn(c.ipv4MulticastConn)
   296  	if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   297  		return err
   298  	}
   299  	p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
   300  	if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   301  		return err
   302  	}
   303  
   304  	if loopback {
   305  		p.SetMulticastLoopback(true)
   306  		p2.SetMulticastLoopback(true)
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  // query is used to perform a lookup and stream results.
   313  func (c *client) query(params *QueryParam) error {
   314  	// Create the service name
   315  	serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
   316  
   317  	// Start listening for response packets
   318  	msgCh := make(chan *dns.Msg, 32)
   319  	go c.recv(c.ipv4UnicastConn, msgCh)
   320  	go c.recv(c.ipv6UnicastConn, msgCh)
   321  	go c.recv(c.ipv4MulticastConn, msgCh)
   322  	go c.recv(c.ipv6MulticastConn, msgCh)
   323  
   324  	// Send the query
   325  	m := new(dns.Msg)
   326  	if params.Type == dns.TypeNone {
   327  		m.SetQuestion(serviceAddr, dns.TypePTR)
   328  	} else {
   329  		m.SetQuestion(serviceAddr, params.Type)
   330  	}
   331  	// RFC 6762, section 18.12.  Repurposing of Top Bit of qclass in Question
   332  	// Section
   333  	//
   334  	// In the Question Section of a Multicast DNS query, the top bit of the qclass
   335  	// field is used to indicate that unicast responses are preferred for this
   336  	// particular question.  (See Section 5.4.)
   337  	if params.WantUnicastResponse {
   338  		m.Question[0].Qclass |= 1 << 15
   339  	}
   340  	m.RecursionDesired = false
   341  	if err := c.sendQuery(m); err != nil {
   342  		return err
   343  	}
   344  
   345  	// Map the in-progress responses
   346  	inprogress := make(map[string]*ServiceEntry)
   347  
   348  	for {
   349  		select {
   350  		case resp := <-msgCh:
   351  			inp := messageToEntry(resp, inprogress)
   352  
   353  			if inp == nil {
   354  				continue
   355  			}
   356  			if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
   357  				// discard anything which we've not asked for
   358  				continue
   359  			}
   360  
   361  			// Check if this entry is complete
   362  			if inp.complete() {
   363  				if inp.sent {
   364  					continue
   365  				}
   366  
   367  				inp.sent = true
   368  				select {
   369  				case params.Entries <- inp:
   370  				case <-params.Context.Done():
   371  					return nil
   372  				}
   373  			} else {
   374  				// Fire off a node specific query
   375  				m := new(dns.Msg)
   376  				m.SetQuestion(inp.Name, inp.Type)
   377  				m.RecursionDesired = false
   378  				if err := c.sendQuery(m); err != nil {
   379  					logger.Logf(logger.ErrorLevel, "[mdns] failed to query instance %s: %v", inp.Name, err)
   380  				}
   381  			}
   382  		case <-params.Context.Done():
   383  			return nil
   384  		}
   385  	}
   386  }
   387  
   388  // sendQuery is used to multicast a query out.
   389  func (c *client) sendQuery(q *dns.Msg) error {
   390  	buf, err := q.Pack()
   391  	if err != nil {
   392  		return err
   393  	}
   394  	if c.ipv4UnicastConn != nil {
   395  		c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
   396  	}
   397  	if c.ipv6UnicastConn != nil {
   398  		c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
   399  	}
   400  	return nil
   401  }
   402  
   403  // recv is used to receive until we get a shutdown.
   404  func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
   405  	if l == nil {
   406  		return
   407  	}
   408  	buf := make([]byte, 65536)
   409  	for {
   410  		c.closeLock.Lock()
   411  		if c.closed {
   412  			c.closeLock.Unlock()
   413  			return
   414  		}
   415  		c.closeLock.Unlock()
   416  		n, err := l.Read(buf)
   417  		if err != nil {
   418  			continue
   419  		}
   420  		msg := new(dns.Msg)
   421  		if err := msg.Unpack(buf[:n]); err != nil {
   422  			continue
   423  		}
   424  		select {
   425  		case msgCh <- msg:
   426  		case <-c.closedCh:
   427  			return
   428  		}
   429  	}
   430  }
   431  
   432  // ensureName is used to ensure the named node is in progress.
   433  func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry {
   434  	if inp, ok := inprogress[name]; ok {
   435  		return inp
   436  	}
   437  	inp := &ServiceEntry{
   438  		Name: name,
   439  		Type: typ,
   440  	}
   441  	inprogress[name] = inp
   442  	return inp
   443  }
   444  
   445  // alias is used to setup an alias between two entries.
   446  func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) {
   447  	srcEntry := ensureName(inprogress, src, typ)
   448  	inprogress[dst] = srcEntry
   449  }
   450  
   451  func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry {
   452  	var inp *ServiceEntry
   453  
   454  	for _, answer := range append(m.Answer, m.Extra...) {
   455  		// TODO(reddaly): Check that response corresponds to serviceAddr?
   456  		switch rr := answer.(type) {
   457  		case *dns.PTR:
   458  			// Create new entry for this
   459  			inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype)
   460  			if inp.complete() {
   461  				continue
   462  			}
   463  		case *dns.SRV:
   464  			// Check for a target mismatch
   465  			if rr.Target != rr.Hdr.Name {
   466  				alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype)
   467  			}
   468  
   469  			// Get the port
   470  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   471  			if inp.complete() {
   472  				continue
   473  			}
   474  			inp.Host = rr.Target
   475  			inp.Port = int(rr.Port)
   476  		case *dns.TXT:
   477  			// Pull out the txt
   478  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   479  			if inp.complete() {
   480  				continue
   481  			}
   482  			inp.Info = strings.Join(rr.Txt, "|")
   483  			inp.InfoFields = rr.Txt
   484  			inp.hasTXT = true
   485  		case *dns.A:
   486  			// Pull out the IP
   487  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   488  			if inp.complete() {
   489  				continue
   490  			}
   491  			inp.Addr = rr.A // @Deprecated
   492  			inp.AddrV4 = rr.A
   493  		case *dns.AAAA:
   494  			// Pull out the IP
   495  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   496  			if inp.complete() {
   497  				continue
   498  			}
   499  			inp.Addr = rr.AAAA // @Deprecated
   500  			inp.AddrV6 = rr.AAAA
   501  		}
   502  
   503  		if inp != nil {
   504  			inp.TTL = int(answer.Header().Ttl)
   505  		}
   506  	}
   507  
   508  	return inp
   509  }