github.com/cilium/cilium@v1.16.2/pkg/fqdn/dnsproxy/proxy.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package dnsproxy
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"math"
    11  	"net"
    12  	"net/netip"
    13  	"regexp"
    14  	"strconv"
    15  	"strings"
    16  	"sync/atomic"
    17  	"syscall"
    18  
    19  	"github.com/cilium/dns"
    20  	"github.com/sirupsen/logrus"
    21  	"golang.org/x/sync/semaphore"
    22  	"golang.org/x/sys/unix"
    23  
    24  	"github.com/cilium/cilium/pkg/datapath/linux/linux_defaults"
    25  	"github.com/cilium/cilium/pkg/endpoint"
    26  	"github.com/cilium/cilium/pkg/fqdn/matchpattern"
    27  	"github.com/cilium/cilium/pkg/fqdn/proxy/ipfamily"
    28  	"github.com/cilium/cilium/pkg/fqdn/restore"
    29  	"github.com/cilium/cilium/pkg/identity"
    30  	ippkg "github.com/cilium/cilium/pkg/ip"
    31  	"github.com/cilium/cilium/pkg/ipcache"
    32  	"github.com/cilium/cilium/pkg/lock"
    33  	"github.com/cilium/cilium/pkg/logging"
    34  	"github.com/cilium/cilium/pkg/logging/logfields"
    35  	"github.com/cilium/cilium/pkg/option"
    36  	"github.com/cilium/cilium/pkg/policy"
    37  	"github.com/cilium/cilium/pkg/proxy/accesslog"
    38  	"github.com/cilium/cilium/pkg/spanstat"
    39  	"github.com/cilium/cilium/pkg/time"
    40  )
    41  
    42  const (
    43  	// ProxyForwardTimeout is the maximum time to wait for DNS responses to
    44  	// forwarded DNS requests. This is needed since UDP queries have no way to
    45  	// indicate that the client has stopped expecting a response.
    46  	ProxyForwardTimeout = 10 * time.Second
    47  
    48  	// ProxyBindTimeout is how long we wait for a successful bind to the bindaddr.
    49  	// Note: This must be divisible by 5 without going to 0
    50  	ProxyBindTimeout = 20 * time.Second
    51  
    52  	// ProxyBindRetryInterval is how long to wait between attempts to bind to the
    53  	// proxy address:port
    54  	ProxyBindRetryInterval = ProxyBindTimeout / 5
    55  )
    56  
    57  // DNSProxy is a L7 proxy for DNS traffic. It keeps a list of allowed DNS
    58  // lookups that can be regexps and blocks lookups that are not allowed.
    59  // A singleton is always running inside cilium-agent.
    60  // Note: All public fields are read only and do not require locking
    61  type DNSProxy struct {
    62  	// BindPort is the port in BindAddr.
    63  	BindPort uint16
    64  
    65  	// LookupRegisteredEndpoint is a provided callback that returns the endpoint ID
    66  	// as a uint16.
    67  	// Note: this is a little pointless since this proxy is in-process but it is
    68  	// intended to allow us to switch to an external proxy process by forcing the
    69  	// design now.
    70  	LookupRegisteredEndpoint LookupEndpointIDByIPFunc
    71  
    72  	// LookupSecIDByIP is a provided callback that returns the IP's security ID
    73  	// from the ipcache.
    74  	// Note: this is a little pointless since this proxy is in-process but it is
    75  	// intended to allow us to switch to an external proxy process by forcing the
    76  	// design now.
    77  	LookupSecIDByIP LookupSecIDByIPFunc
    78  
    79  	// LookupIPsBySecID is a provided callback that returns the IPs by security ID
    80  	// from the ipcache.
    81  	LookupIPsBySecID LookupIPsBySecIDFunc
    82  
    83  	// NotifyOnDNSMsg is a provided callback by which the proxy can emit DNS
    84  	// response data. It is intended to wire into a DNS cache and a
    85  	// fqdn.NameManager.
    86  	// Note: this is a little pointless since this proxy is in-process but it is
    87  	// intended to allow us to switch to an external proxy process by forcing the
    88  	// design now.
    89  	NotifyOnDNSMsg NotifyOnDNSMsgFunc
    90  
    91  	// DNSServers are the cilium/dns server instances.
    92  	// Depending on the configuration, these might be
    93  	// TCPv4, UDPv4, TCPv6 and/or UDPv4.
    94  	// They handle DNS parsing etc. for us.
    95  	DNSServers []*dns.Server
    96  
    97  	// EnableDNSCompression allows the DNS proxy to compress responses to
    98  	// endpoints that are larger than 512 Bytes or the EDNS0 option, if present.
    99  	EnableDNSCompression bool
   100  
   101  	// ConcurrencyLimit limits parallel goroutines number that serve DNS
   102  	ConcurrencyLimit *semaphore.Weighted
   103  	// ConcurrencyGracePeriod is the grace period for waiting on
   104  	// ConcurrencyLimit before timing out
   105  	ConcurrencyGracePeriod time.Duration
   106  
   107  	// logLimiter limits log msgs that could be bursty and too verbose.
   108  	// Currently used when ConcurrencyLimit is set.
   109  	logLimiter logging.Limiter
   110  
   111  	// lookupTargetDNSServer extracts the originally intended target of a DNS
   112  	// query. It is always set to lookupTargetDNSServer in
   113  	// helpers.go but is modified during testing.
   114  	lookupTargetDNSServer func(w dns.ResponseWriter) (serverIP net.IP, serverPort restore.PortProto, addrStr string, err error)
   115  
   116  	// maxIPsPerRestoredDNSRule is the maximum number of IPs to maintain for each
   117  	// restored DNS rule.
   118  	maxIPsPerRestoredDNSRule int
   119  
   120  	// this mutex protects variables below this point
   121  	lock.RWMutex
   122  
   123  	// DNSClients is a container for dns.SharedClient instances.
   124  	DNSClients *SharedClients
   125  
   126  	// usedServers is the set of DNS servers that have been allowed and used successfully.
   127  	// This is used to limit the number of IPs we store for restored DNS rules.
   128  	usedServers map[string]struct{}
   129  
   130  	// allowed tracks all allowed L7 DNS rules by endpointID, destination port,
   131  	// and L3 Selector. All must match for a query to be allowed.
   132  	//
   133  	// Note: Simple DNS names, e.g. bar.foo.com, will treat the "." as a literal.
   134  	allowed perEPAllow
   135  
   136  	// restored is a set of rules restored from a previous instance that can be
   137  	// used until 'allowed' rules for an endpoint are first initialized after
   138  	// a restart
   139  	restored perEPRestored
   140  
   141  	// cache is an internal structure to keep track of all the in use DNS rules. We do that
   142  	// so that we avoid storing multiple similar versions of the same rules, so that we can improve
   143  	// performance and reduce memory consumption when multiple endpoints or ports have similar rules.
   144  	cache regexCache
   145  
   146  	// mapping restored endpoint IP (both IPv4 and IPv6) to *Endpoint
   147  	restoredEPs restoredEPs
   148  
   149  	// rejectReply is the OPCode send from the DNS-proxy to the endpoint if the
   150  	// DNS request is invalid
   151  	rejectReply atomic.Int32
   152  
   153  	// UnbindAddress unbinds dns servers from socket in order to stop serving DNS traffic before proxy shutdown
   154  	unbindAddress func()
   155  }
   156  
   157  // regexCacheEntry is a lookup entry used to cache a compiled regex
   158  // and how many references it has
   159  type regexCacheEntry struct {
   160  	regex          *regexp.Regexp
   161  	referenceCount int
   162  }
   163  
   164  // regexCache is a reference counted cache used for reusing the compiled regex when multiple policies
   165  // have the same set of rules, or the same rule applies to multiple endpoints.
   166  type regexCache map[string]*regexCacheEntry
   167  
   168  // perEPAllow maps EndpointIDs to protocols + ports + selectors + rules
   169  type perEPAllow map[uint64]portProtoToSelectorAllow
   170  
   171  // portProtoToSelectorAllow maps protocol-port numbers to selectors + rules
   172  type portProtoToSelectorAllow map[restore.PortProto]CachedSelectorREEntry
   173  
   174  // CachedSelectorREEntry maps port numbers to selectors to rules, mirroring
   175  // policy.L7DataMap but the DNS rules are compiled into a regex
   176  type CachedSelectorREEntry map[policy.CachedSelector]*regexp.Regexp
   177  
   178  // structure for restored rules that can be used while Cilium agent is restoring endpoints
   179  type perEPRestored map[uint64]map[restore.PortProto][]restoredIPRule
   180  
   181  // restoredIPRule is the dnsproxy internal way of representing a restored IPRule
   182  // where we also store the actual compiled regular expression as a, as well
   183  // as the original restored IPRule
   184  type restoredIPRule struct {
   185  	restore.IPRule
   186  	regex *regexp.Regexp
   187  }
   188  
   189  // map from EP IPs to *Endpoint
   190  type restoredEPs map[netip.Addr]*endpoint.Endpoint
   191  
   192  // asIPRule returns a new restore.IPRule representing the rules, including the provided IP map.
   193  func asIPRule(r *regexp.Regexp, IPs map[string]struct{}) restore.IPRule {
   194  	pattern := "^-$"
   195  	if r != nil {
   196  		pattern = r.String()
   197  	}
   198  	return restore.IPRule{IPs: IPs, Re: restore.RuleRegex{Pattern: &pattern}}
   199  }
   200  
   201  // CheckRestored checks endpointID, destPort, destIP, and name against the restored rules,
   202  // and only returns true if a restored rule matches.
   203  func (p *DNSProxy) checkRestored(endpointID uint64, destPortProto restore.PortProto, destIP string, name string) bool {
   204  	ipRules, exists := p.restored[endpointID][destPortProto]
   205  	if !exists && destPortProto.IsPortV2() {
   206  		// Check if there is a Version 1 restore.
   207  		ipRules, exists = p.restored[endpointID][destPortProto.ToV1()]
   208  		log.WithFields(logrus.Fields{
   209  			logfields.EndpointID: endpointID,
   210  			logfields.Port:       destPortProto.Port(),
   211  			logfields.Protocol:   destPortProto.Protocol(),
   212  		}).Debugf("Checking if restored V1 IP rules (exists: %t) for endpoint: %+v", exists, ipRules)
   213  		if !exists {
   214  			return false
   215  		}
   216  	}
   217  
   218  	for i := range ipRules {
   219  		ipRule := ipRules[i]
   220  		if _, exists := ipRule.IPs[destIP]; exists || len(ipRule.IPs) == 0 {
   221  			if ipRule.regex != nil && ipRule.regex.MatchString(name) {
   222  				return true
   223  			}
   224  		}
   225  	}
   226  	return false
   227  }
   228  
   229  // skipIPInRestorationRLocked skips IPs that are allowed but have never been used,
   230  // but only if at least one server has been used so far.
   231  // Requires the RLock to be held.
   232  func (p *DNSProxy) skipIPInRestorationRLocked(ip string) bool {
   233  	if len(p.usedServers) > 0 {
   234  		if _, used := p.usedServers[ip]; !used {
   235  			return true
   236  		}
   237  	}
   238  	return false
   239  }
   240  
   241  // GetRules creates a fresh copy of EP's DNS rules to be stored
   242  // for later restoration.
   243  func (p *DNSProxy) GetRules(endpointID uint16) (restore.DNSRules, error) {
   244  	// Lock ordering note: Acquiring the IPCache read lock (as LookupIPsBySecID does) while holding
   245  	// the proxy lock can lead to a deadlock. Avoid this by reading the state from DNSProxy while
   246  	// holding the read lock, then perform the IPCache lookups.
   247  	// Note that IPCache state may change in between calls to LookupIPsBySecID.
   248  	p.RLock()
   249  
   250  	type selRegex struct {
   251  		re *regexp.Regexp
   252  		cs policy.CachedSelector
   253  	}
   254  
   255  	portProtoToSelRegex := make(map[restore.PortProto][]selRegex)
   256  	for pp, entries := range p.allowed[uint64(endpointID)] {
   257  		nidRules := make([]selRegex, 0, len(entries))
   258  		// Copy the entries to avoid racy map accesses after we release the lock. We don't need
   259  		// constant time access, hence a preallocated slice instead of another map.
   260  		for cs, regex := range entries {
   261  			nidRules = append(nidRules, selRegex{cs: cs, re: regex})
   262  		}
   263  		portProtoToSelRegex[pp] = nidRules
   264  	}
   265  
   266  	// We've read what we need from the proxy. The following IPCache lookups _must_ occur outside of
   267  	// the critical section.
   268  	p.RUnlock()
   269  
   270  	restored := make(restore.DNSRules)
   271  	for pp, selRegexes := range portProtoToSelRegex {
   272  		var ipRules restore.IPRules
   273  		for _, selRegex := range selRegexes {
   274  			if selRegex.cs.IsWildcard() {
   275  				ipRules = append(ipRules, asIPRule(selRegex.re, nil))
   276  				continue
   277  			}
   278  			ips := make(map[string]struct{})
   279  			count := 0
   280  			nids := selRegex.cs.GetSelections()
   281  		Loop:
   282  			for _, nid := range nids {
   283  				// Note: p.RLock must not be held during this call to IPCache
   284  				nidIPs := p.LookupIPsBySecID(nid)
   285  				p.RLock()
   286  				for _, ip := range nidIPs {
   287  					if p.skipIPInRestorationRLocked(ip) {
   288  						continue
   289  					}
   290  					ips[ip] = struct{}{}
   291  					count++
   292  					if count > p.maxIPsPerRestoredDNSRule {
   293  						log.WithFields(logrus.Fields{
   294  							logfields.EndpointID:            endpointID,
   295  							logfields.Port:                  pp.Port(),
   296  							logfields.Protocol:              pp.Protocol(),
   297  							logfields.EndpointLabelSelector: selRegex.cs,
   298  							logfields.Limit:                 p.maxIPsPerRestoredDNSRule,
   299  							logfields.Count:                 len(nidIPs),
   300  						}).Warning("Too many IPs for a DNS rule, skipping the rest")
   301  						p.RUnlock()
   302  						break Loop
   303  					}
   304  				}
   305  				p.RUnlock()
   306  			}
   307  			ipRules = append(ipRules, asIPRule(selRegex.re, ips))
   308  		}
   309  		restored[pp] = ipRules
   310  	}
   311  
   312  	return restored, nil
   313  }
   314  
   315  // RestoreRules is used in the beginning of endpoint restoration to
   316  // install rules saved before the restart to be used before the endpoint
   317  // is regenerated.
   318  // 'ep' passed in is not fully functional yet, but just unmarshaled from JSON!
   319  func (p *DNSProxy) RestoreRules(ep *endpoint.Endpoint) {
   320  	p.Lock()
   321  	defer p.Unlock()
   322  	if ep.IPv4.IsValid() {
   323  		p.restoredEPs[ep.IPv4] = ep
   324  	}
   325  	if ep.IPv6.IsValid() {
   326  		p.restoredEPs[ep.IPv6] = ep
   327  	}
   328  	// Use V2 if it is populated, otherwise
   329  	// use V1.
   330  	dnsRules := ep.DNSRulesV2
   331  	if len(dnsRules) == 0 && len(ep.DNSRules) > 0 {
   332  		dnsRules = ep.DNSRules
   333  	}
   334  	restoredRules := make(map[restore.PortProto][]restoredIPRule, len(ep.DNSRules))
   335  	for pp, dnsRule := range dnsRules {
   336  		ipRules := make([]restoredIPRule, 0, len(dnsRule))
   337  		for _, ipRule := range dnsRule {
   338  			if ipRule.Re.Pattern == nil {
   339  				continue
   340  			}
   341  			regex, err := p.cache.lookupOrCompileRegex(*ipRule.Re.Pattern)
   342  			if err != nil {
   343  				log.WithFields(logrus.Fields{
   344  					logfields.EndpointID: ep.ID,
   345  					logfields.Rule:       *ipRule.Re.Pattern,
   346  				}).Info("Disregarding restored DNS rule due to failure in compiling regex. Traffic to the FQDN may be disrupted.")
   347  				continue
   348  			}
   349  			rule := restoredIPRule{
   350  				IPRule: ipRule,
   351  				regex:  regex,
   352  			}
   353  			ipRules = append(ipRules, rule)
   354  		}
   355  		restoredRules[pp] = ipRules
   356  	}
   357  	p.restored[uint64(ep.ID)] = restoredRules
   358  
   359  	log.Debugf("Restored rules for endpoint %d: %v", ep.ID, dnsRules)
   360  }
   361  
   362  // 'p' must be locked
   363  func (p *DNSProxy) removeRestoredRulesLocked(endpointID uint64) {
   364  	if _, exists := p.restored[endpointID]; exists {
   365  		// Remove IP->ID mappings for the restored EP
   366  		for ip, ep := range p.restoredEPs {
   367  			if ep.ID == uint16(endpointID) {
   368  				delete(p.restoredEPs, ip)
   369  			}
   370  		}
   371  		for _, rule := range p.restored[endpointID] {
   372  			for _, r := range rule {
   373  				p.cache.releaseRegex(r.regex)
   374  			}
   375  		}
   376  		delete(p.restored, endpointID)
   377  	}
   378  }
   379  
   380  // RemoveRestoredRules removes all restored rules for 'endpointID'.
   381  func (p *DNSProxy) RemoveRestoredRules(endpointID uint16) {
   382  	p.Lock()
   383  	defer p.Unlock()
   384  	p.removeRestoredRulesLocked(uint64(endpointID))
   385  }
   386  
   387  // lookupOrCompileRegex will check if the pattern is already compiled and present in another policy, and
   388  // will reuse it in order to reduce memory consumption. The usage is reference counted, so all calls where
   389  // lookupOrCompileRegex returns no error, a subsequent call to release it via releaseRegex has to
   390  // be done when it's no longer being used by the policy.
   391  func (c regexCache) lookupOrCompileRegex(pattern string) (*regexp.Regexp, error) {
   392  	if entry, ok := c[pattern]; ok {
   393  		entry.referenceCount += 1
   394  		return entry.regex, nil
   395  	}
   396  	regex, err := regexp.Compile(pattern)
   397  	if err != nil {
   398  		return nil, err
   399  	}
   400  	c[pattern] = &regexCacheEntry{regex: regex, referenceCount: 1}
   401  	return regex, nil
   402  }
   403  
   404  // lookupOrInsertRegex is equivalent to lookupOrCompileRegex, but a compiled regex is provided
   405  // instead of the pattern. In case a compiled regex with the same pattern as the provided regex is already present in
   406  // the cache, the already present regex will be returned. By doing that, the duplicate can be garbage collected in case
   407  // there are no other references to it. Trying to insert a nil value is a noop and will return nil
   408  func (c regexCache) lookupOrInsertRegex(regex *regexp.Regexp) *regexp.Regexp {
   409  	if regex == nil {
   410  		return nil
   411  	}
   412  	pattern := regex.String()
   413  	if entry, ok := c[pattern]; ok {
   414  		entry.referenceCount += 1
   415  		return entry.regex
   416  	}
   417  	c[pattern] = &regexCacheEntry{regex: regex, referenceCount: 1}
   418  	return regex
   419  }
   420  
   421  // releaseRegex releases the provided regex. In case there are no longer any references to it,
   422  // it will be freed. Running release on a nil value is a noop.
   423  func (c regexCache) releaseRegex(regex *regexp.Regexp) {
   424  	if regex == nil {
   425  		return
   426  	}
   427  	pattern := regex.String()
   428  	if indexEntry, ok := c[pattern]; ok {
   429  		switch indexEntry.referenceCount {
   430  		case 1:
   431  			delete(c, pattern)
   432  		default:
   433  			indexEntry.referenceCount -= 1
   434  		}
   435  	}
   436  }
   437  
   438  // removeAndReleasePortRulesForID removes the old port rules for the given destPort on the given endpointID. It also
   439  // releases the regexes so that unused regex can be freed from memory.
   440  
   441  func (allow perEPAllow) removeAndReleasePortRulesForID(cache regexCache, endpointID uint64, destPortProto restore.PortProto) {
   442  	epPortProtos, hasEpPortProtos := allow[endpointID]
   443  	if !hasEpPortProtos {
   444  		return
   445  	}
   446  	for _, m := range epPortProtos[destPortProto] {
   447  		cache.releaseRegex(m)
   448  	}
   449  	delete(epPortProtos, destPortProto)
   450  	if len(epPortProtos) == 0 {
   451  		delete(allow, endpointID)
   452  	}
   453  }
   454  
   455  // setPortRulesForID sets the matching rules for endpointID and destPort for
   456  // later lookups. It converts newRules into a compiled regex
   457  func (allow perEPAllow) setPortRulesForID(cache regexCache, endpointID uint64, destPortProto restore.PortProto, newRules policy.L7DataMap) error {
   458  	if len(newRules) == 0 {
   459  		allow.removeAndReleasePortRulesForID(cache, endpointID, destPortProto)
   460  		return nil
   461  	}
   462  	cse := make(CachedSelectorREEntry, len(newRules))
   463  	var err error
   464  	for selector, newRuleset := range newRules {
   465  		pattern := GeneratePattern(newRuleset)
   466  
   467  		var regex *regexp.Regexp
   468  		regex, err = cache.lookupOrCompileRegex(pattern)
   469  		if err != nil {
   470  			break
   471  		}
   472  		cse[selector] = regex
   473  	}
   474  	if err != nil {
   475  		// Unregister the registered regexes before returning the error to avoid
   476  		// leaving unused references in the cache
   477  		for k, regex := range cse {
   478  			cache.releaseRegex(regex)
   479  			delete(cse, k)
   480  		}
   481  		return err
   482  	}
   483  	allow.removeAndReleasePortRulesForID(cache, endpointID, destPortProto)
   484  	epPortProtos, exist := allow[endpointID]
   485  	if !exist {
   486  		epPortProtos = make(portProtoToSelectorAllow)
   487  		allow[endpointID] = epPortProtos
   488  	}
   489  	epPortProtos[destPortProto] = cse
   490  	return nil
   491  }
   492  
   493  // setPortRulesForIDFromUnifiedFormat sets the matching rules for endpointID and destPort for
   494  // later lookups. It does not guarantee it will reuse all the provided regexes, since it will reuse
   495  // already existing regexes with the same pattern in case they are already in use.
   496  func (allow perEPAllow) setPortRulesForIDFromUnifiedFormat(cache regexCache, endpointID uint64, destPortProto restore.PortProto, newRules CachedSelectorREEntry) error {
   497  	if len(newRules) == 0 {
   498  		allow.removeAndReleasePortRulesForID(cache, endpointID, destPortProto)
   499  		return nil
   500  	}
   501  	cse := make(CachedSelectorREEntry, len(newRules))
   502  	for selector, providedRegex := range newRules {
   503  		// In case the regex is already compiled and in use in another regex, lookupOrInsertRegex
   504  		// will return a ref. to the existing regex, and use that one.
   505  		cse[selector] = cache.lookupOrInsertRegex(providedRegex)
   506  	}
   507  
   508  	allow.removeAndReleasePortRulesForID(cache, endpointID, destPortProto)
   509  	epPortProtos, exist := allow[endpointID]
   510  	if !exist {
   511  		epPortProtos = make(portProtoToSelectorAllow)
   512  		allow[endpointID] = epPortProtos
   513  	}
   514  	epPortProtos[destPortProto] = cse
   515  	return nil
   516  }
   517  
   518  // getPortRulesForID returns a precompiled regex representing DNS rules for the
   519  // passed-in endpointID and destPort with setPortRulesForID
   520  func (allow perEPAllow) getPortRulesForID(endpointID uint64, destPortProto restore.PortProto) (rules CachedSelectorREEntry, exists bool) {
   521  	rules, exists = allow[endpointID][destPortProto]
   522  	if !exists && destPortProto.Protocol() != 0 {
   523  		rules, exists = allow[endpointID][destPortProto.ToV1()]
   524  		log.WithFields(logrus.Fields{
   525  			logfields.EndpointID: endpointID,
   526  			logfields.Port:       destPortProto.Port(),
   527  			logfields.Protocol:   destPortProto.Protocol(),
   528  		}).Debugf("Checking for V1 port rule (exists: %t) for endpoint: %+v", exists, rules)
   529  	}
   530  	return
   531  }
   532  
   533  // LookupEndpointIDByIPFunc wraps logic to lookup an endpoint with any backend.
   534  // See DNSProxy.LookupRegisteredEndpoint for usage.
   535  type LookupEndpointIDByIPFunc func(ip netip.Addr) (endpoint *endpoint.Endpoint, err error)
   536  
   537  // LookupSecIDByIPFunc Func wraps logic to lookup an IP's security ID from the
   538  // ipcache.
   539  // See DNSProxy.LookupSecIDByIP for usage.
   540  type LookupSecIDByIPFunc func(ip netip.Addr) (secID ipcache.Identity, exists bool)
   541  
   542  // LookupIPsBySecIDFunc Func wraps logic to lookup an IPs by security ID from the
   543  // ipcache.
   544  type LookupIPsBySecIDFunc func(nid identity.NumericIdentity) []string
   545  
   546  // NotifyOnDNSMsgFunc handles propagating DNS response data
   547  // See DNSProxy.LookupEndpointIDByIP for usage.
   548  type NotifyOnDNSMsgFunc func(lookupTime time.Time, ep *endpoint.Endpoint, epIPPort string, serverID identity.NumericIdentity, serverAddr string, msg *dns.Msg, protocol string, allowed bool, stat *ProxyRequestContext) error
   549  
   550  // ErrFailedAcquireSemaphore is an an error representing the DNS proxy's
   551  // failure to acquire the semaphore. This is error is treated like a timeout.
   552  type ErrFailedAcquireSemaphore struct {
   553  	parallel int
   554  }
   555  
   556  func (e ErrFailedAcquireSemaphore) Timeout() bool { return true }
   557  
   558  // Temporary is deprecated. Return false.
   559  func (e ErrFailedAcquireSemaphore) Temporary() bool { return false }
   560  
   561  func (e ErrFailedAcquireSemaphore) Error() string {
   562  	return fmt.Sprintf(
   563  		"failed to acquire DNS proxy semaphore, %d parallel requests already in-flight",
   564  		e.parallel,
   565  	)
   566  }
   567  
   568  // ErrTimedOutAcquireSemaphore is an an error representing the DNS proxy timing
   569  // out when acquiring the semaphore. It is treated the same as
   570  // ErrTimedOutAcquireSemaphore.
   571  type ErrTimedOutAcquireSemaphore struct {
   572  	ErrFailedAcquireSemaphore
   573  
   574  	gracePeriod time.Duration
   575  }
   576  
   577  func (e ErrTimedOutAcquireSemaphore) Error() string {
   578  	return fmt.Sprintf(
   579  		"timed out after %v acquiring DNS proxy semaphore, %d parallel requests already in-flight",
   580  		e.gracePeriod,
   581  		e.parallel,
   582  	)
   583  }
   584  
   585  // ErrDNSRequestNoEndpoint represents an error when the local daemon cannot
   586  // find the corresponding endpoint that triggered a DNS request processed by
   587  // the local DNS proxy (FQDN proxy).
   588  type ErrDNSRequestNoEndpoint struct{}
   589  
   590  func (ErrDNSRequestNoEndpoint) Error() string {
   591  	return "DNS request cannot be associated with an existing endpoint"
   592  }
   593  
   594  // ProxyRequestContext proxy dns request context struct to send in the callback
   595  type ProxyRequestContext struct {
   596  	TotalTime      spanstat.SpanStat
   597  	ProcessingTime spanstat.SpanStat // This is going to happen at the end of the second callback.
   598  	// Error is a enum of [timeout, allow, denied, proxyerr].
   599  	UpstreamTime         spanstat.SpanStat
   600  	SemaphoreAcquireTime spanstat.SpanStat
   601  	PolicyCheckTime      spanstat.SpanStat
   602  	PolicyGenerationTime spanstat.SpanStat
   603  	DataplaneTime        spanstat.SpanStat
   604  	Success              bool
   605  	Err                  error
   606  	DataSource           accesslog.DNSDataSource
   607  }
   608  
   609  // IsTimeout return true if the ProxyRequest timeout
   610  func (proxyStat *ProxyRequestContext) IsTimeout() bool {
   611  	var neterr net.Error
   612  	if errors.As(proxyStat.Err, &neterr) {
   613  		return neterr.Timeout()
   614  	}
   615  	return false
   616  }
   617  
   618  // DNSProxyConfig is the configuration for the DNS proxy.
   619  type DNSProxyConfig struct {
   620  	Address                string
   621  	Port                   uint16
   622  	IPv4                   bool
   623  	IPv6                   bool
   624  	EnableDNSCompression   bool
   625  	MaxRestoreDNSIPs       int
   626  	ConcurrencyLimit       int
   627  	ConcurrencyGracePeriod time.Duration
   628  }
   629  
   630  // StartDNSProxy starts a proxy used for DNS L7 redirects that listens on
   631  // address and port on IPv4 and/or IPv6 depending on the values of ipv4/ipv6.
   632  // address is the bind address to listen on. Empty binds to all local
   633  // addresses.
   634  // port is the port to bind to for both UDP and TCP. 0 causes the kernel to
   635  // select a free port.
   636  // lookupEPFunc will be called with the source IP of DNS requests, and expects
   637  // a unique identifier for the endpoint that made the request.
   638  // notifyFunc will be called with DNS response data that is returned to a
   639  // requesting endpoint. Note that denied requests will not trigger this
   640  // callback.
   641  func StartDNSProxy(
   642  	dnsProxyConfig DNSProxyConfig,
   643  	lookupEPFunc LookupEndpointIDByIPFunc,
   644  	lookupSecIDFunc LookupSecIDByIPFunc,
   645  	lookupIPsFunc LookupIPsBySecIDFunc,
   646  	notifyFunc NotifyOnDNSMsgFunc,
   647  ) (*DNSProxy, error) {
   648  	if dnsProxyConfig.Port == 0 {
   649  		log.Debug("DNS Proxy port is configured to 0. A random port will be assigned by the OS.")
   650  	}
   651  
   652  	if lookupEPFunc == nil || notifyFunc == nil {
   653  		return nil, errors.New("DNS proxy must have lookupEPFunc and notifyFunc provided")
   654  	}
   655  
   656  	p := &DNSProxy{
   657  		LookupRegisteredEndpoint: lookupEPFunc,
   658  		LookupSecIDByIP:          lookupSecIDFunc,
   659  		LookupIPsBySecID:         lookupIPsFunc,
   660  		NotifyOnDNSMsg:           notifyFunc,
   661  		logLimiter:               logging.NewLimiter(10*time.Second, 1),
   662  		lookupTargetDNSServer:    lookupTargetDNSServer,
   663  		usedServers:              make(map[string]struct{}),
   664  		allowed:                  make(perEPAllow),
   665  		restored:                 make(perEPRestored),
   666  		restoredEPs:              make(restoredEPs),
   667  		cache:                    make(regexCache),
   668  		EnableDNSCompression:     dnsProxyConfig.EnableDNSCompression,
   669  		maxIPsPerRestoredDNSRule: dnsProxyConfig.MaxRestoreDNSIPs,
   670  		DNSClients:               NewSharedClients(),
   671  	}
   672  	if dnsProxyConfig.ConcurrencyLimit > 0 {
   673  		p.ConcurrencyLimit = semaphore.NewWeighted(int64(dnsProxyConfig.ConcurrencyLimit))
   674  		p.ConcurrencyGracePeriod = dnsProxyConfig.ConcurrencyGracePeriod
   675  	}
   676  	p.rejectReply.Store(dns.RcodeRefused)
   677  
   678  	// Start the DNS listeners on UDP and TCP for IPv4 and/or IPv6
   679  	var (
   680  		dnsServers []*dns.Server
   681  		bindPort   uint16
   682  		err        error
   683  	)
   684  
   685  	start := time.Now()
   686  	for time.Since(start) < ProxyBindTimeout {
   687  		dnsServers, bindPort, err = bindToAddr(dnsProxyConfig.Address, dnsProxyConfig.Port, p, dnsProxyConfig.IPv4, dnsProxyConfig.IPv6, p.DNSClients)
   688  		if err == nil {
   689  			break
   690  		}
   691  		log.WithError(err).Warnf("Attempt to bind DNS Proxy failed, retrying in %v", ProxyBindRetryInterval)
   692  		time.Sleep(ProxyBindRetryInterval)
   693  	}
   694  	if err != nil {
   695  		return nil, fmt.Errorf("failed to bind DNS proxy: %w", err)
   696  	}
   697  
   698  	p.BindPort = bindPort
   699  	p.DNSServers = dnsServers
   700  
   701  	log.WithField("port", bindPort).WithField("addresses", len(dnsServers)).Debug("DNS Proxy bound to addresses")
   702  
   703  	for _, s := range p.DNSServers {
   704  		go func(server *dns.Server) {
   705  			// try 5 times during a single ProxyBindTimeout period. We fatal here
   706  			// because we have no other way to indicate failure this late.
   707  			start := time.Now()
   708  			var err error
   709  			for time.Since(start) < ProxyBindTimeout {
   710  				log.Debugf("Trying to start the %s DNS proxy on %s", server.Net, server.Addr)
   711  
   712  				if err = server.ActivateAndServe(); err != nil {
   713  					log.WithError(err).Errorf("Failed to start the %s DNS proxy on %s", server.Net, server.Addr)
   714  					time.Sleep(ProxyBindRetryInterval)
   715  					continue
   716  				}
   717  				break // successful shutdown before timeout
   718  			}
   719  			if err != nil {
   720  				log.WithError(err).Fatalf("Failed to start the %s DNS proxy on %s", server.Net, server.Addr)
   721  			}
   722  		}(s)
   723  	}
   724  
   725  	// This function is called in proxy.Cleanup, which is added to Daemon cleanup module in bootstrapFQDN
   726  	p.unbindAddress = func() { shutdownServers(p.DNSServers) }
   727  
   728  	return p, nil
   729  }
   730  
   731  func shutdownServers(dnsServers []*dns.Server) {
   732  	for _, s := range dnsServers {
   733  		if err := s.Shutdown(); err != nil {
   734  			log.WithError(err).Errorf("Failed to stop the %s DNS proxy on %s", s.Net, s.Addr)
   735  		}
   736  	}
   737  }
   738  
   739  // LookupEndpointByIP wraps LookupRegisteredEndpoint by falling back to an restored EP, if available
   740  func (p *DNSProxy) LookupEndpointByIP(ip netip.Addr) (endpoint *endpoint.Endpoint, err error) {
   741  	endpoint, err = p.LookupRegisteredEndpoint(ip)
   742  	if err != nil {
   743  		// Check restored endpoints
   744  		endpoint, found := p.restoredEPs[ip]
   745  		if found {
   746  			return endpoint, nil
   747  		}
   748  	}
   749  	return endpoint, err
   750  }
   751  
   752  // UpdateAllowed sets newRules for endpointID and destPort. It compiles the DNS
   753  // rules into regexes that are then used in CheckAllowed.
   754  func (p *DNSProxy) UpdateAllowed(endpointID uint64, destPortProto restore.PortProto, newRules policy.L7DataMap) error {
   755  	p.Lock()
   756  	defer p.Unlock()
   757  
   758  	err := p.allowed.setPortRulesForID(p.cache, endpointID, destPortProto, newRules)
   759  	if err == nil {
   760  		// Rules were updated based on policy, remove restored rules
   761  		p.removeRestoredRulesLocked(endpointID)
   762  	}
   763  	return err
   764  }
   765  
   766  // UpdateAllowedFromSelectorRegexes sets newRules for endpointID and destPort.
   767  func (p *DNSProxy) UpdateAllowedFromSelectorRegexes(endpointID uint64, destPortProto restore.PortProto, newRules CachedSelectorREEntry) error {
   768  	p.Lock()
   769  	defer p.Unlock()
   770  
   771  	err := p.allowed.setPortRulesForIDFromUnifiedFormat(p.cache, endpointID, destPortProto, newRules)
   772  	if err == nil {
   773  		// Rules were updated based on policy, remove restored rules
   774  		p.removeRestoredRulesLocked(endpointID)
   775  	}
   776  	return err
   777  }
   778  
   779  // CheckAllowed checks endpointID, destPortProto, destID, destIP, and name against the rules
   780  // added to the proxy or restored during restart, and only returns true if this all match
   781  // something that was added (via UpdateAllowed or RestoreRules) previously.
   782  func (p *DNSProxy) CheckAllowed(endpointID uint64, destPortProto restore.PortProto, destID identity.NumericIdentity, destIP net.IP, name string) (allowed bool, err error) {
   783  	name = strings.ToLower(dns.Fqdn(name))
   784  	p.RLock()
   785  	defer p.RUnlock()
   786  
   787  	epAllow, exists := p.allowed.getPortRulesForID(endpointID, destPortProto)
   788  	if !exists {
   789  		return p.checkRestored(endpointID, destPortProto, destIP.String(), name), nil
   790  	}
   791  
   792  	for selector, regex := range epAllow {
   793  		// The port was matched in getPortRulesForID, above.
   794  		if regex != nil && selector.Selects(destID) && (regex.String() == matchpattern.MatchAllAnchoredPattern || regex.MatchString(name)) {
   795  			return true, nil
   796  		}
   797  	}
   798  
   799  	return false, nil
   800  }
   801  
   802  // setSoMarks sets the socket options needed for a transparent proxy to integrate it's upstream
   803  // (forwarded) connection with Cilium datapath. Some considerations for this design:
   804  //
   805  //   - Since a transparent proxy must reuse the original source IP address (and we must also
   806  //     intercept the responses), we instruct the host networking namespace to allow binding the
   807  //     local address to a foreign address and to receive packets destined to a non-local (foreign)
   808  //     IP address of the source pod via the IP_TRANSPARENT socket option.
   809  //
   810  //   - In order to NOT hijack some random by-standing traffic going to the original pod, we must also
   811  //     use the original port number.
   812  //
   813  //   - (DNS) clients use ephemeral source ports, i.e., the port can be different in every
   814  //     request. Typically, a DNS resolver library uses the same ephemeral port only for requests
   815  //     from a single "gethostbyname" API call, or equivalent.
   816  //
   817  //   - To be able to receive responses to the ephemeral source port, we must have a socket bound to
   818  //     that address:port (for UDP), or a connection from that address:port to the DNS server
   819  //     address:port (for TCP).
   820  //
   821  //   - This leads to a new DNS client and socket for every different original source address -
   822  //     ephemeral port pair we see. We also need to make sure these were actually used to communicate
   823  //     with the DNS server, so we use the whole 5-tuple as a key.
   824  //
   825  // Why can't we keep DNS clients pooled and ready to receive traffic between client requests?
   826  //
   827  //   - We have no guarantees that the source pod will keep on using the same ephemeral port in
   828  //     future. We've had upstream socket bind errors (in Envoy, where we have operated in this mode
   829  //     for years already) when a client pod has rapidly cycled through its ephemeral port space,
   830  //     e.g. when performing netperf CRR or similar performance tests.
   831  //
   832  //   - We could try to keep the client and its bound socket around for some minimal time to save
   833  //     resources when a DNS resolver is enumerating through its domain suffix list, where it seems
   834  //     likely that the same source ephemeral port is going to be reused until the resolver gets an
   835  //     actual result with an IPv4/6 address or quits trying. It might be safe to close the client
   836  //     socket only after a response with the `A`/`AAAA` records have been passed back to the pod,
   837  //     or after a timeout of a few milliseconds. This would be something we currently don't do and
   838  //     is prone to socket bind errors, so this is left for a later exercise.
   839  //
   840  //   - So the client socket can not be left lingering around, as it causes network traffic destined
   841  //     for the source pod to be intercepted to the dnsproxy, which is exactly what we want but only
   842  //     until a DNS response has been received.
   843  func setSoMarks(fd int, ipFamily ipfamily.IPFamily, secId identity.NumericIdentity) error {
   844  	// Set SO_MARK to allow datapath to know these upstream packets from an egress proxy
   845  	mark := linux_defaults.MagicMarkEgress
   846  	mark |= int(uint32(secId&0xFFFF)<<16 | uint32((secId&0xFF0000)>>16))
   847  	err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, mark)
   848  	if err != nil {
   849  		return fmt.Errorf("error setting SO_MARK: %w", err)
   850  	}
   851  
   852  	// Rest of the options are only set in the transparent mode.
   853  	if !option.Config.DNSProxyEnableTransparentMode {
   854  		return nil
   855  	}
   856  
   857  	// Set IP_TRANSPARENT to be able to use a non-host address as the source address
   858  	if err := unix.SetsockoptInt(fd, ipFamily.SocketOptsFamily, ipFamily.SocketOptsTransparent, 1); err != nil {
   859  		return fmt.Errorf("setsockopt(IP_TRANSPARENT) for %s failed: %w", ipFamily.Name, err)
   860  	}
   861  
   862  	// Set SO_REUSEADDR to allow binding to an address that is already used by some other
   863  	// connection in a lingering state. This is needed in cases where we close a client
   864  	// connection but the client issues new requests re-using its source port. In that case we
   865  	// need to be able to reuse the address likely very soon after the prior close, which may
   866  	// not be allowed without this option.
   867  	if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
   868  		return fmt.Errorf("setsockopt(SO_REUSEADDR) failed: %w", err)
   869  	}
   870  
   871  	// Set SO_REUSEPORT to allow two active connections to bind to the same address and
   872  	// port. Normally this would not be needed, but is set to allow a new connection to be
   873  	// created on a port where the old connection may not yet be closed. If two UDP sockets
   874  	// using the same port due to this option were reading at the same time, the OS stack would
   875  	// distribute incoming packets to them essentially randomly. We do not want that, so we
   876  	// strive to avoid that situation. This may be helpful in avoiding bind errors in some cases
   877  	// regardless.
   878  	if !option.Config.EnableBPFTProxy {
   879  		if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
   880  			return fmt.Errorf("setsockopt(SO_REUSEPORT) failed: %w", err)
   881  		}
   882  	}
   883  
   884  	// Set SO_LINGER to ensure the TCP socket is closed and ready to be re-used in case
   885  	// the client reuses the same source port in short succession (this is e.g. the case
   886  	// with glibc). If SO_LINGER is not used, the old socket might have not yet reached
   887  	// the TIME_WAIT state by the time we are trying to reuse the port on a new socket.
   888  	// If that happens, the connect() call will fail with EADDRNOTAVAIL.
   889  	// Note that the linger timeout can also be set to 0, in which case the socket is
   890  	// terminated forcefully with a TCP RST and thus can also be reused immediately.
   891  	if linger := option.Config.DNSProxySocketLingerTimeout; linger >= 0 {
   892  		err = unix.SetsockoptLinger(fd, unix.SOL_SOCKET, unix.SO_LINGER, &unix.Linger{
   893  			Onoff:  1,
   894  			Linger: int32(linger),
   895  		})
   896  		if err != nil {
   897  			return fmt.Errorf("setsockopt(SO_LINGER) failed: %w", err)
   898  		}
   899  	}
   900  
   901  	return nil
   902  }
   903  
   904  // ServeDNS handles individual DNS requests forwarded to the proxy, and meets
   905  // the dns.Handler interface.
   906  // It will:
   907  //   - Look up the endpoint that sent the request by IP, via LookupEndpointByIP.
   908  //   - Look up the Sec ID of the destination server, via LookupSecIDByIP.
   909  //   - Check that the endpoint ID, destination Sec ID, destination port and the
   910  //     qname all match a rule. If not, the request is dropped.
   911  //   - The allowed request is forwarded to the originally intended DNS server IP
   912  //   - The response is shared via NotifyOnDNSMsg (this will go to a
   913  //     fqdn/NameManager instance).
   914  //   - Write the response to the endpoint.
   915  func (p *DNSProxy) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
   916  	stat := ProxyRequestContext{DataSource: accesslog.DNSSourceProxy}
   917  	stat.TotalTime.Start()
   918  	requestID := request.Id // keep the original request ID
   919  	qname := string(request.Question[0].Name)
   920  	protocol := w.LocalAddr().Network()
   921  	epIPPort := w.RemoteAddr().String()
   922  	scopedLog := log.WithFields(logrus.Fields{
   923  		logfields.DNSName:      qname,
   924  		logfields.IPAddr:       epIPPort,
   925  		logfields.DNSRequestID: requestID,
   926  	})
   927  
   928  	if p.ConcurrencyLimit != nil {
   929  		// TODO: Consider plumbing the daemon context here.
   930  		ctx, cancel := context.WithTimeout(context.TODO(), p.ConcurrencyGracePeriod)
   931  		defer cancel()
   932  
   933  		stat.SemaphoreAcquireTime.Start()
   934  		// Enforce the concurrency limit by attempting to acquire the
   935  		// semaphore.
   936  		if err := p.enforceConcurrencyLimit(ctx); err != nil {
   937  			stat.SemaphoreAcquireTime.End(false)
   938  			if p.logLimiter.Allow() {
   939  				scopedLog.WithError(err).Error("Dropping DNS request due to too many DNS requests already in-flight")
   940  			}
   941  			stat.Err = err
   942  			p.NotifyOnDNSMsg(time.Now(), nil, epIPPort, 0, "", request, protocol, false, &stat)
   943  			p.sendRefused(scopedLog, w, request)
   944  			return
   945  		}
   946  		stat.SemaphoreAcquireTime.End(true)
   947  		defer p.ConcurrencyLimit.Release(1)
   948  	}
   949  	stat.ProcessingTime.Start()
   950  
   951  	scopedLog.Debug("Handling DNS query from endpoint")
   952  
   953  	addrPort, err := netip.ParseAddrPort(epIPPort)
   954  	if err != nil {
   955  		scopedLog.WithError(err).Error("cannot extract endpoint IP from DNS request")
   956  		stat.Err = fmt.Errorf("Cannot extract endpoint IP from DNS request: %w", err)
   957  		stat.ProcessingTime.End(false)
   958  		p.NotifyOnDNSMsg(time.Now(), nil, epIPPort, 0, "", request, protocol, false, &stat)
   959  		p.sendRefused(scopedLog, w, request)
   960  		return
   961  	}
   962  	epAddr := addrPort.Addr()
   963  	ep, err := p.LookupEndpointByIP(epAddr)
   964  	if err != nil {
   965  		scopedLog.WithError(err).Error("cannot extract endpoint ID from DNS request")
   966  		stat.Err = fmt.Errorf("Cannot extract endpoint ID from DNS request: %w", err)
   967  		stat.ProcessingTime.End(false)
   968  		p.NotifyOnDNSMsg(time.Now(), nil, epIPPort, 0, "", request, protocol, false, &stat)
   969  		p.sendRefused(scopedLog, w, request)
   970  		return
   971  	}
   972  
   973  	scopedLog = scopedLog.WithFields(logrus.Fields{
   974  		logfields.EndpointID: ep.StringID(),
   975  		logfields.Identity:   ep.GetIdentity(),
   976  	})
   977  
   978  	targetServerIP, targetServerPortProto, targetServerAddrStr, err := p.lookupTargetDNSServer(w)
   979  	if err != nil {
   980  		log.WithError(err).Error("cannot extract destination IP:port from DNS request")
   981  		stat.Err = fmt.Errorf("Cannot extract destination IP:port from DNS request: %w", err)
   982  		stat.ProcessingTime.End(false)
   983  		p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, 0, targetServerAddrStr, request, protocol, false, &stat)
   984  		p.sendRefused(scopedLog, w, request)
   985  		return
   986  	}
   987  
   988  	// Ignore invalid IP - getter will handle invalid value.
   989  	targetServerAddr, _ := ippkg.AddrFromIP(targetServerIP)
   990  	targetServerID := identity.GetWorldIdentityFromIP(targetServerAddr)
   991  	if serverSecID, exists := p.LookupSecIDByIP(targetServerAddr); !exists {
   992  		scopedLog.WithField("server", targetServerAddrStr).Debug("cannot find server ip in ipcache, defaulting to WORLD")
   993  	} else {
   994  		targetServerID = serverSecID.ID
   995  		scopedLog.WithField("server", targetServerAddrStr).Debugf("Found target server to of DNS request secID %+v", serverSecID)
   996  	}
   997  
   998  	// The allowed check is first because we don't want to use DNS responses that
   999  	// endpoints are not allowed to see.
  1000  	// Note: The cache doesn't know about the source of the DNS data (yet) and so
  1001  	// it won't enforce any separation between results from different endpoints.
  1002  	// This isn't ideal but we are trusting the DNS responses anyway.
  1003  	stat.PolicyCheckTime.Start()
  1004  	allowed, err := p.CheckAllowed(uint64(ep.ID), targetServerPortProto, targetServerID, targetServerIP, qname)
  1005  	stat.PolicyCheckTime.End(err == nil)
  1006  	switch {
  1007  	case err != nil:
  1008  		scopedLog.WithError(err).Error("Rejecting DNS query from endpoint due to error")
  1009  		stat.Err = fmt.Errorf("Rejecting DNS query from endpoint due to error: %w", err)
  1010  		stat.ProcessingTime.End(false)
  1011  		p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
  1012  		p.sendRefused(scopedLog, w, request)
  1013  		return
  1014  
  1015  	case !allowed:
  1016  		scopedLog.Debug("Rejecting DNS query from endpoint due to policy")
  1017  		// Send refused msg before calling NotifyOnDNSMsg() because we know
  1018  		// that this DNS request is rejected anyway. NotifyOnDNSMsg depends on
  1019  		// stat.Err field to be set in order to propagate the correct
  1020  		// information for metrics.
  1021  		stat.Err = p.sendRefused(scopedLog, w, request)
  1022  		stat.ProcessingTime.End(true)
  1023  		p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
  1024  		return
  1025  	}
  1026  
  1027  	scopedLog.Debug("Forwarding DNS request for a name that is allowed")
  1028  	p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, true, &stat)
  1029  
  1030  	// Keep the same L4 protocol. This handles DNS re-requests over TCP, for
  1031  	// requests that were too large for UDP.
  1032  	switch protocol {
  1033  	case "udp":
  1034  	case "tcp":
  1035  	default:
  1036  		scopedLog.Error("Cannot parse DNS proxy client network to select forward client")
  1037  		stat.Err = fmt.Errorf("Cannot parse DNS proxy client network to select forward client: %w", err)
  1038  		stat.ProcessingTime.End(false)
  1039  		p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
  1040  		p.sendRefused(scopedLog, w, request)
  1041  		return
  1042  	}
  1043  	stat.ProcessingTime.End(true)
  1044  	stat.UpstreamTime.Start()
  1045  
  1046  	var ipFamily ipfamily.IPFamily
  1047  	if targetServerAddr.Is4() {
  1048  		ipFamily = ipfamily.IPv4()
  1049  	} else {
  1050  		ipFamily = ipfamily.IPv6()
  1051  	}
  1052  
  1053  	dialer := net.Dialer{
  1054  		Timeout: ProxyForwardTimeout,
  1055  		Control: func(network, address string, c syscall.RawConn) error {
  1056  			var soerr error
  1057  			if err := c.Control(func(su uintptr) {
  1058  				soerr = setSoMarks(int(su), ipFamily, ep.GetIdentity())
  1059  			}); err != nil {
  1060  				return err
  1061  			}
  1062  			return soerr
  1063  		},
  1064  	}
  1065  
  1066  	var key string
  1067  	// Do not use original source address if
  1068  	// - not configured, or if
  1069  	// - the source is known to be in the host networking namespace, or
  1070  	// - the destination is known to be outside of the cluster, or
  1071  	// - is the local host
  1072  	if option.Config.DNSProxyEnableTransparentMode && !ep.IsHost() && !epAddr.IsLoopback() && ep.ID != uint16(identity.ReservedIdentityHost) && targetServerID.IsCluster() && targetServerID != identity.ReservedIdentityHost {
  1073  		dialer.LocalAddr = w.RemoteAddr()
  1074  		key = sharedClientKey(protocol, epIPPort, targetServerAddrStr)
  1075  	}
  1076  
  1077  	conf := &dns.Client{
  1078  		Net:            protocol,
  1079  		Dialer:         &dialer,
  1080  		Timeout:        ProxyForwardTimeout,
  1081  		SingleInflight: false,
  1082  	}
  1083  
  1084  	request.Id = dns.Id() // force a random new ID for this request
  1085  	response, _, closer, err := p.DNSClients.Exchange(key, conf, request, targetServerAddrStr)
  1086  	defer closer()
  1087  
  1088  	stat.UpstreamTime.End(err == nil)
  1089  	if err != nil {
  1090  		stat.Err = err
  1091  		if stat.IsTimeout() {
  1092  			scopedLog.WithError(err).Warn("Timeout waiting for response to forwarded proxied DNS lookup")
  1093  			p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
  1094  			return
  1095  		}
  1096  		scopedLog.WithError(err).Error("Cannot forward proxied DNS lookup")
  1097  		stat.Err = fmt.Errorf("cannot forward proxied DNS lookup: %w", err)
  1098  		p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, request, protocol, false, &stat)
  1099  		p.sendRefused(scopedLog, w, request)
  1100  		return
  1101  	}
  1102  
  1103  	scopedLog.WithField(logfields.Response, response).Debug("Received DNS response to proxied lookup")
  1104  	stat.Success = true
  1105  
  1106  	scopedLog.Debug("Notifying with DNS response to original DNS query")
  1107  	p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, response, protocol, true, &stat)
  1108  
  1109  	scopedLog.Debug("Responding to original DNS query")
  1110  	// restore the ID to the one in the initial request so it matches what the requester expects.
  1111  	response.Id = requestID
  1112  	response.Compress = p.EnableDNSCompression && shouldCompressResponse(request, response)
  1113  	err = w.WriteMsg(response)
  1114  	if err != nil {
  1115  		scopedLog.WithError(err).Error("Cannot forward proxied DNS response")
  1116  		stat.Err = fmt.Errorf("Cannot forward proxied DNS response: %w", err)
  1117  		p.NotifyOnDNSMsg(time.Now(), ep, epIPPort, targetServerID, targetServerAddrStr, response, protocol, true, &stat)
  1118  	} else {
  1119  		p.Lock()
  1120  		// Add the server to the set of used DNS servers. This set is never GCd, but is limited by set
  1121  		// of DNS server IPs that are allowed by a policy and for which successful response was received.
  1122  		p.usedServers[targetServerIP.String()] = struct{}{}
  1123  		p.Unlock()
  1124  	}
  1125  }
  1126  
  1127  func (p *DNSProxy) enforceConcurrencyLimit(ctx context.Context) error {
  1128  	if p.ConcurrencyGracePeriod == 0 {
  1129  		// No grace time configured. Failing to acquire semaphore means
  1130  		// immediately give up.
  1131  		if !p.ConcurrencyLimit.TryAcquire(1) {
  1132  			return ErrFailedAcquireSemaphore{
  1133  				parallel: option.Config.DNSProxyConcurrencyLimit,
  1134  			}
  1135  		}
  1136  	} else if err := p.ConcurrencyLimit.Acquire(ctx, 1); err != nil && errors.Is(err, context.DeadlineExceeded) {
  1137  		// We ignore err because errTimedOutAcquireSemaphore implements the
  1138  		// net.Error interface deeming it a timeout error which will be
  1139  		// treated the same as context.DeadlineExceeded.
  1140  		return ErrTimedOutAcquireSemaphore{
  1141  			ErrFailedAcquireSemaphore: ErrFailedAcquireSemaphore{
  1142  				parallel: option.Config.DNSProxyConcurrencyLimit,
  1143  			},
  1144  			gracePeriod: p.ConcurrencyGracePeriod,
  1145  		}
  1146  	}
  1147  	return nil
  1148  }
  1149  
  1150  // sendRefused creates and sends a REFUSED response for request to w
  1151  // The returned error is logged with scopedLog and is returned for convenience
  1152  func (p *DNSProxy) sendRefused(scopedLog *logrus.Entry, w dns.ResponseWriter, request *dns.Msg) (err error) {
  1153  	refused := new(dns.Msg)
  1154  	refused.SetRcode(request, int(p.rejectReply.Load()))
  1155  
  1156  	if err = w.WriteMsg(refused); err != nil {
  1157  		scopedLog.WithError(err).Error("Cannot send REFUSED response")
  1158  		err = fmt.Errorf("cannot send REFUSED response: %w", err)
  1159  	}
  1160  	return err
  1161  }
  1162  
  1163  // SetRejectReply sets the default reject reply on denied dns responses.
  1164  func (p *DNSProxy) SetRejectReply(opt string) {
  1165  	switch strings.ToLower(opt) {
  1166  	case strings.ToLower(option.FQDNProxyDenyWithNameError):
  1167  		p.rejectReply.Store(dns.RcodeNameError)
  1168  	case strings.ToLower(option.FQDNProxyDenyWithRefused):
  1169  		p.rejectReply.Store(dns.RcodeRefused)
  1170  	default:
  1171  		log.Infof("DNS reject response '%s' is not valid, available options are '%v'",
  1172  			opt, option.FQDNRejectOptions)
  1173  		return
  1174  	}
  1175  }
  1176  
  1177  func (p *DNSProxy) GetBindPort() uint16 {
  1178  	return p.BindPort
  1179  }
  1180  
  1181  // ExtractMsgDetails extracts a canonical query name, any IPs in a response,
  1182  // the lowest applicable TTL, rcode, anwer rr types and question types
  1183  // When a CNAME is returned the chain is collapsed down, keeping the lowest TTL,
  1184  // and CNAME targets are returned.
  1185  func ExtractMsgDetails(msg *dns.Msg) (qname string, responseIPs []net.IP, TTL uint32, CNAMEs []string, rcode int, answerTypes []uint16, qTypes []uint16, err error) {
  1186  	if len(msg.Question) == 0 {
  1187  		return "", nil, 0, nil, 0, nil, nil, errors.New("Invalid DNS message")
  1188  	}
  1189  	qname = strings.ToLower(string(msg.Question[0].Name))
  1190  
  1191  	// rrName is the name the next RR should include.
  1192  	// This will change when we see CNAMEs.
  1193  	rrName := strings.ToLower(qname)
  1194  
  1195  	TTL = math.MaxUint32 // a TTL must exist in the RRs
  1196  
  1197  	answerTypes = make([]uint16, 0, len(msg.Answer))
  1198  	for _, ans := range msg.Answer {
  1199  		// Ensure we have records for DNS names we expect
  1200  		if strings.ToLower(ans.Header().Name) != rrName {
  1201  			return qname, nil, 0, nil, 0, nil, nil, fmt.Errorf("Unexpected name (%s) in RRs for %s (query for %s)", ans, rrName, qname)
  1202  		}
  1203  
  1204  		// Handle A, AAAA and CNAME records by accumulating IPs and lowest TTL
  1205  		switch ans := ans.(type) {
  1206  		case *dns.A:
  1207  			responseIPs = append(responseIPs, ans.A)
  1208  			if TTL > ans.Hdr.Ttl {
  1209  				TTL = ans.Hdr.Ttl
  1210  			}
  1211  		case *dns.AAAA:
  1212  			responseIPs = append(responseIPs, ans.AAAA)
  1213  			if TTL > ans.Hdr.Ttl {
  1214  				TTL = ans.Hdr.Ttl
  1215  			}
  1216  		case *dns.CNAME:
  1217  			// We still track the TTL because the lowest TTL in the chain
  1218  			// determines the valid caching time for the whole response.
  1219  			if TTL > ans.Hdr.Ttl {
  1220  				TTL = ans.Hdr.Ttl
  1221  			}
  1222  			rrName = strings.ToLower(ans.Target)
  1223  			CNAMEs = append(CNAMEs, ans.Target)
  1224  		}
  1225  		answerTypes = append(answerTypes, ans.Header().Rrtype)
  1226  	}
  1227  
  1228  	qTypes = make([]uint16, 0, len(msg.Question))
  1229  	for _, q := range msg.Question {
  1230  		qTypes = append(qTypes, q.Qtype)
  1231  	}
  1232  
  1233  	return qname, responseIPs, TTL, CNAMEs, msg.Rcode, answerTypes, qTypes, nil
  1234  }
  1235  
  1236  // bindToAddr attempts to bind to address and port for both UDP and TCP on IPv4 and/or IPv6.
  1237  // If address is empty it automatically binds to the loopback interfaces on IPv4 and/or IPv6.
  1238  // If port is 0 a random open port is assigned and the same one is used for UDP and TCP.
  1239  // Note: This mimics what the dns package does EXCEPT for setting reuseport.
  1240  // This is ok for now but it would simplify proxy management in the future to
  1241  // have it set.
  1242  func bindToAddr(address string, port uint16, handler dns.Handler, ipv4, ipv6 bool, sc *SharedClients) (dnsServers []*dns.Server, bindPort uint16, err error) {
  1243  	defer func() {
  1244  		if err != nil {
  1245  			shutdownServers(dnsServers)
  1246  		}
  1247  	}()
  1248  
  1249  	var ipFamilies []ipfamily.IPFamily
  1250  	if ipv4 {
  1251  		ipFamilies = append(ipFamilies, ipfamily.IPv4())
  1252  	}
  1253  	if ipv6 {
  1254  		ipFamilies = append(ipFamilies, ipfamily.IPv6())
  1255  	}
  1256  
  1257  	for _, ipFamily := range ipFamilies {
  1258  		lc := listenConfig(linux_defaults.MagicMarkEgress, ipFamily)
  1259  
  1260  		tcpListener, err := lc.Listen(context.Background(), ipFamily.TCPAddress, evaluateAddress(address, port, bindPort, ipFamily))
  1261  		if err != nil {
  1262  			return nil, 0, fmt.Errorf("failed to listen on %s: %w", ipFamily.TCPAddress, err)
  1263  		}
  1264  		if option.Config.DNSProxyEnableTransparentMode {
  1265  			// The wrapper is only necessary to forward the closing signal in transparent mode.
  1266  			tcpListener = &wrappedTCPListener{sc: sc, Listener: tcpListener}
  1267  		}
  1268  		dnsServers = append(dnsServers, &dns.Server{
  1269  			Listener: tcpListener, Handler: handler,
  1270  			// Explicitly set a noop factory to prevent data race detection when InitPool is called
  1271  			// multiple times on the default factory even for TCP (IPv4 & IPv6).
  1272  			SessionUDPFactory: &noopSessionUDPFactory{},
  1273  			// Net & Addr are only set for logging purposes and aren't used if using ActivateAndServe.
  1274  			Net: ipFamily.TCPAddress, Addr: tcpListener.Addr().String(),
  1275  		})
  1276  
  1277  		bindPort = uint16(tcpListener.Addr().(*net.TCPAddr).Port)
  1278  
  1279  		udpConn, err := lc.ListenPacket(context.Background(), ipFamily.UDPAddress, evaluateAddress(address, port, bindPort, ipFamily))
  1280  		if err != nil {
  1281  			return nil, 0, fmt.Errorf("failed to listen on %s: %w", ipFamily.UDPAddress, err)
  1282  		}
  1283  		sessionUDPFactory, ferr := NewSessionUDPFactory(ipFamily)
  1284  		if ferr != nil {
  1285  			return nil, 0, fmt.Errorf("failed to create UDP session factory for %s: %w", ipFamily.UDPAddress, err)
  1286  		}
  1287  		dnsServers = append(dnsServers, &dns.Server{
  1288  			PacketConn: udpConn, Handler: handler, SessionUDPFactory: sessionUDPFactory,
  1289  			// Net & Addr are only set for logging purposes and aren't used if using ActivateAndServe.
  1290  			Net: ipFamily.UDPAddress, Addr: udpConn.LocalAddr().String(),
  1291  		})
  1292  	}
  1293  
  1294  	return dnsServers, bindPort, nil
  1295  }
  1296  
  1297  type wrappedTCPListener struct {
  1298  	net.Listener
  1299  	sc *SharedClients
  1300  }
  1301  
  1302  func (w *wrappedTCPListener) Accept() (net.Conn, error) {
  1303  	c, err := w.Listener.Accept()
  1304  	if err != nil {
  1305  		return c, err
  1306  	}
  1307  
  1308  	wc := &wrappedTCPConn{c.(*net.TCPConn), w.sc}
  1309  	return wc, err
  1310  }
  1311  
  1312  type wrappedTCPConn struct {
  1313  	*net.TCPConn
  1314  	sc *SharedClients
  1315  }
  1316  
  1317  func (w *wrappedTCPConn) key() string {
  1318  	return sharedClientKey("tcp", w.RemoteAddr().String(), w.LocalAddr().String())
  1319  }
  1320  
  1321  func (w *wrappedTCPConn) Read(b []byte) (int, error) {
  1322  	n, err := w.TCPConn.Read(b)
  1323  	if err != nil {
  1324  		// Any error is reason enough to close the upstream conn.
  1325  		w.sc.ShutdownTCPClient(w.key())
  1326  	}
  1327  	return n, err
  1328  }
  1329  func (w *wrappedTCPConn) Write(b []byte) (int, error) {
  1330  	n, err := w.TCPConn.Write(b)
  1331  	if err != nil {
  1332  		w.sc.ShutdownTCPClient(w.key())
  1333  	}
  1334  	return n, err
  1335  }
  1336  
  1337  // Close closes the wrapped connection, but also forwards the closing signal to
  1338  // shared clients, so that the upstream connection is closed too.
  1339  func (w *wrappedTCPConn) Close() error {
  1340  	// It's possible that there is no shared client behind this key, as we don't
  1341  	// always use the original source address. That's okay, since ShutdownClient
  1342  	// does nothing for keys it doesn't know.
  1343  	w.sc.ShutdownTCPClient(w.key())
  1344  	return w.TCPConn.Close()
  1345  }
  1346  
  1347  func evaluateAddress(address string, port uint16, bindPort uint16, ipFamily ipfamily.IPFamily) string {
  1348  	// If the address is ever changed, ensure that the change is also reflected
  1349  	// where the proxy bind address is referenced in the iptables rules. See
  1350  	// (*IptablesManager).doGetProxyPort().
  1351  
  1352  	addr := ipFamily.Localhost
  1353  
  1354  	if address != "" {
  1355  		addr = address
  1356  	}
  1357  
  1358  	if bindPort == 0 {
  1359  		return net.JoinHostPort(addr, strconv.Itoa(int(port)))
  1360  	} else {
  1361  		// Already bound to a port by a previous server -> reuse same port
  1362  		return net.JoinHostPort(addr, strconv.Itoa(int(bindPort)))
  1363  	}
  1364  }
  1365  
  1366  // shouldCompressResponse returns true when the response needs to be compressed
  1367  // for a given request.
  1368  // Originally, DNS was limited to 512 byte responses. EDNS0 allows for larger
  1369  // sizes. In either case, responses can apply DNS compression, and the original
  1370  // RFCs require clients to accept this. In cilium/dns there is a comment that BIND
  1371  // does not support compression, so we retain the ability to suppress this.
  1372  func shouldCompressResponse(request, response *dns.Msg) bool {
  1373  	ednsOptions := request.IsEdns0()
  1374  	responseLenNoCompression := response.Len()
  1375  
  1376  	switch {
  1377  	case ednsOptions != nil && responseLenNoCompression > int(ednsOptions.UDPSize()): // uint16 -> int cast should always be safe
  1378  		return true
  1379  	case responseLenNoCompression > 512:
  1380  		return true
  1381  	}
  1382  
  1383  	return false
  1384  }
  1385  
  1386  // GeneratePattern takes a set of l7Rules and returns a regular expression pattern for matching the
  1387  // provided l7 rules.
  1388  func GeneratePattern(l7Rules *policy.PerSelectorPolicy) (pattern string) {
  1389  	if l7Rules == nil || len(l7Rules.DNS) == 0 {
  1390  		return matchpattern.MatchAllAnchoredPattern
  1391  	}
  1392  	reStrings := make([]string, 0, len(l7Rules.DNS))
  1393  	for _, dnsRule := range l7Rules.DNS {
  1394  		if len(dnsRule.MatchName) > 0 {
  1395  			dnsRuleName := strings.ToLower(dns.Fqdn(dnsRule.MatchName))
  1396  			dnsRuleNameAsRE := matchpattern.ToUnAnchoredRegexp(dnsRuleName)
  1397  			reStrings = append(reStrings, dnsRuleNameAsRE)
  1398  		}
  1399  		if len(dnsRule.MatchPattern) > 0 {
  1400  			dnsPattern := matchpattern.Sanitize(dnsRule.MatchPattern)
  1401  			dnsPatternAsRE := matchpattern.ToUnAnchoredRegexp(dnsPattern)
  1402  			if dnsPatternAsRE == matchpattern.MatchAllUnAnchoredPattern {
  1403  				return matchpattern.MatchAllAnchoredPattern
  1404  			}
  1405  			reStrings = append(reStrings, dnsPatternAsRE)
  1406  		}
  1407  	}
  1408  	return "^(?:" + strings.Join(reStrings, "|") + ")$"
  1409  }
  1410  
  1411  func (p *DNSProxy) Cleanup() {
  1412  	if p.unbindAddress != nil {
  1413  		p.unbindAddress()
  1414  	}
  1415  }