github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/internal/mdns/server.go (about)

     1  package mdns
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"net"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"github.com/miekg/dns"
    12  	log "github.com/volts-dev/volts/logger"
    13  	"golang.org/x/net/ipv4"
    14  	"golang.org/x/net/ipv6"
    15  )
    16  
    17  var (
    18  	mdnsGroupIPv4 = net.ParseIP("224.0.0.251")
    19  	mdnsGroupIPv6 = net.ParseIP("ff02::fb")
    20  
    21  	// mDNS wildcard addresses
    22  	mdnsWildcardAddrIPv4 = &net.UDPAddr{
    23  		IP:   net.ParseIP("224.0.0.0"),
    24  		Port: 5353,
    25  	}
    26  	mdnsWildcardAddrIPv6 = &net.UDPAddr{
    27  		IP:   net.ParseIP("ff02::"),
    28  		Port: 5353,
    29  	}
    30  
    31  	// mDNS endpoint addresses
    32  	ipv4Addr = &net.UDPAddr{
    33  		IP:   mdnsGroupIPv4,
    34  		Port: 5353,
    35  	}
    36  	ipv6Addr = &net.UDPAddr{
    37  		IP:   mdnsGroupIPv6,
    38  		Port: 5353,
    39  	}
    40  )
    41  
    42  // GetMachineIP is a func which returns the outbound IP of this machine.
    43  // Used by the server to determine whether to attempt send the response on a local address
    44  type GetMachineIP func() net.IP
    45  
    46  // Config is used to configure the mDNS server
    47  type Config struct {
    48  	// Zone must be provided to support responding to queries
    49  	Zone Zone
    50  
    51  	// Iface if provided binds the multicast listener to the given
    52  	// interface. If not provided, the system default multicase interface
    53  	// is used.
    54  	Iface *net.Interface
    55  
    56  	// Port If it is not 0, replace the port 5353 with this port number.
    57  	Port int
    58  
    59  	// GetMachineIP is a function to return the IP of the local machine
    60  	GetMachineIP GetMachineIP
    61  	// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP
    62  	// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports
    63  	LocalhostChecking bool
    64  }
    65  
    66  // Server is an mDNS server used to listen for mDNS queries and respond if we
    67  // have a matching local record
    68  type Server struct {
    69  	config *Config
    70  
    71  	ipv4List *net.UDPConn
    72  	ipv6List *net.UDPConn
    73  
    74  	shutdown     bool
    75  	shutdownCh   chan struct{}
    76  	shutdownLock sync.Mutex
    77  	wg           sync.WaitGroup
    78  
    79  	outboundIP net.IP
    80  }
    81  
    82  // NewServer is used to create a new mDNS server from a config
    83  func NewServer(config *Config) (*Server, error) {
    84  	setCustomPort(config.Port)
    85  
    86  	// Create the listeners
    87  	// Create wildcard connections (because :5353 can be already taken by other apps)
    88  	ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
    89  	ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
    90  	if ipv4List == nil && ipv6List == nil {
    91  		return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!")
    92  	}
    93  
    94  	if ipv4List == nil {
    95  		ipv4List = &net.UDPConn{}
    96  	}
    97  	if ipv6List == nil {
    98  		ipv6List = &net.UDPConn{}
    99  	}
   100  
   101  	// Join multicast groups to receive announcements
   102  	p1 := ipv4.NewPacketConn(ipv4List)
   103  	p2 := ipv6.NewPacketConn(ipv6List)
   104  	p1.SetMulticastLoopback(true)
   105  	p2.SetMulticastLoopback(true)
   106  
   107  	if config.Iface != nil {
   108  		if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   109  			return nil, err
   110  		}
   111  		if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   112  			return nil, err
   113  		}
   114  	} else {
   115  		ifaces, err := net.Interfaces()
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  		errCount1, errCount2 := 0, 0
   120  		for _, iface := range ifaces {
   121  			if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
   122  				errCount1++
   123  			}
   124  			if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
   125  				errCount2++
   126  			}
   127  		}
   128  		if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
   129  			return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
   130  		}
   131  	}
   132  
   133  	ipFunc := getOutboundIP
   134  	if config.GetMachineIP != nil {
   135  		ipFunc = config.GetMachineIP
   136  	}
   137  
   138  	s := &Server{
   139  		config:     config,
   140  		ipv4List:   ipv4List,
   141  		ipv6List:   ipv6List,
   142  		shutdownCh: make(chan struct{}),
   143  		outboundIP: ipFunc(),
   144  	}
   145  
   146  	go s.recv(s.ipv4List)
   147  	go s.recv(s.ipv6List)
   148  
   149  	s.wg.Add(1)
   150  	go s.probe()
   151  
   152  	return s, nil
   153  }
   154  
   155  // Shutdown is used to shutdown the listener
   156  func (s *Server) Shutdown() error {
   157  	s.shutdownLock.Lock()
   158  	defer s.shutdownLock.Unlock()
   159  
   160  	if s.shutdown {
   161  		return nil
   162  	}
   163  
   164  	s.shutdown = true
   165  	close(s.shutdownCh)
   166  	s.unregister()
   167  
   168  	if s.ipv4List != nil {
   169  		s.ipv4List.Close()
   170  	}
   171  	if s.ipv6List != nil {
   172  		s.ipv6List.Close()
   173  	}
   174  
   175  	s.wg.Wait()
   176  	return nil
   177  }
   178  
   179  // recv is a long running routine to receive packets from an interface
   180  func (s *Server) recv(c *net.UDPConn) {
   181  	if c == nil {
   182  		return
   183  	}
   184  	buf := make([]byte, 65536)
   185  	for {
   186  		s.shutdownLock.Lock()
   187  		if s.shutdown {
   188  			s.shutdownLock.Unlock()
   189  			return
   190  		}
   191  		s.shutdownLock.Unlock()
   192  		n, from, err := c.ReadFrom(buf)
   193  		if err != nil {
   194  			continue
   195  		}
   196  		if err := s.parsePacket(buf[:n], from); err != nil {
   197  			log.Errf("[ERR] mdns: Failed to handle query: %v", err)
   198  		}
   199  	}
   200  }
   201  
   202  // parsePacket is used to parse an incoming packet
   203  func (s *Server) parsePacket(packet []byte, from net.Addr) error {
   204  	var msg dns.Msg
   205  	if err := msg.Unpack(packet); err != nil {
   206  		log.Errf("[ERR] mdns: Failed to unpack packet: %v", err)
   207  		return err
   208  	}
   209  	// TODO: This is a bit of a hack
   210  	// We decided to ignore some mDNS answers for the time being
   211  	// See: https://tools.ietf.org/html/rfc6762#section-7.2
   212  	msg.Truncated = false
   213  	return s.handleQuery(&msg, from)
   214  }
   215  
   216  // handleQuery is used to handle an incoming query
   217  func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
   218  	if query.Opcode != dns.OpcodeQuery {
   219  		// "In both multicast query and multicast response messages, the OPCODE MUST
   220  		// be zero on transmission (only standard queries are currently supported
   221  		// over multicast).  Multicast DNS messages received with an OPCODE other
   222  		// than zero MUST be silently ignored."  Note: OpcodeQuery == 0
   223  		return fmt.Errorf("mdns: received query with non-zero Opcode %v: %v", query.Opcode, *query)
   224  	}
   225  	if query.Rcode != 0 {
   226  		// "In both multicast query and multicast response messages, the Response
   227  		// Code MUST be zero on transmission.  Multicast DNS messages received with
   228  		// non-zero Response Codes MUST be silently ignored."
   229  		return fmt.Errorf("mdns: received query with non-zero Rcode %v: %v", query.Rcode, *query)
   230  	}
   231  
   232  	// TODO(reddaly): Handle "TC (Truncated) Bit":
   233  	//    In query messages, if the TC bit is set, it means that additional
   234  	//    Known-Answer records may be following shortly.  A responder SHOULD
   235  	//    record this fact, and wait for those additional Known-Answer records,
   236  	//    before deciding whether to respond.  If the TC bit is clear, it means
   237  	//    that the querying host has no additional Known Answers.
   238  	if query.Truncated {
   239  		return fmt.Errorf("[ERR] mdns: support for DNS requests with high truncated bit not implemented: %v", *query)
   240  	}
   241  
   242  	var unicastAnswer, multicastAnswer []dns.RR
   243  
   244  	// Handle each question
   245  	for _, q := range query.Question {
   246  		mrecs, urecs := s.handleQuestion(q)
   247  		multicastAnswer = append(multicastAnswer, mrecs...)
   248  		unicastAnswer = append(unicastAnswer, urecs...)
   249  	}
   250  
   251  	// See section 18 of RFC 6762 for rules about DNS headers.
   252  	resp := func(unicast bool) *dns.Msg {
   253  		// 18.1: ID (Query Identifier)
   254  		// 0 for multicast response, query.Id for unicast response
   255  		id := uint16(0)
   256  		if unicast {
   257  			id = query.Id
   258  		}
   259  
   260  		var answer []dns.RR
   261  		if unicast {
   262  			answer = unicastAnswer
   263  		} else {
   264  			answer = multicastAnswer
   265  		}
   266  		if len(answer) == 0 {
   267  			return nil
   268  		}
   269  
   270  		return &dns.Msg{
   271  			MsgHdr: dns.MsgHdr{
   272  				Id: id,
   273  
   274  				// 18.2: QR (Query/Response) Bit - must be set to 1 in response.
   275  				Response: true,
   276  
   277  				// 18.3: OPCODE - must be zero in response (OpcodeQuery == 0)
   278  				Opcode: dns.OpcodeQuery,
   279  
   280  				// 18.4: AA (Authoritative Answer) Bit - must be set to 1
   281  				Authoritative: true,
   282  
   283  				// The following fields must all be set to 0:
   284  				// 18.5: TC (TRUNCATED) Bit
   285  				// 18.6: RD (Recursion Desired) Bit
   286  				// 18.7: RA (Recursion Available) Bit
   287  				// 18.8: Z (Zero) Bit
   288  				// 18.9: AD (Authentic Data) Bit
   289  				// 18.10: CD (Checking Disabled) Bit
   290  				// 18.11: RCODE (Response Code)
   291  			},
   292  			// 18.12 pertains to questions (handled by handleQuestion)
   293  			// 18.13 pertains to resource records (handled by handleQuestion)
   294  
   295  			// 18.14: Name Compression - responses should be compressed (though see
   296  			// caveats in the RFC), so set the Compress bit (part of the dns library
   297  			// API, not part of the DNS packet) to true.
   298  			Compress: true,
   299  			Question: query.Question,
   300  			Answer:   answer,
   301  		}
   302  	}
   303  
   304  	if mresp := resp(false); mresp != nil {
   305  		if err := s.sendResponse(mresp, from); err != nil {
   306  			return fmt.Errorf("mdns: error sending multicast response: %v", err)
   307  		}
   308  	}
   309  	if uresp := resp(true); uresp != nil {
   310  		if err := s.sendResponse(uresp, from); err != nil {
   311  			return fmt.Errorf("mdns: error sending unicast response: %v", err)
   312  		}
   313  	}
   314  	return nil
   315  }
   316  
   317  // handleQuestion is used to handle an incoming question
   318  //
   319  // The response to a question may be transmitted over multicast, unicast, or
   320  // both.  The return values are DNS records for each transmission type.
   321  func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
   322  	records := s.config.Zone.Records(q)
   323  	if len(records) == 0 {
   324  		return nil, nil
   325  	}
   326  
   327  	// Handle unicast and multicast responses.
   328  	// TODO(reddaly): The decision about sending over unicast vs. multicast is not
   329  	// yet fully compliant with RFC 6762.  For example, the unicast bit should be
   330  	// ignored if the records in question are close to TTL expiration.  For now,
   331  	// we just use the unicast bit to make the decision, as per the spec:
   332  	//     RFC 6762, section 18.12.  Repurposing of Top Bit of qclass in Question
   333  	//     Section
   334  	//
   335  	//     In the Question Section of a Multicast DNS query, the top bit of the
   336  	//     qclass field is used to indicate that unicast responses are preferred
   337  	//     for this particular question.  (See Section 5.4.)
   338  	if q.Qclass&(1<<15) != 0 {
   339  		return nil, records
   340  	}
   341  	return records, nil
   342  }
   343  
   344  func (s *Server) probe() {
   345  	defer s.wg.Done()
   346  
   347  	sd, ok := s.config.Zone.(*MDNSService)
   348  	if !ok {
   349  		return
   350  	}
   351  
   352  	name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
   353  
   354  	q := new(dns.Msg)
   355  	q.SetQuestion(name, dns.TypePTR)
   356  	q.RecursionDesired = false
   357  
   358  	srv := &dns.SRV{
   359  		Hdr: dns.RR_Header{
   360  			Name:   name,
   361  			Rrtype: dns.TypeSRV,
   362  			Class:  dns.ClassINET,
   363  			Ttl:    defaultTTL,
   364  		},
   365  		Priority: 0,
   366  		Weight:   0,
   367  		Port:     uint16(sd.Port),
   368  		Target:   sd.HostName,
   369  	}
   370  	txt := &dns.TXT{
   371  		Hdr: dns.RR_Header{
   372  			Name:   name,
   373  			Rrtype: dns.TypeTXT,
   374  			Class:  dns.ClassINET,
   375  			Ttl:    defaultTTL,
   376  		},
   377  		Txt: sd.TXT,
   378  	}
   379  	q.Ns = []dns.RR{srv, txt}
   380  
   381  	randomizer := rand.New(rand.NewSource(time.Now().UnixNano()))
   382  
   383  	for i := 0; i < 3; i++ {
   384  		if err := s.SendMulticast(q); err != nil {
   385  			log.Errf("[ERR] mdns: failed to send probe:", err.Error())
   386  		}
   387  		time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
   388  	}
   389  
   390  	resp := new(dns.Msg)
   391  	resp.MsgHdr.Response = true
   392  
   393  	// set for query
   394  	q.SetQuestion(name, dns.TypeANY)
   395  
   396  	resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
   397  
   398  	// reset
   399  	q.SetQuestion(name, dns.TypePTR)
   400  
   401  	// From RFC6762
   402  	//    The Multicast DNS responder MUST send at least two unsolicited
   403  	//    responses, one second apart. To provide increased robustness against
   404  	//    packet loss, a responder MAY send up to eight unsolicited responses,
   405  	//    provided that the interval between unsolicited responses increases by
   406  	//    at least a factor of two with every response sent.
   407  	timeout := 1 * time.Second
   408  	timer := time.NewTimer(timeout)
   409  	for i := 0; i < 3; i++ {
   410  		if err := s.SendMulticast(resp); err != nil {
   411  			log.Errf("[ERR] mdns: failed to send announcement:", err.Error())
   412  		}
   413  		select {
   414  		case <-timer.C:
   415  			timeout *= 2
   416  			timer.Reset(timeout)
   417  		case <-s.shutdownCh:
   418  			timer.Stop()
   419  			return
   420  		}
   421  	}
   422  }
   423  
   424  // SendMulticast us used to send a multicast response packet
   425  func (s *Server) SendMulticast(msg *dns.Msg) error {
   426  	buf, err := msg.Pack()
   427  	if err != nil {
   428  		return err
   429  	}
   430  	if s.ipv4List != nil {
   431  		s.ipv4List.WriteToUDP(buf, ipv4Addr)
   432  	}
   433  	if s.ipv6List != nil {
   434  		s.ipv6List.WriteToUDP(buf, ipv6Addr)
   435  	}
   436  	return nil
   437  }
   438  
   439  // sendResponse is used to send a response packet
   440  func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
   441  	// TODO(reddaly): Respect the unicast argument, and allow sending responses
   442  	// over multicast.
   443  	buf, err := resp.Pack()
   444  	if err != nil {
   445  		return err
   446  	}
   447  
   448  	// Determine the socket to send from
   449  	addr := from.(*net.UDPAddr)
   450  	conn := s.ipv4List
   451  	backupTarget := net.IPv4zero
   452  
   453  	if addr.IP.To4() == nil {
   454  		conn = s.ipv6List
   455  		backupTarget = net.IPv6zero
   456  	}
   457  	_, err = conn.WriteToUDP(buf, addr)
   458  	// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0
   459  	// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there
   460  	// Sending two responses is OK
   461  	if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) {
   462  		// ignore any errors, this is best efforts
   463  		conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port})
   464  	}
   465  	return err
   466  
   467  }
   468  
   469  func (s *Server) unregister() error {
   470  	sd, ok := s.config.Zone.(*MDNSService)
   471  	if !ok {
   472  		return nil
   473  	}
   474  
   475  	atomic.StoreUint32(&sd.TTL, 0)
   476  	name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
   477  
   478  	q := new(dns.Msg)
   479  	q.SetQuestion(name, dns.TypeANY)
   480  
   481  	resp := new(dns.Msg)
   482  	resp.MsgHdr.Response = true
   483  	resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
   484  
   485  	return s.SendMulticast(resp)
   486  }
   487  
   488  func setCustomPort(port int) {
   489  	if port != 0 {
   490  		if mdnsWildcardAddrIPv4.Port != port {
   491  			mdnsWildcardAddrIPv4.Port = port
   492  		}
   493  		if mdnsWildcardAddrIPv6.Port != port {
   494  			mdnsWildcardAddrIPv6.Port = port
   495  		}
   496  		if ipv4Addr.Port != port {
   497  			ipv4Addr.Port = port
   498  		}
   499  		if ipv6Addr.Port != port {
   500  			ipv6Addr.Port = port
   501  		}
   502  	}
   503  }
   504  
   505  // getOutboundIP returns the IP address of this machine as seen when dialling out
   506  func getOutboundIP() net.IP {
   507  	conn, err := net.Dial("udp", "8.8.8.8:80")
   508  	if err != nil {
   509  		// no net connectivity maybe so fallback
   510  		return nil
   511  	}
   512  	defer conn.Close()
   513  
   514  	localAddr := conn.LocalAddr().(*net.UDPAddr)
   515  
   516  	return localAddr.IP
   517  }