github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/internal/mdns/zone.go (about)

     1  package mdns
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"os"
     7  	"strings"
     8  	"sync/atomic"
     9  
    10  	"github.com/miekg/dns"
    11  )
    12  
    13  const (
    14  	// defaultTTL is the default TTL value in returned DNS records in seconds.
    15  	defaultTTL = 120
    16  )
    17  
    18  // Zone is the interface used to integrate with the server and
    19  // to serve records dynamically
    20  type Zone interface {
    21  	// Records returns DNS records in response to a DNS question.
    22  	Records(q dns.Question) []dns.RR
    23  }
    24  
    25  // MDNSService is used to export a named service by implementing a Zone
    26  type MDNSService struct {
    27  	Instance     string   // Instance name (e.g. "hostService name")
    28  	Service      string   // Service name (e.g. "_http._tcp.")
    29  	Domain       string   // If blank, assumes "local"
    30  	HostName     string   // Host machine DNS name (e.g. "mymachine.net.")
    31  	Port         int      // Service Port
    32  	IPs          []net.IP // IP addresses for the service's host
    33  	TXT          []string // Service TXT records
    34  	TTL          uint32
    35  	serviceAddr  string // Fully qualified service address
    36  	instanceAddr string // Fully qualified instance address
    37  	enumAddr     string // _services._dns-sd._udp.<domain>
    38  }
    39  
    40  // validateFQDN returns an error if the passed string is not a fully qualified
    41  // hdomain name (more specifically, a hostname).
    42  func validateFQDN(s string) error {
    43  	if len(s) == 0 {
    44  		return fmt.Errorf("FQDN must not be blank")
    45  	}
    46  	if s[len(s)-1] != '.' {
    47  		return fmt.Errorf("FQDN must end in period: %s", s)
    48  	}
    49  	// TODO(reddaly): Perform full validation.
    50  
    51  	return nil
    52  }
    53  
    54  // NewMDNSService returns a new instance of MDNSService.
    55  //
    56  // If domain, hostName, or ips is set to the zero value, then a default value
    57  // will be inferred from the operating system.
    58  //
    59  // TODO(reddaly): This interface may need to change to account for "unique
    60  // record" conflict rules of the mDNS protocol.  Upon startup, the server should
    61  // check to ensure that the instance name does not conflict with other instance
    62  // names, and, if required, select a new name.  There may also be conflicting
    63  // hostName A/AAAA records.
    64  func NewMDNSService(instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) {
    65  	// Sanity check inputs
    66  	if instance == "" {
    67  		return nil, fmt.Errorf("missing service instance name")
    68  	}
    69  	if service == "" {
    70  		return nil, fmt.Errorf("missing service name")
    71  	}
    72  	if port == 0 {
    73  		return nil, fmt.Errorf("missing service port")
    74  	}
    75  
    76  	// Set default domain
    77  	if domain == "" {
    78  		domain = "local."
    79  	}
    80  	if err := validateFQDN(domain); err != nil {
    81  		return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err)
    82  	}
    83  
    84  	// Get host information if no host is specified.
    85  	if hostName == "" {
    86  		var err error
    87  		hostName, err = os.Hostname()
    88  		if err != nil {
    89  			return nil, fmt.Errorf("could not determine host: %v", err)
    90  		}
    91  		hostName = fmt.Sprintf("%s.", hostName)
    92  	}
    93  	if err := validateFQDN(hostName); err != nil {
    94  		return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err)
    95  	}
    96  
    97  	if len(ips) == 0 {
    98  		var err error
    99  		ips, err = net.LookupIP(trimDot(hostName))
   100  		if err != nil {
   101  			// Try appending the host domain suffix and lookup again
   102  			// (required for Linux-based hosts)
   103  			tmpHostName := fmt.Sprintf("%s%s", hostName, domain)
   104  
   105  			ips, err = net.LookupIP(trimDot(tmpHostName))
   106  
   107  			if err != nil {
   108  				return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName)
   109  			}
   110  		}
   111  	}
   112  	for _, ip := range ips {
   113  		if ip.To4() == nil && ip.To16() == nil {
   114  			return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip)
   115  		}
   116  	}
   117  
   118  	return &MDNSService{
   119  		Instance:     instance,
   120  		Service:      service,
   121  		Domain:       domain,
   122  		HostName:     hostName,
   123  		Port:         port,
   124  		IPs:          ips,
   125  		TXT:          txt,
   126  		TTL:          defaultTTL,
   127  		serviceAddr:  fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)),
   128  		instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)),
   129  		enumAddr:     fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)),
   130  	}, nil
   131  }
   132  
   133  // trimDot is used to trim the dots from the start or end of a string
   134  func trimDot(s string) string {
   135  	return strings.Trim(s, ".")
   136  }
   137  
   138  // Records returns DNS records in response to a DNS question.
   139  func (m *MDNSService) Records(q dns.Question) []dns.RR {
   140  	switch q.Name {
   141  	case m.enumAddr:
   142  		return m.serviceEnum(q)
   143  	case m.serviceAddr:
   144  		return m.serviceRecords(q)
   145  	case m.instanceAddr:
   146  		return m.instanceRecords(q)
   147  	case m.HostName:
   148  		if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA {
   149  			return m.instanceRecords(q)
   150  		}
   151  		fallthrough
   152  	default:
   153  		return nil
   154  	}
   155  }
   156  
   157  func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR {
   158  	switch q.Qtype {
   159  	case dns.TypeANY:
   160  		fallthrough
   161  	case dns.TypePTR:
   162  		rr := &dns.PTR{
   163  			Hdr: dns.RR_Header{
   164  				Name:   q.Name,
   165  				Rrtype: dns.TypePTR,
   166  				Class:  dns.ClassINET,
   167  				Ttl:    atomic.LoadUint32(&m.TTL),
   168  			},
   169  			Ptr: m.serviceAddr,
   170  		}
   171  		return []dns.RR{rr}
   172  	default:
   173  		return nil
   174  	}
   175  }
   176  
   177  // serviceRecords is called when the query matches the service name
   178  func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR {
   179  	switch q.Qtype {
   180  	case dns.TypeANY:
   181  		fallthrough
   182  	case dns.TypePTR:
   183  		// Build a PTR response for the service
   184  		rr := &dns.PTR{
   185  			Hdr: dns.RR_Header{
   186  				Name:   q.Name,
   187  				Rrtype: dns.TypePTR,
   188  				Class:  dns.ClassINET,
   189  				Ttl:    atomic.LoadUint32(&m.TTL),
   190  			},
   191  			Ptr: m.instanceAddr,
   192  		}
   193  		servRec := []dns.RR{rr}
   194  
   195  		// Get the instance records
   196  		instRecs := m.instanceRecords(dns.Question{
   197  			Name:  m.instanceAddr,
   198  			Qtype: dns.TypeANY,
   199  		})
   200  
   201  		// Return the service record with the instance records
   202  		return append(servRec, instRecs...)
   203  	default:
   204  		return nil
   205  	}
   206  }
   207  
   208  // serviceRecords is called when the query matches the instance name
   209  func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
   210  	switch q.Qtype {
   211  	case dns.TypeANY:
   212  		// Get the SRV, which includes A and AAAA
   213  		recs := m.instanceRecords(dns.Question{
   214  			Name:  m.instanceAddr,
   215  			Qtype: dns.TypeSRV,
   216  		})
   217  
   218  		// Add the TXT record
   219  		recs = append(recs, m.instanceRecords(dns.Question{
   220  			Name:  m.instanceAddr,
   221  			Qtype: dns.TypeTXT,
   222  		})...)
   223  		return recs
   224  
   225  	case dns.TypeA:
   226  		var rr []dns.RR
   227  		for _, ip := range m.IPs {
   228  			if ip4 := ip.To4(); ip4 != nil {
   229  				rr = append(rr, &dns.A{
   230  					Hdr: dns.RR_Header{
   231  						Name:   m.HostName,
   232  						Rrtype: dns.TypeA,
   233  						Class:  dns.ClassINET,
   234  						Ttl:    atomic.LoadUint32(&m.TTL),
   235  					},
   236  					A: ip4,
   237  				})
   238  			}
   239  		}
   240  		return rr
   241  
   242  	case dns.TypeAAAA:
   243  		var rr []dns.RR
   244  		for _, ip := range m.IPs {
   245  			if ip.To4() != nil {
   246  				// TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and
   247  				// putinto AAAA records, but the current logic puts ipv4-encodable
   248  				// addresses into the A records exclusively.  Perhaps this should be
   249  				// configurable?
   250  				continue
   251  			}
   252  
   253  			if ip16 := ip.To16(); ip16 != nil {
   254  				rr = append(rr, &dns.AAAA{
   255  					Hdr: dns.RR_Header{
   256  						Name:   m.HostName,
   257  						Rrtype: dns.TypeAAAA,
   258  						Class:  dns.ClassINET,
   259  						Ttl:    atomic.LoadUint32(&m.TTL),
   260  					},
   261  					AAAA: ip16,
   262  				})
   263  			}
   264  		}
   265  		return rr
   266  
   267  	case dns.TypeSRV:
   268  		// Create the SRV Record
   269  		srv := &dns.SRV{
   270  			Hdr: dns.RR_Header{
   271  				Name:   q.Name,
   272  				Rrtype: dns.TypeSRV,
   273  				Class:  dns.ClassINET,
   274  				Ttl:    atomic.LoadUint32(&m.TTL),
   275  			},
   276  			Priority: 10,
   277  			Weight:   1,
   278  			Port:     uint16(m.Port),
   279  			Target:   m.HostName,
   280  		}
   281  		recs := []dns.RR{srv}
   282  
   283  		// Add the A record
   284  		recs = append(recs, m.instanceRecords(dns.Question{
   285  			Name:  m.instanceAddr,
   286  			Qtype: dns.TypeA,
   287  		})...)
   288  
   289  		// Add the AAAA record
   290  		recs = append(recs, m.instanceRecords(dns.Question{
   291  			Name:  m.instanceAddr,
   292  			Qtype: dns.TypeAAAA,
   293  		})...)
   294  		return recs
   295  
   296  	case dns.TypeTXT:
   297  		txt := &dns.TXT{
   298  			Hdr: dns.RR_Header{
   299  				Name:   q.Name,
   300  				Rrtype: dns.TypeTXT,
   301  				Class:  dns.ClassINET,
   302  				Ttl:    atomic.LoadUint32(&m.TTL),
   303  			},
   304  			Txt: m.TXT,
   305  		}
   306  		return []dns.RR{txt}
   307  	}
   308  	return nil
   309  }