github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns.go (about) 1 // +build linux 2 3 package dnsproxy 4 5 import ( 6 "context" 7 "net" 8 "strconv" 9 "sync" 10 "syscall" 11 "time" 12 13 "github.com/miekg/dns" 14 "go.aporeto.io/trireme-lib/collector" 15 "go.aporeto.io/trireme-lib/controller/pkg/flowtracking" 16 "go.aporeto.io/trireme-lib/controller/pkg/ipsetmanager" 17 "go.aporeto.io/trireme-lib/controller/pkg/pucontext" 18 "go.aporeto.io/trireme-lib/policy" 19 "go.aporeto.io/trireme-lib/utils/cache" 20 "go.uber.org/zap" 21 ) 22 23 // Proxy struct represents the object for dns proxy 24 type Proxy struct { 25 puFromID cache.DataStore 26 conntrack flowtracking.FlowClient 27 collector collector.EventCollector 28 contextIDToServer map[string]*dns.Server 29 chreports chan dnsReport 30 updateIPsets ipsetmanager.ACLManager 31 sync.RWMutex 32 } 33 34 type serveDNS struct { 35 contextID string 36 *Proxy 37 } 38 39 const ( 40 dnsRequestTimeout = 2 * time.Second 41 proxyMarkInt = 0x40 //Duplicated from supervisor/iptablesctrl refer to it 42 ) 43 44 func socketOptions(_, _ string, c syscall.RawConn) error { 45 var opErr error 46 err := c.Control(func(fd uintptr) { 47 if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, proxyMarkInt); err != nil { 48 zap.L().Error("Failed to mark connection", zap.Error(err)) 49 } 50 }) 51 52 if err != nil { 53 return err 54 } 55 56 return opErr 57 } 58 59 func listenUDP(network, addr string) (net.PacketConn, error) { 60 var lc net.ListenConfig 61 62 lc.Control = socketOptions 63 64 return lc.ListenPacket(context.Background(), network, addr) 65 } 66 67 func forwardDNSReq(r *dns.Msg, ip net.IP, port uint16) (*dns.Msg, []string, error) { 68 var ips []string 69 c := new(dns.Client) 70 c.Dialer = &net.Dialer{ 71 Control: func(_, _ string, c syscall.RawConn) error { 72 return c.Control(func(fd uintptr) { 73 if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, proxyMarkInt); err != nil { 74 zap.L().Error("Failed to assing mark to socket", zap.Error(err)) 75 } 76 }) 77 }, 78 Timeout: dnsRequestTimeout, 79 } 80 81 in, _, err := c.Exchange(r, net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))) 82 if err != nil { 83 return nil, nil, err 84 } 85 86 for _, ans := range in.Answer { 87 if ans.Header().Rrtype == dns.TypeA { 88 t, _ := ans.(*dns.A) 89 ips = append(ips, t.A.String()) 90 } 91 92 if ans.Header().Rrtype == dns.TypeAAAA { 93 t, _ := ans.(*dns.AAAA) 94 ips = append(ips, t.AAAA.String()) 95 } 96 } 97 98 return in, ips, nil 99 } 100 101 func (s *serveDNS) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { 102 var err error 103 lAddr := w.LocalAddr().(*net.UDPAddr) 104 rAddr := w.RemoteAddr().(*net.UDPAddr) 105 var puCtx *pucontext.PUContext 106 107 defer func() { 108 if puCtx != nil { 109 s.reportDNSLookup(r.Question[0].Name, puCtx, rAddr.IP, "") 110 } 111 }() 112 113 origIP, origPort, _, err := s.conntrack.GetOriginalDest(net.ParseIP("127.0.0.1"), rAddr.IP, uint16(lAddr.Port), uint16(rAddr.Port), 17) 114 if err != nil { 115 zap.L().Error("Failed to find flow for the redirected dns traffic", zap.Error(err)) 116 return 117 } 118 119 data, err := s.puFromID.Get(s.contextID) 120 if err != nil { 121 zap.L().Error("context not found for the PU with ID", zap.String("contextID", s.contextID)) 122 return 123 } 124 125 dnsReply, ips, err := forwardDNSReq(r, origIP, origPort) 126 if err != nil { 127 zap.L().Debug("Forwarded dns request returned error", zap.Error(err)) 128 return 129 } 130 131 puCtx = data.(*pucontext.PUContext) 132 ps, err1 := puCtx.GetPolicyFromFQDN(r.Question[0].Name) 133 if err1 == nil { 134 for _, p := range ps { 135 s.updateIPsets.UpdateIPsets(ips, p.Policy.ServiceID) 136 if err1 := puCtx.UpdateApplicationACLs(policy.IPRuleList{{Addresses: ips, 137 Ports: p.Ports, 138 Protocols: p.Protocols, 139 Policy: p.Policy, 140 }}); err1 != nil { 141 zap.L().Error("Adding IP rule returned error", zap.Error(err1)) 142 } 143 } 144 } 145 146 if err = w.WriteMsg(dnsReply); err != nil { 147 zap.L().Error("Writing dns response back to the client returned error", zap.Error(err)) 148 } 149 } 150 151 // StartDNSServer starts the dns server on the port provided for contextID 152 func (p *Proxy) StartDNSServer(contextID, port string) error { 153 netPacketConn, err := listenUDP("udp", "127.0.0.1:"+port) 154 if err != nil { 155 return err 156 } 157 158 var server *dns.Server 159 160 storeInMap := func() { 161 p.Lock() 162 defer p.Unlock() 163 164 p.contextIDToServer[contextID] = server 165 } 166 167 server = &dns.Server{NotifyStartedFunc: storeInMap, PacketConn: netPacketConn, Handler: &serveDNS{contextID, p}} 168 169 go func() { 170 if err := server.ActivateAndServe(); err != nil { 171 zap.L().Error("Could not start DNS proxy server", zap.Error(err)) 172 } 173 }() 174 175 return nil 176 } 177 178 // ShutdownDNS shuts down the dns server for contextID 179 func (p *Proxy) ShutdownDNS(contextID string) { 180 p.Lock() 181 defer p.Unlock() 182 if s, ok := p.contextIDToServer[contextID]; ok { 183 if err := s.Shutdown(); err != nil { 184 zap.L().Error("shutdown of dns server returned error", zap.String("contextID", contextID), zap.Error(err)) 185 } 186 delete(p.contextIDToServer, contextID) 187 } 188 } 189 190 // New creates an instance of the dns proxy 191 func New(puFromID cache.DataStore, conntrack flowtracking.FlowClient, c collector.EventCollector, aclmanager ipsetmanager.ACLManager) *Proxy { 192 ch := make(chan dnsReport) 193 p := &Proxy{chreports: ch, puFromID: puFromID, collector: c, conntrack: conntrack, contextIDToServer: map[string]*dns.Server{}, updateIPsets: aclmanager} 194 go p.reportDNSRequests(ch) 195 return p 196 }