github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns_linux.go (about)

     1  // +build linux
     2  
     3  package dnsproxy
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"net"
     9  	"strconv"
    10  	"sync"
    11  	"syscall"
    12  	"time"
    13  
    14  	"github.com/miekg/dns"
    15  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    16  	"go.aporeto.io/enforcerd/trireme-lib/controller/constants"
    17  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/counters"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/flowtracking"
    19  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/ipsetmanager"
    20  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext"
    21  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    22  	"go.aporeto.io/enforcerd/trireme-lib/utils/cache"
    23  	"go.uber.org/zap"
    24  )
    25  
    26  // removeExpiredEntryFunc is the type of the function that gets called when an IP entry expires (when its TTL hits 0)
    27  type removeExpiredEntryFunc func(string)
    28  
    29  // Proxy struct represents the object for dns proxy
    30  type Proxy struct {
    31  	puFromID                 cache.DataStore
    32  	conntrack                flowtracking.FlowClient
    33  	collector                collector.EventCollector
    34  	contextIDToServer        map[string]*dns.Server
    35  	chreports                chan dnsReport
    36  	contextIDToDNSNames      *cache.Cache
    37  	contextIDToDNSNamesLocks *mutexMap
    38  	IPToTTL                  *cache.Cache
    39  	IPToTTLLocks             *mutexMap
    40  	removeExpiredEntry       removeExpiredEntryFunc
    41  	sync.Mutex
    42  }
    43  type dnsNamesToIP struct {
    44  	nameToIP     map[string][]string
    45  	dnsNamesLock sync.Mutex
    46  }
    47  type dnsttlinfo struct {
    48  	ipaddress string
    49  	ttl       uint32
    50  }
    51  
    52  type iptottlinfo struct {
    53  	ipaddress  string
    54  	expiryTime time.Time
    55  	timer      *time.Timer
    56  	contextIDs map[string]struct{}
    57  	fqdns      map[string]struct{}
    58  }
    59  type serveDNS struct {
    60  	contextID string
    61  	*Proxy
    62  }
    63  
    64  const (
    65  	dnsRequestTimeout = 2 * time.Second
    66  )
    67  
    68  func socketOptions(_, _ string, c syscall.RawConn) error {
    69  	var opErr error
    70  	err := c.Control(func(fd uintptr) {
    71  		if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, constants.ProxyMarkInt); err != nil {
    72  			zap.L().Error("dnsproxy: failed to mark connection", zap.Error(err))
    73  		}
    74  	})
    75  
    76  	if err != nil {
    77  		return err
    78  	}
    79  
    80  	return opErr
    81  }
    82  
    83  func listenUDP(ctx context.Context, network, addr string) (net.PacketConn, error) {
    84  	var lc net.ListenConfig
    85  
    86  	lc.Control = socketOptions
    87  
    88  	return lc.ListenPacket(ctx, network, addr)
    89  }
    90  
    91  func forwardDNSReq(r *dns.Msg, ip net.IP, port uint16) ([]byte, []string, []*dnsttlinfo, error) {
    92  	var ips []string
    93  	var resp []byte
    94  	var msg *dns.Msg
    95  	var conn *dns.Conn
    96  	var err error
    97  
    98  	c := new(dns.Client)
    99  
   100  	dial := func(address string) (*dns.Conn, error) {
   101  		c.Dialer = &net.Dialer{
   102  			Control: func(_, _ string, c syscall.RawConn) error {
   103  				return c.Control(func(fd uintptr) {
   104  					if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, constants.ProxyMarkInt); err != nil {
   105  						zap.L().Error("dnsproxy: failed to assing mark to socket", zap.Error(err))
   106  					}
   107  				})
   108  			},
   109  			Timeout: dnsRequestTimeout,
   110  		}
   111  
   112  		conn, err := c.Dial(address)
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  
   117  		return conn, nil
   118  	}
   119  
   120  	sendRequest := func(r *dns.Msg, conn *dns.Conn) error {
   121  		opt := r.IsEdns0()
   122  		// If EDNS0 is used use that for size.
   123  		if opt != nil && opt.UDPSize() >= dns.MinMsgSize {
   124  			conn.UDPSize = opt.UDPSize()
   125  		}
   126  		// Otherwise use the client's configured UDP size.
   127  		if opt == nil && c.UDPSize >= dns.MinMsgSize {
   128  			conn.UDPSize = c.UDPSize
   129  		}
   130  
   131  		t := time.Now()
   132  		// write with the appropriate write timeout
   133  		if err = conn.SetWriteDeadline(t.Add(c.Dialer.Timeout)); err != nil {
   134  			return err
   135  		}
   136  
   137  		if err = conn.WriteMsg(r); err != nil {
   138  			return err
   139  		}
   140  
   141  		return nil
   142  	}
   143  
   144  	readResponse := func(conn *dns.Conn) ([]byte, *dns.Msg, error) {
   145  		if err := conn.SetReadDeadline(time.Now().Add(c.Dialer.Timeout)); err != nil {
   146  			return nil, nil, err
   147  		}
   148  
   149  		p, err := conn.ReadMsgHeader(nil)
   150  		if err != nil {
   151  			return nil, nil, err
   152  		}
   153  
   154  		m := new(dns.Msg)
   155  		if err := m.Unpack(p); err != nil {
   156  			// If an error was returned, we still want to allow the user to use
   157  			// the message, but naively they can just check err if they don't want
   158  			// to use an erroneous message
   159  			return nil, nil, err
   160  		}
   161  
   162  		return p, m, nil
   163  	}
   164  
   165  	if conn, err = dial(net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))); err != nil {
   166  		return nil, nil, nil, err
   167  	}
   168  
   169  	defer conn.Close() // nolint: errcheck
   170  
   171  	if err := sendRequest(r, conn); err != nil {
   172  		return nil, nil, nil, err
   173  	}
   174  
   175  	if resp, msg, err = readResponse(conn); err != nil {
   176  		return nil, nil, nil, err
   177  	}
   178  	dnsttlinfolist := []*dnsttlinfo{}
   179  
   180  	for _, ans := range msg.Answer {
   181  		if ans.Header().Rrtype == dns.TypeA {
   182  			t, _ := ans.(*dns.A)
   183  
   184  			ips = append(ips, t.A.String())
   185  			dnsttlinfolist = append(dnsttlinfolist, &dnsttlinfo{
   186  				ipaddress: t.A.String(),
   187  				ttl:       ans.Header().Ttl,
   188  			})
   189  		}
   190  
   191  		if ans.Header().Rrtype == dns.TypeAAAA {
   192  			t, _ := ans.(*dns.AAAA)
   193  			ips = append(ips, t.AAAA.String())
   194  
   195  			dnsttlinfolist = append(dnsttlinfolist, &dnsttlinfo{
   196  				ipaddress: t.AAAA.String(),
   197  				ttl:       ans.Header().Ttl,
   198  			})
   199  		}
   200  	}
   201  	return resp, ips, dnsttlinfolist, nil
   202  }
   203  
   204  const (
   205  	strInvalidDNSRequest = "invalid DNS request"
   206  )
   207  
   208  func (s *serveDNS) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
   209  	var err error
   210  	lAddr := w.LocalAddr().(*net.UDPAddr)
   211  	rAddr := w.RemoteAddr().(*net.UDPAddr)
   212  	var pctx *pucontext.PUContext
   213  	var ipsRaw []string
   214  	var origIP net.IP
   215  	var origPort uint16
   216  	var reportError string
   217  
   218  	defer func() {
   219  		if pctx != nil {
   220  			// if there is no question section, this was an invalid request
   221  			name := "invalid"
   222  			if len(r.Question) > 0 {
   223  				name = r.Question[0].Name
   224  			}
   225  			s.reportDNSLookup(name, pctx, rAddr.IP, uint16(rAddr.Port), origIP, origPort, ipsRaw, reportError)
   226  		}
   227  	}()
   228  
   229  	pctxRaw, err := s.puFromID.Get(s.contextID)
   230  	if err != nil {
   231  		zap.L().Error("dnsproxy: context not found for the PU with ID", zap.String("contextID", s.contextID), zap.Error(err))
   232  		reportError = fmt.Sprintf("PU context: %s", err)
   233  		return
   234  	}
   235  	pctx = pctxRaw.(*pucontext.PUContext)
   236  
   237  	// check if the DNS request is actually valid
   238  	// we have seen with the AWS resolve in the past that it *does* respond with an empty Question section
   239  	if len(r.Question) <= 0 {
   240  		pctx.Counters().IncrementCounter(counters.ErrDNSInvalidRequest)
   241  		zap.L().Debug("dnsproxy: invalid DNS request received (missing question section)", zap.String("contextID", s.contextID))
   242  		reportError = strInvalidDNSRequest
   243  		return
   244  	}
   245  
   246  	// TODO: shouldn't we let the lookup go regardless of our problems?
   247  	origIP, origPort, _, err = s.conntrack.GetOriginalDest(net.ParseIP("127.0.0.1"), rAddr.IP, uint16(lAddr.Port), uint16(rAddr.Port), 17)
   248  	if err != nil {
   249  		zap.L().Error("dnsproxy: failed to find flow for the redirected DNS traffic", zap.String("contextID", s.contextID), zap.Error(err))
   250  		reportError = fmt.Sprintf("conntrack: DNS request flow: %s", err)
   251  		return
   252  	}
   253  
   254  	// perform the upstream DNS lookup
   255  	dnsReply, ipsRaw, dnsttlinfolistRaw, err := forwardDNSReq(r, origIP, origPort)
   256  	if err != nil {
   257  		pctx.Counters().IncrementCounter(counters.ErrDNSForwardFailed)
   258  		zap.L().Debug("dnsproxy: forwarded DNS request returned error", zap.String("contextID", s.contextID), zap.Error(err))
   259  		reportError = fmt.Sprintf("DNS request failed: %s", err)
   260  		return
   261  	}
   262  
   263  	// get all policies associated with the FQDN from
   264  	policies, policyName, err1 := pctx.GetPolicyFromFQDN(r.Question[0].Name)
   265  
   266  	// if they exist, then err1 is nil, and we need to update
   267  	// - the ipsets
   268  	// - the applicationacls inside of the enforcer
   269  	// - the internal cache
   270  	if err1 == nil {
   271  		type ipDetail struct {
   272  			ttl        uint32
   273  			updateOnly bool
   274  		}
   275  		ipsToProcess := make(map[string]ipDetail, len(ipsRaw))
   276  		for _, pol := range policies {
   277  			for _, i := range dnsttlinfolistRaw {
   278  				// TODO: this does not work yet - will come in a separate PR
   279  				//if checkIfACLExists(pctx, pol, i.ipaddress) {
   280  				//	// no need to program ACLs and ipsets
   281  				//	// however, there are two cases here:
   282  				//	// 1. this was a static entry in the external network, and we truly want to skip it
   283  				//	// 2. this comes from us programming it down below
   284  				//	// In case (2) we need to actually call handleTTLInfoList but with updateOnly to extend the TTL
   285  				//	// This way no new expiry entries will be made when not necessary as for static entries,
   286  				//	// but they will be extended when necessary as well.
   287  				//	if _, ok := ipsToProcess[i.ipaddress]; !ok {
   288  				//		ipsToProcess[i.ipaddress] = ipDetail{ttl: i.ttl, updateOnly: true}
   289  				//	}
   290  				//	continue
   291  				//}
   292  
   293  				if _, ok := ipsToProcess[i.ipaddress]; !ok {
   294  					ipsToProcess[i.ipaddress] = ipDetail{ttl: i.ttl, updateOnly: false}
   295  				}
   296  				ips := []string{i.ipaddress}
   297  
   298  				// makes sure to update any ipsets related to the serviceID of the policy
   299  				// this matches the case when the destinations are *not* overlapping with the target networks and the decision is not done
   300  				// in the enforcer
   301  				zap.L().Debug("dnsproxy: ipset: adding IP addresses", zap.String("contextID", s.contextID), zap.String("serviceID", pol.Policy.ServiceID), zap.Strings("ipaddresses", ips))
   302  				ipsetmanager.V4().UpdateACLIPsets(ips, pol.Policy.ServiceID)
   303  				ipsetmanager.V6().UpdateACLIPsets(ips, pol.Policy.ServiceID)
   304  
   305  				// makes sure to update the ApplicationACLs inside of the enforcer
   306  				// this matches the case when the destination is overlapping with the target networks, and the decision is made inside
   307  				// of the enforcer, and not with ipsets
   308  				zap.L().Debug("dnsproxy: adding IP addresses to enforcer ApplicationACLs", zap.String("contextID", s.contextID), zap.String("serviceID", pol.Policy.ServiceID), zap.Strings("ipaddresses", ips))
   309  				if err1 := pctx.UpdateApplicationACLs(policy.IPRuleList{{
   310  					Addresses: ips,
   311  					Ports:     pol.Ports,
   312  					Protocols: pol.Protocols,
   313  					Policy:    pol.Policy,
   314  				}}); err1 != nil {
   315  					zap.L().Error("dnsproxy: adding IP rule returned error", zap.String("contextID", s.contextID), zap.Error(err1))
   316  				}
   317  			}
   318  		}
   319  
   320  		// for processing in our caches, we only care about the IPs that we needed to program ACLs/ipsets
   321  		ips := make([]string, 0, len(ipsToProcess))
   322  		var dnsttlinfolist, dnsttlinfolistUpdateOnly []*dnsttlinfo
   323  		for ip, d := range ipsToProcess {
   324  			if d.updateOnly {
   325  				dnsttlinfolistUpdateOnly = append(dnsttlinfolistUpdateOnly, &dnsttlinfo{ipaddress: ip, ttl: d.ttl})
   326  			} else {
   327  				ips = append(ips, ip)
   328  				dnsttlinfolist = append(dnsttlinfolist, &dnsttlinfo{ipaddress: ip, ttl: d.ttl})
   329  			}
   330  		}
   331  
   332  		// update the cache/map for this FQDN with new
   333  		s.updateFQDNWithIPs(s.contextID, policyName, ips)
   334  
   335  		// add or update only if required the expiry entries
   336  		s.handleTTLInfoList(s.contextID, policyName, dnsttlinfolist, false)
   337  		s.handleTTLInfoList(s.contextID, policyName, dnsttlinfolistUpdateOnly, true)
   338  	}
   339  
   340  	configureDependentServices(pctx, r.Question[0].Name, ipsRaw)
   341  
   342  	// write the DNS reply back to the client
   343  	if _, err = w.Write(dnsReply); err != nil {
   344  		pctx.Counters().IncrementCounter(counters.ErrDNSResponseFailed)
   345  		zap.L().Error("dnsproxy: writing DNS response back to the client returned error", zap.String("contextID", s.contextID), zap.Error(err))
   346  	}
   347  }
   348  
   349  // TODO: this does not work yet - will come in a separate PR
   350  // func checkIfACLExists(pctx *pucontext.PUContext, pol policy.PortProtocolPolicy, ipStr string) bool {
   351  // 	ip := net.ParseIP(ipStr)
   352  // 	if ip == nil {
   353  // 		return false
   354  // 	}
   355  // 	for _, protoStr := range pol.Protocols {
   356  // 		proto, err := strconv.Atoi(protoStr)
   357  // 		if err != nil {
   358  // 			continue
   359  // 		}
   360  // 		for _, portStr := range pol.Ports {
   361  // 			// it could be a range definition
   362  // 			var ports []uint16
   363  // 			if strings.Contains(portStr, ":") {
   364  // 				tmp := strings.SplitN(portStr, ":", 2)
   365  // 				if len(tmp) != 2 {
   366  // 					continue
   367  // 				}
   368  // 				startPort, err := strconv.Atoi(tmp[0])
   369  // 				if err != nil {
   370  // 					continue
   371  // 				}
   372  // 				endPort, err := strconv.Atoi(tmp[1])
   373  // 				if err != nil {
   374  // 					continue
   375  // 				}
   376  // 				ports = make([]uint16, 0, endPort-startPort+1)
   377  // 				for i := startPort; i <= endPort; i++ {
   378  // 					ports = append(ports, uint16(i))
   379  // 				}
   380  // 			} else {
   381  // 				port, err := strconv.Atoi(portStr)
   382  // 				if err != nil {
   383  // 					continue
   384  // 				}
   385  // 				ports = []uint16{uint16(port)}
   386  // 			}
   387  
   388  // 			for _, port := range ports {
   389  // 				reportPol, actionPol, err := pctx.ApplicationACLPolicyFromAddr(ip, port, uint8(proto))
   390  // 				if err == nil && (reportPol != nil || actionPol != nil) {
   391  // 					zap.L().Debug("dnsproxy: ACL already found for IP",
   392  // 						zap.String("contextID", pctx.ManagementID()),
   393  // 						zap.String("ipaddress", ipStr),
   394  // 						zap.Any("reportPol", reportPol),
   395  // 						zap.Any("actionPol", actionPol),
   396  // 					)
   397  // 					return true
   398  // 				}
   399  // 			}
   400  // 		}
   401  // 	}
   402  // 	return false
   403  // }
   404  
   405  func (p *Proxy) handleTTLInfoList(contextID, fqdn string, dnsttlinfolist []*dnsttlinfo, updateOnly bool) {
   406  	for _, dnsinfo := range dnsttlinfolist {
   407  		zap.L().Debug("handleTTLInfoList", zap.String("fqdn", fqdn), zap.String("ipaddress", dnsinfo.ipaddress), zap.Bool("updateOnly", updateOnly))
   408  		p.handleTTLInfo(contextID, fqdn, dnsinfo, updateOnly)
   409  	}
   410  }
   411  
   412  func (p *Proxy) handleTTLInfo(contextID, fqdn string, dnsinfo *dnsttlinfo, updateOnly bool) {
   413  	newEntryExpiryTime := time.Now().Add(time.Duration(dnsinfo.ttl) * time.Second)
   414  	ul := p.IPToTTLLocks.Lock(dnsinfo.ipaddress)
   415  	defer ul.Unlock()
   416  	ttlInfoRaw, err := p.IPToTTL.Get(dnsinfo.ipaddress)
   417  	if err != nil {
   418  		// if we are supposed to be updating only
   419  		// then skip this entry
   420  		if updateOnly {
   421  			return
   422  		}
   423  
   424  		// otherwise add a new entry
   425  		newEntry := iptottlinfo{
   426  			ipaddress:  dnsinfo.ipaddress,
   427  			expiryTime: newEntryExpiryTime,
   428  			contextIDs: map[string]struct{}{contextID: {}},
   429  			fqdns:      map[string]struct{}{fqdn: {}},
   430  		}
   431  		// NOTE: the dnsinfo.ipaddress is in a for loop
   432  		// so we need to make sure the IP address is on the stack when the callback is called
   433  		// hence the anonymous function wrapping of the timer
   434  		func(ipaddress string) {
   435  			newEntry.timer = time.AfterFunc(time.Duration(dnsinfo.ttl)*time.Second, func() {
   436  				if p.removeExpiredEntry != nil {
   437  					p.removeExpiredEntry(ipaddress)
   438  				}
   439  			})
   440  		}(dnsinfo.ipaddress)
   441  		if err := p.IPToTTL.Add(dnsinfo.ipaddress, newEntry); err != nil {
   442  			zap.L().Debug("dnsproxy: failed to add entry to IPToTTL cache", zap.String("contextID", contextID), zap.Any("iptottlinfo", newEntry), zap.Error(err))
   443  			return
   444  		}
   445  	} else {
   446  		// update TTL info and reset timer if necessary
   447  		ttlInfo := ttlInfoRaw.(iptottlinfo)
   448  		if newEntryExpiryTime.After(ttlInfo.expiryTime) {
   449  			ttlInfo.timer.Reset(time.Duration(dnsinfo.ttl) * time.Second)
   450  		}
   451  		ttlInfo.expiryTime = newEntryExpiryTime
   452  		ttlInfo.contextIDs[contextID] = struct{}{}
   453  		ttlInfo.fqdns[fqdn] = struct{}{}
   454  		p.IPToTTL.AddOrUpdate(dnsinfo.ipaddress, ttlInfo)
   455  	}
   456  }
   457  
   458  func (p *Proxy) defaultRemoveExpiredEntry(ipaddress string) {
   459  	// retrieve the IPtoTTLInfo
   460  	ul := p.IPToTTLLocks.Lock(ipaddress)
   461  	defer ul.Unlock()
   462  	ttlInfoRaw, err := p.IPToTTL.Get(ipaddress)
   463  	if err != nil {
   464  		zap.L().Debug("dnsproxy: entry already gone from IPToTTL cache", zap.String("ipaddress", ipaddress))
   465  		return
   466  	}
   467  	ttlInfo := ttlInfoRaw.(iptottlinfo)
   468  
   469  	for contextID := range ttlInfo.contextIDs {
   470  		pctxRaw, err := p.puFromID.Get(contextID)
   471  		if err != nil {
   472  			zap.L().Error("dnsproxy: context not found for the PU with ID", zap.String("contextID", contextID))
   473  			continue
   474  		}
   475  		pctx := pctxRaw.(*pucontext.PUContext)
   476  
   477  		for fqdn := range ttlInfo.fqdns {
   478  			policies, policyName, err := pctx.GetPolicyFromFQDN(fqdn)
   479  			if err != nil {
   480  				continue
   481  			}
   482  
   483  			// remove IP address from ipsets
   484  			for _, pol := range policies {
   485  				zap.L().Debug("dnsproxy: ipset: removing IP address", zap.String("contextID", contextID), zap.String("serviceID", pol.Policy.ServiceID), zap.String("ipaddress", ipaddress))
   486  				ipsetmanager.V4().DeleteEntryFromIPset([]string{ipaddress}, pol.Policy.ServiceID)
   487  				ipsetmanager.V6().DeleteEntryFromIPset([]string{ipaddress}, pol.Policy.ServiceID)
   488  
   489  				// remove IP address from enforcer ApplicationACLs
   490  				zap.L().Debug("dnsproxy: removing IP address from enforcer ApplicationACL", zap.String("contextID", contextID), zap.String("ipaddress", ipaddress))
   491  				if err := pctx.RemoveApplicationACL(ipaddress, pol.Protocols, pol.Ports, pol.Policy); err != nil {
   492  					zap.L().Debug("dnsproxy: RemoveApplicationACL failed", zap.String("contextID", contextID), zap.String("serviceID", pol.Policy.ServiceID), zap.String("ipaddress", ipaddress), zap.Error(err))
   493  				}
   494  			}
   495  
   496  			p.removeIPfromFQDN(contextID, policyName, ipaddress)
   497  		}
   498  	}
   499  
   500  	// clean up after ourselves and remove ourselves from the cache
   501  	if err := p.IPToTTL.Remove(ipaddress); err != nil {
   502  		zap.L().Debug("dnsproxy: failed to remove entry from IPToTTL cache", zap.String("ipaddress", ipaddress), zap.Error(err))
   503  	}
   504  	p.IPToTTLLocks.Remove(ipaddress)
   505  
   506  }
   507  
   508  // StartDNSServer starts the dns server on the port provided for contextID
   509  func (p *Proxy) StartDNSServer(ctx context.Context, contextID, port string) error {
   510  	netPacketConn, err := listenUDP(ctx, "udp", "127.0.0.1:"+port)
   511  	if err != nil {
   512  		return err
   513  	}
   514  
   515  	var server *dns.Server
   516  
   517  	storeInMap := func() {
   518  		p.Lock()
   519  		defer p.Unlock()
   520  
   521  		p.contextIDToServer[contextID] = server
   522  	}
   523  
   524  	server = &dns.Server{NotifyStartedFunc: storeInMap, PacketConn: netPacketConn, Handler: &serveDNS{contextID, p}}
   525  
   526  	go func() {
   527  		if err := server.ActivateAndServe(); err != nil {
   528  			zap.L().Error("dnsproxy: could not start DNS proxy server", zap.String("contextID", contextID), zap.Error(err))
   529  		}
   530  	}()
   531  
   532  	return nil
   533  }
   534  
   535  // shutdownDNS shuts down the dns server for contextID
   536  func (p *Proxy) shutdownDNS(contextID string) {
   537  
   538  	if s, ok := p.contextIDToServer[contextID]; ok {
   539  		if err := s.Shutdown(); err != nil {
   540  			zap.L().Error("dnsproxy: shutdown of DNS server returned error", zap.String("contextID", contextID), zap.Error(err))
   541  		}
   542  		delete(p.contextIDToServer, contextID)
   543  	}
   544  }
   545  
   546  // New creates an instance of the dns proxy
   547  func New(ctx context.Context, puFromID cache.DataStore, conntrack flowtracking.FlowClient, c collector.EventCollector) *Proxy {
   548  	ch := make(chan dnsReport)
   549  	p := &Proxy{
   550  		chreports:                ch,
   551  		puFromID:                 puFromID,
   552  		collector:                c,
   553  		conntrack:                conntrack,
   554  		contextIDToServer:        map[string]*dns.Server{},
   555  		contextIDToDNSNames:      cache.NewCache("contextIDtoDNSNames"),
   556  		contextIDToDNSNamesLocks: newMutexMap(),
   557  		IPToTTL:                  cache.NewCache("IPToTTL"),
   558  		IPToTTLLocks:             newMutexMap(),
   559  	}
   560  	p.removeExpiredEntry = p.defaultRemoveExpiredEntry
   561  	go p.reportDNSRequests(ctx, ch)
   562  	return p
   563  }
   564  
   565  // SyncWithPlatformCache is only needed in Windows currently
   566  func (p *Proxy) SyncWithPlatformCache(ctx context.Context, pctx *pucontext.PUContext) error {
   567  	return nil
   568  }
   569  
   570  // HandleDNSResponsePacket is only needed in Windows currently
   571  func (p *Proxy) HandleDNSResponsePacket(dnsPacketData []byte, sourceIP net.IP, sourcePort uint16, destIP net.IP, destPort uint16, puFromContextID func(string) (*pucontext.PUContext, error)) error {
   572  	return nil
   573  }
   574  
   575  // Enforce starts enforcing policies for the given policy.PUInfo.
   576  func (p *Proxy) Enforce(ctx context.Context, contextID string, puInfo *policy.PUInfo) error {
   577  	// during the first Enforce call, we still need to initialize map
   578  	// we do that and return
   579  	ul := p.contextIDToDNSNamesLocks.Lock(contextID)
   580  	defer ul.Unlock()
   581  	tmp, err := p.contextIDToDNSNames.Get(contextID)
   582  	if err != nil {
   583  		// this means that the map is not initialized yet, do so now
   584  		return p.doHandleCreate(ctx, contextID, puInfo)
   585  	}
   586  
   587  	// during a policy refresh, we will enter this part here:
   588  	// - iterate over all DNSACLs for this PU
   589  	// - for all already learned IPs for all DNS names: program ipsets and enforcer ApplicationACLs
   590  	dnsNames := tmp.(*dnsNamesToIP).Copy()
   591  	for fqdn, policies := range puInfo.Policy.DNSACLs {
   592  		ips, ok := dnsNames.nameToIP[fqdn]
   593  		if ok {
   594  			// we have already learned those DNS names
   595  			// make sure to reprogram ipsets and ApplicationACLs in the enforcer
   596  			// on a policy refresh
   597  			for _, pol := range policies {
   598  				zap.L().Debug("dnsproxy: ipset: adding IP addresses after policy refresh", zap.String("contextID", contextID), zap.String("fqdn", fqdn), zap.String("serviceID", pol.Policy.ServiceID), zap.Strings("ipaddresses", ips))
   599  				ipsetmanager.V4().UpdateACLIPsets(ips, pol.Policy.ServiceID)
   600  				ipsetmanager.V6().UpdateACLIPsets(ips, pol.Policy.ServiceID)
   601  
   602  				// makes sure to update the ApplicationACLs inside of the enforcer
   603  				// this matches the case when the destination is overlapping with the target networks, and the decision is made inside
   604  				// of the enforcer, and not with ipsets
   605  				data, err := p.puFromID.Get(contextID)
   606  				if err != nil {
   607  					zap.L().Error("dnsproxy: context not found for the PU with ID", zap.String("contextID", contextID))
   608  					continue
   609  				}
   610  				pctx := data.(*pucontext.PUContext)
   611  				if err1 := pctx.UpdateApplicationACLs(policy.IPRuleList{{
   612  					Addresses: ips,
   613  					Ports:     pol.Ports,
   614  					Protocols: pol.Protocols,
   615  					Policy:    pol.Policy,
   616  				}}); err1 != nil {
   617  					zap.L().Error("dnsproxy: adding IP rule returned error after policy refresh", zap.String("contextID", contextID), zap.Error(err1))
   618  				}
   619  			}
   620  			continue
   621  		}
   622  		// This is a new fqdn. DNS proxy will fix these IPs as it learns them
   623  		dnsNames.nameToIP[fqdn] = []string{}
   624  	}
   625  	// this is only necessary to add new FQDNs to the map - which is essential for the DNS proxy to know about
   626  	p.contextIDToDNSNames.AddOrUpdate(contextID, dnsNames)
   627  	return nil
   628  }
   629  
   630  func (p *Proxy) doHandleCreate(_ context.Context, contextID string, puInfo *policy.PUInfo) error {
   631  	nameToIP := &dnsNamesToIP{
   632  		nameToIP: map[string][]string{},
   633  	}
   634  	for name := range puInfo.Policy.DNSACLs {
   635  		nameToIP.nameToIP[name] = []string{}
   636  	}
   637  	if err := p.contextIDToDNSNames.Add(contextID, nameToIP); err != nil {
   638  		zap.L().Error("dnsproxy: contextID already enforced", zap.String("contextID", contextID))
   639  	}
   640  
   641  	return nil
   642  }
   643  
   644  // Unenforce stops enforcing policy for the given IP.
   645  func (p *Proxy) Unenforce(_ context.Context, contextID string) error {
   646  	p.Lock()
   647  	defer p.Unlock()
   648  	ul := p.contextIDToDNSNamesLocks.Lock(contextID)
   649  	if err := p.contextIDToDNSNames.Remove(contextID); err != nil {
   650  		zap.L().Error("dnsproxy: contextID already removed/unenforced", zap.String("contextID", contextID))
   651  	}
   652  	p.contextIDToDNSNamesLocks.Remove(contextID)
   653  	ul.Unlock()
   654  	p.shutdownDNS(contextID)
   655  	return nil
   656  }
   657  
   658  func (d *dnsNamesToIP) Copy() *dnsNamesToIP {
   659  	d.dnsNamesLock.Lock()
   660  	defer d.dnsNamesLock.Unlock()
   661  	newdns := &dnsNamesToIP{
   662  		nameToIP: make(map[string][]string, len(d.nameToIP)),
   663  	}
   664  	for key, value := range d.nameToIP {
   665  		newvalue := make([]string, len(value))
   666  		copy(newvalue, value)
   667  		newdns.nameToIP[key] = newvalue
   668  	}
   669  
   670  	return newdns
   671  }
   672  
   673  // updateFQDNWithIPs will add any new IPs in `ips` and add it to the internal map of `contextIDToDNSNames` for our s.contextID
   674  func (p *Proxy) updateFQDNWithIPs(contextID, fqdn string, ips []string) {
   675  	ul := p.contextIDToDNSNamesLocks.Lock(contextID)
   676  	defer ul.Unlock()
   677  	tmp, err := p.contextIDToDNSNames.Get(contextID)
   678  	if err != nil {
   679  		zap.L().Debug("dnsproxy: failed to get fqdn map for contextID in updateFQDNWithIPs", zap.String("contextID", contextID))
   680  		return
   681  	}
   682  	fqdntoIPs := tmp.(*dnsNamesToIP).Copy()
   683  	existingIPsMap := make(map[string]struct{}, len(fqdntoIPs.nameToIP[fqdn]))
   684  	for _, e := range fqdntoIPs.nameToIP[fqdn] {
   685  		existingIPsMap[e] = struct{}{}
   686  	}
   687  	toAdd := make([]string, 0, len(ips))
   688  	for _, newIP := range ips {
   689  		if _, ok := existingIPsMap[newIP]; ok {
   690  			continue
   691  		}
   692  		toAdd = append(toAdd, newIP)
   693  	}
   694  	fqdntoIPs.nameToIP[fqdn] = append(fqdntoIPs.nameToIP[fqdn], toAdd...)
   695  	_ = p.contextIDToDNSNames.AddOrUpdate(contextID, fqdntoIPs)
   696  	zap.L().Debug("dnsproxy: updating FQDN map after IP addresses were added", zap.String("contextID", contextID), zap.Any("fqdntoIPs", fqdntoIPs.nameToIP))
   697  }
   698  
   699  func (p *Proxy) removeIPfromFQDN(contextID, fqdn string, ipAddress string) {
   700  	ul := p.contextIDToDNSNamesLocks.Lock(contextID)
   701  	defer ul.Unlock()
   702  	tmp, err := p.contextIDToDNSNames.Get(contextID)
   703  	if err != nil {
   704  		zap.L().Debug("dnsproxy: failed to get fqdn map for contextID in removeIPfromFQDN", zap.String("contextID", contextID))
   705  		return
   706  	}
   707  	fqdntoIPs := tmp.(*dnsNamesToIP).Copy()
   708  	existingIPsMap := make(map[string]struct{}, len(fqdntoIPs.nameToIP[fqdn]))
   709  	for _, e := range fqdntoIPs.nameToIP[fqdn] {
   710  		existingIPsMap[e] = struct{}{}
   711  	}
   712  
   713  	// remove IP from map/cache
   714  	if len(fqdntoIPs.nameToIP[fqdn]) > 0 {
   715  		ips := make([]string, 0, len(fqdntoIPs.nameToIP[fqdn])-1)
   716  		var found bool
   717  		for _, ip := range fqdntoIPs.nameToIP[fqdn] {
   718  			if ip == ipAddress {
   719  				found = true
   720  				continue
   721  			}
   722  			ips = append(ips, ip)
   723  		}
   724  		if !found {
   725  			zap.L().Debug("dnsproxy: ipaddress was already removed from list", zap.String("contextID", contextID), zap.String("fqdn", fqdn), zap.String("ipaddress", ipAddress))
   726  		}
   727  		fqdntoIPs.nameToIP[fqdn] = ips
   728  	}
   729  
   730  	zap.L().Debug("dnsproxy: updating FQDN map after IP address was deleted", zap.String("contextID", contextID), zap.Any("iplist", fqdntoIPs.nameToIP))
   731  	_ = p.contextIDToDNSNames.AddOrUpdate(contextID, fqdntoIPs) // nolint
   732  }