github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/dnsproxy/dns_windows_test.go (about)

     1  // +build windows
     2  
     3  package dnsproxy
     4  
     5  import (
     6  	"context"
     7  	"encoding/hex"
     8  	"errors"
     9  	"net"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/magiconair/properties/assert"
    15  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    16  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/ipsetmanager"
    17  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext"
    19  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    20  	"go.aporeto.io/enforcerd/trireme-lib/utils/cache"
    21  )
    22  
    23  func addDNSNamePolicy(context *pucontext.PUContext) {
    24  	context.DNSACLs = policy.DNSRuleList{
    25  		"google.com.": []policy.PortProtocolPolicy{
    26  			{Ports: []string{"80"},
    27  				Protocols: []string{"6"},
    28  				Policy: &policy.FlowPolicy{
    29  					Action:   policy.Accept,
    30  					PolicyID: "2",
    31  				}},
    32  		},
    33  	}
    34  }
    35  
    36  // DNSCollector implements a default collector infrastructure to syslog
    37  type DNSCollector struct{}
    38  
    39  // CollectFlowEvent is part of the EventCollector interface.
    40  func (d *DNSCollector) CollectFlowEvent(record *collector.FlowRecord) {}
    41  
    42  // CollectContainerEvent is part of the EventCollector interface.
    43  func (d *DNSCollector) CollectContainerEvent(record *collector.ContainerRecord) {}
    44  
    45  // CollectUserEvent is part of the EventCollector interface.
    46  func (d *DNSCollector) CollectUserEvent(record *collector.UserRecord) {}
    47  
    48  // CollectTraceEvent collects iptables trace events
    49  func (d *DNSCollector) CollectTraceEvent(records []string) {}
    50  
    51  // CollectPacketEvent collects packet events from the datapath
    52  func (d *DNSCollector) CollectPacketEvent(report *collector.PacketReport) {}
    53  
    54  // CollectCounterEvent collect counters from the datapath
    55  func (d *DNSCollector) CollectCounterEvent(report *collector.CounterReport) {}
    56  
    57  // CollectPingEvent collects ping events from the datapath
    58  func (d *DNSCollector) CollectPingEvent(report *collector.PingReport) {}
    59  
    60  // CollectConnectionExceptionReport collects the connection exception report
    61  func (d *DNSCollector) CollectConnectionExceptionReport(_ *collector.ConnectionExceptionReport) {}
    62  
    63  var r collector.DNSRequestReport
    64  var l sync.Mutex
    65  
    66  // CollectDNSRequests collect counters from the datapath
    67  func (d *DNSCollector) CollectDNSRequests(report *collector.DNSRequestReport) {
    68  	l.Lock()
    69  	r = *report
    70  	l.Unlock()
    71  }
    72  
    73  const (
    74  	dnsResponseHex1 = "45200048d22f00006a11ad8f08080808c0a8000e0035e7560034385a00088180000100010000000006676f6f676c6503636f6d0000010001c00c000100010000009d0004acd90d0e"
    75  	dnsResponseHex2 = "45200054eb6700006a11944b08080808c0a8000e0035e7570040863400098180000100010000000006676f6f676c6503636f6d00001c0001c00c001c00010000012b00102607f8b040020c030000000000000066"
    76  )
    77  
    78  func TestDNS(t *testing.T) {
    79  	puIDcache := cache.NewCache("puFromContextID")
    80  
    81  	dnsResponsePacket1, _ := hex.DecodeString(dnsResponseHex1)
    82  	dnsResponsePacket2, _ := hex.DecodeString(dnsResponseHex2)
    83  
    84  	parsedPacket1, _ := packet.New(uint64(packet.PacketTypeNetwork), dnsResponsePacket1, "83", true)
    85  	parsedPacket2, _ := packet.New(uint64(packet.PacketTypeNetwork), dnsResponsePacket2, "83", true)
    86  
    87  	fp := &policy.PUInfo{
    88  		Runtime: policy.NewPURuntimeWithDefaults(),
    89  		Policy:  policy.NewPUPolicyWithDefaults(),
    90  	}
    91  	pu, _ := pucontext.NewPU("pu1", fp, nil, 24*time.Hour) // nolint
    92  
    93  	findPU := func(id string) (*pucontext.PUContext, error) {
    94  		if id == "pu1" {
    95  			return pu, nil
    96  		}
    97  		return nil, errors.New("unknown PU")
    98  	}
    99  
   100  	addDNSNamePolicy(pu)
   101  
   102  	puIDcache.AddOrUpdate("pu1", pu)
   103  	collector := &DNSCollector{}
   104  
   105  	ips := ipsetmanager.NewTestIpsetProvider()
   106  	ipsetmanager.SetIpsetTestInstance(ips)
   107  	proxy := New(context.Background(), puIDcache, nil, collector)
   108  
   109  	err := proxy.StartDNSServer(context.Background(), "pu1", "53001")
   110  	assert.Equal(t, err == nil, true, "start dns server")
   111  
   112  	err = proxy.HandleDNSResponsePacket(parsedPacket1.GetUDPData(), parsedPacket1.SourceAddress(), parsedPacket1.SourcePort(), parsedPacket1.DestinationAddress(), parsedPacket1.DestPort(), findPU)
   113  	assert.Equal(t, err == nil, true, "dns packet 1 failed")
   114  
   115  	err = proxy.HandleDNSResponsePacket(parsedPacket2.GetUDPData(), parsedPacket2.SourceAddress(), parsedPacket2.SourcePort(), parsedPacket2.DestinationAddress(), parsedPacket2.DestPort(), findPU)
   116  	assert.Equal(t, err == nil, true, "dns packet 2 failed")
   117  
   118  	// wait a sec for report delivered via channel, and then expect one report since the next will be time-delayed
   119  	time.Sleep(1 * time.Second)
   120  	l.Lock()
   121  	assert.Equal(t, r.NameLookup == "google.com.", true, "lookup should be google.com")
   122  	assert.Equal(t, r.Count == 1, true, "count should be 1")
   123  	l.Unlock()
   124  
   125  	defaultFlowPolicy := &policy.FlowPolicy{Action: policy.Reject | policy.Log, PolicyID: "default", ServiceID: "default"}
   126  
   127  	// test acls updated
   128  	rpt, pkt, err := pu.ApplicationACLs.GetMatchingAction(net.ParseIP("172.217.13.14"), 80, packet.IPProtocolTCP, defaultFlowPolicy)
   129  	assert.Equal(t, err == nil, true, "GetMatchingAction failed")
   130  	assert.Equal(t, rpt.Action.Accepted(), true, "should be accepted (report)")
   131  	assert.Equal(t, pkt.Action.Accepted(), true, "should be accepted (packet)")
   132  	rpt, pkt, err = pu.ApplicationACLs.GetMatchingAction(net.ParseIP("2607:f8b0:4002:c03::66"), 80, packet.IPProtocolTCP, defaultFlowPolicy)
   133  	assert.Equal(t, err == nil, true, "GetMatchingAction failed")
   134  	assert.Equal(t, rpt.Action.Accepted(), true, "should be accepted (report)")
   135  	assert.Equal(t, pkt.Action.Accepted(), true, "should be accepted (packet)")
   136  
   137  	// test SyncWithPlatformCache
   138  	clearWindowsDNSCacheFunc = func() error {
   139  		return errors.New("error from unit test")
   140  	}
   141  	defer func() {
   142  		clearWindowsDNSCacheFunc = clearWindowsDNSCache
   143  	}()
   144  	err = proxy.SyncWithPlatformCache(context.Background(), pu)
   145  	assert.Equal(t, err != nil, true, "clearWindowsDNSCache not called with DNSACLs present")
   146  	assert.Matches(t, err.Error(), "error from unit test")
   147  	pu.DNSACLs = policy.DNSRuleList{}
   148  	err = proxy.SyncWithPlatformCache(context.Background(), pu)
   149  	assert.Equal(t, err == nil, true, "clearWindowsDNSCache called without DNSACLs present")
   150  
   151  	proxy.ShutdownDNS("pu1")
   152  }