github.com/annwntech/go-micro/v2@v2.9.5/util/mdns/client.go (about)

     1  package mdns
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"net"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/miekg/dns"
    13  	"golang.org/x/net/ipv4"
    14  	"golang.org/x/net/ipv6"
    15  )
    16  
    17  // ServiceEntry is returned after we query for a service
    18  type ServiceEntry struct {
    19  	Name       string
    20  	Host       string
    21  	AddrV4     net.IP
    22  	AddrV6     net.IP
    23  	Port       int
    24  	Info       string
    25  	InfoFields []string
    26  	TTL        int
    27  	Type       uint16
    28  
    29  	Addr net.IP // @Deprecated
    30  
    31  	hasTXT bool
    32  	sent   bool
    33  }
    34  
    35  // complete is used to check if we have all the info we need
    36  func (s *ServiceEntry) complete() bool {
    37  
    38  	return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
    39  }
    40  
    41  // QueryParam is used to customize how a Lookup is performed
    42  type QueryParam struct {
    43  	Service             string               // Service to lookup
    44  	Domain              string               // Lookup domain, default "local"
    45  	Type                uint16               // Lookup type, defaults to dns.TypePTR
    46  	Context             context.Context      // Context
    47  	Timeout             time.Duration        // Lookup timeout, default 1 second. Ignored if Context is provided
    48  	Interface           *net.Interface       // Multicast interface to use
    49  	Entries             chan<- *ServiceEntry // Entries Channel
    50  	WantUnicastResponse bool                 // Unicast response desired, as per 5.4 in RFC
    51  }
    52  
    53  // DefaultParams is used to return a default set of QueryParam's
    54  func DefaultParams(service string) *QueryParam {
    55  	return &QueryParam{
    56  		Service:             service,
    57  		Domain:              "local",
    58  		Timeout:             time.Second,
    59  		Entries:             make(chan *ServiceEntry),
    60  		WantUnicastResponse: false, // TODO(reddaly): Change this default.
    61  	}
    62  }
    63  
    64  // Query looks up a given service, in a domain, waiting at most
    65  // for a timeout before finishing the query. The results are streamed
    66  // to a channel. Sends will not block, so clients should make sure to
    67  // either read or buffer.
    68  func Query(params *QueryParam) error {
    69  	// Create a new client
    70  	client, err := newClient()
    71  	if err != nil {
    72  		return err
    73  	}
    74  	defer client.Close()
    75  
    76  	// Set the multicast interface
    77  	if params.Interface != nil {
    78  		if err := client.setInterface(params.Interface, false); err != nil {
    79  			return err
    80  		}
    81  	}
    82  
    83  	// Ensure defaults are set
    84  	if params.Domain == "" {
    85  		params.Domain = "local"
    86  	}
    87  
    88  	if params.Context == nil {
    89  		if params.Timeout == 0 {
    90  			params.Timeout = time.Second
    91  		}
    92  		params.Context, _ = context.WithTimeout(context.Background(), params.Timeout)
    93  		if err != nil {
    94  			return err
    95  		}
    96  	}
    97  
    98  	// Run the query
    99  	return client.query(params)
   100  }
   101  
   102  // Listen listens indefinitely for multicast updates
   103  func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error {
   104  	// Create a new client
   105  	client, err := newClient()
   106  	if err != nil {
   107  		return err
   108  	}
   109  	defer client.Close()
   110  
   111  	client.setInterface(nil, true)
   112  
   113  	// Start listening for response packets
   114  	msgCh := make(chan *dns.Msg, 32)
   115  
   116  	go client.recv(client.ipv4UnicastConn, msgCh)
   117  	go client.recv(client.ipv6UnicastConn, msgCh)
   118  	go client.recv(client.ipv4MulticastConn, msgCh)
   119  	go client.recv(client.ipv6MulticastConn, msgCh)
   120  
   121  	ip := make(map[string]*ServiceEntry)
   122  
   123  	for {
   124  		select {
   125  		case <-exit:
   126  			return nil
   127  		case <-client.closedCh:
   128  			return nil
   129  		case m := <-msgCh:
   130  			e := messageToEntry(m, ip)
   131  			if e == nil {
   132  				continue
   133  			}
   134  
   135  			// Check if this entry is complete
   136  			if e.complete() {
   137  				if e.sent {
   138  					continue
   139  				}
   140  				e.sent = true
   141  				entries <- e
   142  				ip = make(map[string]*ServiceEntry)
   143  			} else {
   144  				// Fire off a node specific query
   145  				m := new(dns.Msg)
   146  				m.SetQuestion(e.Name, dns.TypePTR)
   147  				m.RecursionDesired = false
   148  				if err := client.sendQuery(m); err != nil {
   149  					log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err)
   150  				}
   151  			}
   152  		}
   153  	}
   154  
   155  	return nil
   156  }
   157  
   158  // Lookup is the same as Query, however it uses all the default parameters
   159  func Lookup(service string, entries chan<- *ServiceEntry) error {
   160  	params := DefaultParams(service)
   161  	params.Entries = entries
   162  	return Query(params)
   163  }
   164  
   165  // Client provides a query interface that can be used to
   166  // search for service providers using mDNS
   167  type client struct {
   168  	ipv4UnicastConn *net.UDPConn
   169  	ipv6UnicastConn *net.UDPConn
   170  
   171  	ipv4MulticastConn *net.UDPConn
   172  	ipv6MulticastConn *net.UDPConn
   173  
   174  	closed    bool
   175  	closedCh  chan struct{} // TODO(reddaly): This doesn't appear to be used.
   176  	closeLock sync.Mutex
   177  }
   178  
   179  // NewClient creates a new mdns Client that can be used to query
   180  // for records
   181  func newClient() (*client, error) {
   182  	// TODO(reddaly): At least attempt to bind to the port required in the spec.
   183  	// Create a IPv4 listener
   184  	uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
   185  	uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
   186  	if err4 != nil && err6 != nil {
   187  		log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
   188  	}
   189  
   190  	if uconn4 == nil && uconn6 == nil {
   191  		return nil, fmt.Errorf("failed to bind to any unicast udp port")
   192  	}
   193  
   194  	if uconn4 == nil {
   195  		uconn4 = &net.UDPConn{}
   196  	}
   197  
   198  	if uconn6 == nil {
   199  		uconn6 = &net.UDPConn{}
   200  	}
   201  
   202  	mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
   203  	mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
   204  	if err4 != nil && err6 != nil {
   205  		log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
   206  	}
   207  
   208  	if mconn4 == nil && mconn6 == nil {
   209  		return nil, fmt.Errorf("failed to bind to any multicast udp port")
   210  	}
   211  
   212  	if mconn4 == nil {
   213  		mconn4 = &net.UDPConn{}
   214  	}
   215  
   216  	if mconn6 == nil {
   217  		mconn6 = &net.UDPConn{}
   218  	}
   219  
   220  	p1 := ipv4.NewPacketConn(mconn4)
   221  	p2 := ipv6.NewPacketConn(mconn6)
   222  	p1.SetMulticastLoopback(true)
   223  	p2.SetMulticastLoopback(true)
   224  
   225  	ifaces, err := net.Interfaces()
   226  	if err != nil {
   227  		return nil, err
   228  	}
   229  
   230  	var errCount1, errCount2 int
   231  
   232  	for _, iface := range ifaces {
   233  		if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   234  			errCount1++
   235  		}
   236  		if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   237  			errCount2++
   238  		}
   239  	}
   240  
   241  	if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
   242  		return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
   243  	}
   244  
   245  	c := &client{
   246  		ipv4MulticastConn: mconn4,
   247  		ipv6MulticastConn: mconn6,
   248  		ipv4UnicastConn:   uconn4,
   249  		ipv6UnicastConn:   uconn6,
   250  		closedCh:          make(chan struct{}),
   251  	}
   252  	return c, nil
   253  }
   254  
   255  // Close is used to cleanup the client
   256  func (c *client) Close() error {
   257  	c.closeLock.Lock()
   258  	defer c.closeLock.Unlock()
   259  
   260  	if c.closed {
   261  		return nil
   262  	}
   263  	c.closed = true
   264  
   265  	close(c.closedCh)
   266  
   267  	if c.ipv4UnicastConn != nil {
   268  		c.ipv4UnicastConn.Close()
   269  	}
   270  	if c.ipv6UnicastConn != nil {
   271  		c.ipv6UnicastConn.Close()
   272  	}
   273  	if c.ipv4MulticastConn != nil {
   274  		c.ipv4MulticastConn.Close()
   275  	}
   276  	if c.ipv6MulticastConn != nil {
   277  		c.ipv6MulticastConn.Close()
   278  	}
   279  
   280  	return nil
   281  }
   282  
   283  // setInterface is used to set the query interface, uses sytem
   284  // default if not provided
   285  func (c *client) setInterface(iface *net.Interface, loopback bool) error {
   286  	p := ipv4.NewPacketConn(c.ipv4UnicastConn)
   287  	if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   288  		return err
   289  	}
   290  	p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
   291  	if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   292  		return err
   293  	}
   294  	p = ipv4.NewPacketConn(c.ipv4MulticastConn)
   295  	if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   296  		return err
   297  	}
   298  	p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
   299  	if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   300  		return err
   301  	}
   302  
   303  	if loopback {
   304  		p.SetMulticastLoopback(true)
   305  		p2.SetMulticastLoopback(true)
   306  	}
   307  
   308  	return nil
   309  }
   310  
   311  // query is used to perform a lookup and stream results
   312  func (c *client) query(params *QueryParam) error {
   313  	// Create the service name
   314  	serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
   315  
   316  	// Start listening for response packets
   317  	msgCh := make(chan *dns.Msg, 32)
   318  	go c.recv(c.ipv4UnicastConn, msgCh)
   319  	go c.recv(c.ipv6UnicastConn, msgCh)
   320  	go c.recv(c.ipv4MulticastConn, msgCh)
   321  	go c.recv(c.ipv6MulticastConn, msgCh)
   322  
   323  	// Send the query
   324  	m := new(dns.Msg)
   325  	if params.Type == dns.TypeNone {
   326  		m.SetQuestion(serviceAddr, dns.TypePTR)
   327  	} else {
   328  		m.SetQuestion(serviceAddr, params.Type)
   329  	}
   330  	// RFC 6762, section 18.12.  Repurposing of Top Bit of qclass in Question
   331  	// Section
   332  	//
   333  	// In the Question Section of a Multicast DNS query, the top bit of the qclass
   334  	// field is used to indicate that unicast responses are preferred for this
   335  	// particular question.  (See Section 5.4.)
   336  	if params.WantUnicastResponse {
   337  		m.Question[0].Qclass |= 1 << 15
   338  	}
   339  	m.RecursionDesired = false
   340  	if err := c.sendQuery(m); err != nil {
   341  		return err
   342  	}
   343  
   344  	// Map the in-progress responses
   345  	inprogress := make(map[string]*ServiceEntry)
   346  
   347  	for {
   348  		select {
   349  		case resp := <-msgCh:
   350  			inp := messageToEntry(resp, inprogress)
   351  
   352  			if inp == nil {
   353  				continue
   354  			}
   355  			if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
   356  				// discard anything which we've not asked for
   357  				continue
   358  			}
   359  
   360  			// Check if this entry is complete
   361  			if inp.complete() {
   362  				if inp.sent {
   363  					continue
   364  				}
   365  
   366  				inp.sent = true
   367  				select {
   368  				case params.Entries <- inp:
   369  				case <-params.Context.Done():
   370  					return nil
   371  				}
   372  			} else {
   373  				// Fire off a node specific query
   374  				m := new(dns.Msg)
   375  				m.SetQuestion(inp.Name, inp.Type)
   376  				m.RecursionDesired = false
   377  				if err := c.sendQuery(m); err != nil {
   378  					log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
   379  				}
   380  			}
   381  		case <-params.Context.Done():
   382  			return nil
   383  		}
   384  	}
   385  }
   386  
   387  // sendQuery is used to multicast a query out
   388  func (c *client) sendQuery(q *dns.Msg) error {
   389  	buf, err := q.Pack()
   390  	if err != nil {
   391  		return err
   392  	}
   393  	if c.ipv4UnicastConn != nil {
   394  		c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
   395  	}
   396  	if c.ipv6UnicastConn != nil {
   397  		c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
   398  	}
   399  	return nil
   400  }
   401  
   402  // recv is used to receive until we get a shutdown
   403  func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
   404  	if l == nil {
   405  		return
   406  	}
   407  	buf := make([]byte, 65536)
   408  	for {
   409  		c.closeLock.Lock()
   410  		if c.closed {
   411  			c.closeLock.Unlock()
   412  			return
   413  		}
   414  		c.closeLock.Unlock()
   415  		n, err := l.Read(buf)
   416  		if err != nil {
   417  			continue
   418  		}
   419  		msg := new(dns.Msg)
   420  		if err := msg.Unpack(buf[:n]); err != nil {
   421  			continue
   422  		}
   423  		select {
   424  		case msgCh <- msg:
   425  		case <-c.closedCh:
   426  			return
   427  		}
   428  	}
   429  }
   430  
   431  // ensureName is used to ensure the named node is in progress
   432  func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry {
   433  	if inp, ok := inprogress[name]; ok {
   434  		return inp
   435  	}
   436  	inp := &ServiceEntry{
   437  		Name: name,
   438  		Type: typ,
   439  	}
   440  	inprogress[name] = inp
   441  	return inp
   442  }
   443  
   444  // alias is used to setup an alias between two entries
   445  func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) {
   446  	srcEntry := ensureName(inprogress, src, typ)
   447  	inprogress[dst] = srcEntry
   448  }
   449  
   450  func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry {
   451  	var inp *ServiceEntry
   452  
   453  	for _, answer := range append(m.Answer, m.Extra...) {
   454  		// TODO(reddaly): Check that response corresponds to serviceAddr?
   455  		switch rr := answer.(type) {
   456  		case *dns.PTR:
   457  			// Create new entry for this
   458  			inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype)
   459  			if inp.complete() {
   460  				continue
   461  			}
   462  		case *dns.SRV:
   463  			// Check for a target mismatch
   464  			if rr.Target != rr.Hdr.Name {
   465  				alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype)
   466  			}
   467  
   468  			// Get the port
   469  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   470  			if inp.complete() {
   471  				continue
   472  			}
   473  			inp.Host = rr.Target
   474  			inp.Port = int(rr.Port)
   475  		case *dns.TXT:
   476  			// Pull out the txt
   477  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   478  			if inp.complete() {
   479  				continue
   480  			}
   481  			inp.Info = strings.Join(rr.Txt, "|")
   482  			inp.InfoFields = rr.Txt
   483  			inp.hasTXT = true
   484  		case *dns.A:
   485  			// Pull out the IP
   486  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   487  			if inp.complete() {
   488  				continue
   489  			}
   490  			inp.Addr = rr.A // @Deprecated
   491  			inp.AddrV4 = rr.A
   492  		case *dns.AAAA:
   493  			// Pull out the IP
   494  			inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
   495  			if inp.complete() {
   496  				continue
   497  			}
   498  			inp.Addr = rr.AAAA // @Deprecated
   499  			inp.AddrV6 = rr.AAAA
   500  		}
   501  
   502  		if inp != nil {
   503  			inp.TTL = int(answer.Header().Ttl)
   504  		}
   505  	}
   506  
   507  	return inp
   508  }