github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns_test.go (about) 1 // +build linux 2 3 package dnsproxy 4 5 import ( 6 "context" 7 "net" 8 "sync" 9 "testing" 10 "time" 11 12 "github.com/magiconair/properties/assert" 13 "go.aporeto.io/trireme-lib/collector" 14 provider "go.aporeto.io/trireme-lib/controller/pkg/aclprovider" 15 "go.aporeto.io/trireme-lib/controller/pkg/ipsetmanager" 16 "go.aporeto.io/trireme-lib/controller/pkg/pucontext" 17 "go.aporeto.io/trireme-lib/policy" 18 "go.aporeto.io/trireme-lib/utils/cache" 19 ) 20 21 type flowClientDummy struct { 22 } 23 24 func (c *flowClientDummy) Close() error { 25 return nil 26 } 27 28 func (c *flowClientDummy) UpdateMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32, network bool) error { 29 return nil 30 } 31 32 func (c *flowClientDummy) UpdateNetworkFlowMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32) error { 33 return nil 34 } 35 36 func (c *flowClientDummy) UpdateApplicationFlowMark(ipSrc, ipDst net.IP, protonum uint8, srcport, dstport uint16, newmark uint32) error { 37 return nil 38 } 39 40 func (c *flowClientDummy) GetOriginalDest(ipSrc, ipDst net.IP, srcport, dstport uint16, protonum uint8) (net.IP, uint16, uint32, error) { 41 return net.ParseIP("8.8.8.8"), 53, 100, nil 42 } 43 44 func addDNSNamePolicy(context *pucontext.PUContext) { 45 context.DNSACLs = policy.DNSRuleList{ 46 "www.google.com": []policy.PortProtocolPolicy{ 47 {Ports: []string{"80"}, 48 Protocols: []string{"tcp"}, 49 Policy: &policy.FlowPolicy{ 50 Action: policy.Accept, 51 PolicyID: "2", 52 }}, 53 }, 54 } 55 } 56 57 func CustomDialer(ctx context.Context, network, address string) (net.Conn, error) { 58 d := net.Dialer{} 59 return d.DialContext(ctx, "udp", "127.0.0.1:53001") 60 } 61 62 func createCustomResolver() *net.Resolver { 63 r := &net.Resolver{ 64 PreferGo: true, 65 Dial: CustomDialer, 66 } 67 68 return r 69 } 70 71 // DNSCollector implements a default collector infrastructure to syslog 72 type DNSCollector struct{} 73 74 // CollectFlowEvent is part of the EventCollector interface. 75 func (d *DNSCollector) CollectFlowEvent(record *collector.FlowRecord) {} 76 77 // CollectContainerEvent is part of the EventCollector interface. 78 func (d *DNSCollector) CollectContainerEvent(record *collector.ContainerRecord) {} 79 80 // CollectUserEvent is part of the EventCollector interface. 81 func (d *DNSCollector) CollectUserEvent(record *collector.UserRecord) {} 82 83 // CollectTraceEvent collects iptables trace events 84 func (d *DNSCollector) CollectTraceEvent(records []string) {} 85 86 // CollectPingEvent collects ping events 87 func (d *DNSCollector) CollectPingEvent(report *collector.PingReport) {} 88 89 // CollectPacketEvent collects packet events from the datapath 90 func (d *DNSCollector) CollectPacketEvent(report *collector.PacketReport) {} 91 92 // CollectCounterEvent collect counters from the datapath 93 func (d *DNSCollector) CollectCounterEvent(report *collector.CounterReport) {} 94 95 var r collector.DNSRequestReport 96 var l sync.Mutex 97 98 // CollectDNSRequests collect counters from the datapath 99 func (d *DNSCollector) CollectDNSRequests(report *collector.DNSRequestReport) { 100 l.Lock() 101 r = *report 102 l.Unlock() 103 } 104 105 func TestDNS(t *testing.T) { 106 puIDcache := cache.NewCache("puFromContextID") 107 108 fp := &policy.PUInfo{ 109 Runtime: policy.NewPURuntimeWithDefaults(), 110 Policy: policy.NewPUPolicyWithDefaults(), 111 } 112 pu, _ := pucontext.NewPU("pu1", fp, 24*time.Hour) // nolint 113 114 addDNSNamePolicy(pu) 115 116 puIDcache.AddOrUpdate("pu1", pu) 117 conntrack := &flowClientDummy{} 118 collector := &DNSCollector{} 119 120 ips := provider.NewTestIpsetProvider() 121 proxy := New(puIDcache, conntrack, collector, ipsetmanager.CreateIPsetManager(ips, ips)) 122 123 err := proxy.StartDNSServer("pu1", "53001") 124 assert.Equal(t, err == nil, true, "start dns server") 125 126 resolver := createCustomResolver() 127 ctx := context.Background() 128 waitTimeBeforeReport = 3 * time.Second 129 resolver.LookupIPAddr(ctx, "www.google.com") //nolint 130 resolver.LookupIPAddr(ctx, "www.google.com") //nolint 131 132 assert.Equal(t, err == nil, true, "err should be nil") 133 134 time.Sleep(5 * time.Second) 135 l.Lock() 136 assert.Equal(t, r.NameLookup == "www.google.com.", true, "lookup should be www.google.com") 137 assert.Equal(t, r.Count >= 2 && r.Count <= 10, true, "count should be 2") 138 l.Unlock() 139 proxy.ShutdownDNS("pu1") 140 }