github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns_test.go (about)

     1  // +build linux
     2  
     3  package dnsproxy
     4  
     5  import (
     6  	"context"
     7  	"net"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/magiconair/properties/assert"
    13  	"go.aporeto.io/trireme-lib/collector"
    14  	provider "go.aporeto.io/trireme-lib/controller/pkg/aclprovider"
    15  	"go.aporeto.io/trireme-lib/controller/pkg/ipsetmanager"
    16  	"go.aporeto.io/trireme-lib/controller/pkg/pucontext"
    17  	"go.aporeto.io/trireme-lib/policy"
    18  	"go.aporeto.io/trireme-lib/utils/cache"
    19  )
    20  
    21  type flowClientDummy struct {
    22  }
    23  
    24  func (c *flowClientDummy) Close() error {
    25  	return nil
    26  }
    27  
    28  func (c *flowClientDummy) UpdateMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32, network bool) error {
    29  	return nil
    30  }
    31  
    32  func (c *flowClientDummy) UpdateNetworkFlowMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32) error {
    33  	return nil
    34  }
    35  
    36  func (c *flowClientDummy) UpdateApplicationFlowMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32) error {
    37  	return nil
    38  }
    39  
    40  func (c *flowClientDummy) GetOriginalDest(ipSrc, ipDst net.IP, srcport, dstport uint16, protonum uint8) (net.IP, uint16, uint32, error) {
    41  	return net.ParseIP("8.8.8.8"), 53, 100, nil
    42  }
    43  
    44  func addDNSNamePolicy(context *pucontext.PUContext) {
    45  	context.DNSACLs = policy.DNSRuleList{
    46  		"www.google.com": []policy.PortProtocolPolicy{
    47  			{Ports: []string{"80"},
    48  				Protocols: []string{"tcp"},
    49  				Policy: &policy.FlowPolicy{
    50  					Action:   policy.Accept,
    51  					PolicyID: "2",
    52  				}},
    53  		},
    54  	}
    55  }
    56  
    57  func CustomDialer(ctx context.Context, network, address string) (net.Conn, error) {
    58  	d := net.Dialer{}
    59  	return d.DialContext(ctx, "udp", "127.0.0.1:53001")
    60  }
    61  
    62  func createCustomResolver() *net.Resolver {
    63  	r := &net.Resolver{
    64  		PreferGo: true,
    65  		Dial:     CustomDialer,
    66  	}
    67  
    68  	return r
    69  }
    70  
    71  // DNSCollector implements a default collector infrastructure to syslog
    72  type DNSCollector struct{}
    73  
    74  // CollectFlowEvent is part of the EventCollector interface.
    75  func (d *DNSCollector) CollectFlowEvent(record *collector.FlowRecord) {}
    76  
    77  // CollectContainerEvent is part of the EventCollector interface.
    78  func (d *DNSCollector) CollectContainerEvent(record *collector.ContainerRecord) {}
    79  
    80  // CollectUserEvent is part of the EventCollector interface.
    81  func (d *DNSCollector) CollectUserEvent(record *collector.UserRecord) {}
    82  
    83  // CollectTraceEvent collects iptables trace events
    84  func (d *DNSCollector) CollectTraceEvent(records []string) {}
    85  
    86  // CollectPingEvent collects ping events
    87  func (d *DNSCollector) CollectPingEvent(report *collector.PingReport) {}
    88  
    89  // CollectPacketEvent collects packet events from the datapath
    90  func (d *DNSCollector) CollectPacketEvent(report *collector.PacketReport) {}
    91  
    92  // CollectCounterEvent collect counters from the datapath
    93  func (d *DNSCollector) CollectCounterEvent(report *collector.CounterReport) {}
    94  
    95  var r collector.DNSRequestReport
    96  var l sync.Mutex
    97  
    98  // CollectDNSRequests collect counters from the datapath
    99  func (d *DNSCollector) CollectDNSRequests(report *collector.DNSRequestReport) {
   100  	l.Lock()
   101  	r = *report
   102  	l.Unlock()
   103  }
   104  
   105  func TestDNS(t *testing.T) {
   106  	puIDcache := cache.NewCache("puFromContextID")
   107  
   108  	fp := &policy.PUInfo{
   109  		Runtime: policy.NewPURuntimeWithDefaults(),
   110  		Policy:  policy.NewPUPolicyWithDefaults(),
   111  	}
   112  	pu, _ := pucontext.NewPU("pu1", fp, 24*time.Hour) // nolint
   113  
   114  	addDNSNamePolicy(pu)
   115  
   116  	puIDcache.AddOrUpdate("pu1", pu)
   117  	conntrack := &flowClientDummy{}
   118  	collector := &DNSCollector{}
   119  
   120  	ips := provider.NewTestIpsetProvider()
   121  	proxy := New(puIDcache, conntrack, collector, ipsetmanager.CreateIPsetManager(ips, ips))
   122  
   123  	err := proxy.StartDNSServer("pu1", "53001")
   124  	assert.Equal(t, err == nil, true, "start dns server")
   125  
   126  	resolver := createCustomResolver()
   127  	ctx := context.Background()
   128  	waitTimeBeforeReport = 3 * time.Second
   129  	resolver.LookupIPAddr(ctx, "www.google.com") //nolint
   130  	resolver.LookupIPAddr(ctx, "www.google.com") //nolint
   131  
   132  	assert.Equal(t, err == nil, true, "err should be nil")
   133  
   134  	time.Sleep(5 * time.Second)
   135  	l.Lock()
   136  	assert.Equal(t, r.NameLookup == "www.google.com.", true, "lookup should be www.google.com")
   137  	assert.Equal(t, r.Count >= 2 && r.Count <= 10, true, "count should be 2")
   138  	l.Unlock()
   139  	proxy.ShutdownDNS("pu1")
   140  }