github.com/mdempsky/go@v0.0.0-20151201204031-5dd372bd1e70/src/net/dnsclient_unix.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build darwin dragonfly freebsd linux netbsd openbsd solaris
     6  
     7  // DNS client: see RFC 1035.
     8  // Has to be linked into package net for Dial.
     9  
    10  // TODO(rsc):
    11  //	Could potentially handle many outstanding lookups faster.
    12  //	Could have a small cache.
    13  //	Random UDP source port (net.Dial should do that for us).
    14  //	Random request IDs.
    15  
    16  package net
    17  
    18  import (
    19  	"errors"
    20  	"io"
    21  	"math/rand"
    22  	"os"
    23  	"sync"
    24  	"time"
    25  )
    26  
    27  // A dnsConn represents a DNS transport endpoint.
    28  type dnsConn interface {
    29  	Conn
    30  
    31  	// readDNSResponse reads a DNS response message from the DNS
    32  	// transport endpoint and returns the received DNS response
    33  	// message.
    34  	readDNSResponse() (*dnsMsg, error)
    35  
    36  	// writeDNSQuery writes a DNS query message to the DNS
    37  	// connection endpoint.
    38  	writeDNSQuery(*dnsMsg) error
    39  }
    40  
    41  func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
    42  	b := make([]byte, 512) // see RFC 1035
    43  	n, err := c.Read(b)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  	msg := &dnsMsg{}
    48  	if !msg.Unpack(b[:n]) {
    49  		return nil, errors.New("cannot unmarshal DNS message")
    50  	}
    51  	return msg, nil
    52  }
    53  
    54  func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
    55  	b, ok := msg.Pack()
    56  	if !ok {
    57  		return errors.New("cannot marshal DNS message")
    58  	}
    59  	if _, err := c.Write(b); err != nil {
    60  		return err
    61  	}
    62  	return nil
    63  }
    64  
    65  func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
    66  	b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
    67  	if _, err := io.ReadFull(c, b[:2]); err != nil {
    68  		return nil, err
    69  	}
    70  	l := int(b[0])<<8 | int(b[1])
    71  	if l > len(b) {
    72  		b = make([]byte, l)
    73  	}
    74  	n, err := io.ReadFull(c, b[:l])
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	msg := &dnsMsg{}
    79  	if !msg.Unpack(b[:n]) {
    80  		return nil, errors.New("cannot unmarshal DNS message")
    81  	}
    82  	return msg, nil
    83  }
    84  
    85  func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
    86  	b, ok := msg.Pack()
    87  	if !ok {
    88  		return errors.New("cannot marshal DNS message")
    89  	}
    90  	l := uint16(len(b))
    91  	b = append([]byte{byte(l >> 8), byte(l)}, b...)
    92  	if _, err := c.Write(b); err != nil {
    93  		return err
    94  	}
    95  	return nil
    96  }
    97  
    98  func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
    99  	switch network {
   100  	case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
   101  	default:
   102  		return nil, UnknownNetworkError(network)
   103  	}
   104  	// Calling Dial here is scary -- we have to be sure not to
   105  	// dial a name that will require a DNS lookup, or Dial will
   106  	// call back here to translate it. The DNS config parser has
   107  	// already checked that all the cfg.servers[i] are IP
   108  	// addresses, which Dial will use without a DNS lookup.
   109  	c, err := d.Dial(network, server)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	switch network {
   114  	case "tcp", "tcp4", "tcp6":
   115  		return c.(*TCPConn), nil
   116  	case "udp", "udp4", "udp6":
   117  		return c.(*UDPConn), nil
   118  	}
   119  	panic("unreachable")
   120  }
   121  
   122  // exchange sends a query on the connection and hopes for a response.
   123  func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
   124  	d := Dialer{Timeout: timeout}
   125  	out := dnsMsg{
   126  		dnsMsgHdr: dnsMsgHdr{
   127  			recursion_desired: true,
   128  		},
   129  		question: []dnsQuestion{
   130  			{name, qtype, dnsClassINET},
   131  		},
   132  	}
   133  	for _, network := range []string{"udp", "tcp"} {
   134  		c, err := d.dialDNS(network, server)
   135  		if err != nil {
   136  			return nil, err
   137  		}
   138  		defer c.Close()
   139  		if timeout > 0 {
   140  			c.SetDeadline(time.Now().Add(timeout))
   141  		}
   142  		out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
   143  		if err := c.writeDNSQuery(&out); err != nil {
   144  			return nil, err
   145  		}
   146  		in, err := c.readDNSResponse()
   147  		if err != nil {
   148  			return nil, err
   149  		}
   150  		if in.id != out.id {
   151  			return nil, errors.New("DNS message ID mismatch")
   152  		}
   153  		if in.truncated { // see RFC 5966
   154  			continue
   155  		}
   156  		return in, nil
   157  	}
   158  	return nil, errors.New("no answer from DNS server")
   159  }
   160  
   161  // Do a lookup for a single name, which must be rooted
   162  // (otherwise answer will not find the answers).
   163  func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
   164  	if len(cfg.servers) == 0 {
   165  		return "", nil, &DNSError{Err: "no DNS servers", Name: name}
   166  	}
   167  	timeout := time.Duration(cfg.timeout) * time.Second
   168  	var lastErr error
   169  	for i := 0; i < cfg.attempts; i++ {
   170  		for _, server := range cfg.servers {
   171  			server = JoinHostPort(server, "53")
   172  			msg, err := exchange(server, name, qtype, timeout)
   173  			if err != nil {
   174  				lastErr = &DNSError{
   175  					Err:    err.Error(),
   176  					Name:   name,
   177  					Server: server,
   178  				}
   179  				if nerr, ok := err.(Error); ok && nerr.Timeout() {
   180  					lastErr.(*DNSError).IsTimeout = true
   181  				}
   182  				continue
   183  			}
   184  			cname, rrs, err := answer(name, server, msg, qtype)
   185  			// If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
   186  			// it means the response in msg was not useful and trying another
   187  			// server probably won't help. Return now in those cases.
   188  			// TODO: indicate this in a more obvious way, such as a field on DNSError?
   189  			if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError {
   190  				return cname, rrs, err
   191  			}
   192  			lastErr = err
   193  		}
   194  	}
   195  	return "", nil, lastErr
   196  }
   197  
   198  // addrRecordList converts and returns a list of IP addresses from DNS
   199  // address records (both A and AAAA). Other record types are ignored.
   200  func addrRecordList(rrs []dnsRR) []IPAddr {
   201  	addrs := make([]IPAddr, 0, 4)
   202  	for _, rr := range rrs {
   203  		switch rr := rr.(type) {
   204  		case *dnsRR_A:
   205  			addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
   206  		case *dnsRR_AAAA:
   207  			ip := make(IP, IPv6len)
   208  			copy(ip, rr.AAAA[:])
   209  			addrs = append(addrs, IPAddr{IP: ip})
   210  		}
   211  	}
   212  	return addrs
   213  }
   214  
   215  // A resolverConfig represents a DNS stub resolver configuration.
   216  type resolverConfig struct {
   217  	initOnce sync.Once // guards init of resolverConfig
   218  
   219  	// ch is used as a semaphore that only allows one lookup at a
   220  	// time to recheck resolv.conf.
   221  	ch          chan struct{} // guards lastChecked and modTime
   222  	lastChecked time.Time     // last time resolv.conf was checked
   223  	modTime     time.Time     // time of resolv.conf modification
   224  
   225  	mu        sync.RWMutex // protects dnsConfig
   226  	dnsConfig *dnsConfig   // parsed resolv.conf structure used in lookups
   227  }
   228  
   229  var resolvConf resolverConfig
   230  
   231  // init initializes conf and is only called via conf.initOnce.
   232  func (conf *resolverConfig) init() {
   233  	// Set dnsConfig, modTime, and lastChecked so we don't parse
   234  	// resolv.conf twice the first time.
   235  	conf.dnsConfig = systemConf().resolv
   236  	if conf.dnsConfig == nil {
   237  		conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
   238  	}
   239  
   240  	if fi, err := os.Stat("/etc/resolv.conf"); err == nil {
   241  		conf.modTime = fi.ModTime()
   242  	}
   243  	conf.lastChecked = time.Now()
   244  
   245  	// Prepare ch so that only one update of resolverConfig may
   246  	// run at once.
   247  	conf.ch = make(chan struct{}, 1)
   248  }
   249  
   250  // tryUpdate tries to update conf with the named resolv.conf file.
   251  // The name variable only exists for testing. It is otherwise always
   252  // "/etc/resolv.conf".
   253  func (conf *resolverConfig) tryUpdate(name string) {
   254  	conf.initOnce.Do(conf.init)
   255  
   256  	// Ensure only one update at a time checks resolv.conf.
   257  	if !conf.tryAcquireSema() {
   258  		return
   259  	}
   260  	defer conf.releaseSema()
   261  
   262  	now := time.Now()
   263  	if conf.lastChecked.After(now.Add(-5 * time.Second)) {
   264  		return
   265  	}
   266  	conf.lastChecked = now
   267  
   268  	if fi, err := os.Stat(name); err == nil {
   269  		if fi.ModTime().Equal(conf.modTime) {
   270  			return
   271  		}
   272  		conf.modTime = fi.ModTime()
   273  	} else {
   274  		// If modTime wasn't set prior, assume nothing has changed.
   275  		if conf.modTime.IsZero() {
   276  			return
   277  		}
   278  		conf.modTime = time.Time{}
   279  	}
   280  
   281  	dnsConf := dnsReadConfig(name)
   282  	conf.mu.Lock()
   283  	conf.dnsConfig = dnsConf
   284  	conf.mu.Unlock()
   285  }
   286  
   287  func (conf *resolverConfig) tryAcquireSema() bool {
   288  	select {
   289  	case conf.ch <- struct{}{}:
   290  		return true
   291  	default:
   292  		return false
   293  	}
   294  }
   295  
   296  func (conf *resolverConfig) releaseSema() {
   297  	<-conf.ch
   298  }
   299  
   300  func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
   301  	if !isDomainName(name) {
   302  		return "", nil, &DNSError{Err: "invalid domain name", Name: name}
   303  	}
   304  	resolvConf.tryUpdate("/etc/resolv.conf")
   305  	resolvConf.mu.RLock()
   306  	conf := resolvConf.dnsConfig
   307  	resolvConf.mu.RUnlock()
   308  	for _, fqdn := range conf.nameList(name) {
   309  		cname, rrs, err = tryOneName(conf, fqdn, qtype)
   310  		if err == nil {
   311  			break
   312  		}
   313  	}
   314  	if err, ok := err.(*DNSError); ok {
   315  		// Show original name passed to lookup, not suffixed one.
   316  		// In general we might have tried many suffixes; showing
   317  		// just one is misleading. See also golang.org/issue/6324.
   318  		err.Name = name
   319  	}
   320  	return
   321  }
   322  
   323  // nameList returns a list of names for sequential DNS queries.
   324  func (conf *dnsConfig) nameList(name string) []string {
   325  	// If name is rooted (trailing dot), try only that name.
   326  	rooted := len(name) > 0 && name[len(name)-1] == '.'
   327  	if rooted {
   328  		return []string{name}
   329  	}
   330  	// Build list of search choices.
   331  	names := make([]string, 0, 1+len(conf.search))
   332  	// If name has enough dots, try unsuffixed first.
   333  	if count(name, '.') >= conf.ndots {
   334  		names = append(names, name+".")
   335  	}
   336  	// Try suffixes.
   337  	for _, suffix := range conf.search {
   338  		suffixed := name + "." + suffix
   339  		if suffixed[len(suffixed)-1] != '.' {
   340  			suffixed += "."
   341  		}
   342  		names = append(names, suffixed)
   343  	}
   344  	// Try unsuffixed, if not tried first above.
   345  	if count(name, '.') < conf.ndots {
   346  		names = append(names, name+".")
   347  	}
   348  	return names
   349  }
   350  
   351  // hostLookupOrder specifies the order of LookupHost lookup strategies.
   352  // It is basically a simplified representation of nsswitch.conf.
   353  // "files" means /etc/hosts.
   354  type hostLookupOrder int
   355  
   356  const (
   357  	// hostLookupCgo means defer to cgo.
   358  	hostLookupCgo      hostLookupOrder = iota
   359  	hostLookupFilesDNS                 // files first
   360  	hostLookupDNSFiles                 // dns first
   361  	hostLookupFiles                    // only files
   362  	hostLookupDNS                      // only DNS
   363  )
   364  
   365  var lookupOrderName = map[hostLookupOrder]string{
   366  	hostLookupCgo:      "cgo",
   367  	hostLookupFilesDNS: "files,dns",
   368  	hostLookupDNSFiles: "dns,files",
   369  	hostLookupFiles:    "files",
   370  	hostLookupDNS:      "dns",
   371  }
   372  
   373  func (o hostLookupOrder) String() string {
   374  	if s, ok := lookupOrderName[o]; ok {
   375  		return s
   376  	}
   377  	return "hostLookupOrder=" + itoa(int(o)) + "??"
   378  }
   379  
   380  // goLookupHost is the native Go implementation of LookupHost.
   381  // Used only if cgoLookupHost refuses to handle the request
   382  // (that is, only if cgoLookupHost is the stub in cgo_stub.go).
   383  // Normally we let cgo use the C library resolver instead of
   384  // depending on our lookup code, so that Go and C get the same
   385  // answers.
   386  func goLookupHost(name string) (addrs []string, err error) {
   387  	return goLookupHostOrder(name, hostLookupFilesDNS)
   388  }
   389  
   390  func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) {
   391  	if order == hostLookupFilesDNS || order == hostLookupFiles {
   392  		// Use entries from /etc/hosts if they match.
   393  		addrs = lookupStaticHost(name)
   394  		if len(addrs) > 0 || order == hostLookupFiles {
   395  			return
   396  		}
   397  	}
   398  	ips, err := goLookupIPOrder(name, order)
   399  	if err != nil {
   400  		return
   401  	}
   402  	addrs = make([]string, 0, len(ips))
   403  	for _, ip := range ips {
   404  		addrs = append(addrs, ip.String())
   405  	}
   406  	return
   407  }
   408  
   409  // lookup entries from /etc/hosts
   410  func goLookupIPFiles(name string) (addrs []IPAddr) {
   411  	for _, haddr := range lookupStaticHost(name) {
   412  		haddr, zone := splitHostZone(haddr)
   413  		if ip := ParseIP(haddr); ip != nil {
   414  			addr := IPAddr{IP: ip, Zone: zone}
   415  			addrs = append(addrs, addr)
   416  		}
   417  	}
   418  	sortByRFC6724(addrs)
   419  	return
   420  }
   421  
   422  // goLookupIP is the native Go implementation of LookupIP.
   423  // The libc versions are in cgo_*.go.
   424  func goLookupIP(name string) (addrs []IPAddr, err error) {
   425  	return goLookupIPOrder(name, hostLookupFilesDNS)
   426  }
   427  
   428  func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) {
   429  	if order == hostLookupFilesDNS || order == hostLookupFiles {
   430  		addrs = goLookupIPFiles(name)
   431  		if len(addrs) > 0 || order == hostLookupFiles {
   432  			return addrs, nil
   433  		}
   434  	}
   435  	if !isDomainName(name) {
   436  		return nil, &DNSError{Err: "invalid domain name", Name: name}
   437  	}
   438  	resolvConf.tryUpdate("/etc/resolv.conf")
   439  	resolvConf.mu.RLock()
   440  	conf := resolvConf.dnsConfig
   441  	resolvConf.mu.RUnlock()
   442  	type racer struct {
   443  		rrs []dnsRR
   444  		error
   445  	}
   446  	lane := make(chan racer, 1)
   447  	qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA}
   448  	var lastErr error
   449  	for _, fqdn := range conf.nameList(name) {
   450  		for _, qtype := range qtypes {
   451  			go func(qtype uint16) {
   452  				_, rrs, err := tryOneName(conf, fqdn, qtype)
   453  				lane <- racer{rrs, err}
   454  			}(qtype)
   455  		}
   456  		for range qtypes {
   457  			racer := <-lane
   458  			if racer.error != nil {
   459  				lastErr = racer.error
   460  				continue
   461  			}
   462  			addrs = append(addrs, addrRecordList(racer.rrs)...)
   463  		}
   464  		if len(addrs) > 0 {
   465  			break
   466  		}
   467  	}
   468  	if lastErr, ok := lastErr.(*DNSError); ok {
   469  		// Show original name passed to lookup, not suffixed one.
   470  		// In general we might have tried many suffixes; showing
   471  		// just one is misleading. See also golang.org/issue/6324.
   472  		lastErr.Name = name
   473  	}
   474  	sortByRFC6724(addrs)
   475  	if len(addrs) == 0 {
   476  		if lastErr != nil {
   477  			return nil, lastErr
   478  		}
   479  		if order == hostLookupDNSFiles {
   480  			addrs = goLookupIPFiles(name)
   481  		}
   482  	}
   483  	return addrs, nil
   484  }
   485  
   486  // goLookupCNAME is the native Go implementation of LookupCNAME.
   487  // Used only if cgoLookupCNAME refuses to handle the request
   488  // (that is, only if cgoLookupCNAME is the stub in cgo_stub.go).
   489  // Normally we let cgo use the C library resolver instead of
   490  // depending on our lookup code, so that Go and C get the same
   491  // answers.
   492  func goLookupCNAME(name string) (cname string, err error) {
   493  	_, rrs, err := lookup(name, dnsTypeCNAME)
   494  	if err != nil {
   495  		return
   496  	}
   497  	cname = rrs[0].(*dnsRR_CNAME).Cname
   498  	return
   499  }
   500  
   501  // goLookupPTR is the native Go implementation of LookupAddr.
   502  // Used only if cgoLookupPTR refuses to handle the request (that is,
   503  // only if cgoLookupPTR is the stub in cgo_stub.go).
   504  // Normally we let cgo use the C library resolver instead of depending
   505  // on our lookup code, so that Go and C get the same answers.
   506  func goLookupPTR(addr string) ([]string, error) {
   507  	names := lookupStaticAddr(addr)
   508  	if len(names) > 0 {
   509  		return names, nil
   510  	}
   511  	arpa, err := reverseaddr(addr)
   512  	if err != nil {
   513  		return nil, err
   514  	}
   515  	_, rrs, err := lookup(arpa, dnsTypePTR)
   516  	if err != nil {
   517  		return nil, err
   518  	}
   519  	ptrs := make([]string, len(rrs))
   520  	for i, rr := range rrs {
   521  		ptrs[i] = rr.(*dnsRR_PTR).Ptr
   522  	}
   523  	return ptrs, nil
   524  }