github.com/jonasi/go@v0.0.0-20150930005915-e78e654c1de0/src/net/dnsclient_unix.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build darwin dragonfly freebsd linux netbsd openbsd solaris
     6  
     7  // DNS client: see RFC 1035.
     8  // Has to be linked into package net for Dial.
     9  
    10  // TODO(rsc):
    11  //	Could potentially handle many outstanding lookups faster.
    12  //	Could have a small cache.
    13  //	Random UDP source port (net.Dial should do that for us).
    14  //	Random request IDs.
    15  
    16  package net
    17  
    18  import (
    19  	"errors"
    20  	"io"
    21  	"math/rand"
    22  	"os"
    23  	"strconv"
    24  	"sync"
    25  	"time"
    26  )
    27  
    28  // A dnsConn represents a DNS transport endpoint.
    29  type dnsConn interface {
    30  	Conn
    31  
    32  	// readDNSResponse reads a DNS response message from the DNS
    33  	// transport endpoint and returns the received DNS response
    34  	// message.
    35  	readDNSResponse() (*dnsMsg, error)
    36  
    37  	// writeDNSQuery writes a DNS query message to the DNS
    38  	// connection endpoint.
    39  	writeDNSQuery(*dnsMsg) error
    40  }
    41  
    42  func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
    43  	b := make([]byte, 512) // see RFC 1035
    44  	n, err := c.Read(b)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	msg := &dnsMsg{}
    49  	if !msg.Unpack(b[:n]) {
    50  		return nil, errors.New("cannot unmarshal DNS message")
    51  	}
    52  	return msg, nil
    53  }
    54  
    55  func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
    56  	b, ok := msg.Pack()
    57  	if !ok {
    58  		return errors.New("cannot marshal DNS message")
    59  	}
    60  	if _, err := c.Write(b); err != nil {
    61  		return err
    62  	}
    63  	return nil
    64  }
    65  
    66  func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
    67  	b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
    68  	if _, err := io.ReadFull(c, b[:2]); err != nil {
    69  		return nil, err
    70  	}
    71  	l := int(b[0])<<8 | int(b[1])
    72  	if l > len(b) {
    73  		b = make([]byte, l)
    74  	}
    75  	n, err := io.ReadFull(c, b[:l])
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	msg := &dnsMsg{}
    80  	if !msg.Unpack(b[:n]) {
    81  		return nil, errors.New("cannot unmarshal DNS message")
    82  	}
    83  	return msg, nil
    84  }
    85  
    86  func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
    87  	b, ok := msg.Pack()
    88  	if !ok {
    89  		return errors.New("cannot marshal DNS message")
    90  	}
    91  	l := uint16(len(b))
    92  	b = append([]byte{byte(l >> 8), byte(l)}, b...)
    93  	if _, err := c.Write(b); err != nil {
    94  		return err
    95  	}
    96  	return nil
    97  }
    98  
    99  func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
   100  	switch network {
   101  	case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
   102  	default:
   103  		return nil, UnknownNetworkError(network)
   104  	}
   105  	// Calling Dial here is scary -- we have to be sure not to
   106  	// dial a name that will require a DNS lookup, or Dial will
   107  	// call back here to translate it. The DNS config parser has
   108  	// already checked that all the cfg.servers[i] are IP
   109  	// addresses, which Dial will use without a DNS lookup.
   110  	c, err := d.Dial(network, server)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  	switch network {
   115  	case "tcp", "tcp4", "tcp6":
   116  		return c.(*TCPConn), nil
   117  	case "udp", "udp4", "udp6":
   118  		return c.(*UDPConn), nil
   119  	}
   120  	panic("unreachable")
   121  }
   122  
   123  // exchange sends a query on the connection and hopes for a response.
   124  func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
   125  	d := Dialer{Timeout: timeout}
   126  	out := dnsMsg{
   127  		dnsMsgHdr: dnsMsgHdr{
   128  			recursion_desired: true,
   129  		},
   130  		question: []dnsQuestion{
   131  			{name, qtype, dnsClassINET},
   132  		},
   133  	}
   134  	for _, network := range []string{"udp", "tcp"} {
   135  		c, err := d.dialDNS(network, server)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  		defer c.Close()
   140  		if timeout > 0 {
   141  			c.SetDeadline(time.Now().Add(timeout))
   142  		}
   143  		out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
   144  		if err := c.writeDNSQuery(&out); err != nil {
   145  			return nil, err
   146  		}
   147  		in, err := c.readDNSResponse()
   148  		if err != nil {
   149  			return nil, err
   150  		}
   151  		if in.id != out.id {
   152  			return nil, errors.New("DNS message ID mismatch")
   153  		}
   154  		if in.truncated { // see RFC 5966
   155  			continue
   156  		}
   157  		return in, nil
   158  	}
   159  	return nil, errors.New("no answer from DNS server")
   160  }
   161  
   162  // Do a lookup for a single name, which must be rooted
   163  // (otherwise answer will not find the answers).
   164  func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
   165  	if len(cfg.servers) == 0 {
   166  		return "", nil, &DNSError{Err: "no DNS servers", Name: name}
   167  	}
   168  	timeout := time.Duration(cfg.timeout) * time.Second
   169  	var lastErr error
   170  	for i := 0; i < cfg.attempts; i++ {
   171  		for _, server := range cfg.servers {
   172  			server = JoinHostPort(server, "53")
   173  			msg, err := exchange(server, name, qtype, timeout)
   174  			if err != nil {
   175  				lastErr = &DNSError{
   176  					Err:    err.Error(),
   177  					Name:   name,
   178  					Server: server,
   179  				}
   180  				if nerr, ok := err.(Error); ok && nerr.Timeout() {
   181  					lastErr.(*DNSError).IsTimeout = true
   182  				}
   183  				continue
   184  			}
   185  			cname, rrs, err := answer(name, server, msg, qtype)
   186  			if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError && msg.recursion_available {
   187  				return cname, rrs, err
   188  			}
   189  			lastErr = err
   190  		}
   191  	}
   192  	return "", nil, lastErr
   193  }
   194  
   195  // addrRecordList converts and returns a list of IP addresses from DNS
   196  // address records (both A and AAAA). Other record types are ignored.
   197  func addrRecordList(rrs []dnsRR) []IPAddr {
   198  	addrs := make([]IPAddr, 0, 4)
   199  	for _, rr := range rrs {
   200  		switch rr := rr.(type) {
   201  		case *dnsRR_A:
   202  			addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
   203  		case *dnsRR_AAAA:
   204  			ip := make(IP, IPv6len)
   205  			copy(ip, rr.AAAA[:])
   206  			addrs = append(addrs, IPAddr{IP: ip})
   207  		}
   208  	}
   209  	return addrs
   210  }
   211  
   212  // A resolverConfig represents a DNS stub resolver configuration.
   213  type resolverConfig struct {
   214  	initOnce sync.Once // guards init of resolverConfig
   215  
   216  	// ch is used as a semaphore that only allows one lookup at a
   217  	// time to recheck resolv.conf.
   218  	ch          chan struct{} // guards lastChecked and modTime
   219  	lastChecked time.Time     // last time resolv.conf was checked
   220  	modTime     time.Time     // time of resolv.conf modification
   221  
   222  	mu        sync.RWMutex // protects dnsConfig
   223  	dnsConfig *dnsConfig   // parsed resolv.conf structure used in lookups
   224  }
   225  
   226  var resolvConf resolverConfig
   227  
   228  // init initializes conf and is only called via conf.initOnce.
   229  func (conf *resolverConfig) init() {
   230  	// Set dnsConfig, modTime, and lastChecked so we don't parse
   231  	// resolv.conf twice the first time.
   232  	conf.dnsConfig = systemConf().resolv
   233  	if conf.dnsConfig == nil {
   234  		conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
   235  	}
   236  
   237  	if fi, err := os.Stat("/etc/resolv.conf"); err == nil {
   238  		conf.modTime = fi.ModTime()
   239  	}
   240  	conf.lastChecked = time.Now()
   241  
   242  	// Prepare ch so that only one update of resolverConfig may
   243  	// run at once.
   244  	conf.ch = make(chan struct{}, 1)
   245  }
   246  
   247  // tryUpdate tries to update conf with the named resolv.conf file.
   248  // The name variable only exists for testing. It is otherwise always
   249  // "/etc/resolv.conf".
   250  func (conf *resolverConfig) tryUpdate(name string) {
   251  	conf.initOnce.Do(conf.init)
   252  
   253  	// Ensure only one update at a time checks resolv.conf.
   254  	if !conf.tryAcquireSema() {
   255  		return
   256  	}
   257  	defer conf.releaseSema()
   258  
   259  	now := time.Now()
   260  	if conf.lastChecked.After(now.Add(-5 * time.Second)) {
   261  		return
   262  	}
   263  	conf.lastChecked = now
   264  
   265  	if fi, err := os.Stat(name); err == nil {
   266  		if fi.ModTime().Equal(conf.modTime) {
   267  			return
   268  		}
   269  		conf.modTime = fi.ModTime()
   270  	} else {
   271  		// If modTime wasn't set prior, assume nothing has changed.
   272  		if conf.modTime.IsZero() {
   273  			return
   274  		}
   275  		conf.modTime = time.Time{}
   276  	}
   277  
   278  	dnsConf := dnsReadConfig(name)
   279  	conf.mu.Lock()
   280  	conf.dnsConfig = dnsConf
   281  	conf.mu.Unlock()
   282  }
   283  
   284  func (conf *resolverConfig) tryAcquireSema() bool {
   285  	select {
   286  	case conf.ch <- struct{}{}:
   287  		return true
   288  	default:
   289  		return false
   290  	}
   291  }
   292  
   293  func (conf *resolverConfig) releaseSema() {
   294  	<-conf.ch
   295  }
   296  
   297  func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
   298  	if !isDomainName(name) {
   299  		return "", nil, &DNSError{Err: "invalid domain name", Name: name}
   300  	}
   301  	resolvConf.tryUpdate("/etc/resolv.conf")
   302  	resolvConf.mu.RLock()
   303  	conf := resolvConf.dnsConfig
   304  	resolvConf.mu.RUnlock()
   305  	for _, fqdn := range conf.nameList(name) {
   306  		cname, rrs, err = tryOneName(conf, fqdn, qtype)
   307  		if err == nil {
   308  			break
   309  		}
   310  	}
   311  	if err, ok := err.(*DNSError); ok {
   312  		// Show original name passed to lookup, not suffixed one.
   313  		// In general we might have tried many suffixes; showing
   314  		// just one is misleading. See also golang.org/issue/6324.
   315  		err.Name = name
   316  	}
   317  	return
   318  }
   319  
   320  // nameList returns a list of names for sequential DNS queries.
   321  func (conf *dnsConfig) nameList(name string) []string {
   322  	// If name is rooted (trailing dot), try only that name.
   323  	rooted := len(name) > 0 && name[len(name)-1] == '.'
   324  	if rooted {
   325  		return []string{name}
   326  	}
   327  	// Build list of search choices.
   328  	names := make([]string, 0, 1+len(conf.search))
   329  	// If name has enough dots, try unsuffixed first.
   330  	if count(name, '.') >= conf.ndots {
   331  		names = append(names, name+".")
   332  	}
   333  	// Try suffixes.
   334  	for _, suffix := range conf.search {
   335  		suffixed := name + "." + suffix
   336  		if suffixed[len(suffixed)-1] != '.' {
   337  			suffixed += "."
   338  		}
   339  		names = append(names, suffixed)
   340  	}
   341  	// Try unsuffixed, if not tried first above.
   342  	if count(name, '.') < conf.ndots {
   343  		names = append(names, name+".")
   344  	}
   345  	return names
   346  }
   347  
   348  // hostLookupOrder specifies the order of LookupHost lookup strategies.
   349  // It is basically a simplified representation of nsswitch.conf.
   350  // "files" means /etc/hosts.
   351  type hostLookupOrder int
   352  
   353  const (
   354  	// hostLookupCgo means defer to cgo.
   355  	hostLookupCgo      hostLookupOrder = iota
   356  	hostLookupFilesDNS                 // files first
   357  	hostLookupDNSFiles                 // dns first
   358  	hostLookupFiles                    // only files
   359  	hostLookupDNS                      // only DNS
   360  )
   361  
   362  var lookupOrderName = map[hostLookupOrder]string{
   363  	hostLookupCgo:      "cgo",
   364  	hostLookupFilesDNS: "files,dns",
   365  	hostLookupDNSFiles: "dns,files",
   366  	hostLookupFiles:    "files",
   367  	hostLookupDNS:      "dns",
   368  }
   369  
   370  func (o hostLookupOrder) String() string {
   371  	if s, ok := lookupOrderName[o]; ok {
   372  		return s
   373  	}
   374  	return "hostLookupOrder=" + strconv.Itoa(int(o)) + "??"
   375  }
   376  
   377  // goLookupHost is the native Go implementation of LookupHost.
   378  // Used only if cgoLookupHost refuses to handle the request
   379  // (that is, only if cgoLookupHost is the stub in cgo_stub.go).
   380  // Normally we let cgo use the C library resolver instead of
   381  // depending on our lookup code, so that Go and C get the same
   382  // answers.
   383  func goLookupHost(name string) (addrs []string, err error) {
   384  	return goLookupHostOrder(name, hostLookupFilesDNS)
   385  }
   386  
   387  func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) {
   388  	if order == hostLookupFilesDNS || order == hostLookupFiles {
   389  		// Use entries from /etc/hosts if they match.
   390  		addrs = lookupStaticHost(name)
   391  		if len(addrs) > 0 || order == hostLookupFiles {
   392  			return
   393  		}
   394  	}
   395  	ips, err := goLookupIPOrder(name, order)
   396  	if err != nil {
   397  		return
   398  	}
   399  	addrs = make([]string, 0, len(ips))
   400  	for _, ip := range ips {
   401  		addrs = append(addrs, ip.String())
   402  	}
   403  	return
   404  }
   405  
   406  // lookup entries from /etc/hosts
   407  func goLookupIPFiles(name string) (addrs []IPAddr) {
   408  	for _, haddr := range lookupStaticHost(name) {
   409  		haddr, zone := splitHostZone(haddr)
   410  		if ip := ParseIP(haddr); ip != nil {
   411  			addr := IPAddr{IP: ip, Zone: zone}
   412  			addrs = append(addrs, addr)
   413  		}
   414  	}
   415  	sortByRFC6724(addrs)
   416  	return
   417  }
   418  
   419  // goLookupIP is the native Go implementation of LookupIP.
   420  // The libc versions are in cgo_*.go.
   421  func goLookupIP(name string) (addrs []IPAddr, err error) {
   422  	return goLookupIPOrder(name, hostLookupFilesDNS)
   423  }
   424  
   425  func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) {
   426  	if order == hostLookupFilesDNS || order == hostLookupFiles {
   427  		addrs = goLookupIPFiles(name)
   428  		if len(addrs) > 0 || order == hostLookupFiles {
   429  			return addrs, nil
   430  		}
   431  	}
   432  	if !isDomainName(name) {
   433  		return nil, &DNSError{Err: "invalid domain name", Name: name}
   434  	}
   435  	resolvConf.tryUpdate("/etc/resolv.conf")
   436  	resolvConf.mu.RLock()
   437  	conf := resolvConf.dnsConfig
   438  	resolvConf.mu.RUnlock()
   439  	type racer struct {
   440  		rrs []dnsRR
   441  		error
   442  	}
   443  	lane := make(chan racer, 1)
   444  	qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA}
   445  	var lastErr error
   446  	for _, fqdn := range conf.nameList(name) {
   447  		for _, qtype := range qtypes {
   448  			go func(qtype uint16) {
   449  				_, rrs, err := tryOneName(conf, fqdn, qtype)
   450  				lane <- racer{rrs, err}
   451  			}(qtype)
   452  		}
   453  		for range qtypes {
   454  			racer := <-lane
   455  			if racer.error != nil {
   456  				lastErr = racer.error
   457  				continue
   458  			}
   459  			addrs = append(addrs, addrRecordList(racer.rrs)...)
   460  		}
   461  		if len(addrs) > 0 {
   462  			break
   463  		}
   464  	}
   465  	if lastErr, ok := lastErr.(*DNSError); ok {
   466  		// Show original name passed to lookup, not suffixed one.
   467  		// In general we might have tried many suffixes; showing
   468  		// just one is misleading. See also golang.org/issue/6324.
   469  		lastErr.Name = name
   470  	}
   471  	sortByRFC6724(addrs)
   472  	if len(addrs) == 0 {
   473  		if lastErr != nil {
   474  			return nil, lastErr
   475  		}
   476  		if order == hostLookupDNSFiles {
   477  			addrs = goLookupIPFiles(name)
   478  		}
   479  	}
   480  	return addrs, nil
   481  }
   482  
   483  // goLookupCNAME is the native Go implementation of LookupCNAME.
   484  // Used only if cgoLookupCNAME refuses to handle the request
   485  // (that is, only if cgoLookupCNAME is the stub in cgo_stub.go).
   486  // Normally we let cgo use the C library resolver instead of
   487  // depending on our lookup code, so that Go and C get the same
   488  // answers.
   489  func goLookupCNAME(name string) (cname string, err error) {
   490  	_, rrs, err := lookup(name, dnsTypeCNAME)
   491  	if err != nil {
   492  		return
   493  	}
   494  	cname = rrs[0].(*dnsRR_CNAME).Cname
   495  	return
   496  }
   497  
   498  // goLookupPTR is the native Go implementation of LookupAddr.
   499  // Used only if cgoLookupPTR refuses to handle the request (that is,
   500  // only if cgoLookupPTR is the stub in cgo_stub.go).
   501  // Normally we let cgo use the C library resolver instead of depending
   502  // on our lookup code, so that Go and C get the same answers.
   503  func goLookupPTR(addr string) ([]string, error) {
   504  	names := lookupStaticAddr(addr)
   505  	if len(names) > 0 {
   506  		return names, nil
   507  	}
   508  	arpa, err := reverseaddr(addr)
   509  	if err != nil {
   510  		return nil, err
   511  	}
   512  	_, rrs, err := lookup(arpa, dnsTypePTR)
   513  	if err != nil {
   514  		return nil, err
   515  	}
   516  	ptrs := make([]string, len(rrs))
   517  	for i, rr := range rrs {
   518  		ptrs[i] = rr.(*dnsRR_PTR).Ptr
   519  	}
   520  	return ptrs, nil
   521  }