github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns.go (about)

     1  // +build linux
     2  
     3  package dnsproxy
     4  
     5  import (
     6  	"context"
     7  	"net"
     8  	"strconv"
     9  	"sync"
    10  	"syscall"
    11  	"time"
    12  
    13  	"github.com/miekg/dns"
    14  	"go.aporeto.io/trireme-lib/collector"
    15  	"go.aporeto.io/trireme-lib/controller/pkg/flowtracking"
    16  	"go.aporeto.io/trireme-lib/controller/pkg/ipsetmanager"
    17  	"go.aporeto.io/trireme-lib/controller/pkg/pucontext"
    18  	"go.aporeto.io/trireme-lib/policy"
    19  	"go.aporeto.io/trireme-lib/utils/cache"
    20  	"go.uber.org/zap"
    21  )
    22  
    23  // Proxy struct represents the object for dns proxy
    24  type Proxy struct {
    25  	puFromID          cache.DataStore
    26  	conntrack         flowtracking.FlowClient
    27  	collector         collector.EventCollector
    28  	contextIDToServer map[string]*dns.Server
    29  	chreports         chan dnsReport
    30  	updateIPsets      ipsetmanager.ACLManager
    31  	sync.RWMutex
    32  }
    33  
    34  type serveDNS struct {
    35  	contextID string
    36  	*Proxy
    37  }
    38  
    39  const (
    40  	dnsRequestTimeout = 2 * time.Second
    41  	proxyMarkInt      = 0x40 //Duplicated from supervisor/iptablesctrl refer to it
    42  )
    43  
    44  func socketOptions(_, _ string, c syscall.RawConn) error {
    45  	var opErr error
    46  	err := c.Control(func(fd uintptr) {
    47  		if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, proxyMarkInt); err != nil {
    48  			zap.L().Error("Failed to mark connection", zap.Error(err))
    49  		}
    50  	})
    51  
    52  	if err != nil {
    53  		return err
    54  	}
    55  
    56  	return opErr
    57  }
    58  
    59  func listenUDP(network, addr string) (net.PacketConn, error) {
    60  	var lc net.ListenConfig
    61  
    62  	lc.Control = socketOptions
    63  
    64  	return lc.ListenPacket(context.Background(), network, addr)
    65  }
    66  
    67  func forwardDNSReq(r *dns.Msg, ip net.IP, port uint16) (*dns.Msg, []string, error) {
    68  	var ips []string
    69  	c := new(dns.Client)
    70  	c.Dialer = &net.Dialer{
    71  		Control: func(_, _ string, c syscall.RawConn) error {
    72  			return c.Control(func(fd uintptr) {
    73  				if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, proxyMarkInt); err != nil {
    74  					zap.L().Error("Failed to assing mark to socket", zap.Error(err))
    75  				}
    76  			})
    77  		},
    78  		Timeout: dnsRequestTimeout,
    79  	}
    80  
    81  	in, _, err := c.Exchange(r, net.JoinHostPort(ip.String(), strconv.Itoa(int(port))))
    82  	if err != nil {
    83  		return nil, nil, err
    84  	}
    85  
    86  	for _, ans := range in.Answer {
    87  		if ans.Header().Rrtype == dns.TypeA {
    88  			t, _ := ans.(*dns.A)
    89  			ips = append(ips, t.A.String())
    90  		}
    91  
    92  		if ans.Header().Rrtype == dns.TypeAAAA {
    93  			t, _ := ans.(*dns.AAAA)
    94  			ips = append(ips, t.AAAA.String())
    95  		}
    96  	}
    97  
    98  	return in, ips, nil
    99  }
   100  
   101  func (s *serveDNS) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
   102  	var err error
   103  	lAddr := w.LocalAddr().(*net.UDPAddr)
   104  	rAddr := w.RemoteAddr().(*net.UDPAddr)
   105  	var puCtx *pucontext.PUContext
   106  
   107  	defer func() {
   108  		if puCtx != nil {
   109  			s.reportDNSLookup(r.Question[0].Name, puCtx, rAddr.IP, "")
   110  		}
   111  	}()
   112  
   113  	origIP, origPort, _, err := s.conntrack.GetOriginalDest(net.ParseIP("127.0.0.1"), rAddr.IP, uint16(lAddr.Port), uint16(rAddr.Port), 17)
   114  	if err != nil {
   115  		zap.L().Error("Failed to find flow for the redirected dns traffic", zap.Error(err))
   116  		return
   117  	}
   118  
   119  	data, err := s.puFromID.Get(s.contextID)
   120  	if err != nil {
   121  		zap.L().Error("context not found for the PU with ID", zap.String("contextID", s.contextID))
   122  		return
   123  	}
   124  
   125  	dnsReply, ips, err := forwardDNSReq(r, origIP, origPort)
   126  	if err != nil {
   127  		zap.L().Debug("Forwarded dns request returned error", zap.Error(err))
   128  		return
   129  	}
   130  
   131  	puCtx = data.(*pucontext.PUContext)
   132  	ps, err1 := puCtx.GetPolicyFromFQDN(r.Question[0].Name)
   133  	if err1 == nil {
   134  		for _, p := range ps {
   135  			s.updateIPsets.UpdateIPsets(ips, p.Policy.ServiceID)
   136  			if err1 := puCtx.UpdateApplicationACLs(policy.IPRuleList{{Addresses: ips,
   137  				Ports:     p.Ports,
   138  				Protocols: p.Protocols,
   139  				Policy:    p.Policy,
   140  			}}); err1 != nil {
   141  				zap.L().Error("Adding IP rule returned error", zap.Error(err1))
   142  			}
   143  		}
   144  	}
   145  
   146  	if err = w.WriteMsg(dnsReply); err != nil {
   147  		zap.L().Error("Writing dns response back to the client returned error", zap.Error(err))
   148  	}
   149  }
   150  
   151  // StartDNSServer starts the dns server on the port provided for contextID
   152  func (p *Proxy) StartDNSServer(contextID, port string) error {
   153  	netPacketConn, err := listenUDP("udp", "127.0.0.1:"+port)
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	var server *dns.Server
   159  
   160  	storeInMap := func() {
   161  		p.Lock()
   162  		defer p.Unlock()
   163  
   164  		p.contextIDToServer[contextID] = server
   165  	}
   166  
   167  	server = &dns.Server{NotifyStartedFunc: storeInMap, PacketConn: netPacketConn, Handler: &serveDNS{contextID, p}}
   168  
   169  	go func() {
   170  		if err := server.ActivateAndServe(); err != nil {
   171  			zap.L().Error("Could not start DNS proxy server", zap.Error(err))
   172  		}
   173  	}()
   174  
   175  	return nil
   176  }
   177  
   178  // ShutdownDNS shuts down the dns server for contextID
   179  func (p *Proxy) ShutdownDNS(contextID string) {
   180  	p.Lock()
   181  	defer p.Unlock()
   182  	if s, ok := p.contextIDToServer[contextID]; ok {
   183  		if err := s.Shutdown(); err != nil {
   184  			zap.L().Error("shutdown of dns server returned error", zap.String("contextID", contextID), zap.Error(err))
   185  		}
   186  		delete(p.contextIDToServer, contextID)
   187  	}
   188  }
   189  
   190  // New creates an instance of the dns proxy
   191  func New(puFromID cache.DataStore, conntrack flowtracking.FlowClient, c collector.EventCollector, aclmanager ipsetmanager.ACLManager) *Proxy {
   192  	ch := make(chan dnsReport)
   193  	p := &Proxy{chreports: ch, puFromID: puFromID, collector: c, conntrack: conntrack, contextIDToServer: map[string]*dns.Server{}, updateIPsets: aclmanager}
   194  	go p.reportDNSRequests(ch)
   195  	return p
   196  }