github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns_linux.go (about) 1 // +build linux 2 3 package dnsproxy 4 5 import ( 6 "context" 7 "fmt" 8 "net" 9 "strconv" 10 "sync" 11 "syscall" 12 "time" 13 14 "github.com/miekg/dns" 15 "go.aporeto.io/enforcerd/trireme-lib/collector" 16 "go.aporeto.io/enforcerd/trireme-lib/controller/constants" 17 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/counters" 18 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/flowtracking" 19 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/ipsetmanager" 20 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext" 21 "go.aporeto.io/enforcerd/trireme-lib/policy" 22 "go.aporeto.io/enforcerd/trireme-lib/utils/cache" 23 "go.uber.org/zap" 24 ) 25 26 // removeExpiredEntryFunc is the type of the function that gets called when an IP entry expires (when its TTL hits 0) 27 type removeExpiredEntryFunc func(string) 28 29 // Proxy struct represents the object for dns proxy 30 type Proxy struct { 31 puFromID cache.DataStore 32 conntrack flowtracking.FlowClient 33 collector collector.EventCollector 34 contextIDToServer map[string]*dns.Server 35 chreports chan dnsReport 36 contextIDToDNSNames *cache.Cache 37 contextIDToDNSNamesLocks *mutexMap 38 IPToTTL *cache.Cache 39 IPToTTLLocks *mutexMap 40 removeExpiredEntry removeExpiredEntryFunc 41 sync.Mutex 42 } 43 type dnsNamesToIP struct { 44 nameToIP map[string][]string 45 dnsNamesLock sync.Mutex 46 } 47 type dnsttlinfo struct { 48 ipaddress string 49 ttl uint32 50 } 51 52 type iptottlinfo struct { 53 ipaddress string 54 expiryTime time.Time 55 timer *time.Timer 56 contextIDs map[string]struct{} 57 fqdns map[string]struct{} 58 } 59 type serveDNS struct { 60 contextID string 61 *Proxy 62 } 63 64 const ( 65 dnsRequestTimeout = 2 * time.Second 66 ) 67 68 func socketOptions(_, _ string, c syscall.RawConn) error { 69 var opErr error 70 err := c.Control(func(fd uintptr) { 71 if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, constants.ProxyMarkInt); err != nil { 72 zap.L().Error("dnsproxy: failed to mark connection", zap.Error(err)) 73 } 74 }) 75 76 if err != nil { 77 return err 78 } 79 80 return opErr 81 } 82 83 func listenUDP(ctx context.Context, network, addr string) (net.PacketConn, error) { 84 var lc net.ListenConfig 85 86 lc.Control = socketOptions 87 88 return lc.ListenPacket(ctx, network, addr) 89 } 90 91 func forwardDNSReq(r *dns.Msg, ip net.IP, port uint16) ([]byte, []string, []*dnsttlinfo, error) { 92 var ips []string 93 var resp []byte 94 var msg *dns.Msg 95 var conn *dns.Conn 96 var err error 97 98 c := new(dns.Client) 99 100 dial := func(address string) (*dns.Conn, error) { 101 c.Dialer = &net.Dialer{ 102 Control: func(_, _ string, c syscall.RawConn) error { 103 return c.Control(func(fd uintptr) { 104 if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, constants.ProxyMarkInt); err != nil { 105 zap.L().Error("dnsproxy: failed to assing mark to socket", zap.Error(err)) 106 } 107 }) 108 }, 109 Timeout: dnsRequestTimeout, 110 } 111 112 conn, err := c.Dial(address) 113 if err != nil { 114 return nil, err 115 } 116 117 return conn, nil 118 } 119 120 sendRequest := func(r *dns.Msg, conn *dns.Conn) error { 121 opt := r.IsEdns0() 122 // If EDNS0 is used use that for size. 123 if opt != nil && opt.UDPSize() >= dns.MinMsgSize { 124 conn.UDPSize = opt.UDPSize() 125 } 126 // Otherwise use the client's configured UDP size. 127 if opt == nil && c.UDPSize >= dns.MinMsgSize { 128 conn.UDPSize = c.UDPSize 129 } 130 131 t := time.Now() 132 // write with the appropriate write timeout 133 if err = conn.SetWriteDeadline(t.Add(c.Dialer.Timeout)); err != nil { 134 return err 135 } 136 137 if err = conn.WriteMsg(r); err != nil { 138 return err 139 } 140 141 return nil 142 } 143 144 readResponse := func(conn *dns.Conn) ([]byte, *dns.Msg, error) { 145 if err := conn.SetReadDeadline(time.Now().Add(c.Dialer.Timeout)); err != nil { 146 return nil, nil, err 147 } 148 149 p, err := conn.ReadMsgHeader(nil) 150 if err != nil { 151 return nil, nil, err 152 } 153 154 m := new(dns.Msg) 155 if err := m.Unpack(p); err != nil { 156 // If an error was returned, we still want to allow the user to use 157 // the message, but naively they can just check err if they don't want 158 // to use an erroneous message 159 return nil, nil, err 160 } 161 162 return p, m, nil 163 } 164 165 if conn, err = dial(net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))); err != nil { 166 return nil, nil, nil, err 167 } 168 169 defer conn.Close() // nolint: errcheck 170 171 if err := sendRequest(r, conn); err != nil { 172 return nil, nil, nil, err 173 } 174 175 if resp, msg, err = readResponse(conn); err != nil { 176 return nil, nil, nil, err 177 } 178 dnsttlinfolist := []*dnsttlinfo{} 179 180 for _, ans := range msg.Answer { 181 if ans.Header().Rrtype == dns.TypeA { 182 t, _ := ans.(*dns.A) 183 184 ips = append(ips, t.A.String()) 185 dnsttlinfolist = append(dnsttlinfolist, &dnsttlinfo{ 186 ipaddress: t.A.String(), 187 ttl: ans.Header().Ttl, 188 }) 189 } 190 191 if ans.Header().Rrtype == dns.TypeAAAA { 192 t, _ := ans.(*dns.AAAA) 193 ips = append(ips, t.AAAA.String()) 194 195 dnsttlinfolist = append(dnsttlinfolist, &dnsttlinfo{ 196 ipaddress: t.AAAA.String(), 197 ttl: ans.Header().Ttl, 198 }) 199 } 200 } 201 return resp, ips, dnsttlinfolist, nil 202 } 203 204 const ( 205 strInvalidDNSRequest = "invalid DNS request" 206 ) 207 208 func (s *serveDNS) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { 209 var err error 210 lAddr := w.LocalAddr().(*net.UDPAddr) 211 rAddr := w.RemoteAddr().(*net.UDPAddr) 212 var pctx *pucontext.PUContext 213 var ipsRaw []string 214 var origIP net.IP 215 var origPort uint16 216 var reportError string 217 218 defer func() { 219 if pctx != nil { 220 // if there is no question section, this was an invalid request 221 name := "invalid" 222 if len(r.Question) > 0 { 223 name = r.Question[0].Name 224 } 225 s.reportDNSLookup(name, pctx, rAddr.IP, uint16(rAddr.Port), origIP, origPort, ipsRaw, reportError) 226 } 227 }() 228 229 pctxRaw, err := s.puFromID.Get(s.contextID) 230 if err != nil { 231 zap.L().Error("dnsproxy: context not found for the PU with ID", zap.String("contextID", s.contextID), zap.Error(err)) 232 reportError = fmt.Sprintf("PU context: %s", err) 233 return 234 } 235 pctx = pctxRaw.(*pucontext.PUContext) 236 237 // check if the DNS request is actually valid 238 // we have seen with the AWS resolve in the past that it *does* respond with an empty Question section 239 if len(r.Question) <= 0 { 240 pctx.Counters().IncrementCounter(counters.ErrDNSInvalidRequest) 241 zap.L().Debug("dnsproxy: invalid DNS request received (missing question section)", zap.String("contextID", s.contextID)) 242 reportError = strInvalidDNSRequest 243 return 244 } 245 246 // TODO: shouldn't we let the lookup go regardless of our problems? 247 origIP, origPort, _, err = s.conntrack.GetOriginalDest(net.ParseIP("127.0.0.1"), rAddr.IP, uint16(lAddr.Port), uint16(rAddr.Port), 17) 248 if err != nil { 249 zap.L().Error("dnsproxy: failed to find flow for the redirected DNS traffic", zap.String("contextID", s.contextID), zap.Error(err)) 250 reportError = fmt.Sprintf("conntrack: DNS request flow: %s", err) 251 return 252 } 253 254 // perform the upstream DNS lookup 255 dnsReply, ipsRaw, dnsttlinfolistRaw, err := forwardDNSReq(r, origIP, origPort) 256 if err != nil { 257 pctx.Counters().IncrementCounter(counters.ErrDNSForwardFailed) 258 zap.L().Debug("dnsproxy: forwarded DNS request returned error", zap.String("contextID", s.contextID), zap.Error(err)) 259 reportError = fmt.Sprintf("DNS request failed: %s", err) 260 return 261 } 262 263 // get all policies associated with the FQDN from 264 policies, policyName, err1 := pctx.GetPolicyFromFQDN(r.Question[0].Name) 265 266 // if they exist, then err1 is nil, and we need to update 267 // - the ipsets 268 // - the applicationacls inside of the enforcer 269 // - the internal cache 270 if err1 == nil { 271 type ipDetail struct { 272 ttl uint32 273 updateOnly bool 274 } 275 ipsToProcess := make(map[string]ipDetail, len(ipsRaw)) 276 for _, pol := range policies { 277 for _, i := range dnsttlinfolistRaw { 278 // TODO: this does not work yet - will come in a separate PR 279 //if checkIfACLExists(pctx, pol, i.ipaddress) { 280 // // no need to program ACLs and ipsets 281 // // however, there are two cases here: 282 // // 1. this was a static entry in the external network, and we truly want to skip it 283 // // 2. this comes from us programming it down below 284 // // In case (2) we need to actually call handleTTLInfoList but with updateOnly to extend the TTL 285 // // This way no new expiry entries will be made when not necessary as for static entries, 286 // // but they will be extended when necessary as well. 287 // if _, ok := ipsToProcess[i.ipaddress]; !ok { 288 // ipsToProcess[i.ipaddress] = ipDetail{ttl: i.ttl, updateOnly: true} 289 // } 290 // continue 291 //} 292 293 if _, ok := ipsToProcess[i.ipaddress]; !ok { 294 ipsToProcess[i.ipaddress] = ipDetail{ttl: i.ttl, updateOnly: false} 295 } 296 ips := []string{i.ipaddress} 297 298 // makes sure to update any ipsets related to the serviceID of the policy 299 // this matches the case when the destinations are *not* overlapping with the target networks and the decision is not done 300 // in the enforcer 301 zap.L().Debug("dnsproxy: ipset: adding IP addresses", zap.String("contextID", s.contextID), zap.String("serviceID", pol.Policy.ServiceID), zap.Strings("ipaddresses", ips)) 302 ipsetmanager.V4().UpdateACLIPsets(ips, pol.Policy.ServiceID) 303 ipsetmanager.V6().UpdateACLIPsets(ips, pol.Policy.ServiceID) 304 305 // makes sure to update the ApplicationACLs inside of the enforcer 306 // this matches the case when the destination is overlapping with the target networks, and the decision is made inside 307 // of the enforcer, and not with ipsets 308 zap.L().Debug("dnsproxy: adding IP addresses to enforcer ApplicationACLs", zap.String("contextID", s.contextID), zap.String("serviceID", pol.Policy.ServiceID), zap.Strings("ipaddresses", ips)) 309 if err1 := pctx.UpdateApplicationACLs(policy.IPRuleList{{ 310 Addresses: ips, 311 Ports: pol.Ports, 312 Protocols: pol.Protocols, 313 Policy: pol.Policy, 314 }}); err1 != nil { 315 zap.L().Error("dnsproxy: adding IP rule returned error", zap.String("contextID", s.contextID), zap.Error(err1)) 316 } 317 } 318 } 319 320 // for processing in our caches, we only care about the IPs that we needed to program ACLs/ipsets 321 ips := make([]string, 0, len(ipsToProcess)) 322 var dnsttlinfolist, dnsttlinfolistUpdateOnly []*dnsttlinfo 323 for ip, d := range ipsToProcess { 324 if d.updateOnly { 325 dnsttlinfolistUpdateOnly = append(dnsttlinfolistUpdateOnly, &dnsttlinfo{ipaddress: ip, ttl: d.ttl}) 326 } else { 327 ips = append(ips, ip) 328 dnsttlinfolist = append(dnsttlinfolist, &dnsttlinfo{ipaddress: ip, ttl: d.ttl}) 329 } 330 } 331 332 // update the cache/map for this FQDN with new 333 s.updateFQDNWithIPs(s.contextID, policyName, ips) 334 335 // add or update only if required the expiry entries 336 s.handleTTLInfoList(s.contextID, policyName, dnsttlinfolist, false) 337 s.handleTTLInfoList(s.contextID, policyName, dnsttlinfolistUpdateOnly, true) 338 } 339 340 configureDependentServices(pctx, r.Question[0].Name, ipsRaw) 341 342 // write the DNS reply back to the client 343 if _, err = w.Write(dnsReply); err != nil { 344 pctx.Counters().IncrementCounter(counters.ErrDNSResponseFailed) 345 zap.L().Error("dnsproxy: writing DNS response back to the client returned error", zap.String("contextID", s.contextID), zap.Error(err)) 346 } 347 } 348 349 // TODO: this does not work yet - will come in a separate PR 350 // func checkIfACLExists(pctx *pucontext.PUContext, pol policy.PortProtocolPolicy, ipStr string) bool { 351 // ip := net.ParseIP(ipStr) 352 // if ip == nil { 353 // return false 354 // } 355 // for _, protoStr := range pol.Protocols { 356 // proto, err := strconv.Atoi(protoStr) 357 // if err != nil { 358 // continue 359 // } 360 // for _, portStr := range pol.Ports { 361 // // it could be a range definition 362 // var ports []uint16 363 // if strings.Contains(portStr, ":") { 364 // tmp := strings.SplitN(portStr, ":", 2) 365 // if len(tmp) != 2 { 366 // continue 367 // } 368 // startPort, err := strconv.Atoi(tmp[0]) 369 // if err != nil { 370 // continue 371 // } 372 // endPort, err := strconv.Atoi(tmp[1]) 373 // if err != nil { 374 // continue 375 // } 376 // ports = make([]uint16, 0, endPort-startPort+1) 377 // for i := startPort; i <= endPort; i++ { 378 // ports = append(ports, uint16(i)) 379 // } 380 // } else { 381 // port, err := strconv.Atoi(portStr) 382 // if err != nil { 383 // continue 384 // } 385 // ports = []uint16{uint16(port)} 386 // } 387 388 // for _, port := range ports { 389 // reportPol, actionPol, err := pctx.ApplicationACLPolicyFromAddr(ip, port, uint8(proto)) 390 // if err == nil && (reportPol != nil || actionPol != nil) { 391 // zap.L().Debug("dnsproxy: ACL already found for IP", 392 // zap.String("contextID", pctx.ManagementID()), 393 // zap.String("ipaddress", ipStr), 394 // zap.Any("reportPol", reportPol), 395 // zap.Any("actionPol", actionPol), 396 // ) 397 // return true 398 // } 399 // } 400 // } 401 // } 402 // return false 403 // } 404 405 func (p *Proxy) handleTTLInfoList(contextID, fqdn string, dnsttlinfolist []*dnsttlinfo, updateOnly bool) { 406 for _, dnsinfo := range dnsttlinfolist { 407 zap.L().Debug("handleTTLInfoList", zap.String("fqdn", fqdn), zap.String("ipaddress", dnsinfo.ipaddress), zap.Bool("updateOnly", updateOnly)) 408 p.handleTTLInfo(contextID, fqdn, dnsinfo, updateOnly) 409 } 410 } 411 412 func (p *Proxy) handleTTLInfo(contextID, fqdn string, dnsinfo *dnsttlinfo, updateOnly bool) { 413 newEntryExpiryTime := time.Now().Add(time.Duration(dnsinfo.ttl) * time.Second) 414 ul := p.IPToTTLLocks.Lock(dnsinfo.ipaddress) 415 defer ul.Unlock() 416 ttlInfoRaw, err := p.IPToTTL.Get(dnsinfo.ipaddress) 417 if err != nil { 418 // if we are supposed to be updating only 419 // then skip this entry 420 if updateOnly { 421 return 422 } 423 424 // otherwise add a new entry 425 newEntry := iptottlinfo{ 426 ipaddress: dnsinfo.ipaddress, 427 expiryTime: newEntryExpiryTime, 428 contextIDs: map[string]struct{}{contextID: {}}, 429 fqdns: map[string]struct{}{fqdn: {}}, 430 } 431 // NOTE: the dnsinfo.ipaddress is in a for loop 432 // so we need to make sure the IP address is on the stack when the callback is called 433 // hence the anonymous function wrapping of the timer 434 func(ipaddress string) { 435 newEntry.timer = time.AfterFunc(time.Duration(dnsinfo.ttl)*time.Second, func() { 436 if p.removeExpiredEntry != nil { 437 p.removeExpiredEntry(ipaddress) 438 } 439 }) 440 }(dnsinfo.ipaddress) 441 if err := p.IPToTTL.Add(dnsinfo.ipaddress, newEntry); err != nil { 442 zap.L().Debug("dnsproxy: failed to add entry to IPToTTL cache", zap.String("contextID", contextID), zap.Any("iptottlinfo", newEntry), zap.Error(err)) 443 return 444 } 445 } else { 446 // update TTL info and reset timer if necessary 447 ttlInfo := ttlInfoRaw.(iptottlinfo) 448 if newEntryExpiryTime.After(ttlInfo.expiryTime) { 449 ttlInfo.timer.Reset(time.Duration(dnsinfo.ttl) * time.Second) 450 } 451 ttlInfo.expiryTime = newEntryExpiryTime 452 ttlInfo.contextIDs[contextID] = struct{}{} 453 ttlInfo.fqdns[fqdn] = struct{}{} 454 p.IPToTTL.AddOrUpdate(dnsinfo.ipaddress, ttlInfo) 455 } 456 } 457 458 func (p *Proxy) defaultRemoveExpiredEntry(ipaddress string) { 459 // retrieve the IPtoTTLInfo 460 ul := p.IPToTTLLocks.Lock(ipaddress) 461 defer ul.Unlock() 462 ttlInfoRaw, err := p.IPToTTL.Get(ipaddress) 463 if err != nil { 464 zap.L().Debug("dnsproxy: entry already gone from IPToTTL cache", zap.String("ipaddress", ipaddress)) 465 return 466 } 467 ttlInfo := ttlInfoRaw.(iptottlinfo) 468 469 for contextID := range ttlInfo.contextIDs { 470 pctxRaw, err := p.puFromID.Get(contextID) 471 if err != nil { 472 zap.L().Error("dnsproxy: context not found for the PU with ID", zap.String("contextID", contextID)) 473 continue 474 } 475 pctx := pctxRaw.(*pucontext.PUContext) 476 477 for fqdn := range ttlInfo.fqdns { 478 policies, policyName, err := pctx.GetPolicyFromFQDN(fqdn) 479 if err != nil { 480 continue 481 } 482 483 // remove IP address from ipsets 484 for _, pol := range policies { 485 zap.L().Debug("dnsproxy: ipset: removing IP address", zap.String("contextID", contextID), zap.String("serviceID", pol.Policy.ServiceID), zap.String("ipaddress", ipaddress)) 486 ipsetmanager.V4().DeleteEntryFromIPset([]string{ipaddress}, pol.Policy.ServiceID) 487 ipsetmanager.V6().DeleteEntryFromIPset([]string{ipaddress}, pol.Policy.ServiceID) 488 489 // remove IP address from enforcer ApplicationACLs 490 zap.L().Debug("dnsproxy: removing IP address from enforcer ApplicationACL", zap.String("contextID", contextID), zap.String("ipaddress", ipaddress)) 491 if err := pctx.RemoveApplicationACL(ipaddress, pol.Protocols, pol.Ports, pol.Policy); err != nil { 492 zap.L().Debug("dnsproxy: RemoveApplicationACL failed", zap.String("contextID", contextID), zap.String("serviceID", pol.Policy.ServiceID), zap.String("ipaddress", ipaddress), zap.Error(err)) 493 } 494 } 495 496 p.removeIPfromFQDN(contextID, policyName, ipaddress) 497 } 498 } 499 500 // clean up after ourselves and remove ourselves from the cache 501 if err := p.IPToTTL.Remove(ipaddress); err != nil { 502 zap.L().Debug("dnsproxy: failed to remove entry from IPToTTL cache", zap.String("ipaddress", ipaddress), zap.Error(err)) 503 } 504 p.IPToTTLLocks.Remove(ipaddress) 505 506 } 507 508 // StartDNSServer starts the dns server on the port provided for contextID 509 func (p *Proxy) StartDNSServer(ctx context.Context, contextID, port string) error { 510 netPacketConn, err := listenUDP(ctx, "udp", "127.0.0.1:"+port) 511 if err != nil { 512 return err 513 } 514 515 var server *dns.Server 516 517 storeInMap := func() { 518 p.Lock() 519 defer p.Unlock() 520 521 p.contextIDToServer[contextID] = server 522 } 523 524 server = &dns.Server{NotifyStartedFunc: storeInMap, PacketConn: netPacketConn, Handler: &serveDNS{contextID, p}} 525 526 go func() { 527 if err := server.ActivateAndServe(); err != nil { 528 zap.L().Error("dnsproxy: could not start DNS proxy server", zap.String("contextID", contextID), zap.Error(err)) 529 } 530 }() 531 532 return nil 533 } 534 535 // shutdownDNS shuts down the dns server for contextID 536 func (p *Proxy) shutdownDNS(contextID string) { 537 538 if s, ok := p.contextIDToServer[contextID]; ok { 539 if err := s.Shutdown(); err != nil { 540 zap.L().Error("dnsproxy: shutdown of DNS server returned error", zap.String("contextID", contextID), zap.Error(err)) 541 } 542 delete(p.contextIDToServer, contextID) 543 } 544 } 545 546 // New creates an instance of the dns proxy 547 func New(ctx context.Context, puFromID cache.DataStore, conntrack flowtracking.FlowClient, c collector.EventCollector) *Proxy { 548 ch := make(chan dnsReport) 549 p := &Proxy{ 550 chreports: ch, 551 puFromID: puFromID, 552 collector: c, 553 conntrack: conntrack, 554 contextIDToServer: map[string]*dns.Server{}, 555 contextIDToDNSNames: cache.NewCache("contextIDtoDNSNames"), 556 contextIDToDNSNamesLocks: newMutexMap(), 557 IPToTTL: cache.NewCache("IPToTTL"), 558 IPToTTLLocks: newMutexMap(), 559 } 560 p.removeExpiredEntry = p.defaultRemoveExpiredEntry 561 go p.reportDNSRequests(ctx, ch) 562 return p 563 } 564 565 // SyncWithPlatformCache is only needed in Windows currently 566 func (p *Proxy) SyncWithPlatformCache(ctx context.Context, pctx *pucontext.PUContext) error { 567 return nil 568 } 569 570 // HandleDNSResponsePacket is only needed in Windows currently 571 func (p *Proxy) HandleDNSResponsePacket(dnsPacketData []byte, sourceIP net.IP, sourcePort uint16, destIP net.IP, destPort uint16, puFromContextID func(string) (*pucontext.PUContext, error)) error { 572 return nil 573 } 574 575 // Enforce starts enforcing policies for the given policy.PUInfo. 576 func (p *Proxy) Enforce(ctx context.Context, contextID string, puInfo *policy.PUInfo) error { 577 // during the first Enforce call, we still need to initialize map 578 // we do that and return 579 ul := p.contextIDToDNSNamesLocks.Lock(contextID) 580 defer ul.Unlock() 581 tmp, err := p.contextIDToDNSNames.Get(contextID) 582 if err != nil { 583 // this means that the map is not initialized yet, do so now 584 return p.doHandleCreate(ctx, contextID, puInfo) 585 } 586 587 // during a policy refresh, we will enter this part here: 588 // - iterate over all DNSACLs for this PU 589 // - for all already learned IPs for all DNS names: program ipsets and enforcer ApplicationACLs 590 dnsNames := tmp.(*dnsNamesToIP).Copy() 591 for fqdn, policies := range puInfo.Policy.DNSACLs { 592 ips, ok := dnsNames.nameToIP[fqdn] 593 if ok { 594 // we have already learned those DNS names 595 // make sure to reprogram ipsets and ApplicationACLs in the enforcer 596 // on a policy refresh 597 for _, pol := range policies { 598 zap.L().Debug("dnsproxy: ipset: adding IP addresses after policy refresh", zap.String("contextID", contextID), zap.String("fqdn", fqdn), zap.String("serviceID", pol.Policy.ServiceID), zap.Strings("ipaddresses", ips)) 599 ipsetmanager.V4().UpdateACLIPsets(ips, pol.Policy.ServiceID) 600 ipsetmanager.V6().UpdateACLIPsets(ips, pol.Policy.ServiceID) 601 602 // makes sure to update the ApplicationACLs inside of the enforcer 603 // this matches the case when the destination is overlapping with the target networks, and the decision is made inside 604 // of the enforcer, and not with ipsets 605 data, err := p.puFromID.Get(contextID) 606 if err != nil { 607 zap.L().Error("dnsproxy: context not found for the PU with ID", zap.String("contextID", contextID)) 608 continue 609 } 610 pctx := data.(*pucontext.PUContext) 611 if err1 := pctx.UpdateApplicationACLs(policy.IPRuleList{{ 612 Addresses: ips, 613 Ports: pol.Ports, 614 Protocols: pol.Protocols, 615 Policy: pol.Policy, 616 }}); err1 != nil { 617 zap.L().Error("dnsproxy: adding IP rule returned error after policy refresh", zap.String("contextID", contextID), zap.Error(err1)) 618 } 619 } 620 continue 621 } 622 // This is a new fqdn. DNS proxy will fix these IPs as it learns them 623 dnsNames.nameToIP[fqdn] = []string{} 624 } 625 // this is only necessary to add new FQDNs to the map - which is essential for the DNS proxy to know about 626 p.contextIDToDNSNames.AddOrUpdate(contextID, dnsNames) 627 return nil 628 } 629 630 func (p *Proxy) doHandleCreate(_ context.Context, contextID string, puInfo *policy.PUInfo) error { 631 nameToIP := &dnsNamesToIP{ 632 nameToIP: map[string][]string{}, 633 } 634 for name := range puInfo.Policy.DNSACLs { 635 nameToIP.nameToIP[name] = []string{} 636 } 637 if err := p.contextIDToDNSNames.Add(contextID, nameToIP); err != nil { 638 zap.L().Error("dnsproxy: contextID already enforced", zap.String("contextID", contextID)) 639 } 640 641 return nil 642 } 643 644 // Unenforce stops enforcing policy for the given IP. 645 func (p *Proxy) Unenforce(_ context.Context, contextID string) error { 646 p.Lock() 647 defer p.Unlock() 648 ul := p.contextIDToDNSNamesLocks.Lock(contextID) 649 if err := p.contextIDToDNSNames.Remove(contextID); err != nil { 650 zap.L().Error("dnsproxy: contextID already removed/unenforced", zap.String("contextID", contextID)) 651 } 652 p.contextIDToDNSNamesLocks.Remove(contextID) 653 ul.Unlock() 654 p.shutdownDNS(contextID) 655 return nil 656 } 657 658 func (d *dnsNamesToIP) Copy() *dnsNamesToIP { 659 d.dnsNamesLock.Lock() 660 defer d.dnsNamesLock.Unlock() 661 newdns := &dnsNamesToIP{ 662 nameToIP: make(map[string][]string, len(d.nameToIP)), 663 } 664 for key, value := range d.nameToIP { 665 newvalue := make([]string, len(value)) 666 copy(newvalue, value) 667 newdns.nameToIP[key] = newvalue 668 } 669 670 return newdns 671 } 672 673 // updateFQDNWithIPs will add any new IPs in `ips` and add it to the internal map of `contextIDToDNSNames` for our s.contextID 674 func (p *Proxy) updateFQDNWithIPs(contextID, fqdn string, ips []string) { 675 ul := p.contextIDToDNSNamesLocks.Lock(contextID) 676 defer ul.Unlock() 677 tmp, err := p.contextIDToDNSNames.Get(contextID) 678 if err != nil { 679 zap.L().Debug("dnsproxy: failed to get fqdn map for contextID in updateFQDNWithIPs", zap.String("contextID", contextID)) 680 return 681 } 682 fqdntoIPs := tmp.(*dnsNamesToIP).Copy() 683 existingIPsMap := make(map[string]struct{}, len(fqdntoIPs.nameToIP[fqdn])) 684 for _, e := range fqdntoIPs.nameToIP[fqdn] { 685 existingIPsMap[e] = struct{}{} 686 } 687 toAdd := make([]string, 0, len(ips)) 688 for _, newIP := range ips { 689 if _, ok := existingIPsMap[newIP]; ok { 690 continue 691 } 692 toAdd = append(toAdd, newIP) 693 } 694 fqdntoIPs.nameToIP[fqdn] = append(fqdntoIPs.nameToIP[fqdn], toAdd...) 695 _ = p.contextIDToDNSNames.AddOrUpdate(contextID, fqdntoIPs) 696 zap.L().Debug("dnsproxy: updating FQDN map after IP addresses were added", zap.String("contextID", contextID), zap.Any("fqdntoIPs", fqdntoIPs.nameToIP)) 697 } 698 699 func (p *Proxy) removeIPfromFQDN(contextID, fqdn string, ipAddress string) { 700 ul := p.contextIDToDNSNamesLocks.Lock(contextID) 701 defer ul.Unlock() 702 tmp, err := p.contextIDToDNSNames.Get(contextID) 703 if err != nil { 704 zap.L().Debug("dnsproxy: failed to get fqdn map for contextID in removeIPfromFQDN", zap.String("contextID", contextID)) 705 return 706 } 707 fqdntoIPs := tmp.(*dnsNamesToIP).Copy() 708 existingIPsMap := make(map[string]struct{}, len(fqdntoIPs.nameToIP[fqdn])) 709 for _, e := range fqdntoIPs.nameToIP[fqdn] { 710 existingIPsMap[e] = struct{}{} 711 } 712 713 // remove IP from map/cache 714 if len(fqdntoIPs.nameToIP[fqdn]) > 0 { 715 ips := make([]string, 0, len(fqdntoIPs.nameToIP[fqdn])-1) 716 var found bool 717 for _, ip := range fqdntoIPs.nameToIP[fqdn] { 718 if ip == ipAddress { 719 found = true 720 continue 721 } 722 ips = append(ips, ip) 723 } 724 if !found { 725 zap.L().Debug("dnsproxy: ipaddress was already removed from list", zap.String("contextID", contextID), zap.String("fqdn", fqdn), zap.String("ipaddress", ipAddress)) 726 } 727 fqdntoIPs.nameToIP[fqdn] = ips 728 } 729 730 zap.L().Debug("dnsproxy: updating FQDN map after IP address was deleted", zap.String("contextID", contextID), zap.Any("iplist", fqdntoIPs.nameToIP)) 731 _ = p.contextIDToDNSNames.AddOrUpdate(contextID, fqdntoIPs) // nolint 732 }