github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns_linux_test.go (about)

     1  // +build linux
     2  
     3  package dnsproxy
     4  
     5  import (
     6  	"bufio"
     7  	"context"
     8  	"fmt"
     9  	"net"
    10  	"os"
    11  	"reflect"
    12  	"regexp"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/magiconair/properties/assert"
    19  	"github.com/miekg/dns"
    20  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    21  	"go.aporeto.io/enforcerd/trireme-lib/controller/constants"
    22  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/acls"
    23  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext"
    24  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    25  	"go.aporeto.io/enforcerd/trireme-lib/utils/cache"
    26  )
    27  
    28  type flowClientDummy struct {
    29  }
    30  
    31  func (c *flowClientDummy) Close() error {
    32  	return nil
    33  }
    34  
    35  func (c *flowClientDummy) UpdateMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32, network bool) error {
    36  	return nil
    37  }
    38  
    39  func (c *flowClientDummy) UpdateNetworkFlowMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32) error {
    40  	return nil
    41  }
    42  
    43  func (c *flowClientDummy) UpdateApplicationFlowMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32) error {
    44  	return nil
    45  }
    46  
    47  func findDNSServerIP() net.IP {
    48  
    49  	file, err := os.Open("/etc/resolv.conf")
    50  
    51  	if err != nil {
    52  		return net.ParseIP("8.8.8.8")
    53  	}
    54  
    55  	scanner := bufio.NewScanner(file)
    56  
    57  	// this regex is doing a whole word search
    58  	s := "\\b" + "nameserver" + "\\b"
    59  	match := regexp.MustCompile(s)
    60  
    61  	for scanner.Scan() {
    62  		line := scanner.Text()
    63  		if match.MatchString(line) {
    64  			return net.ParseIP(strings.Fields(line)[1])
    65  		}
    66  	}
    67  
    68  	return net.ParseIP("8.8.8.8")
    69  }
    70  
    71  func (c *flowClientDummy) GetOriginalDest(ipSrc, ipDst net.IP, srcport, dstport uint16, protonum uint8) (net.IP, uint16, uint32, error) {
    72  
    73  	dnsServerIP := findDNSServerIP()
    74  	fmt.Println("using DNS Server IP", dnsServerIP)
    75  	return dnsServerIP, 53, 100, nil
    76  }
    77  
    78  func addDNSNamePolicy(context *pucontext.PUContext) {
    79  	context.DNSACLs = policy.DNSRuleList{
    80  		"www.google.com.": []policy.PortProtocolPolicy{
    81  			{Ports: []string{"80"},
    82  				Protocols: []string{"tcp"},
    83  				Policy: &policy.FlowPolicy{
    84  					Action:   policy.Accept,
    85  					PolicyID: "2",
    86  				}},
    87  		},
    88  	}
    89  }
    90  
    91  func CustomDialer(ctx context.Context, network, address string) (net.Conn, error) {
    92  	d := net.Dialer{}
    93  	return d.DialContext(ctx, "udp", "127.0.0.1:53001")
    94  }
    95  
    96  func createCustomResolver() *net.Resolver {
    97  	r := &net.Resolver{
    98  		PreferGo: true,
    99  		Dial:     CustomDialer,
   100  	}
   101  
   102  	return r
   103  }
   104  
   105  // DNSCollector implements a default collector infrastructure to syslog
   106  type DNSCollector struct{}
   107  
   108  // CollectFlowEvent is part of the EventCollector interface.
   109  func (d *DNSCollector) CollectFlowEvent(record *collector.FlowRecord) {}
   110  
   111  // CollectContainerEvent is part of the EventCollector interface.
   112  func (d *DNSCollector) CollectContainerEvent(record *collector.ContainerRecord) {}
   113  
   114  // CollectUserEvent is part of the EventCollector interface.
   115  func (d *DNSCollector) CollectUserEvent(record *collector.UserRecord) {}
   116  
   117  // CollectTraceEvent collects iptables trace events
   118  func (d *DNSCollector) CollectTraceEvent(records []string) {}
   119  
   120  // CollectPingEvent collects ping events
   121  func (d *DNSCollector) CollectPingEvent(report *collector.PingReport) {}
   122  
   123  // CollectPacketEvent collects packet events from the datapath
   124  func (d *DNSCollector) CollectPacketEvent(report *collector.PacketReport) {}
   125  
   126  // CollectCounterEvent collect counters from the datapath
   127  func (d *DNSCollector) CollectCounterEvent(report *collector.CounterReport) {}
   128  
   129  // CollectConnectionExceptionReport collects the connection exception report
   130  func (d *DNSCollector) CollectConnectionExceptionReport(_ *collector.ConnectionExceptionReport) {
   131  }
   132  
   133  var r collector.DNSRequestReport
   134  var l sync.Mutex
   135  
   136  // CollectDNSRequests collect counters from the datapath
   137  func (d *DNSCollector) CollectDNSRequests(report *collector.DNSRequestReport) {
   138  	l.Lock()
   139  	r = *report
   140  	l.Unlock()
   141  }
   142  
   143  func TestDNS(t *testing.T) {
   144  	ctx, cancel := context.WithCancel(context.Background())
   145  	defer cancel()
   146  	puIDcache := cache.NewCache("puFromContextID")
   147  
   148  	fp := &policy.PUInfo{
   149  		Runtime: policy.NewPURuntimeWithDefaults(),
   150  		Policy:  policy.NewPUPolicyWithDefaults(),
   151  	}
   152  	pu, _ := pucontext.NewPU("pu1", fp, nil, 24*time.Hour) // nolint
   153  
   154  	addDNSNamePolicy(pu)
   155  
   156  	puIDcache.AddOrUpdate("pu1", pu)
   157  	conntrack := &flowClientDummy{}
   158  	collector := &DNSCollector{}
   159  
   160  	proxy := New(ctx, puIDcache, conntrack, collector)
   161  
   162  	err := proxy.StartDNSServer(ctx, "pu1", "53001")
   163  	assert.Equal(t, err == nil, true, "start dns server")
   164  
   165  	resolver := createCustomResolver()
   166  	waitTimeBeforeReport = 3 * time.Second
   167  	resolver.LookupIPAddr(ctx, "www.google.com") // nolint
   168  	resolver.LookupIPAddr(ctx, "www.google.com") // nolint
   169  
   170  	assert.Equal(t, err == nil, true, "err should be nil")
   171  
   172  	time.Sleep(5 * time.Second)
   173  	l.Lock()
   174  	assert.Equal(t, r.NameLookup == "www.google.com.", true, "lookup should be www.google.com")
   175  	assert.Equal(t, r.Count >= 2 && r.Count <= 10, true, fmt.Sprintf("count should be 2, got %d", r.Count))
   176  	l.Unlock()
   177  	proxy.Unenforce(ctx, "pu1") // nolint
   178  }
   179  
   180  const (
   181  	contextID   = "host"
   182  	serviceID   = "serviceID"
   183  	port80      = "80"
   184  	fqdn        = "www.example.com."
   185  	fqdnTwo     = "two.example.com."
   186  	fqdnKeep    = "keep.example.com."
   187  	ip192_0_2_1 = "192.0.2.1"
   188  	ip192_0_2_2 = "192.0.2.2"
   189  	ip192_0_2_3 = "192.0.2.3"
   190  )
   191  
   192  func TestProxy_removeIPfromFQDN(t *testing.T) {
   193  	type args struct {
   194  		contextID string
   195  		fqdn      string
   196  		ipAddress string
   197  	}
   198  	tests := []struct {
   199  		name             string
   200  		args             args
   201  		existing         map[string]*dnsNamesToIP
   202  		wantContextEntry bool
   203  		want             map[string][]string
   204  	}{
   205  		{
   206  			name: "context not in cache",
   207  			args: args{
   208  				contextID: "does not exist",
   209  				fqdn:      fqdn,
   210  				ipAddress: "",
   211  			},
   212  			wantContextEntry: false,
   213  		},
   214  		{
   215  			name: "nothing to remove from empty list",
   216  			args: args{
   217  				contextID: contextID,
   218  				fqdn:      fqdn,
   219  				ipAddress: ip192_0_2_1,
   220  			},
   221  			wantContextEntry: true,
   222  			existing: map[string]*dnsNamesToIP{
   223  				contextID: {
   224  					nameToIP: map[string][]string{
   225  						fqdnKeep: {ip192_0_2_1},
   226  						fqdn:     {},
   227  					},
   228  				},
   229  			},
   230  			want: map[string][]string{
   231  				fqdnKeep: {ip192_0_2_1},
   232  				fqdn:     {},
   233  			},
   234  		},
   235  		{
   236  			name: "IP does not match from existing list",
   237  			args: args{
   238  				contextID: contextID,
   239  				fqdn:      fqdn,
   240  				ipAddress: ip192_0_2_1,
   241  			},
   242  			wantContextEntry: true,
   243  			existing: map[string]*dnsNamesToIP{
   244  				contextID: {
   245  					nameToIP: map[string][]string{
   246  						fqdnKeep: {ip192_0_2_1},
   247  						fqdn:     {ip192_0_2_2},
   248  					},
   249  				},
   250  			},
   251  			want: map[string][]string{
   252  				fqdnKeep: {ip192_0_2_1},
   253  				fqdn:     {ip192_0_2_2},
   254  			},
   255  		},
   256  		{
   257  			name: "IP successfully being removed from list",
   258  			args: args{
   259  				contextID: contextID,
   260  				fqdn:      fqdn,
   261  				ipAddress: ip192_0_2_1,
   262  			},
   263  			wantContextEntry: true,
   264  			existing: map[string]*dnsNamesToIP{
   265  				contextID: {
   266  					nameToIP: map[string][]string{
   267  						fqdnKeep: {ip192_0_2_1},
   268  						fqdn:     {ip192_0_2_1},
   269  					},
   270  				},
   271  			},
   272  			want: map[string][]string{
   273  				fqdnKeep: {ip192_0_2_1},
   274  				fqdn:     {},
   275  			},
   276  		},
   277  		{
   278  			name: "IP successfully being removed from list with other entries",
   279  			args: args{
   280  				contextID: contextID,
   281  				fqdn:      fqdn,
   282  				ipAddress: ip192_0_2_2,
   283  			},
   284  			wantContextEntry: true,
   285  			existing: map[string]*dnsNamesToIP{
   286  				contextID: {
   287  					nameToIP: map[string][]string{
   288  						fqdnKeep: {ip192_0_2_1},
   289  						fqdn:     {ip192_0_2_1, ip192_0_2_2, ip192_0_2_3},
   290  					},
   291  				},
   292  			},
   293  			want: map[string][]string{
   294  				fqdnKeep: {ip192_0_2_1},
   295  				fqdn:     {ip192_0_2_1, ip192_0_2_3},
   296  			},
   297  		},
   298  	}
   299  	for _, tt := range tests {
   300  		t.Run(tt.name, func(t *testing.T) {
   301  			p := &Proxy{
   302  				contextIDToDNSNames:      cache.NewCache("contextIDtoDNSNames"),
   303  				contextIDToDNSNamesLocks: newMutexMap(),
   304  			}
   305  			for k, v := range tt.existing {
   306  				p.contextIDToDNSNames.AddOrUpdate(k, v)
   307  			}
   308  			p.removeIPfromFQDN(tt.args.contextID, tt.args.fqdn, tt.args.ipAddress)
   309  
   310  			val, err := p.contextIDToDNSNames.Get(tt.args.contextID)
   311  			if (err == nil) != tt.wantContextEntry {
   312  				t.Errorf("entry for context %q does not exist", tt.args.contextID)
   313  			}
   314  			if err == nil {
   315  				m := val.(*dnsNamesToIP)
   316  				if !reflect.DeepEqual(m.nameToIP, tt.want) {
   317  					t.Errorf("want %#v, have %#v", tt.want, m.nameToIP)
   318  				}
   319  			}
   320  		})
   321  	}
   322  }
   323  
   324  func TestProxy_updateFQDNWithIPs(t *testing.T) {
   325  	type args struct {
   326  		contextID   string
   327  		fqdn        string
   328  		ipAddresses []string
   329  	}
   330  	tests := []struct {
   331  		name             string
   332  		args             args
   333  		existing         map[string]*dnsNamesToIP
   334  		wantContextEntry bool
   335  		want             map[string][]string
   336  	}{
   337  		{
   338  			name: "context not in cache",
   339  			args: args{
   340  				contextID:   "does not exist",
   341  				fqdn:        fqdn,
   342  				ipAddresses: nil,
   343  			},
   344  			wantContextEntry: false,
   345  		},
   346  		{
   347  			name: "adding an empty list",
   348  			args: args{
   349  				contextID:   contextID,
   350  				fqdn:        fqdn,
   351  				ipAddresses: []string{},
   352  			},
   353  			wantContextEntry: true,
   354  			existing: map[string]*dnsNamesToIP{
   355  				contextID: {
   356  					nameToIP: map[string][]string{
   357  						fqdnKeep: {ip192_0_2_1},
   358  						fqdn:     {},
   359  					},
   360  				},
   361  			},
   362  			want: map[string][]string{
   363  				fqdnKeep: {ip192_0_2_1},
   364  				fqdn:     {},
   365  			},
   366  		},
   367  		{
   368  			name: "adding IPs to an empty list",
   369  			args: args{
   370  				contextID:   contextID,
   371  				fqdn:        fqdn,
   372  				ipAddresses: []string{ip192_0_2_2, ip192_0_2_3},
   373  			},
   374  			wantContextEntry: true,
   375  			existing: map[string]*dnsNamesToIP{
   376  				contextID: {
   377  					nameToIP: map[string][]string{
   378  						fqdnKeep: {ip192_0_2_1},
   379  						fqdn:     {},
   380  					},
   381  				},
   382  			},
   383  			want: map[string][]string{
   384  				fqdnKeep: {ip192_0_2_1},
   385  				fqdn:     {ip192_0_2_2, ip192_0_2_3},
   386  			},
   387  		},
   388  		{
   389  			name: "adding an existing IP to the list",
   390  			args: args{
   391  				contextID:   contextID,
   392  				fqdn:        fqdn,
   393  				ipAddresses: []string{ip192_0_2_2},
   394  			},
   395  			wantContextEntry: true,
   396  			existing: map[string]*dnsNamesToIP{
   397  				contextID: {
   398  					nameToIP: map[string][]string{
   399  						fqdnKeep: {ip192_0_2_1},
   400  						fqdn:     {ip192_0_2_1, ip192_0_2_2},
   401  					},
   402  				},
   403  			},
   404  			want: map[string][]string{
   405  				fqdnKeep: {ip192_0_2_1},
   406  				fqdn:     {ip192_0_2_1, ip192_0_2_2},
   407  			},
   408  		},
   409  		{
   410  			name: "adding an existing IP and a new IP to an existing list",
   411  			args: args{
   412  				contextID:   contextID,
   413  				fqdn:        fqdn,
   414  				ipAddresses: []string{ip192_0_2_2, ip192_0_2_3},
   415  			},
   416  			wantContextEntry: true,
   417  			existing: map[string]*dnsNamesToIP{
   418  				contextID: {
   419  					nameToIP: map[string][]string{
   420  						fqdnKeep: {ip192_0_2_1},
   421  						fqdn:     {ip192_0_2_1, ip192_0_2_2},
   422  					},
   423  				},
   424  			},
   425  			want: map[string][]string{
   426  				fqdnKeep: {ip192_0_2_1},
   427  				fqdn:     {ip192_0_2_1, ip192_0_2_2, ip192_0_2_3},
   428  			},
   429  		},
   430  	}
   431  	for _, tt := range tests {
   432  		t.Run(tt.name, func(t *testing.T) {
   433  			p := &Proxy{
   434  				contextIDToDNSNames:      cache.NewCache("contextIDtoDNSNames"),
   435  				contextIDToDNSNamesLocks: newMutexMap(),
   436  			}
   437  			for k, v := range tt.existing {
   438  				p.contextIDToDNSNames.AddOrUpdate(k, v)
   439  			}
   440  			p.updateFQDNWithIPs(tt.args.contextID, tt.args.fqdn, tt.args.ipAddresses)
   441  
   442  			val, err := p.contextIDToDNSNames.Get(tt.args.contextID)
   443  			if (err == nil) != tt.wantContextEntry {
   444  				t.Errorf("entry for context %q does not exist", tt.args.contextID)
   445  			}
   446  			if err == nil {
   447  				m := val.(*dnsNamesToIP)
   448  				if !reflect.DeepEqual(m.nameToIP, tt.want) {
   449  					t.Errorf("want %#v, have %#v", tt.want, m.nameToIP)
   450  				}
   451  			}
   452  		})
   453  	}
   454  }
   455  
   456  func TestProxy_defaultRemoveExpiredEntry(t *testing.T) {
   457  	type args struct {
   458  		ipaddress string
   459  	}
   460  	type existing struct {
   461  		pus   map[string]*pucontext.PUContext
   462  		fqdns map[string]*dnsNamesToIP
   463  		ips   map[string]iptottlinfo
   464  	}
   465  	type want struct {
   466  		removedFromIPToTTL bool
   467  		fqdns              map[string]*dnsNamesToIP
   468  	}
   469  	tests := []struct {
   470  		name     string
   471  		args     args
   472  		existing existing
   473  		want     want
   474  	}{
   475  		{
   476  			name: "TTL info for IP does not exist in cache",
   477  			args: args{ipaddress: ip192_0_2_1},
   478  			existing: existing{
   479  				ips: map[string]iptottlinfo{
   480  					ip192_0_2_2: {
   481  						contextIDs: map[string]struct{}{
   482  							contextID: {},
   483  						},
   484  						fqdns: map[string]struct{}{
   485  							fqdn:    {},
   486  							fqdnTwo: {},
   487  						},
   488  					},
   489  				},
   490  				fqdns: map[string]*dnsNamesToIP{
   491  					contextID: {
   492  						nameToIP: map[string][]string{},
   493  					},
   494  				},
   495  			},
   496  			want: want{
   497  				removedFromIPToTTL: true,
   498  				fqdns: map[string]*dnsNamesToIP{
   499  					contextID: {
   500  						nameToIP: map[string][]string{},
   501  					},
   502  				},
   503  			},
   504  		},
   505  		{
   506  			name: "PUContext in TTL info for IP does not exist in cache",
   507  			args: args{ipaddress: ip192_0_2_1},
   508  			existing: existing{
   509  				ips: map[string]iptottlinfo{
   510  					ip192_0_2_1: {
   511  						contextIDs: map[string]struct{}{
   512  							contextID: {},
   513  						},
   514  						fqdns: map[string]struct{}{
   515  							fqdn:    {},
   516  							fqdnTwo: {},
   517  						},
   518  					},
   519  				},
   520  				fqdns: map[string]*dnsNamesToIP{
   521  					contextID: {
   522  						nameToIP: map[string][]string{},
   523  					},
   524  				},
   525  				pus: map[string]*pucontext.PUContext{},
   526  			},
   527  			want: want{
   528  				removedFromIPToTTL: true,
   529  				fqdns: map[string]*dnsNamesToIP{
   530  					contextID: {
   531  						nameToIP: map[string][]string{},
   532  					},
   533  				},
   534  			},
   535  		},
   536  		{
   537  			name: "successfully remove 192_0_2_1",
   538  			args: args{ipaddress: ip192_0_2_1},
   539  			existing: existing{
   540  				pus: map[string]*pucontext.PUContext{
   541  					contextID: {
   542  						RWMutex:         sync.RWMutex{},
   543  						ApplicationACLs: acls.NewACLCache(),
   544  						DNSACLs: map[string][]policy.PortProtocolPolicy{
   545  							fqdn: {
   546  								{
   547  									Ports:     []string{port80},
   548  									Protocols: []string{constants.TCPProtoNum},
   549  									Policy: &policy.FlowPolicy{
   550  										ServiceID: serviceID,
   551  									},
   552  								},
   553  							},
   554  						},
   555  					},
   556  				},
   557  				fqdns: map[string]*dnsNamesToIP{
   558  					contextID: {
   559  						nameToIP: map[string][]string{
   560  							fqdn: {ip192_0_2_1, ip192_0_2_2},
   561  						},
   562  					},
   563  				},
   564  				ips: map[string]iptottlinfo{
   565  					ip192_0_2_1: {
   566  						contextIDs: map[string]struct{}{
   567  							contextID: {},
   568  						},
   569  						fqdns: map[string]struct{}{
   570  							fqdn:    {},
   571  							fqdnTwo: {},
   572  						},
   573  					},
   574  				},
   575  			},
   576  			want: want{
   577  				removedFromIPToTTL: true,
   578  				fqdns: map[string]*dnsNamesToIP{
   579  					contextID: {
   580  						nameToIP: map[string][]string{
   581  							fqdn: {ip192_0_2_2},
   582  						},
   583  					},
   584  				},
   585  			},
   586  		},
   587  	}
   588  	for _, tt := range tests {
   589  		t.Run(tt.name, func(t *testing.T) {
   590  			p := &Proxy{
   591  				puFromID:                 cache.NewCache("puFromContextID"),
   592  				contextIDToServer:        map[string]*dns.Server{},
   593  				contextIDToDNSNames:      cache.NewCache("contextIDtoDNSNames"),
   594  				contextIDToDNSNamesLocks: newMutexMap(),
   595  				IPToTTL:                  cache.NewCache("IPToTTL"),
   596  				IPToTTLLocks:             newMutexMap(),
   597  			}
   598  			for k, v := range tt.existing.pus {
   599  				p.puFromID.AddOrUpdate(k, v)
   600  			}
   601  			for k, v := range tt.existing.fqdns {
   602  				p.contextIDToDNSNames.AddOrUpdate(k, v)
   603  			}
   604  			for k, v := range tt.existing.ips {
   605  				p.IPToTTL.AddOrUpdate(k, v)
   606  			}
   607  			p.defaultRemoveExpiredEntry(tt.args.ipaddress)
   608  
   609  			// we check to see if the entries got removed correctly
   610  			if _, err := p.IPToTTL.Get(tt.args.ipaddress); (err != nil) != tt.want.removedFromIPToTTL {
   611  				t.Errorf("entry exists in IPToTTL cache: %v, want.removedFromIPToTTL %v", (err != nil), tt.want.removedFromIPToTTL)
   612  			}
   613  			for contextID, wantFqdns := range tt.want.fqdns {
   614  				existingFqdnsRaw, err := p.contextIDToDNSNames.Get(contextID)
   615  				if err != nil {
   616  					t.Errorf("no map for context %q in contextIDToDNSNames any longer", contextID)
   617  				}
   618  				existingFqdns := existingFqdnsRaw.(*dnsNamesToIP)
   619  				if !reflect.DeepEqual(wantFqdns.nameToIP, existingFqdns.nameToIP) {
   620  					t.Errorf("context %q: have fqdns %#v - want fqdns %#v", contextID, existingFqdns.nameToIP, wantFqdns.nameToIP)
   621  				}
   622  			}
   623  		})
   624  	}
   625  }
   626  
   627  func TestProxy_handleTTLInfoList(t *testing.T) {
   628  	type args struct {
   629  		contextID      string
   630  		fqdn           string
   631  		dnsttlinfolist []*dnsttlinfo
   632  		updateOnly     bool
   633  	}
   634  	type existing struct {
   635  		ips map[string]iptottlinfo
   636  	}
   637  	type want struct {
   638  		mustFireExpiry               bool
   639  		newEntry                     bool
   640  		existingEntryIncreasedExpiry bool
   641  	}
   642  	tests := []struct {
   643  		name     string
   644  		args     args
   645  		existing existing
   646  		want     want
   647  	}{
   648  		{
   649  			name: "creating new TTL info in the cache for new IP when updateOnly is false",
   650  			args: args{
   651  				contextID: contextID,
   652  				fqdn:      fqdn,
   653  				dnsttlinfolist: []*dnsttlinfo{
   654  					{
   655  						ipaddress: ip192_0_2_1,
   656  						ttl:       1,
   657  					},
   658  				},
   659  				updateOnly: false,
   660  			},
   661  			existing: existing{
   662  				ips: map[string]iptottlinfo{},
   663  			},
   664  			want: want{
   665  				mustFireExpiry: true,
   666  				newEntry:       true,
   667  			},
   668  		},
   669  		{
   670  			name: "not creating new TTL info in the cache for new IP when updateOnly is true",
   671  			args: args{
   672  				contextID: contextID,
   673  				fqdn:      fqdn,
   674  				dnsttlinfolist: []*dnsttlinfo{
   675  					{
   676  						ipaddress: ip192_0_2_1,
   677  						ttl:       1,
   678  					},
   679  				},
   680  				updateOnly: true,
   681  			},
   682  			existing: existing{
   683  				ips: map[string]iptottlinfo{},
   684  			},
   685  			want: want{
   686  				mustFireExpiry: false,
   687  				newEntry:       false,
   688  			},
   689  		},
   690  		{
   691  			name: "updating existing TTL info in the cache",
   692  			args: args{
   693  				contextID: contextID,
   694  				fqdn:      fqdn,
   695  				dnsttlinfolist: []*dnsttlinfo{
   696  					{
   697  						ipaddress: ip192_0_2_1,
   698  						ttl:       1,
   699  					},
   700  				},
   701  				updateOnly: true,
   702  			},
   703  			existing: existing{
   704  				ips: map[string]iptottlinfo{
   705  					ip192_0_2_1: {
   706  						ipaddress:  ip192_0_2_1,
   707  						expiryTime: time.Now(),
   708  						contextIDs: map[string]struct{}{},
   709  						fqdns:      map[string]struct{}{},
   710  					},
   711  				},
   712  			},
   713  			want: want{
   714  				mustFireExpiry:               true,
   715  				existingEntryIncreasedExpiry: true,
   716  			},
   717  		},
   718  	}
   719  	for _, tt := range tests {
   720  		t.Run(tt.name, func(t *testing.T) {
   721  			p := &Proxy{
   722  				puFromID:                 cache.NewCache("puFromContextID"),
   723  				contextIDToServer:        map[string]*dns.Server{},
   724  				contextIDToDNSNames:      cache.NewCache("contextIDtoDNSNames"),
   725  				contextIDToDNSNamesLocks: newMutexMap(),
   726  				IPToTTL:                  cache.NewCache("IPToTTL"),
   727  				IPToTTLLocks:             newMutexMap(),
   728  			}
   729  
   730  			// exercises the expiry trigger
   731  			var wg sync.WaitGroup
   732  			if tt.want.mustFireExpiry {
   733  				wg.Add(1)
   734  			}
   735  			p.removeExpiredEntry = func(ipaddress string) {
   736  				// this is just to exercise the expiry trigger
   737  				t.Logf("removeExpiredEntry %q", ipaddress)
   738  				wg.Done()
   739  			}
   740  
   741  			for k, v := range tt.existing.ips {
   742  				// TODO: this is awful, not sure how to better mock this at this point
   743  				v.timer = time.AfterFunc(time.Second, func() {
   744  					if p.removeExpiredEntry != nil {
   745  						p.removeExpiredEntry(k)
   746  					}
   747  				})
   748  				p.IPToTTL.AddOrUpdate(k, v)
   749  			}
   750  			p.handleTTLInfoList(tt.args.contextID, tt.args.fqdn, tt.args.dnsttlinfolist, tt.args.updateOnly)
   751  			wg.Wait()
   752  
   753  			if tt.want.newEntry {
   754  				for _, i := range tt.args.dnsttlinfolist {
   755  					if _, err := p.IPToTTL.Get(i.ipaddress); err != nil {
   756  						t.Errorf("no new entry for IP %q", i.ipaddress)
   757  					}
   758  				}
   759  			}
   760  
   761  			if tt.want.existingEntryIncreasedExpiry {
   762  				for _, i := range tt.args.dnsttlinfolist {
   763  					iptottlExisting, ok := tt.existing.ips[i.ipaddress]
   764  					if !ok {
   765  						t.Errorf("not in the existing map %q", i.ipaddress)
   766  					}
   767  					iptottlRaw, err := p.IPToTTL.Get(i.ipaddress)
   768  					if err != nil {
   769  						t.Errorf("no entry for IP %q", i.ipaddress)
   770  					}
   771  					iptottlUpdated := iptottlRaw.(iptottlinfo)
   772  
   773  					if !iptottlUpdated.expiryTime.After(iptottlExisting.expiryTime) {
   774  						t.Errorf("expiry time not updated")
   775  					}
   776  				}
   777  			}
   778  		})
   779  	}
   780  }
   781  
   782  func TestProxy_Enforce(t *testing.T) {
   783  	type args struct {
   784  		contextID string
   785  		puInfo    *policy.PUInfo
   786  	}
   787  	type existing struct {
   788  		fqdns map[string]*dnsNamesToIP
   789  		pus   map[string]*pucontext.PUContext
   790  	}
   791  	type want struct {
   792  		fqdns map[string]*dnsNamesToIP
   793  	}
   794  	tests := []struct {
   795  		name     string
   796  		args     args
   797  		existing existing
   798  		want     want
   799  		wantErr  bool
   800  	}{
   801  		{
   802  			name: "if this is the first enforce call on a PU, simply initialize the data structures, with FQDNs from policy",
   803  			args: args{
   804  				contextID: contextID,
   805  				puInfo: &policy.PUInfo{
   806  					Runtime:   nil,
   807  					ContextID: contextID,
   808  					Policy: &policy.PUPolicy{
   809  						DNSACLs: map[string][]policy.PortProtocolPolicy{
   810  							fqdn: {
   811  								{
   812  									Ports:     []string{port80},
   813  									Protocols: []string{constants.TCPProtoNum},
   814  									Policy: &policy.FlowPolicy{
   815  										ServiceID: serviceID,
   816  									},
   817  								},
   818  							},
   819  							fqdnTwo: {
   820  								{
   821  									Ports:     []string{port80},
   822  									Protocols: []string{constants.TCPProtoNum},
   823  									Policy: &policy.FlowPolicy{
   824  										ServiceID: serviceID,
   825  									},
   826  								},
   827  							},
   828  						},
   829  					},
   830  				},
   831  			},
   832  			existing: existing{
   833  				fqdns: map[string]*dnsNamesToIP{},
   834  				pus:   map[string]*pucontext.PUContext{},
   835  			},
   836  			want: want{
   837  				fqdns: map[string]*dnsNamesToIP{
   838  					contextID: {
   839  						nameToIP: map[string][]string{
   840  							fqdn:    {},
   841  							fqdnTwo: {},
   842  						},
   843  					},
   844  				},
   845  			},
   846  		},
   847  		{
   848  			name: "new FQDNs from a policy simply register with the DNS proxy",
   849  			args: args{
   850  				contextID: contextID,
   851  				puInfo: &policy.PUInfo{
   852  					Runtime:   nil,
   853  					ContextID: contextID,
   854  					Policy: &policy.PUPolicy{
   855  						DNSACLs: map[string][]policy.PortProtocolPolicy{
   856  							fqdn: {
   857  								{
   858  									Ports:     []string{port80},
   859  									Protocols: []string{constants.TCPProtoNum},
   860  									Policy: &policy.FlowPolicy{
   861  										ServiceID: serviceID,
   862  									},
   863  								},
   864  							},
   865  							fqdnTwo: {
   866  								{
   867  									Ports:     []string{port80},
   868  									Protocols: []string{constants.TCPProtoNum},
   869  									Policy: &policy.FlowPolicy{
   870  										ServiceID: serviceID,
   871  									},
   872  								},
   873  							},
   874  						},
   875  					},
   876  				},
   877  			},
   878  			existing: existing{
   879  				fqdns: map[string]*dnsNamesToIP{
   880  					contextID: {
   881  						nameToIP: map[string][]string{},
   882  					},
   883  				},
   884  				pus: map[string]*pucontext.PUContext{},
   885  			},
   886  			want: want{
   887  				fqdns: map[string]*dnsNamesToIP{
   888  					contextID: {
   889  						nameToIP: map[string][]string{
   890  							fqdn:    {},
   891  							fqdnTwo: {},
   892  						},
   893  					},
   894  				},
   895  			},
   896  		},
   897  		{
   898  			name: "existing FQDNs update ipsets and ApplicationACLs",
   899  			args: args{
   900  				contextID: contextID,
   901  				puInfo: &policy.PUInfo{
   902  					Runtime:   nil,
   903  					ContextID: contextID,
   904  					Policy: &policy.PUPolicy{
   905  						DNSACLs: map[string][]policy.PortProtocolPolicy{
   906  							fqdn: {
   907  								{
   908  									Ports:     []string{port80},
   909  									Protocols: []string{constants.TCPProtoNum},
   910  									Policy: &policy.FlowPolicy{
   911  										ServiceID: serviceID,
   912  									},
   913  								},
   914  							},
   915  							fqdnTwo: {
   916  								{
   917  									Ports:     []string{port80},
   918  									Protocols: []string{constants.TCPProtoNum},
   919  									Policy: &policy.FlowPolicy{
   920  										ServiceID: serviceID,
   921  									},
   922  								},
   923  							},
   924  						},
   925  					},
   926  				},
   927  			},
   928  			existing: existing{
   929  				fqdns: map[string]*dnsNamesToIP{
   930  					contextID: {
   931  						nameToIP: map[string][]string{
   932  							fqdn:    {ip192_0_2_1},
   933  							fqdnTwo: {ip192_0_2_2},
   934  						},
   935  					},
   936  				},
   937  				pus: map[string]*pucontext.PUContext{
   938  					contextID: {
   939  						RWMutex:         sync.RWMutex{},
   940  						ApplicationACLs: acls.NewACLCache(),
   941  						DNSACLs: map[string][]policy.PortProtocolPolicy{
   942  							fqdn: {
   943  								{
   944  									Ports:     []string{port80},
   945  									Protocols: []string{constants.TCPProtoNum},
   946  									Policy: &policy.FlowPolicy{
   947  										ServiceID: serviceID,
   948  									},
   949  								},
   950  							},
   951  						},
   952  					},
   953  				},
   954  			},
   955  			want: want{
   956  				fqdns: map[string]*dnsNamesToIP{
   957  					contextID: {
   958  						nameToIP: map[string][]string{
   959  							fqdn:    {ip192_0_2_1},
   960  							fqdnTwo: {ip192_0_2_2},
   961  						},
   962  					},
   963  				},
   964  			},
   965  		},
   966  		{
   967  			name: "existing FQDNs do not update ipsets and ApplicationACLs if PU context does not exist",
   968  			args: args{
   969  				contextID: contextID,
   970  				puInfo: &policy.PUInfo{
   971  					Runtime:   nil,
   972  					ContextID: contextID,
   973  					Policy: &policy.PUPolicy{
   974  						DNSACLs: map[string][]policy.PortProtocolPolicy{
   975  							fqdn: {
   976  								{
   977  									Ports:     []string{port80},
   978  									Protocols: []string{constants.TCPProtoNum},
   979  									Policy: &policy.FlowPolicy{
   980  										ServiceID: serviceID,
   981  									},
   982  								},
   983  							},
   984  							fqdnTwo: {
   985  								{
   986  									Ports:     []string{port80},
   987  									Protocols: []string{constants.TCPProtoNum},
   988  									Policy: &policy.FlowPolicy{
   989  										ServiceID: serviceID,
   990  									},
   991  								},
   992  							},
   993  						},
   994  					},
   995  				},
   996  			},
   997  			existing: existing{
   998  				fqdns: map[string]*dnsNamesToIP{
   999  					contextID: {
  1000  						nameToIP: map[string][]string{
  1001  							fqdn:    {ip192_0_2_1},
  1002  							fqdnTwo: {ip192_0_2_2},
  1003  						},
  1004  					},
  1005  				},
  1006  				pus: map[string]*pucontext.PUContext{},
  1007  			},
  1008  			want: want{
  1009  				fqdns: map[string]*dnsNamesToIP{
  1010  					contextID: {
  1011  						nameToIP: map[string][]string{
  1012  							fqdn:    {ip192_0_2_1},
  1013  							fqdnTwo: {ip192_0_2_2},
  1014  						},
  1015  					},
  1016  				},
  1017  			},
  1018  		},
  1019  	}
  1020  	for _, tt := range tests {
  1021  		t.Run(tt.name, func(t *testing.T) {
  1022  			ctx, cancel := context.WithCancel(context.Background())
  1023  			defer cancel()
  1024  			p := &Proxy{
  1025  				puFromID:                 cache.NewCache("puFromContextID"),
  1026  				contextIDToServer:        map[string]*dns.Server{},
  1027  				contextIDToDNSNames:      cache.NewCache("contextIDtoDNSNames"),
  1028  				contextIDToDNSNamesLocks: newMutexMap(),
  1029  			}
  1030  			for k, v := range tt.existing.fqdns {
  1031  				p.contextIDToDNSNames.AddOrUpdate(k, v)
  1032  			}
  1033  			for k, v := range tt.existing.pus {
  1034  				p.puFromID.AddOrUpdate(k, v)
  1035  			}
  1036  			if err := p.Enforce(ctx, tt.args.contextID, tt.args.puInfo); (err != nil) != tt.wantErr {
  1037  				t.Errorf("Proxy.Enforce() error = %v, wantErr %v", err, tt.wantErr)
  1038  			}
  1039  			for contextID, wantFqdns := range tt.want.fqdns {
  1040  				existingFqdnsRaw, err := p.contextIDToDNSNames.Get(contextID)
  1041  				if err != nil {
  1042  					t.Errorf("no map for context %q in contextIDToDNSNames any longer", contextID)
  1043  				}
  1044  				existingFqdns := existingFqdnsRaw.(*dnsNamesToIP)
  1045  				if !reflect.DeepEqual(wantFqdns.nameToIP, existingFqdns.nameToIP) {
  1046  					t.Errorf("context %q: have fqdns %#v - want fqdns %#v", contextID, existingFqdns.nameToIP, wantFqdns.nameToIP)
  1047  				}
  1048  			}
  1049  		})
  1050  	}
  1051  }