github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns_windows_test.go (about) 1 // +build windows 2 3 package dnsproxy 4 5 import ( 6 "context" 7 "encoding/hex" 8 "errors" 9 "net" 10 "sync" 11 "testing" 12 "time" 13 14 "github.com/magiconair/properties/assert" 15 "go.aporeto.io/enforcerd/trireme-lib/collector" 16 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/ipsetmanager" 17 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet" 18 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext" 19 "go.aporeto.io/enforcerd/trireme-lib/policy" 20 "go.aporeto.io/enforcerd/trireme-lib/utils/cache" 21 ) 22 23 func addDNSNamePolicy(context *pucontext.PUContext) { 24 context.DNSACLs = policy.DNSRuleList{ 25 "google.com.": []policy.PortProtocolPolicy{ 26 {Ports: []string{"80"}, 27 Protocols: []string{"6"}, 28 Policy: &policy.FlowPolicy{ 29 Action: policy.Accept, 30 PolicyID: "2", 31 }}, 32 }, 33 } 34 } 35 36 // DNSCollector implements a default collector infrastructure to syslog 37 type DNSCollector struct{} 38 39 // CollectFlowEvent is part of the EventCollector interface. 40 func (d *DNSCollector) CollectFlowEvent(record *collector.FlowRecord) {} 41 42 // CollectContainerEvent is part of the EventCollector interface. 43 func (d *DNSCollector) CollectContainerEvent(record *collector.ContainerRecord) {} 44 45 // CollectUserEvent is part of the EventCollector interface. 46 func (d *DNSCollector) CollectUserEvent(record *collector.UserRecord) {} 47 48 // CollectTraceEvent collects iptables trace events 49 func (d *DNSCollector) CollectTraceEvent(records []string) {} 50 51 // CollectPacketEvent collects packet events from the datapath 52 func (d *DNSCollector) CollectPacketEvent(report *collector.PacketReport) {} 53 54 // CollectCounterEvent collect counters from the datapath 55 func (d *DNSCollector) CollectCounterEvent(report *collector.CounterReport) {} 56 57 // CollectPingEvent collects ping events from the datapath 58 func (d *DNSCollector) CollectPingEvent(report *collector.PingReport) {} 59 60 // CollectConnectionExceptionReport collects the connection exception report 61 func (d *DNSCollector) CollectConnectionExceptionReport(_ *collector.ConnectionExceptionReport) {} 62 63 var r collector.DNSRequestReport 64 var l sync.Mutex 65 66 // CollectDNSRequests collect counters from the datapath 67 func (d *DNSCollector) CollectDNSRequests(report *collector.DNSRequestReport) { 68 l.Lock() 69 r = *report 70 l.Unlock() 71 } 72 73 const ( 74 dnsResponseHex1 = "45200048d22f00006a11ad8f08080808c0a8000e0035e7560034385a00088180000100010000000006676f6f676c6503636f6d0000010001c00c000100010000009d0004acd90d0e" 75 dnsResponseHex2 = "45200054eb6700006a11944b08080808c0a8000e0035e7570040863400098180000100010000000006676f6f676c6503636f6d00001c0001c00c001c00010000012b00102607f8b040020c030000000000000066" 76 ) 77 78 func TestDNS(t *testing.T) { 79 puIDcache := cache.NewCache("puFromContextID") 80 81 dnsResponsePacket1, _ := hex.DecodeString(dnsResponseHex1) 82 dnsResponsePacket2, _ := hex.DecodeString(dnsResponseHex2) 83 84 parsedPacket1, _ := packet.New(uint64(packet.PacketTypeNetwork), dnsResponsePacket1, "83", true) 85 parsedPacket2, _ := packet.New(uint64(packet.PacketTypeNetwork), dnsResponsePacket2, "83", true) 86 87 fp := &policy.PUInfo{ 88 Runtime: policy.NewPURuntimeWithDefaults(), 89 Policy: policy.NewPUPolicyWithDefaults(), 90 } 91 pu, _ := pucontext.NewPU("pu1", fp, nil, 24*time.Hour) // nolint 92 93 findPU := func(id string) (*pucontext.PUContext, error) { 94 if id == "pu1" { 95 return pu, nil 96 } 97 return nil, errors.New("unknown PU") 98 } 99 100 addDNSNamePolicy(pu) 101 102 puIDcache.AddOrUpdate("pu1", pu) 103 collector := &DNSCollector{} 104 105 ips := ipsetmanager.NewTestIpsetProvider() 106 ipsetmanager.SetIpsetTestInstance(ips) 107 proxy := New(context.Background(), puIDcache, nil, collector) 108 109 err := proxy.StartDNSServer(context.Background(), "pu1", "53001") 110 assert.Equal(t, err == nil, true, "start dns server") 111 112 err = proxy.HandleDNSResponsePacket(parsedPacket1.GetUDPData(), parsedPacket1.SourceAddress(), parsedPacket1.SourcePort(), parsedPacket1.DestinationAddress(), parsedPacket1.DestPort(), findPU) 113 assert.Equal(t, err == nil, true, "dns packet 1 failed") 114 115 err = proxy.HandleDNSResponsePacket(parsedPacket2.GetUDPData(), parsedPacket2.SourceAddress(), parsedPacket2.SourcePort(), parsedPacket2.DestinationAddress(), parsedPacket2.DestPort(), findPU) 116 assert.Equal(t, err == nil, true, "dns packet 2 failed") 117 118 // wait a sec for report delivered via channel, and then expect one report since the next will be time-delayed 119 time.Sleep(1 * time.Second) 120 l.Lock() 121 assert.Equal(t, r.NameLookup == "google.com.", true, "lookup should be google.com") 122 assert.Equal(t, r.Count == 1, true, "count should be 1") 123 l.Unlock() 124 125 defaultFlowPolicy := &policy.FlowPolicy{Action: policy.Reject | policy.Log, PolicyID: "default", ServiceID: "default"} 126 127 // test acls updated 128 rpt, pkt, err := pu.ApplicationACLs.GetMatchingAction(net.ParseIP("172.217.13.14"), 80, packet.IPProtocolTCP, defaultFlowPolicy) 129 assert.Equal(t, err == nil, true, "GetMatchingAction failed") 130 assert.Equal(t, rpt.Action.Accepted(), true, "should be accepted (report)") 131 assert.Equal(t, pkt.Action.Accepted(), true, "should be accepted (packet)") 132 rpt, pkt, err = pu.ApplicationACLs.GetMatchingAction(net.ParseIP("2607:f8b0:4002:c03::66"), 80, packet.IPProtocolTCP, defaultFlowPolicy) 133 assert.Equal(t, err == nil, true, "GetMatchingAction failed") 134 assert.Equal(t, rpt.Action.Accepted(), true, "should be accepted (report)") 135 assert.Equal(t, pkt.Action.Accepted(), true, "should be accepted (packet)") 136 137 // test SyncWithPlatformCache 138 clearWindowsDNSCacheFunc = func() error { 139 return errors.New("error from unit test") 140 } 141 defer func() { 142 clearWindowsDNSCacheFunc = clearWindowsDNSCache 143 }() 144 err = proxy.SyncWithPlatformCache(context.Background(), pu) 145 assert.Equal(t, err != nil, true, "clearWindowsDNSCache not called with DNSACLs present") 146 assert.Matches(t, err.Error(), "error from unit test") 147 pu.DNSACLs = policy.DNSRuleList{} 148 err = proxy.SyncWithPlatformCache(context.Background(), pu) 149 assert.Equal(t, err == nil, true, "clearWindowsDNSCache called without DNSACLs present") 150 151 proxy.ShutdownDNS("pu1") 152 }