github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/client/rootd/dns/server.go (about) 1 package dns 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "slices" 9 "strings" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/miekg/dns" 15 "github.com/puzpuzpuz/xsync/v3" 16 "golang.org/x/exp/maps" 17 "google.golang.org/protobuf/types/known/durationpb" 18 19 "github.com/datawire/dlib/dcontext" 20 "github.com/datawire/dlib/dgroup" 21 "github.com/datawire/dlib/dlog" 22 "github.com/datawire/dlib/dtime" 23 rpc "github.com/telepresenceio/telepresence/rpc/v2/daemon" 24 "github.com/telepresenceio/telepresence/rpc/v2/manager" 25 "github.com/telepresenceio/telepresence/v2/pkg/client" 26 "github.com/telepresenceio/telepresence/v2/pkg/dnsproxy" 27 "github.com/telepresenceio/telepresence/v2/pkg/iputil" 28 "github.com/telepresenceio/telepresence/v2/pkg/slice" 29 "github.com/telepresenceio/telepresence/v2/pkg/vif" 30 ) 31 32 type Resolver func(context.Context, *dns.Question) (dnsproxy.RRs, int, error) 33 34 const ( 35 // defaultClusterDomain used unless traffic-manager reports otherwise. 36 defaultClusterDomain = "cluster.local." 37 38 // sanityCheck is the query used when verifying that a DNS query reaches our DNS server. It should result 39 // in an increase of the requestCount but always yield an NXDOMAIN reply. 40 santiyCheck = "jhfweoitnkgyeta." + tel2SubDomain 41 santiyCheckDot = santiyCheck + "." 42 43 // dnsTTL is the number of seconds that a found DNS record should be allowed to live in the callers cache. We 44 // keep this low to avoid such caching. 45 dnsTTL = 4 46 ) 47 48 type FallbackPool interface { 49 Exchange(context.Context, *dns.Client, *dns.Msg) (*dns.Msg, time.Duration, error) 50 RemoteAddr() string 51 LocalAddrs() []*net.UDPAddr 52 Close() 53 } 54 55 const ( 56 _ = int32(iota) 57 recursionQueryNotYetReceived 58 recursionQueryReceived 59 recursionNotDetected 60 recursionDetected 61 ) 62 63 //nolint:gochecknoglobals // constant 64 var DefaultExcludeSuffixes = []string{ 65 ".com", 66 ".io", 67 ".net", 68 ".org", 69 ".ru", 70 } 71 72 type nsAndDomains struct { 73 domains []string 74 namespace string 75 } 76 77 // Server is a DNS server which implements the github.com/miekg/dns Handler interface. 78 type Server struct { 79 sync.RWMutex 80 ctx context.Context // necessary to make logging work in ServeDNS function 81 fallbackPool FallbackPool 82 resolve Resolver 83 requestCount int64 84 cache *xsync.MapOf[cacheKey, *cacheEntry] 85 recursive int32 // one of the recursionXXX constants declared above (unique type avoided because it just gets messy with the atomic calls) 86 87 // Suffixes to immediately drop from the query before processing. This list will always contain the tel2Search domain. 88 // The overriding resolver will also add the search path found in /etc/resolv.conf, because that search path is not 89 // intended for this resolver and will get reapplied when passing things on to the fallback resolver. 90 dropSuffixes []string 91 92 // routes are typically namespaces, accessible using <service-name>.<namespace-name>. 93 routes map[string]struct{} 94 95 // search are appended to a query to form new names that are then dispatched to the 96 // DNS resolver. The act of appending is not performed by this server, but rather 97 // by the system's DNS resolver before calling on this server. 98 search []string 99 100 // domains contains the sum of the include-suffixes and routes. It is currently only 101 // used by the darwin resolver to keep track of files to add or remove. 102 domains map[string]struct{} 103 104 // nsAndDomainsCh receives requests to change the top level domains and the search path. 105 nsAndDomainsCh chan nsAndDomains 106 107 includeSuffixes []string 108 109 excludeSuffixes []string 110 111 excludes []string 112 mappings map[string]string 113 114 lookupTimeout time.Duration 115 116 localIP net.IP 117 remoteIP net.IP 118 119 // clusterDomain reported by the traffic-manager 120 clusterDomain string 121 122 // Function that sends a lookup request to the traffic-manager 123 clusterLookup Resolver 124 125 error string 126 127 // ready is closed when the DNS server is fully configured 128 ready chan struct{} 129 } 130 131 type cacheEntry struct { 132 created time.Time 133 currentQType int32 // will be set to the current qType during call to cluster 134 answer dnsproxy.RRs 135 rCode int 136 wait chan struct{} 137 } 138 139 // cacheTTL is the time to live for an entry in the local DNS cache. 140 const cacheTTL = 60 * time.Second 141 142 func (dv *cacheEntry) expired() bool { 143 return time.Since(dv.created) > cacheTTL 144 } 145 146 func (dv *cacheEntry) close() { 147 select { 148 case <-dv.wait: 149 default: 150 close(dv.wait) 151 } 152 } 153 154 func sliceToLower(ss []string) []string { 155 for i, s := range ss { 156 ss[i] = strings.ToLower(s) 157 } 158 return ss 159 } 160 161 // NewServer returns a new dns.Server. 162 func NewServer(config *rpc.DNSConfig, clusterLookup Resolver) *Server { 163 if config == nil { 164 config = &rpc.DNSConfig{} 165 } 166 if len(config.ExcludeSuffixes) == 0 { 167 config.ExcludeSuffixes = DefaultExcludeSuffixes 168 } 169 if config.LookupTimeout.AsDuration() <= 0 { 170 config.LookupTimeout = durationpb.New(8 * time.Second) 171 } 172 s := &Server{ 173 cache: xsync.NewMapOf[cacheKey, *cacheEntry](), 174 routes: make(map[string]struct{}), 175 domains: make(map[string]struct{}), 176 excludes: sliceToLower(config.Excludes), 177 excludeSuffixes: sliceToLower(config.ExcludeSuffixes), 178 includeSuffixes: sliceToLower(config.IncludeSuffixes), 179 mappings: mappingsMap(config.Mappings), 180 localIP: config.LocalIp, 181 remoteIP: config.RemoteIp, 182 dropSuffixes: []string{tel2SubDomainDot}, 183 search: []string{tel2SubDomain}, 184 nsAndDomainsCh: make(chan nsAndDomains, 5), 185 clusterDomain: defaultClusterDomain, 186 clusterLookup: clusterLookup, 187 ready: make(chan struct{}), 188 } 189 if lt := config.LookupTimeout; lt != nil { 190 s.lookupTimeout = lt.AsDuration() 191 } 192 return s 193 } 194 195 // tel2SubDomain helps differentiate between single label and qualified DNS queries. 196 // 197 // Dealing with single label names is tricky because what we really want is to receive the 198 // name and then forward it verbatim to the DNS resolver in the cluster so that it can 199 // add whatever search paths to it that it sees fit, but in order to receive single name 200 // queries in the first place, our DNS resolver must have a search path that adds a domain 201 // that the DNS system knows that we will handle. 202 // 203 // Example flow: 204 // The user queries for the name "alpha". The DNS system on the host tries the search path 205 // of our DNS resolver which contains "tel2-search" and creates the name "alpha.tel2-search". 206 // The DNS system now discovers that our DNS resolver handles that domain, so we receive 207 // the query. We then strip the "tel2-search" and send the original single label name to the 208 // cluster, and we add it back before we forward the reply. 209 const ( 210 tel2SubDomain = "tel2-search" 211 tel2SubDomainDot = tel2SubDomain + "." 212 ) 213 214 // wpadDot is used when rejecting all WPAD (Wep Proxy Auto-Discovery) queries. 215 const wpadDot = "wpad." 216 217 var ( 218 localhostIPv4 = net.IP{127, 0, 0, 1} //nolint:gochecknoglobals // constant 219 localhostIPv6 = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} //nolint:gochecknoglobals // constant 220 ) 221 222 func (s *Server) shouldDoClusterLookup(query string) bool { 223 name := query[:len(query)-1] // skip last dot 224 if strings.HasPrefix(query, wpadDot) { 225 // Reject "wpad.*" 226 dlog.Debugf(s.ctx, `Cluster DNS excluded by exclude-prefix "wpad." for name %q`, name) 227 return false 228 } 229 230 if s.isExcluded(name) { 231 // Reject any host explicitly added to the exclude list. 232 dlog.Debugf(s.ctx, "Cluster DNS explicitly excluded for name %q", name) 233 return false 234 } 235 236 if !strings.ContainsRune(name, '.') { 237 // Single label names are always included. 238 dlog.Debugf(s.ctx, "Cluster DNS included for single label name %q", name) 239 return true 240 } 241 242 // Skip configured exclude-suffixes unless also matched by an include-suffix 243 // that is longer (i.e. more specific). 244 for _, es := range s.excludeSuffixes { 245 if strings.HasSuffix(name, es) { 246 // Exclude unless more specific include. 247 for _, is := range s.includeSuffixes { 248 if len(is) >= len(es) && strings.HasSuffix(name, is) { 249 dlog.Debugf(s.ctx, 250 "Cluster DNS included by include-suffix %q (overriding exclude-suffix %q) for name %q", is, es, name) 251 return true 252 } 253 } 254 dlog.Debugf(s.ctx, "Cluster DNS excluded by exclude-suffix %q for name %q", es, name) 255 return false 256 } 257 } 258 259 // Always include configured search paths 260 for _, sfx := range s.search { 261 if strings.HasSuffix(name, sfx) { 262 dlog.Debugf(s.ctx, "Cluster DNS included by search %q of name %q", sfx, name) 263 return true 264 } 265 } 266 267 // Always include configured routes 268 for sfx := range s.routes { 269 if strings.HasSuffix(name, sfx) { 270 dlog.Debugf(s.ctx, "Cluster DNS included by namespace %q of name %q", sfx, name) 271 return true 272 } 273 } 274 275 // Always include queries for the cluster domain. 276 if strings.HasSuffix(query, "."+s.clusterDomain) { 277 dlog.Debugf(s.ctx, "Cluster DNS included by cluster domain %q of name %q", s.clusterDomain, name) 278 return true 279 } 280 281 // Always include configured includeSuffixes 282 for _, sfx := range s.includeSuffixes { 283 if strings.HasSuffix(name, sfx) { 284 dlog.Debugf(s.ctx, 285 "Cluster DNS included by include-suffix %q for name %q", sfx, name) 286 return true 287 } 288 } 289 290 // Pass any queries for the cluster domain. 291 dlog.Debugf(s.ctx, "Cluster DNS excluded for name %q. No inclusion rule was matched", name) 292 return false 293 } 294 295 func (s *Server) isExcluded(name string) bool { 296 if slice.Contains(s.excludes, name) { 297 return true 298 } 299 300 // When intercepting, this function will potentially receive the hostname of any search param, so their 301 // unqualified hostname should be evaluated too. 302 qLen := len(name) 303 for _, sp := range s.search { 304 if strings.HasSuffix(name, "."+sp) && slice.Contains(s.excludes, name[:qLen-len(sp)-1]) { 305 return true 306 } 307 } 308 return false 309 } 310 311 func (s *Server) isDomainExcluded(name string) bool { 312 return slices.Contains(s.excludeSuffixes, "."+name) 313 } 314 315 func (s *Server) resolveInCluster(c context.Context, q *dns.Question) (result dnsproxy.RRs, rCode int, err error) { 316 query := q.Name 317 if query == "localhost." { 318 // BUG(lukeshu): I have no idea why a lookup 319 // for localhost even makes it to here on my 320 // home WiFi when connecting to a k3sctl 321 // cluster (but not a kubernaut.io cluster). 322 // But it does, so I need this in order to be 323 // productive at home. We should really 324 // root-cause this, because it's weird. 325 hdr := dns.RR_Header{ 326 Name: q.Name, 327 Rrtype: q.Qtype, 328 Class: q.Qclass, 329 } 330 switch q.Qtype { 331 case dns.TypeA: 332 return dnsproxy.RRs{&dns.A{ 333 Hdr: hdr, 334 A: localhostIPv4, 335 }}, dns.RcodeSuccess, nil 336 case dns.TypeAAAA: 337 return dnsproxy.RRs{&dns.AAAA{ 338 Hdr: hdr, 339 AAAA: localhostIPv6, 340 }}, dns.RcodeSuccess, nil 341 default: 342 return nil, dns.RcodeNameError, nil 343 } 344 } 345 346 if !s.shouldDoClusterLookup(query) { 347 return nil, dns.RcodeNameError, nil 348 } 349 350 // Give the cluster lookup a reasonable timeout. 351 c, cancel := context.WithTimeout(c, s.lookupTimeout) 352 defer cancel() 353 354 result, rCode, err = s.clusterLookup(c, q) 355 if err != nil { 356 return nil, rCode, client.CheckTimeout(c, err) 357 } 358 359 // Keep the TTLs of requests resolved in the cluster low. We 360 // cache them locally anyway, but our cache is flushed when things are 361 // intercepted or the namespaces change. 362 for _, rr := range result { 363 if h := rr.Header(); h != nil { 364 h.Ttl = dnsTTL 365 } 366 } 367 return result, rCode, nil 368 } 369 370 func (s *Server) GetConfig() *rpc.DNSConfig { 371 s.RLock() 372 c := rpc.DNSConfig{ 373 LocalIp: s.localIP, 374 RemoteIp: s.remoteIP, 375 ExcludeSuffixes: s.excludeSuffixes, 376 IncludeSuffixes: s.includeSuffixes, 377 Excludes: s.excludes, 378 Error: s.error, 379 } 380 if s.lookupTimeout != 0 { 381 c.LookupTimeout = durationpb.New(s.lookupTimeout) 382 } 383 if len(s.mappings) > 0 { 384 ns := maps.Keys(s.mappings) 385 slices.Sort(ns) 386 c.Mappings = make([]*rpc.DNSMapping, len(s.mappings)) 387 for i, n := range ns { 388 c.Mappings[i] = &rpc.DNSMapping{ 389 Name: strings.TrimSuffix(n, "."), 390 AliasFor: strings.TrimSuffix(s.mappings[n], "."), 391 } 392 } 393 } 394 s.RUnlock() 395 return &c 396 } 397 398 func (s *Server) Ready() <-chan struct{} { 399 return s.ready 400 } 401 402 func (s *Server) Stop() { 403 // Close s.ready unless it's already closed 404 select { 405 case <-s.ready: 406 default: 407 close(s.ready) 408 } 409 } 410 411 func (s *Server) SetClusterDNS(dns *manager.DNS, remoteIP net.IP) { 412 s.Lock() 413 if s.remoteIP == nil { 414 s.remoteIP = remoteIP 415 } 416 if dns != nil { 417 if slices.Equal(s.excludeSuffixes, DefaultExcludeSuffixes) && len(dns.ExcludeSuffixes) > 0 { 418 s.excludeSuffixes = sliceToLower(dns.ExcludeSuffixes) 419 } 420 if len(s.includeSuffixes) == 0 { 421 s.includeSuffixes = sliceToLower(dns.IncludeSuffixes) 422 } 423 s.clusterDomain = strings.ToLower(dns.ClusterDomain) 424 } 425 s.Unlock() 426 } 427 428 // SetTopLevelDomainsAndSearchPath updates the DNS top level domains and the search path used by the resolver. 429 func (s *Server) SetTopLevelDomainsAndSearchPath(ctx context.Context, domains []string, namespace string) { 430 das := nsAndDomains{ 431 domains: domains, 432 namespace: namespace, 433 } 434 select { 435 case <-ctx.Done(): 436 case s.nsAndDomainsCh <- das: 437 } 438 } 439 440 func (s *Server) purgeRecordsFromCache(keyName string) { 441 keyName = strings.TrimSuffix(keyName, ".") + "." 442 for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} { 443 toDeleteKey := cacheKey{name: keyName, qType: qType} 444 if old, ok := s.cache.LoadAndDelete(toDeleteKey); ok { 445 old.close() 446 } 447 } 448 } 449 450 // SetExcludes sets the excludes list in the config. 451 func (s *Server) SetExcludes(excludes []string) { 452 for i, e := range excludes { 453 excludes[i] = strings.ToLower(e) 454 } 455 s.Lock() 456 oldExcludes := s.excludes 457 s.excludes = excludes 458 s.Unlock() 459 460 for _, e := range slice.AppendUnique(oldExcludes, excludes...) { 461 s.purgeRecordsFromCache(e) 462 } 463 } 464 465 func mappingsMap(mappings []*rpc.DNSMapping) map[string]string { 466 if l := len(mappings); l > 0 { 467 mm := make(map[string]string, l) 468 for _, m := range mappings { 469 al := m.AliasFor 470 if ip := iputil.Parse(al); ip == nil { 471 al += "." 472 } 473 mm[strings.ToLower(m.Name+".")] = strings.ToLower(al) 474 } 475 return mm 476 } 477 return nil 478 } 479 480 // SetMappings sets the Mappings list in the config. 481 func (s *Server) SetMappings(mappings []*rpc.DNSMapping) { 482 mm := mappingsMap(mappings) 483 s.Lock() 484 s.mappings = mm 485 s.Unlock() 486 487 // Flush the mappings. 488 for n := range mm { 489 s.purgeRecordsFromCache(n) 490 } 491 } 492 493 func newLocalUDPListener(c context.Context) (net.PacketConn, error) { 494 lc := &net.ListenConfig{} 495 return lc.ListenPacket(c, "udp", "127.0.0.1:0") 496 } 497 498 func (s *Server) processSearchPaths(g *dgroup.Group, processor func(context.Context, vif.Device) error, dev vif.Device) { 499 g.Go("SearchPaths", func(c context.Context) error { 500 s.performRecursionCheck(c) 501 prevDas := nsAndDomains{ 502 domains: []string{}, 503 namespace: "", 504 } 505 unchanged := func(das nsAndDomains) bool { 506 return das.namespace == prevDas.namespace && slices.Equal(das.domains, prevDas.domains) 507 } 508 509 for { 510 select { 511 case <-c.Done(): 512 return nil 513 case das := <-s.nsAndDomainsCh: 514 // Only interested in the last one, and only if it differs 515 if len(s.nsAndDomainsCh) > 0 || unchanged(das) { 516 continue 517 } 518 prevDas = das 519 520 routes := make(map[string]struct{}, len(das.domains)) 521 for _, domain := range das.domains { 522 if domain != "" && !s.isDomainExcluded(domain) { 523 routes[domain] = struct{}{} 524 } 525 } 526 if !s.isDomainExcluded("svc") { 527 routes["svc"] = struct{}{} 528 } 529 s.Lock() 530 s.routes = routes 531 532 // The connected namespace must be included as a search path for the cases 533 // where it's up to the traffic-manager to resolve. It cannot resolve a single 534 // label name intended for other namespaces. 535 s.search = []string{tel2SubDomain, das.namespace} 536 s.Unlock() 537 538 if err := processor(c, dev); err != nil { 539 return err 540 } 541 } 542 } 543 }) 544 } 545 546 func (s *Server) flushDNS() { 547 s.cache.Range(func(key cacheKey, _ *cacheEntry) bool { 548 if old, ok := s.cache.LoadAndDelete(key); ok { 549 old.close() 550 } 551 return true 552 }) 553 } 554 555 // splitToUDPAddr splits the given address into an UDPAddr. It's 556 // an error if the address is based on a hostname rather than an IP. 557 func splitToUDPAddr(netAddr net.Addr) (*net.UDPAddr, error) { 558 ip, port, err := iputil.SplitToIPPort(netAddr) 559 if err != nil { 560 return nil, err 561 } 562 return &net.UDPAddr{IP: ip, Port: int(port)}, nil 563 } 564 565 // RequestCount returns the number of requests that this server has received. 566 func (s *Server) RequestCount() int { 567 return int(atomic.LoadInt64(&s.requestCount)) 568 } 569 570 func copyRRs(rrs dnsproxy.RRs, qTypes []uint16) dnsproxy.RRs { 571 if len(rrs) == 0 { 572 return rrs 573 } 574 cp := make(dnsproxy.RRs, 0, len(rrs)) 575 for _, rr := range rrs { 576 if slice.Contains(qTypes, rr.Header().Rrtype) { 577 cp = append(cp, dns.Copy(rr)) 578 } 579 } 580 return cp 581 } 582 583 type cacheKey struct { 584 name string 585 qType uint16 586 } 587 588 func (c *cacheKey) String() string { 589 return fmt.Sprintf("%s %s", dns.TypeToString[c.qType], c.name) 590 } 591 592 const ( 593 // recursionCheck is a special host name in a well known namespace that isn't expected to exist. It 594 // is used once for determining if the cluster's DNS resolver will call the Telepresence DNS resolver 595 // recursively. This is common when the cluster is running on the local host (k3s in docker for instance). 596 // 597 // The check is performed using the following steps. 598 // 1. A lookup is made for "tel-recursion-check 599 // 2. When our DNS-resolver receives this lookup, it modifies the name to "tel2-recursion-check.kube-system." 600 // and sends that on to the cluster. 601 // 3. If our DNS-resolver now encounters a query for the "tel2-recursion-check.kube-system.", then we know 602 // that a recursion took place. 603 // 4. If no request for "tel2-recursion-check.kube-system." is received, then it's assumed that the resolver 604 // is not recursive. 605 recursionCheck = "tel2-recursion-check." 606 recursionCheck2 = "tel2-recursion-check.kube-system." 607 ) 608 609 func (s *Server) resolveWithRecursionCheck(q *dns.Question) (dnsproxy.RRs, int, error) { 610 if strings.HasPrefix(q.Name, recursionCheck) { 611 if strings.HasPrefix(q.Name, recursionCheck2) { 612 if atomic.CompareAndSwapInt32(&s.recursive, recursionQueryReceived, recursionDetected) { 613 dlog.Debug(s.ctx, "DNS resolver is recursive") 614 } 615 return nil, dns.RcodeNameError, nil 616 } 617 618 if atomic.CompareAndSwapInt32(&s.recursive, recursionQueryNotYetReceived, recursionQueryReceived) { 619 tc, cancel := context.WithTimeout(s.ctx, recursionTestTimeout) 620 go func() { 621 defer cancel() 622 nq := *q // by value copy 623 nq.Name = recursionCheck2 624 _, _, _ = s.resolveInCluster(s.ctx, &nq) // We really don't care about the reply here. 625 }() 626 <-tc.Done() 627 628 // When we've gotten the reply from the cluster, we know if recursion did occur. 629 if atomic.CompareAndSwapInt32(&s.recursive, recursionQueryReceived, recursionNotDetected) { 630 dlog.Debug(s.ctx, "DNS resolver is not recursive") 631 } 632 } 633 return localHostReply(q), dns.RcodeSuccess, nil 634 } 635 636 answer, rCode, err := s.resolveThruCache(q) 637 if err != nil || rCode != dns.RcodeSuccess { 638 // For A and AAAA queries, we check if we have a successful counterpart in the cache. If we 639 // do, then this query must return NOERROR EMPTY 640 ck := cacheKey{name: q.Name, qType: dns.TypeNone} 641 switch q.Qtype { 642 case dns.TypeA: 643 ck.qType = dns.TypeAAAA 644 case dns.TypeAAAA: 645 ck.qType = dns.TypeA 646 } 647 if ck.qType != dns.TypeNone { 648 if ce, ok := s.cache.Load(ck); ok { 649 <-ce.wait 650 if !ce.expired() && ce.rCode == dns.RcodeSuccess && atomic.LoadInt32(&ce.currentQType) == int32(ck.qType) { 651 dlog.Debugf(s.ctx, "found counterpart for %s %s", dns.TypeToString[uint16(ce.currentQType)], ce.answer) 652 err = nil 653 rCode = dns.RcodeSuccess 654 } 655 } 656 } 657 } 658 return answer, rCode, err 659 } 660 661 // resolveThruCache resolves the given query by first performing a cache lookup. If a cached 662 // entry is found that hasn't expired, it's returned. If not, this function will call 663 // resolveQuery() to resolve and store in the case. 664 func (s *Server) resolveThruCache(q *dns.Question) (answer dnsproxy.RRs, rCode int, err error) { 665 dv := &cacheEntry{wait: make(chan struct{}), created: time.Now()} 666 key := cacheKey{name: q.Name, qType: q.Qtype} 667 if oldDv, loaded := s.cache.LoadOrStore(key, dv); loaded { 668 if atomic.LoadInt32(&s.recursive) == recursionDetected && atomic.LoadInt32(&oldDv.currentQType) == int32(q.Qtype) { 669 // We have to assume that this is a recursion from the cluster. 670 dlog.Debugf(s.ctx, "returning error for query %q: assumed to be recursive", key.String()) 671 return nil, dns.RcodeNameError, nil 672 } 673 <-oldDv.wait 674 if !oldDv.expired() { 675 qTypes := []uint16{q.Qtype} 676 if q.Qtype != dns.TypeCNAME { 677 // Allow additional CNAME records if they are present. 678 for _, rr := range oldDv.answer { 679 if rr.Header().Rrtype == dns.TypeCNAME { 680 qTypes = append(qTypes, dns.TypeCNAME) 681 break 682 } 683 } 684 } 685 return copyRRs(oldDv.answer, qTypes), oldDv.rCode, nil 686 } 687 s.cache.Store(key, dv) 688 } 689 690 atomic.StoreInt32(&dv.currentQType, int32(q.Qtype)) 691 defer func() { 692 if rCode != dns.RcodeSuccess { 693 s.cache.Delete(key) // Don't cache unless the lookup succeeded. 694 } else { 695 dv.answer = answer 696 dv.rCode = rCode 697 698 // Return a result for the correct query type. The result will be nil (nxdomain) if nothing was found. It might 699 // also be empty if no RRs were found for the given query type and that is OK. 700 // See https://datatracker.ietf.org/doc/html/rfc4074#section-3 701 answer = copyRRs(answer, []uint16{q.Qtype}) 702 } 703 atomic.StoreInt32(&dv.currentQType, int32(dns.TypeNone)) 704 dv.close() 705 }() 706 return s.resolve(s.ctx, q) 707 } 708 709 // dfs is a func that implements the fmt.Stringer interface. Used in log statements to ensure 710 // that the function isn't evaluated until the log output is formatted (which will happen only 711 // if the given loglevel is enabled). 712 type dfs func() string 713 714 func (d dfs) String() string { 715 return d() 716 } 717 718 func (s *Server) performRecursionCheck(c context.Context) { 719 s.Lock() 720 if _, ok := s.routes["kube-system"]; !ok { 721 s.routes["kube-system"] = struct{}{} 722 nl := len(s.routes) 723 defer func() { 724 s.Lock() 725 if nl == len(s.routes) { 726 delete(s.routes, "kube-system") 727 } 728 s.Unlock() 729 }() 730 } 731 s.Unlock() 732 defer func() { 733 dlog.Debug(c, "Recursion check finished") 734 close(s.ready) 735 }() 736 rc := recursionCheck + tel2SubDomain 737 dlog.Debugf(c, "Performing initial recursion check with %s", rc) 738 i := 0 739 atomic.StoreInt32(&s.recursive, recursionQueryNotYetReceived) 740 for ; c.Err() == nil && i < maxRecursionTestRetries && atomic.LoadInt32(&s.recursive) == recursionQueryNotYetReceived; i++ { 741 go func() { 742 _, _ = net.DefaultResolver.LookupIP(c, "ip4", rc) 743 }() 744 time.Sleep(500 * time.Millisecond) 745 } 746 if i == maxRecursionTestRetries { 747 msg := "DNS doesn't seem to work properly" 748 dlog.Error(c, msg) 749 s.Lock() 750 s.error = msg 751 s.Unlock() 752 return 753 } 754 // Await result 755 for c.Err() == nil { 756 rc := atomic.LoadInt32(&s.recursive) 757 if rc == recursionDetected || rc == recursionNotDetected { 758 break 759 } 760 dtime.SleepWithContext(c, 10*time.Millisecond) 761 } 762 } 763 764 func localHostReply(q *dns.Question) dnsproxy.RRs { 765 switch q.Qtype { 766 case dns.TypeA: 767 return dnsproxy.RRs{&dns.A{ 768 Hdr: dns.RR_Header{ 769 Name: q.Name, 770 Rrtype: q.Qtype, 771 Class: q.Qclass, 772 }, 773 A: localhostIPv4, 774 }} 775 case dns.TypeAAAA: 776 return dnsproxy.RRs{&dns.AAAA{ 777 Hdr: dns.RR_Header{ 778 Name: q.Name, 779 Rrtype: q.Qtype, 780 Class: q.Qclass, 781 }, 782 AAAA: localhostIPv6, 783 }} 784 default: 785 return nil 786 } 787 } 788 789 // ServeDNS is an implementation of github.com/miekg/dns Handler.ServeDNS. 790 func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { 791 c := s.ctx 792 atomic.AddInt64(&s.requestCount, 1) 793 794 q := &r.Question[0] 795 qts := dns.TypeToString[q.Qtype] 796 dlog.Debugf(c, "ServeDNS %5d %-6s %s", r.Id, qts, q.Name) 797 798 msg := new(dns.Msg) 799 var pfx dfs = func() string { return "" } 800 var txt dfs = func() string { return "" } 801 var rct dfs = func() string { return dns.RcodeToString[msg.Rcode] } 802 803 defer func() { 804 dlog.Debugf(c, "%s%5d %-6s %s -> %s %s", pfx, r.Id, qts, q.Name, rct, txt) 805 _ = w.WriteMsg(msg) 806 807 // Closing the response tells the DNS service to terminate 808 if c.Err() != nil { 809 _ = w.Close() 810 } 811 }() 812 813 // The sanity-check query is sent during the configuration phase of the DNS server and then 814 // never again. It must reply with localhost. 815 // 816 // NOTE! The sanity-check will always use the tel2-search subdomain, so the check made here 817 // must be made before the tel2-search is removed. 818 if q.Name == santiyCheckDot { 819 answer := localHostReply(q) 820 if answer == nil { 821 msg.SetRcode(r, dns.RcodeNotImplemented) 822 return 823 } 824 msg.SetReply(r) 825 msg.Answer = answer 826 msg.Authoritative = true 827 txt = func() string { return answer.String() } 828 dlog.Debug(c, "sanity-check OK") 829 return 830 } 831 832 if !dnsproxy.SupportedType(q.Qtype) { 833 msg.SetRcode(r, dns.RcodeNotImplemented) 834 return 835 } 836 837 // We make changes to the query name, so we better restore it prior to writing an 838 // answer back, or the caller will get confused. 839 origName := q.Name 840 defer func() { 841 qs := msg.Question 842 if len(qs) > 0 { 843 mq := &qs[0] // Important to use a pointer here. We don't want to change a by-value copied struct. 844 if mq.Name == q.Name { 845 mq.Name = origName 846 } 847 } 848 for _, rr := range msg.Answer { 849 h := rr.Header() 850 if h.Name == q.Name { 851 h.Name = origName 852 } 853 } 854 q.Name = origName 855 }() 856 857 // We're all about lowercase in here 858 q.Name = strings.ToLower(origName) 859 860 // The tel2SubDomain serves one purpose and one purpose alone. It's there to coerce the 861 // system DNS resolver to direct requests to this resolver. The system configuration to 862 // make this happen vary depending on OS, but the purpose is always the same. Given that, 863 // the first step in the resolution is to remove this domain-suffix if it exists. 864 ln := len(q.Name) 865 for _, dropSuffix := range s.dropSuffixes { 866 if strings.HasSuffix(q.Name, dropSuffix) { 867 // Remove the suffix and ensure that the name still ends 868 // with a dot after the removal. If it doesn't, then this 869 // was not really a domain suffix but rather a partial 870 // domain name. 871 n := q.Name[:ln-len(dropSuffix)] 872 if last := len(n) - 1; last > 0 && n[last] == '.' { 873 q.Name = n 874 break 875 } 876 } 877 } 878 879 if strings.Contains(q.Name, tel2SubDomainDot) { 880 // This is a bogus name because it has some domain after 881 // the tel2-search domain. Should normally never happen, but 882 // will happen if someone queries for the tel2-search domain 883 // as a single label name. 884 msg.SetRcode(r, dns.RcodeNameError) 885 return 886 } 887 888 // try and resolve any mappings before consulting the cache, so that mapping hits don't 889 // end up in the cache. 890 answer, rCode, err := s.resolveMapping(q) 891 if err == errNoMapping { 892 answer, rCode, err = s.resolveWithRecursionCheck(q) 893 } 894 895 if err == nil && rCode == dns.RcodeSuccess { 896 if rCode != dns.RcodeSuccess { 897 msg.SetRcode(r, rCode) 898 } else { 899 msg.SetReply(r) 900 } 901 msg.Answer = answer 902 msg.Authoritative = true 903 // mac dns seems to fallback if you don't 904 // support recursion, if you have more than a 905 // single dns server, this will prevent us 906 // from intercepting all queries 907 msg.RecursionAvailable = true 908 txt = func() string { return answer.String() } 909 return 910 } 911 912 // The recursion check query, or queries that end with the cluster domain name, are not dispatched to the 913 // fallback DNS-server. 914 s.RLock() 915 cd := s.clusterDomain 916 s.RUnlock() 917 if s.fallbackPool == nil || 918 strings.HasPrefix(q.Name, recursionCheck2) || 919 strings.HasSuffix(q.Name, cd) || 920 strings.HasSuffix(origName, tel2SubDomainDot) { 921 if err == nil { 922 rCode = dns.RcodeNameError 923 } else { 924 rCode = dns.RcodeServerFailure 925 if errors.Is(err, context.DeadlineExceeded) { 926 txt = func() string { return "timeout" } 927 } else { 928 txt = err.Error 929 } 930 } 931 msg.SetRcode(r, rCode) 932 return 933 } 934 935 // Use original query name when sending things to the fallback resolver. 936 q.Name = origName 937 pfx = func() string { return fmt.Sprintf("(%s) ", s.fallbackPool.RemoteAddr()) } 938 dc := &dns.Client{Net: "udp", Timeout: s.lookupTimeout} 939 var poolMsg *dns.Msg 940 poolMsg, _, err = s.fallbackPool.Exchange(c, dc, r) 941 if err != nil { 942 rCode = dns.RcodeServerFailure 943 txt = err.Error 944 if netErr, ok := err.(net.Error); ok { 945 switch { 946 case netErr.Timeout(): 947 txt = func() string { return "timeout" } 948 case netErr.Temporary(): //nolint:staticcheck // err.Temporary is deprecated 949 rCode = dns.RcodeRefused 950 default: 951 } 952 } 953 msg.SetRcode(r, rCode) 954 } else { 955 msg = poolMsg 956 txt = func() string { return dnsproxy.RRs(msg.Answer).String() } 957 } 958 } 959 960 var errNoMapping = errors.New("no mapping") //nolint:gochecknoglobals // constant 961 962 func (s *Server) resolveMapping(q *dns.Question) (dnsproxy.RRs, int, error) { 963 switch q.Qtype { 964 case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME: 965 default: 966 return nil, dns.RcodeNameError, errNoMapping 967 } 968 969 s.RLock() 970 mappingAlias, ok := s.mappings[q.Name] 971 s.RUnlock() 972 973 if !ok { 974 return nil, dns.RcodeNameError, errNoMapping 975 } 976 if ip := iputil.Parse(mappingAlias); ip != nil { 977 // The name resolves to an A or AAAA record known by this DNS server. 978 var rrs dnsproxy.RRs 979 if q.Qtype == dns.TypeA && len(ip) == 4 { 980 rrs = dnsproxy.RRs{&dns.A{ 981 Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: dnsTTL}, 982 A: ip, 983 }} 984 } else if q.Qtype == dns.TypeAAAA && len(ip) == 16 { 985 rrs = dnsproxy.RRs{&dns.AAAA{ 986 Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: dnsTTL}, 987 AAAA: ip, 988 }} 989 } 990 return rrs, dns.RcodeSuccess, nil 991 } 992 993 cnameRRs := dnsproxy.RRs{&dns.CNAME{ 994 Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: dnsTTL}, 995 Target: mappingAlias, 996 }} 997 998 if q.Qtype == dns.TypeCNAME { 999 // A query for the CNAME must only return the CNAME. 1000 return cnameRRs, dns.RcodeSuccess, nil 1001 } 1002 1003 // A query for an A or AAAA must resolve the CNAME and then return both the result and the 1004 // CNAME that resolved to it. 1005 answer, rCode, err := s.resolveWithRecursionCheck(&dns.Question{ 1006 Name: mappingAlias, 1007 Qtype: q.Qtype, 1008 Qclass: q.Qclass, 1009 }) 1010 if err == nil { 1011 answer = append(cnameRRs, answer...) 1012 } 1013 return answer, rCode, err 1014 } 1015 1016 // Run starts the DNS server(s) and waits for them to end. 1017 func (s *Server) Run(c context.Context, initDone chan<- struct{}, listeners []net.PacketConn, fallbackPool FallbackPool, resolve Resolver) error { 1018 s.ctx = c 1019 s.fallbackPool = fallbackPool 1020 s.resolve = resolve 1021 1022 g := dgroup.NewGroup(c, dgroup.GroupConfig{}) 1023 for _, listener := range listeners { 1024 srv := &dns.Server{PacketConn: listener, Handler: s, ReadTimeout: time.Second} 1025 g.Go(listener.LocalAddr().String(), func(c context.Context) error { 1026 go func() { 1027 <-c.Done() 1028 dlog.Debugf(c, "Shutting down DNS server") 1029 _ = srv.ShutdownContext(dcontext.HardContext(c)) 1030 }() 1031 return srv.ActivateAndServe() 1032 }) 1033 } 1034 close(initDone) 1035 return g.Wait() 1036 }