github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/client/rootd/dns/server.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"slices"
     9  	"strings"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/miekg/dns"
    15  	"github.com/puzpuzpuz/xsync/v3"
    16  	"golang.org/x/exp/maps"
    17  	"google.golang.org/protobuf/types/known/durationpb"
    18  
    19  	"github.com/datawire/dlib/dcontext"
    20  	"github.com/datawire/dlib/dgroup"
    21  	"github.com/datawire/dlib/dlog"
    22  	"github.com/datawire/dlib/dtime"
    23  	rpc "github.com/telepresenceio/telepresence/rpc/v2/daemon"
    24  	"github.com/telepresenceio/telepresence/rpc/v2/manager"
    25  	"github.com/telepresenceio/telepresence/v2/pkg/client"
    26  	"github.com/telepresenceio/telepresence/v2/pkg/dnsproxy"
    27  	"github.com/telepresenceio/telepresence/v2/pkg/iputil"
    28  	"github.com/telepresenceio/telepresence/v2/pkg/slice"
    29  	"github.com/telepresenceio/telepresence/v2/pkg/vif"
    30  )
    31  
    32  type Resolver func(context.Context, *dns.Question) (dnsproxy.RRs, int, error)
    33  
    34  const (
    35  	// defaultClusterDomain used unless traffic-manager reports otherwise.
    36  	defaultClusterDomain = "cluster.local."
    37  
    38  	// sanityCheck is the query used when verifying that a DNS query reaches our DNS server. It should result
    39  	// in an increase of the requestCount but always yield an NXDOMAIN reply.
    40  	santiyCheck    = "jhfweoitnkgyeta." + tel2SubDomain
    41  	santiyCheckDot = santiyCheck + "."
    42  
    43  	// dnsTTL is the number of seconds that a found DNS record should be allowed to live in the callers cache. We
    44  	// keep this low to avoid such caching.
    45  	dnsTTL = 4
    46  )
    47  
    48  type FallbackPool interface {
    49  	Exchange(context.Context, *dns.Client, *dns.Msg) (*dns.Msg, time.Duration, error)
    50  	RemoteAddr() string
    51  	LocalAddrs() []*net.UDPAddr
    52  	Close()
    53  }
    54  
    55  const (
    56  	_ = int32(iota)
    57  	recursionQueryNotYetReceived
    58  	recursionQueryReceived
    59  	recursionNotDetected
    60  	recursionDetected
    61  )
    62  
    63  //nolint:gochecknoglobals // constant
    64  var DefaultExcludeSuffixes = []string{
    65  	".com",
    66  	".io",
    67  	".net",
    68  	".org",
    69  	".ru",
    70  }
    71  
    72  type nsAndDomains struct {
    73  	domains   []string
    74  	namespace string
    75  }
    76  
    77  // Server is a DNS server which implements the github.com/miekg/dns Handler interface.
    78  type Server struct {
    79  	sync.RWMutex
    80  	ctx          context.Context // necessary to make logging work in ServeDNS function
    81  	fallbackPool FallbackPool
    82  	resolve      Resolver
    83  	requestCount int64
    84  	cache        *xsync.MapOf[cacheKey, *cacheEntry]
    85  	recursive    int32 // one of the recursionXXX constants declared above (unique type avoided because it just gets messy with the atomic calls)
    86  
    87  	// Suffixes to immediately drop from the query before processing. This list will always contain the tel2Search domain.
    88  	// The overriding resolver will also add the search path found in /etc/resolv.conf, because that search path is not
    89  	// intended for this resolver and will get reapplied when passing things on to the fallback resolver.
    90  	dropSuffixes []string
    91  
    92  	// routes are typically namespaces, accessible using <service-name>.<namespace-name>.
    93  	routes map[string]struct{}
    94  
    95  	// search are appended to a query to form new names that are then dispatched to the
    96  	// DNS resolver. The act of appending is not performed by this server, but rather
    97  	// by the system's DNS resolver before calling on this server.
    98  	search []string
    99  
   100  	// domains contains the sum of the include-suffixes and routes. It is currently only
   101  	// used by the darwin resolver to keep track of files to add or remove.
   102  	domains map[string]struct{}
   103  
   104  	// nsAndDomainsCh receives requests to change the top level domains and the search path.
   105  	nsAndDomainsCh chan nsAndDomains
   106  
   107  	includeSuffixes []string
   108  
   109  	excludeSuffixes []string
   110  
   111  	excludes []string
   112  	mappings map[string]string
   113  
   114  	lookupTimeout time.Duration
   115  
   116  	localIP  net.IP
   117  	remoteIP net.IP
   118  
   119  	// clusterDomain reported by the traffic-manager
   120  	clusterDomain string
   121  
   122  	// Function that sends a lookup request to the traffic-manager
   123  	clusterLookup Resolver
   124  
   125  	error string
   126  
   127  	// ready is closed when the DNS server is fully configured
   128  	ready chan struct{}
   129  }
   130  
   131  type cacheEntry struct {
   132  	created      time.Time
   133  	currentQType int32 // will be set to the current qType during call to cluster
   134  	answer       dnsproxy.RRs
   135  	rCode        int
   136  	wait         chan struct{}
   137  }
   138  
   139  // cacheTTL is the time to live for an entry in the local DNS cache.
   140  const cacheTTL = 60 * time.Second
   141  
   142  func (dv *cacheEntry) expired() bool {
   143  	return time.Since(dv.created) > cacheTTL
   144  }
   145  
   146  func (dv *cacheEntry) close() {
   147  	select {
   148  	case <-dv.wait:
   149  	default:
   150  		close(dv.wait)
   151  	}
   152  }
   153  
   154  func sliceToLower(ss []string) []string {
   155  	for i, s := range ss {
   156  		ss[i] = strings.ToLower(s)
   157  	}
   158  	return ss
   159  }
   160  
   161  // NewServer returns a new dns.Server.
   162  func NewServer(config *rpc.DNSConfig, clusterLookup Resolver) *Server {
   163  	if config == nil {
   164  		config = &rpc.DNSConfig{}
   165  	}
   166  	if len(config.ExcludeSuffixes) == 0 {
   167  		config.ExcludeSuffixes = DefaultExcludeSuffixes
   168  	}
   169  	if config.LookupTimeout.AsDuration() <= 0 {
   170  		config.LookupTimeout = durationpb.New(8 * time.Second)
   171  	}
   172  	s := &Server{
   173  		cache:           xsync.NewMapOf[cacheKey, *cacheEntry](),
   174  		routes:          make(map[string]struct{}),
   175  		domains:         make(map[string]struct{}),
   176  		excludes:        sliceToLower(config.Excludes),
   177  		excludeSuffixes: sliceToLower(config.ExcludeSuffixes),
   178  		includeSuffixes: sliceToLower(config.IncludeSuffixes),
   179  		mappings:        mappingsMap(config.Mappings),
   180  		localIP:         config.LocalIp,
   181  		remoteIP:        config.RemoteIp,
   182  		dropSuffixes:    []string{tel2SubDomainDot},
   183  		search:          []string{tel2SubDomain},
   184  		nsAndDomainsCh:  make(chan nsAndDomains, 5),
   185  		clusterDomain:   defaultClusterDomain,
   186  		clusterLookup:   clusterLookup,
   187  		ready:           make(chan struct{}),
   188  	}
   189  	if lt := config.LookupTimeout; lt != nil {
   190  		s.lookupTimeout = lt.AsDuration()
   191  	}
   192  	return s
   193  }
   194  
   195  // tel2SubDomain helps differentiate between single label and qualified DNS queries.
   196  //
   197  // Dealing with single label names is tricky because what we really want is to receive the
   198  // name and then forward it verbatim to the DNS resolver in the cluster so that it can
   199  // add whatever search paths to it that it sees fit, but in order to receive single name
   200  // queries in the first place, our DNS resolver must have a search path that adds a domain
   201  // that the DNS system knows that we will handle.
   202  //
   203  // Example flow:
   204  // The user queries for the name "alpha". The DNS system on the host tries the search path
   205  // of our DNS resolver which contains "tel2-search" and creates the name "alpha.tel2-search".
   206  // The DNS system now discovers that our DNS resolver handles that domain, so we receive
   207  // the query. We then strip the "tel2-search" and send the original single label name to the
   208  // cluster, and we add it back before we forward the reply.
   209  const (
   210  	tel2SubDomain    = "tel2-search"
   211  	tel2SubDomainDot = tel2SubDomain + "."
   212  )
   213  
   214  // wpadDot is used when rejecting all WPAD (Wep Proxy Auto-Discovery) queries.
   215  const wpadDot = "wpad."
   216  
   217  var (
   218  	localhostIPv4 = net.IP{127, 0, 0, 1}                                   //nolint:gochecknoglobals // constant
   219  	localhostIPv6 = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} //nolint:gochecknoglobals // constant
   220  )
   221  
   222  func (s *Server) shouldDoClusterLookup(query string) bool {
   223  	name := query[:len(query)-1] // skip last dot
   224  	if strings.HasPrefix(query, wpadDot) {
   225  		// Reject "wpad.*"
   226  		dlog.Debugf(s.ctx, `Cluster DNS excluded by exclude-prefix "wpad." for name %q`, name)
   227  		return false
   228  	}
   229  
   230  	if s.isExcluded(name) {
   231  		// Reject any host explicitly added to the exclude list.
   232  		dlog.Debugf(s.ctx, "Cluster DNS explicitly excluded for name %q", name)
   233  		return false
   234  	}
   235  
   236  	if !strings.ContainsRune(name, '.') {
   237  		// Single label names are always included.
   238  		dlog.Debugf(s.ctx, "Cluster DNS included for single label name %q", name)
   239  		return true
   240  	}
   241  
   242  	// Skip configured exclude-suffixes unless also matched by an include-suffix
   243  	// that is longer (i.e. more specific).
   244  	for _, es := range s.excludeSuffixes {
   245  		if strings.HasSuffix(name, es) {
   246  			// Exclude unless more specific include.
   247  			for _, is := range s.includeSuffixes {
   248  				if len(is) >= len(es) && strings.HasSuffix(name, is) {
   249  					dlog.Debugf(s.ctx,
   250  						"Cluster DNS included by include-suffix %q (overriding exclude-suffix %q) for name %q", is, es, name)
   251  					return true
   252  				}
   253  			}
   254  			dlog.Debugf(s.ctx, "Cluster DNS excluded by exclude-suffix %q for name %q", es, name)
   255  			return false
   256  		}
   257  	}
   258  
   259  	// Always include configured search paths
   260  	for _, sfx := range s.search {
   261  		if strings.HasSuffix(name, sfx) {
   262  			dlog.Debugf(s.ctx, "Cluster DNS included by search %q of name %q", sfx, name)
   263  			return true
   264  		}
   265  	}
   266  
   267  	// Always include configured routes
   268  	for sfx := range s.routes {
   269  		if strings.HasSuffix(name, sfx) {
   270  			dlog.Debugf(s.ctx, "Cluster DNS included by namespace %q of name %q", sfx, name)
   271  			return true
   272  		}
   273  	}
   274  
   275  	// Always include queries for the cluster domain.
   276  	if strings.HasSuffix(query, "."+s.clusterDomain) {
   277  		dlog.Debugf(s.ctx, "Cluster DNS included by cluster domain %q of name %q", s.clusterDomain, name)
   278  		return true
   279  	}
   280  
   281  	// Always include configured includeSuffixes
   282  	for _, sfx := range s.includeSuffixes {
   283  		if strings.HasSuffix(name, sfx) {
   284  			dlog.Debugf(s.ctx,
   285  				"Cluster DNS included by include-suffix %q for name %q", sfx, name)
   286  			return true
   287  		}
   288  	}
   289  
   290  	// Pass any queries for the cluster domain.
   291  	dlog.Debugf(s.ctx, "Cluster DNS excluded for name %q. No inclusion rule was matched", name)
   292  	return false
   293  }
   294  
   295  func (s *Server) isExcluded(name string) bool {
   296  	if slice.Contains(s.excludes, name) {
   297  		return true
   298  	}
   299  
   300  	// When intercepting, this function will potentially receive the hostname of any search param, so their
   301  	// unqualified hostname should be evaluated too.
   302  	qLen := len(name)
   303  	for _, sp := range s.search {
   304  		if strings.HasSuffix(name, "."+sp) && slice.Contains(s.excludes, name[:qLen-len(sp)-1]) {
   305  			return true
   306  		}
   307  	}
   308  	return false
   309  }
   310  
   311  func (s *Server) isDomainExcluded(name string) bool {
   312  	return slices.Contains(s.excludeSuffixes, "."+name)
   313  }
   314  
   315  func (s *Server) resolveInCluster(c context.Context, q *dns.Question) (result dnsproxy.RRs, rCode int, err error) {
   316  	query := q.Name
   317  	if query == "localhost." {
   318  		// BUG(lukeshu): I have no idea why a lookup
   319  		// for localhost even makes it to here on my
   320  		// home WiFi when connecting to a k3sctl
   321  		// cluster (but not a kubernaut.io cluster).
   322  		// But it does, so I need this in order to be
   323  		// productive at home.  We should really
   324  		// root-cause this, because it's weird.
   325  		hdr := dns.RR_Header{
   326  			Name:   q.Name,
   327  			Rrtype: q.Qtype,
   328  			Class:  q.Qclass,
   329  		}
   330  		switch q.Qtype {
   331  		case dns.TypeA:
   332  			return dnsproxy.RRs{&dns.A{
   333  				Hdr: hdr,
   334  				A:   localhostIPv4,
   335  			}}, dns.RcodeSuccess, nil
   336  		case dns.TypeAAAA:
   337  			return dnsproxy.RRs{&dns.AAAA{
   338  				Hdr:  hdr,
   339  				AAAA: localhostIPv6,
   340  			}}, dns.RcodeSuccess, nil
   341  		default:
   342  			return nil, dns.RcodeNameError, nil
   343  		}
   344  	}
   345  
   346  	if !s.shouldDoClusterLookup(query) {
   347  		return nil, dns.RcodeNameError, nil
   348  	}
   349  
   350  	// Give the cluster lookup a reasonable timeout.
   351  	c, cancel := context.WithTimeout(c, s.lookupTimeout)
   352  	defer cancel()
   353  
   354  	result, rCode, err = s.clusterLookup(c, q)
   355  	if err != nil {
   356  		return nil, rCode, client.CheckTimeout(c, err)
   357  	}
   358  
   359  	// Keep the TTLs of requests resolved in the cluster low. We
   360  	// cache them locally anyway, but our cache is flushed when things are
   361  	// intercepted or the namespaces change.
   362  	for _, rr := range result {
   363  		if h := rr.Header(); h != nil {
   364  			h.Ttl = dnsTTL
   365  		}
   366  	}
   367  	return result, rCode, nil
   368  }
   369  
   370  func (s *Server) GetConfig() *rpc.DNSConfig {
   371  	s.RLock()
   372  	c := rpc.DNSConfig{
   373  		LocalIp:         s.localIP,
   374  		RemoteIp:        s.remoteIP,
   375  		ExcludeSuffixes: s.excludeSuffixes,
   376  		IncludeSuffixes: s.includeSuffixes,
   377  		Excludes:        s.excludes,
   378  		Error:           s.error,
   379  	}
   380  	if s.lookupTimeout != 0 {
   381  		c.LookupTimeout = durationpb.New(s.lookupTimeout)
   382  	}
   383  	if len(s.mappings) > 0 {
   384  		ns := maps.Keys(s.mappings)
   385  		slices.Sort(ns)
   386  		c.Mappings = make([]*rpc.DNSMapping, len(s.mappings))
   387  		for i, n := range ns {
   388  			c.Mappings[i] = &rpc.DNSMapping{
   389  				Name:     strings.TrimSuffix(n, "."),
   390  				AliasFor: strings.TrimSuffix(s.mappings[n], "."),
   391  			}
   392  		}
   393  	}
   394  	s.RUnlock()
   395  	return &c
   396  }
   397  
   398  func (s *Server) Ready() <-chan struct{} {
   399  	return s.ready
   400  }
   401  
   402  func (s *Server) Stop() {
   403  	// Close s.ready unless it's already closed
   404  	select {
   405  	case <-s.ready:
   406  	default:
   407  		close(s.ready)
   408  	}
   409  }
   410  
   411  func (s *Server) SetClusterDNS(dns *manager.DNS, remoteIP net.IP) {
   412  	s.Lock()
   413  	if s.remoteIP == nil {
   414  		s.remoteIP = remoteIP
   415  	}
   416  	if dns != nil {
   417  		if slices.Equal(s.excludeSuffixes, DefaultExcludeSuffixes) && len(dns.ExcludeSuffixes) > 0 {
   418  			s.excludeSuffixes = sliceToLower(dns.ExcludeSuffixes)
   419  		}
   420  		if len(s.includeSuffixes) == 0 {
   421  			s.includeSuffixes = sliceToLower(dns.IncludeSuffixes)
   422  		}
   423  		s.clusterDomain = strings.ToLower(dns.ClusterDomain)
   424  	}
   425  	s.Unlock()
   426  }
   427  
   428  // SetTopLevelDomainsAndSearchPath updates the DNS top level domains and the search path used by the resolver.
   429  func (s *Server) SetTopLevelDomainsAndSearchPath(ctx context.Context, domains []string, namespace string) {
   430  	das := nsAndDomains{
   431  		domains:   domains,
   432  		namespace: namespace,
   433  	}
   434  	select {
   435  	case <-ctx.Done():
   436  	case s.nsAndDomainsCh <- das:
   437  	}
   438  }
   439  
   440  func (s *Server) purgeRecordsFromCache(keyName string) {
   441  	keyName = strings.TrimSuffix(keyName, ".") + "."
   442  	for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
   443  		toDeleteKey := cacheKey{name: keyName, qType: qType}
   444  		if old, ok := s.cache.LoadAndDelete(toDeleteKey); ok {
   445  			old.close()
   446  		}
   447  	}
   448  }
   449  
   450  // SetExcludes sets the excludes list in the config.
   451  func (s *Server) SetExcludes(excludes []string) {
   452  	for i, e := range excludes {
   453  		excludes[i] = strings.ToLower(e)
   454  	}
   455  	s.Lock()
   456  	oldExcludes := s.excludes
   457  	s.excludes = excludes
   458  	s.Unlock()
   459  
   460  	for _, e := range slice.AppendUnique(oldExcludes, excludes...) {
   461  		s.purgeRecordsFromCache(e)
   462  	}
   463  }
   464  
   465  func mappingsMap(mappings []*rpc.DNSMapping) map[string]string {
   466  	if l := len(mappings); l > 0 {
   467  		mm := make(map[string]string, l)
   468  		for _, m := range mappings {
   469  			al := m.AliasFor
   470  			if ip := iputil.Parse(al); ip == nil {
   471  				al += "."
   472  			}
   473  			mm[strings.ToLower(m.Name+".")] = strings.ToLower(al)
   474  		}
   475  		return mm
   476  	}
   477  	return nil
   478  }
   479  
   480  // SetMappings sets the Mappings list in the config.
   481  func (s *Server) SetMappings(mappings []*rpc.DNSMapping) {
   482  	mm := mappingsMap(mappings)
   483  	s.Lock()
   484  	s.mappings = mm
   485  	s.Unlock()
   486  
   487  	// Flush the mappings.
   488  	for n := range mm {
   489  		s.purgeRecordsFromCache(n)
   490  	}
   491  }
   492  
   493  func newLocalUDPListener(c context.Context) (net.PacketConn, error) {
   494  	lc := &net.ListenConfig{}
   495  	return lc.ListenPacket(c, "udp", "127.0.0.1:0")
   496  }
   497  
   498  func (s *Server) processSearchPaths(g *dgroup.Group, processor func(context.Context, vif.Device) error, dev vif.Device) {
   499  	g.Go("SearchPaths", func(c context.Context) error {
   500  		s.performRecursionCheck(c)
   501  		prevDas := nsAndDomains{
   502  			domains:   []string{},
   503  			namespace: "",
   504  		}
   505  		unchanged := func(das nsAndDomains) bool {
   506  			return das.namespace == prevDas.namespace && slices.Equal(das.domains, prevDas.domains)
   507  		}
   508  
   509  		for {
   510  			select {
   511  			case <-c.Done():
   512  				return nil
   513  			case das := <-s.nsAndDomainsCh:
   514  				// Only interested in the last one, and only if it differs
   515  				if len(s.nsAndDomainsCh) > 0 || unchanged(das) {
   516  					continue
   517  				}
   518  				prevDas = das
   519  
   520  				routes := make(map[string]struct{}, len(das.domains))
   521  				for _, domain := range das.domains {
   522  					if domain != "" && !s.isDomainExcluded(domain) {
   523  						routes[domain] = struct{}{}
   524  					}
   525  				}
   526  				if !s.isDomainExcluded("svc") {
   527  					routes["svc"] = struct{}{}
   528  				}
   529  				s.Lock()
   530  				s.routes = routes
   531  
   532  				// The connected namespace must be included as a search path for the cases
   533  				// where it's up to the traffic-manager to resolve. It cannot resolve a single
   534  				// label name intended for other namespaces.
   535  				s.search = []string{tel2SubDomain, das.namespace}
   536  				s.Unlock()
   537  
   538  				if err := processor(c, dev); err != nil {
   539  					return err
   540  				}
   541  			}
   542  		}
   543  	})
   544  }
   545  
   546  func (s *Server) flushDNS() {
   547  	s.cache.Range(func(key cacheKey, _ *cacheEntry) bool {
   548  		if old, ok := s.cache.LoadAndDelete(key); ok {
   549  			old.close()
   550  		}
   551  		return true
   552  	})
   553  }
   554  
   555  // splitToUDPAddr splits the given address into an UDPAddr. It's
   556  // an  error if the address is based on a hostname rather than an IP.
   557  func splitToUDPAddr(netAddr net.Addr) (*net.UDPAddr, error) {
   558  	ip, port, err := iputil.SplitToIPPort(netAddr)
   559  	if err != nil {
   560  		return nil, err
   561  	}
   562  	return &net.UDPAddr{IP: ip, Port: int(port)}, nil
   563  }
   564  
   565  // RequestCount returns the number of requests that this server has received.
   566  func (s *Server) RequestCount() int {
   567  	return int(atomic.LoadInt64(&s.requestCount))
   568  }
   569  
   570  func copyRRs(rrs dnsproxy.RRs, qTypes []uint16) dnsproxy.RRs {
   571  	if len(rrs) == 0 {
   572  		return rrs
   573  	}
   574  	cp := make(dnsproxy.RRs, 0, len(rrs))
   575  	for _, rr := range rrs {
   576  		if slice.Contains(qTypes, rr.Header().Rrtype) {
   577  			cp = append(cp, dns.Copy(rr))
   578  		}
   579  	}
   580  	return cp
   581  }
   582  
   583  type cacheKey struct {
   584  	name  string
   585  	qType uint16
   586  }
   587  
   588  func (c *cacheKey) String() string {
   589  	return fmt.Sprintf("%s %s", dns.TypeToString[c.qType], c.name)
   590  }
   591  
   592  const (
   593  	// recursionCheck is a special host name in a well known namespace that isn't expected to exist. It
   594  	// is used once for determining if the cluster's DNS resolver will call the Telepresence DNS resolver
   595  	// recursively. This is common when the cluster is running on the local host (k3s in docker for instance).
   596  	//
   597  	// The check is performed using the following steps.
   598  	// 1. A lookup is made for "tel-recursion-check
   599  	// 2. When our DNS-resolver receives this lookup, it modifies the name to "tel2-recursion-check.kube-system."
   600  	//    and sends that on to the cluster.
   601  	// 3. If our DNS-resolver now encounters a query for the "tel2-recursion-check.kube-system.", then we know
   602  	//    that a recursion took place.
   603  	// 4. If no request for "tel2-recursion-check.kube-system." is received, then it's assumed that the resolver
   604  	//    is not recursive.
   605  	recursionCheck  = "tel2-recursion-check."
   606  	recursionCheck2 = "tel2-recursion-check.kube-system."
   607  )
   608  
   609  func (s *Server) resolveWithRecursionCheck(q *dns.Question) (dnsproxy.RRs, int, error) {
   610  	if strings.HasPrefix(q.Name, recursionCheck) {
   611  		if strings.HasPrefix(q.Name, recursionCheck2) {
   612  			if atomic.CompareAndSwapInt32(&s.recursive, recursionQueryReceived, recursionDetected) {
   613  				dlog.Debug(s.ctx, "DNS resolver is recursive")
   614  			}
   615  			return nil, dns.RcodeNameError, nil
   616  		}
   617  
   618  		if atomic.CompareAndSwapInt32(&s.recursive, recursionQueryNotYetReceived, recursionQueryReceived) {
   619  			tc, cancel := context.WithTimeout(s.ctx, recursionTestTimeout)
   620  			go func() {
   621  				defer cancel()
   622  				nq := *q // by value copy
   623  				nq.Name = recursionCheck2
   624  				_, _, _ = s.resolveInCluster(s.ctx, &nq) // We really don't care about the reply here.
   625  			}()
   626  			<-tc.Done()
   627  
   628  			// When we've gotten the reply from the cluster, we know if recursion did occur.
   629  			if atomic.CompareAndSwapInt32(&s.recursive, recursionQueryReceived, recursionNotDetected) {
   630  				dlog.Debug(s.ctx, "DNS resolver is not recursive")
   631  			}
   632  		}
   633  		return localHostReply(q), dns.RcodeSuccess, nil
   634  	}
   635  
   636  	answer, rCode, err := s.resolveThruCache(q)
   637  	if err != nil || rCode != dns.RcodeSuccess {
   638  		// For A and AAAA queries, we check if we have a successful counterpart in the cache. If we
   639  		// do, then this query must return NOERROR EMPTY
   640  		ck := cacheKey{name: q.Name, qType: dns.TypeNone}
   641  		switch q.Qtype {
   642  		case dns.TypeA:
   643  			ck.qType = dns.TypeAAAA
   644  		case dns.TypeAAAA:
   645  			ck.qType = dns.TypeA
   646  		}
   647  		if ck.qType != dns.TypeNone {
   648  			if ce, ok := s.cache.Load(ck); ok {
   649  				<-ce.wait
   650  				if !ce.expired() && ce.rCode == dns.RcodeSuccess && atomic.LoadInt32(&ce.currentQType) == int32(ck.qType) {
   651  					dlog.Debugf(s.ctx, "found counterpart for %s %s", dns.TypeToString[uint16(ce.currentQType)], ce.answer)
   652  					err = nil
   653  					rCode = dns.RcodeSuccess
   654  				}
   655  			}
   656  		}
   657  	}
   658  	return answer, rCode, err
   659  }
   660  
   661  // resolveThruCache resolves the given query by first performing a cache lookup. If a cached
   662  // entry is found that hasn't expired, it's returned. If not, this function will call
   663  // resolveQuery() to resolve and store in the case.
   664  func (s *Server) resolveThruCache(q *dns.Question) (answer dnsproxy.RRs, rCode int, err error) {
   665  	dv := &cacheEntry{wait: make(chan struct{}), created: time.Now()}
   666  	key := cacheKey{name: q.Name, qType: q.Qtype}
   667  	if oldDv, loaded := s.cache.LoadOrStore(key, dv); loaded {
   668  		if atomic.LoadInt32(&s.recursive) == recursionDetected && atomic.LoadInt32(&oldDv.currentQType) == int32(q.Qtype) {
   669  			// We have to assume that this is a recursion from the cluster.
   670  			dlog.Debugf(s.ctx, "returning error for query %q: assumed to be recursive", key.String())
   671  			return nil, dns.RcodeNameError, nil
   672  		}
   673  		<-oldDv.wait
   674  		if !oldDv.expired() {
   675  			qTypes := []uint16{q.Qtype}
   676  			if q.Qtype != dns.TypeCNAME {
   677  				// Allow additional CNAME records if they are present.
   678  				for _, rr := range oldDv.answer {
   679  					if rr.Header().Rrtype == dns.TypeCNAME {
   680  						qTypes = append(qTypes, dns.TypeCNAME)
   681  						break
   682  					}
   683  				}
   684  			}
   685  			return copyRRs(oldDv.answer, qTypes), oldDv.rCode, nil
   686  		}
   687  		s.cache.Store(key, dv)
   688  	}
   689  
   690  	atomic.StoreInt32(&dv.currentQType, int32(q.Qtype))
   691  	defer func() {
   692  		if rCode != dns.RcodeSuccess {
   693  			s.cache.Delete(key) // Don't cache unless the lookup succeeded.
   694  		} else {
   695  			dv.answer = answer
   696  			dv.rCode = rCode
   697  
   698  			// Return a result for the correct query type. The result will be nil (nxdomain) if nothing was found. It might
   699  			// also be empty if no RRs were found for the given query type and that is OK.
   700  			// See https://datatracker.ietf.org/doc/html/rfc4074#section-3
   701  			answer = copyRRs(answer, []uint16{q.Qtype})
   702  		}
   703  		atomic.StoreInt32(&dv.currentQType, int32(dns.TypeNone))
   704  		dv.close()
   705  	}()
   706  	return s.resolve(s.ctx, q)
   707  }
   708  
   709  // dfs is a func that implements the fmt.Stringer interface. Used in log statements to ensure
   710  // that the function isn't evaluated until the log output is formatted (which will happen only
   711  // if the given loglevel is enabled).
   712  type dfs func() string
   713  
   714  func (d dfs) String() string {
   715  	return d()
   716  }
   717  
   718  func (s *Server) performRecursionCheck(c context.Context) {
   719  	s.Lock()
   720  	if _, ok := s.routes["kube-system"]; !ok {
   721  		s.routes["kube-system"] = struct{}{}
   722  		nl := len(s.routes)
   723  		defer func() {
   724  			s.Lock()
   725  			if nl == len(s.routes) {
   726  				delete(s.routes, "kube-system")
   727  			}
   728  			s.Unlock()
   729  		}()
   730  	}
   731  	s.Unlock()
   732  	defer func() {
   733  		dlog.Debug(c, "Recursion check finished")
   734  		close(s.ready)
   735  	}()
   736  	rc := recursionCheck + tel2SubDomain
   737  	dlog.Debugf(c, "Performing initial recursion check with %s", rc)
   738  	i := 0
   739  	atomic.StoreInt32(&s.recursive, recursionQueryNotYetReceived)
   740  	for ; c.Err() == nil && i < maxRecursionTestRetries && atomic.LoadInt32(&s.recursive) == recursionQueryNotYetReceived; i++ {
   741  		go func() {
   742  			_, _ = net.DefaultResolver.LookupIP(c, "ip4", rc)
   743  		}()
   744  		time.Sleep(500 * time.Millisecond)
   745  	}
   746  	if i == maxRecursionTestRetries {
   747  		msg := "DNS doesn't seem to work properly"
   748  		dlog.Error(c, msg)
   749  		s.Lock()
   750  		s.error = msg
   751  		s.Unlock()
   752  		return
   753  	}
   754  	// Await result
   755  	for c.Err() == nil {
   756  		rc := atomic.LoadInt32(&s.recursive)
   757  		if rc == recursionDetected || rc == recursionNotDetected {
   758  			break
   759  		}
   760  		dtime.SleepWithContext(c, 10*time.Millisecond)
   761  	}
   762  }
   763  
   764  func localHostReply(q *dns.Question) dnsproxy.RRs {
   765  	switch q.Qtype {
   766  	case dns.TypeA:
   767  		return dnsproxy.RRs{&dns.A{
   768  			Hdr: dns.RR_Header{
   769  				Name:   q.Name,
   770  				Rrtype: q.Qtype,
   771  				Class:  q.Qclass,
   772  			},
   773  			A: localhostIPv4,
   774  		}}
   775  	case dns.TypeAAAA:
   776  		return dnsproxy.RRs{&dns.AAAA{
   777  			Hdr: dns.RR_Header{
   778  				Name:   q.Name,
   779  				Rrtype: q.Qtype,
   780  				Class:  q.Qclass,
   781  			},
   782  			AAAA: localhostIPv6,
   783  		}}
   784  	default:
   785  		return nil
   786  	}
   787  }
   788  
   789  // ServeDNS is an implementation of github.com/miekg/dns Handler.ServeDNS.
   790  func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
   791  	c := s.ctx
   792  	atomic.AddInt64(&s.requestCount, 1)
   793  
   794  	q := &r.Question[0]
   795  	qts := dns.TypeToString[q.Qtype]
   796  	dlog.Debugf(c, "ServeDNS %5d %-6s %s", r.Id, qts, q.Name)
   797  
   798  	msg := new(dns.Msg)
   799  	var pfx dfs = func() string { return "" }
   800  	var txt dfs = func() string { return "" }
   801  	var rct dfs = func() string { return dns.RcodeToString[msg.Rcode] }
   802  
   803  	defer func() {
   804  		dlog.Debugf(c, "%s%5d %-6s %s -> %s %s", pfx, r.Id, qts, q.Name, rct, txt)
   805  		_ = w.WriteMsg(msg)
   806  
   807  		// Closing the response tells the DNS service to terminate
   808  		if c.Err() != nil {
   809  			_ = w.Close()
   810  		}
   811  	}()
   812  
   813  	// The sanity-check query is sent during the configuration phase of the DNS server and then
   814  	// never again. It must reply with localhost.
   815  	//
   816  	// NOTE! The sanity-check will always use the tel2-search subdomain, so the check made here
   817  	//       must be made before the tel2-search is removed.
   818  	if q.Name == santiyCheckDot {
   819  		answer := localHostReply(q)
   820  		if answer == nil {
   821  			msg.SetRcode(r, dns.RcodeNotImplemented)
   822  			return
   823  		}
   824  		msg.SetReply(r)
   825  		msg.Answer = answer
   826  		msg.Authoritative = true
   827  		txt = func() string { return answer.String() }
   828  		dlog.Debug(c, "sanity-check OK")
   829  		return
   830  	}
   831  
   832  	if !dnsproxy.SupportedType(q.Qtype) {
   833  		msg.SetRcode(r, dns.RcodeNotImplemented)
   834  		return
   835  	}
   836  
   837  	// We make changes to the query name, so we better restore it prior to writing an
   838  	// answer back, or the caller will get confused.
   839  	origName := q.Name
   840  	defer func() {
   841  		qs := msg.Question
   842  		if len(qs) > 0 {
   843  			mq := &qs[0] // Important to use a pointer here. We don't want to change a by-value copied struct.
   844  			if mq.Name == q.Name {
   845  				mq.Name = origName
   846  			}
   847  		}
   848  		for _, rr := range msg.Answer {
   849  			h := rr.Header()
   850  			if h.Name == q.Name {
   851  				h.Name = origName
   852  			}
   853  		}
   854  		q.Name = origName
   855  	}()
   856  
   857  	// We're all about lowercase in here
   858  	q.Name = strings.ToLower(origName)
   859  
   860  	// The tel2SubDomain serves one purpose and one purpose alone. It's there to coerce the
   861  	// system DNS resolver to direct requests to this resolver. The system configuration to
   862  	// make this happen vary depending on OS, but the purpose is always the same. Given that,
   863  	// the first step in the resolution is to remove this domain-suffix if it exists.
   864  	ln := len(q.Name)
   865  	for _, dropSuffix := range s.dropSuffixes {
   866  		if strings.HasSuffix(q.Name, dropSuffix) {
   867  			// Remove the suffix and ensure that the name still ends
   868  			// with a dot after the removal. If it doesn't, then this
   869  			// was not really a domain suffix but rather a partial
   870  			// domain name.
   871  			n := q.Name[:ln-len(dropSuffix)]
   872  			if last := len(n) - 1; last > 0 && n[last] == '.' {
   873  				q.Name = n
   874  				break
   875  			}
   876  		}
   877  	}
   878  
   879  	if strings.Contains(q.Name, tel2SubDomainDot) {
   880  		// This is a bogus name because it has some domain after
   881  		// the tel2-search domain. Should normally never happen, but
   882  		// will happen if someone queries for the tel2-search domain
   883  		// as a single label name.
   884  		msg.SetRcode(r, dns.RcodeNameError)
   885  		return
   886  	}
   887  
   888  	// try and resolve any mappings before consulting the cache, so that mapping hits don't
   889  	// end up in the cache.
   890  	answer, rCode, err := s.resolveMapping(q)
   891  	if err == errNoMapping {
   892  		answer, rCode, err = s.resolveWithRecursionCheck(q)
   893  	}
   894  
   895  	if err == nil && rCode == dns.RcodeSuccess {
   896  		if rCode != dns.RcodeSuccess {
   897  			msg.SetRcode(r, rCode)
   898  		} else {
   899  			msg.SetReply(r)
   900  		}
   901  		msg.Answer = answer
   902  		msg.Authoritative = true
   903  		// mac dns seems to fallback if you don't
   904  		// support recursion, if you have more than a
   905  		// single dns server, this will prevent us
   906  		// from intercepting all queries
   907  		msg.RecursionAvailable = true
   908  		txt = func() string { return answer.String() }
   909  		return
   910  	}
   911  
   912  	// The recursion check query, or queries that end with the cluster domain name, are not dispatched to the
   913  	// fallback DNS-server.
   914  	s.RLock()
   915  	cd := s.clusterDomain
   916  	s.RUnlock()
   917  	if s.fallbackPool == nil ||
   918  		strings.HasPrefix(q.Name, recursionCheck2) ||
   919  		strings.HasSuffix(q.Name, cd) ||
   920  		strings.HasSuffix(origName, tel2SubDomainDot) {
   921  		if err == nil {
   922  			rCode = dns.RcodeNameError
   923  		} else {
   924  			rCode = dns.RcodeServerFailure
   925  			if errors.Is(err, context.DeadlineExceeded) {
   926  				txt = func() string { return "timeout" }
   927  			} else {
   928  				txt = err.Error
   929  			}
   930  		}
   931  		msg.SetRcode(r, rCode)
   932  		return
   933  	}
   934  
   935  	// Use original query name when sending things to the fallback resolver.
   936  	q.Name = origName
   937  	pfx = func() string { return fmt.Sprintf("(%s) ", s.fallbackPool.RemoteAddr()) }
   938  	dc := &dns.Client{Net: "udp", Timeout: s.lookupTimeout}
   939  	var poolMsg *dns.Msg
   940  	poolMsg, _, err = s.fallbackPool.Exchange(c, dc, r)
   941  	if err != nil {
   942  		rCode = dns.RcodeServerFailure
   943  		txt = err.Error
   944  		if netErr, ok := err.(net.Error); ok {
   945  			switch {
   946  			case netErr.Timeout():
   947  				txt = func() string { return "timeout" }
   948  			case netErr.Temporary(): //nolint:staticcheck // err.Temporary is deprecated
   949  				rCode = dns.RcodeRefused
   950  			default:
   951  			}
   952  		}
   953  		msg.SetRcode(r, rCode)
   954  	} else {
   955  		msg = poolMsg
   956  		txt = func() string { return dnsproxy.RRs(msg.Answer).String() }
   957  	}
   958  }
   959  
   960  var errNoMapping = errors.New("no mapping") //nolint:gochecknoglobals // constant
   961  
   962  func (s *Server) resolveMapping(q *dns.Question) (dnsproxy.RRs, int, error) {
   963  	switch q.Qtype {
   964  	case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME:
   965  	default:
   966  		return nil, dns.RcodeNameError, errNoMapping
   967  	}
   968  
   969  	s.RLock()
   970  	mappingAlias, ok := s.mappings[q.Name]
   971  	s.RUnlock()
   972  
   973  	if !ok {
   974  		return nil, dns.RcodeNameError, errNoMapping
   975  	}
   976  	if ip := iputil.Parse(mappingAlias); ip != nil {
   977  		// The name resolves to an A or AAAA record known by this DNS server.
   978  		var rrs dnsproxy.RRs
   979  		if q.Qtype == dns.TypeA && len(ip) == 4 {
   980  			rrs = dnsproxy.RRs{&dns.A{
   981  				Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: dnsTTL},
   982  				A:   ip,
   983  			}}
   984  		} else if q.Qtype == dns.TypeAAAA && len(ip) == 16 {
   985  			rrs = dnsproxy.RRs{&dns.AAAA{
   986  				Hdr:  dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: dnsTTL},
   987  				AAAA: ip,
   988  			}}
   989  		}
   990  		return rrs, dns.RcodeSuccess, nil
   991  	}
   992  
   993  	cnameRRs := dnsproxy.RRs{&dns.CNAME{
   994  		Hdr:    dns.RR_Header{Name: q.Name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: dnsTTL},
   995  		Target: mappingAlias,
   996  	}}
   997  
   998  	if q.Qtype == dns.TypeCNAME {
   999  		// A query for the CNAME must only return the CNAME.
  1000  		return cnameRRs, dns.RcodeSuccess, nil
  1001  	}
  1002  
  1003  	// A query for an A or AAAA must resolve the CNAME and then return both the result and the
  1004  	// CNAME that resolved to it.
  1005  	answer, rCode, err := s.resolveWithRecursionCheck(&dns.Question{
  1006  		Name:   mappingAlias,
  1007  		Qtype:  q.Qtype,
  1008  		Qclass: q.Qclass,
  1009  	})
  1010  	if err == nil {
  1011  		answer = append(cnameRRs, answer...)
  1012  	}
  1013  	return answer, rCode, err
  1014  }
  1015  
  1016  // Run starts the DNS server(s) and waits for them to end.
  1017  func (s *Server) Run(c context.Context, initDone chan<- struct{}, listeners []net.PacketConn, fallbackPool FallbackPool, resolve Resolver) error {
  1018  	s.ctx = c
  1019  	s.fallbackPool = fallbackPool
  1020  	s.resolve = resolve
  1021  
  1022  	g := dgroup.NewGroup(c, dgroup.GroupConfig{})
  1023  	for _, listener := range listeners {
  1024  		srv := &dns.Server{PacketConn: listener, Handler: s, ReadTimeout: time.Second}
  1025  		g.Go(listener.LocalAddr().String(), func(c context.Context) error {
  1026  			go func() {
  1027  				<-c.Done()
  1028  				dlog.Debugf(c, "Shutting down DNS server")
  1029  				_ = srv.ShutdownContext(dcontext.HardContext(c))
  1030  			}()
  1031  			return srv.ActivateAndServe()
  1032  		})
  1033  	}
  1034  	close(initDone)
  1035  	return g.Wait()
  1036  }