github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/nfq_windows_test.go (about)

     1  // +build windows
     2  
     3  package nfqdatapath
     4  
     5  import (
     6  	"context"
     7  	"encoding/hex"
     8  	"fmt"
     9  	"strconv"
    10  	"sync"
    11  	"syscall"
    12  	"testing"
    13  	"unsafe"
    14  
    15  	"github.com/golang/mock/gomock"
    16  	. "github.com/smartystreets/goconvey/convey"
    17  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/constants"
    19  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/packetgen"
    20  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet"
    21  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    22  	"go.aporeto.io/enforcerd/trireme-lib/utils/frontman"
    23  )
    24  
    25  // Declare function pointer so that it can be overridden by unit test.
    26  // This is not actually needed in Windows, but we need the declaration and the empty function for tests.
    27  var procSetValuePtr func(procName string, value int) error = procSetValueMock
    28  
    29  type forwardedPacket struct {
    30  	outbound, drop, ignoreFlow bool
    31  	mark                       int
    32  	packetBytes                []byte
    33  }
    34  
    35  // fakeWrapper is the mock for frontman.Wrapper.
    36  // We mock frontman.Wrapper and not frontman.Driver because we need to save the go funcs passed to PacketFilterStart.
    37  type fakeWrapper struct {
    38  	receiveCallback, loggingCallback func(uintptr, uintptr) uintptr
    39  	forwardedPackets                 []*forwardedPacket
    40  	sync.Mutex
    41  }
    42  
    43  func (w *fakeWrapper) queuePacket(p *forwardedPacket) {
    44  	w.Lock()
    45  	defer w.Unlock()
    46  	w.forwardedPackets = append(w.forwardedPackets, p)
    47  }
    48  
    49  func (w *fakeWrapper) GetForwardedPackets() []*forwardedPacket {
    50  	w.Lock()
    51  	defer w.Unlock()
    52  	result := w.forwardedPackets
    53  	w.forwardedPackets = nil
    54  	return result
    55  }
    56  
    57  func (w *fakeWrapper) PacketFilterStart(firewallName string, receiveCallback, loggingCallback func(uintptr, uintptr) uintptr) error {
    58  	w.receiveCallback = receiveCallback
    59  	w.loggingCallback = loggingCallback
    60  	return nil
    61  }
    62  
    63  func (w *fakeWrapper) PacketFilterForward(info *frontman.PacketInfo, packetBytes []byte) error {
    64  	p := &forwardedPacket{
    65  		outbound:    info.Outbound != 0,
    66  		drop:        info.Drop != 0,
    67  		ignoreFlow:  info.IgnoreFlow != 0,
    68  		mark:        int(info.Mark),
    69  		packetBytes: make([]byte, info.PacketSize),
    70  	}
    71  	if n := copy(p.packetBytes, packetBytes); n != int(info.PacketSize) {
    72  		return fmt.Errorf("%d bytes copied for packet, but expected %d", n, info.PacketSize)
    73  	}
    74  	w.queuePacket(p)
    75  	return nil
    76  }
    77  
    78  func Test_WindowsPacketCallbacks(t *testing.T) {
    79  
    80  	// unused in Windows
    81  	_ = testDstIP
    82  	_ = debug
    83  
    84  	Convey("Given I create a new enforcer instance for Windows and have a valid processing unit context", t, func() {
    85  
    86  		wrapper := &fakeWrapper{}
    87  		frontman.Wrapper = wrapper
    88  
    89  		Convey("Given I create a two processing unit instances", func() {
    90  
    91  			ctrl := gomock.NewController(t)
    92  			defer ctrl.Finish()
    93  
    94  			enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer)
    95  
    96  			err := enforcer.startFrontmanPacketFilter(context.Background(), enforcer.nflogger)
    97  			So(err, ShouldBeNil)
    98  
    99  			Convey("When I pass a syn packet through the enforcer", func() {
   100  
   101  				PacketFlow := packetgen.NewTemplateFlow()
   102  				_, err = PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   103  				So(err, ShouldBeNil)
   104  				tcpPacketFromFlow, err := PacketFlow.GetFirstSynPacket().ToBytes()
   105  				So(err, ShouldBeNil)
   106  				mark := 12345
   107  				tcpPacket, err := packet.New(0, tcpPacketFromFlow, strconv.Itoa(mark), true)
   108  				if err == nil && tcpPacket != nil {
   109  					tcpPacket.UpdateIPv4Checksum()
   110  					tcpPacket.UpdateTCPChecksum()
   111  				}
   112  				So(err, ShouldBeNil)
   113  				So(tcpPacket.Mark, ShouldEqual, strconv.Itoa(mark))
   114  
   115  				packetBytes := tcpPacket.GetTCPBytes()
   116  				packetInfo := &frontman.PacketInfo{
   117  					Ipv4:       1,
   118  					Protocol:   tcpPacket.IPProto(),
   119  					PacketSize: uint32(len(packetBytes)),
   120  					Mark:       uint32(mark),
   121  				}
   122  				if tcpPacket.SourceAddress().String() == testSrcIP {
   123  					packetInfo.Outbound = 1
   124  				}
   125  				ret := wrapper.receiveCallback(uintptr(unsafe.Pointer(packetInfo)), uintptr(unsafe.Pointer(&packetBytes[0])))
   126  				So(ret, ShouldBeZeroValue)
   127  
   128  				oldPacket := tcpPacket
   129  				forwardedPackets := wrapper.GetForwardedPackets()
   130  				So(forwardedPackets, ShouldHaveLength, 1)
   131  				tcpPacket, err = packet.New(0, forwardedPackets[0].packetBytes, strconv.Itoa(mark), true)
   132  				So(err, ShouldBeNil)
   133  
   134  				// In our 3 way security handshake syn and syn-ack packet should grow in length
   135  				So(tcpPacket.GetTCPFlags()&packet.TCPSynMask, ShouldNotBeZeroValue)
   136  				So(tcpPacket.IPTotalLen(), ShouldBeGreaterThan, oldPacket.IPTotalLen())
   137  
   138  				// reverse it and strip identity
   139  				packetInfo.Outbound ^= 1
   140  				packetBytes = tcpPacket.GetTCPBytes()
   141  				packetInfo.PacketSize = uint32(len(packetBytes))
   142  				ret = wrapper.receiveCallback(uintptr(unsafe.Pointer(packetInfo)), uintptr(unsafe.Pointer(&packetBytes[0])))
   143  				So(ret, ShouldBeZeroValue)
   144  				forwardedPackets = wrapper.GetForwardedPackets()
   145  				So(forwardedPackets, ShouldHaveLength, 1)
   146  				tcpPacket, err = packet.New(0, forwardedPackets[0].packetBytes, strconv.Itoa(mark), true)
   147  				So(err, ShouldBeNil)
   148  				So(tcpPacket.IPTotalLen(), ShouldEqual, oldPacket.IPTotalLen())
   149  			})
   150  
   151  			Convey("When I pass a synack packet for non-PU traffic", func() {
   152  
   153  				PacketFlow := packetgen.NewTemplateFlow()
   154  				_, err = PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   155  				So(err, ShouldBeNil)
   156  				tcpPacketFromFlow, err := PacketFlow.GetFirstSynAckPacket().ToBytes()
   157  				So(err, ShouldBeNil)
   158  				mark := 12345
   159  				tcpPacket, err := packet.New(0, tcpPacketFromFlow, strconv.Itoa(mark), true)
   160  				if err == nil && tcpPacket != nil {
   161  					tcpPacket.UpdateIPv4Checksum()
   162  					tcpPacket.UpdateTCPChecksum()
   163  				}
   164  				So(err, ShouldBeNil)
   165  				So(tcpPacket.Mark, ShouldEqual, strconv.Itoa(mark))
   166  
   167  				packetBytes := tcpPacket.GetTCPBytes()
   168  				packetInfo := &frontman.PacketInfo{
   169  					Ipv4:       1,
   170  					Protocol:   tcpPacket.IPProto(),
   171  					PacketSize: uint32(len(packetBytes)),
   172  					Mark:       uint32(mark),
   173  				}
   174  				if tcpPacket.SourceAddress().String() == testSrcIP {
   175  					packetInfo.Outbound = 1
   176  				}
   177  				ret := wrapper.receiveCallback(uintptr(unsafe.Pointer(packetInfo)), uintptr(unsafe.Pointer(&packetBytes[0])))
   178  				So(ret, ShouldBeZeroValue)
   179  
   180  				forwardedPackets := wrapper.GetForwardedPackets()
   181  				So(forwardedPackets, ShouldHaveLength, 1)
   182  				tcpPacket, err = packet.New(0, forwardedPackets[0].packetBytes, strconv.Itoa(mark), true)
   183  				So(err, ShouldBeNil)
   184  				So(tcpPacket, ShouldNotBeNil)
   185  				// IgnoreFlow flag should be set
   186  				So(forwardedPackets[0].ignoreFlow, ShouldNotBeZeroValue)
   187  			})
   188  
   189  			Convey("When I say to log that a packet is rejected", func() {
   190  
   191  				puHash, err := policy.Fnv32Hash("SomeProcessingUnitId1")
   192  				So(err, ShouldBeNil)
   193  
   194  				dnsRequestPacket, err := hex.DecodeString("450000380542000080110000c0a8446dc0a84401ebe60035002409f5df510100000100000000000006676f6f676c6503636f6d0000010001")
   195  				So(err, ShouldBeNil)
   196  				dnsPacket, err := packet.New(0, dnsRequestPacket, "0", true)
   197  				So(err, ShouldBeNil)
   198  
   199  				packetHeaderBytes := dnsPacket.GetBuffer(0)[:dnsPacket.IPHeaderLen()+packet.UDPDataPos]
   200  				logPacketInfo := &frontman.LogPacketInfo{
   201  					Ipv4:       1,
   202  					Protocol:   dnsPacket.IPProto(),
   203  					PacketSize: uint32(len(packetHeaderBytes)),
   204  					GroupID:    11,
   205  				}
   206  
   207  				copy(logPacketInfo.LogPrefix[:], syscall.StringToUTF16(puHash+":5d6044b9e99572000149d650:5d60448a884e46000145cf67:6")) // nolint:staticcheck
   208  
   209  				flowRecord := CreateFlowRecord(1, "192.168.68.109", "192.168.68.1", 0, 53, policy.Reject|policy.Log, collector.PolicyDrop)
   210  				mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
   211  
   212  				ret := wrapper.loggingCallback(uintptr(unsafe.Pointer(logPacketInfo)), uintptr(unsafe.Pointer(&packetHeaderBytes[0])))
   213  				So(ret, ShouldBeZeroValue)
   214  			})
   215  		})
   216  	})
   217  }
   218  
   219  // Empty interface implementations
   220  
   221  func (w *fakeWrapper) GetDestInfo(socket uintptr, destInfo *frontman.DestInfo) error {
   222  	return nil
   223  }
   224  
   225  func (w *fakeWrapper) ApplyDestHandle(socket, destHandle uintptr) error {
   226  	return nil
   227  }
   228  
   229  func (w *fakeWrapper) FreeDestHandle(destHandle uintptr) error {
   230  	return nil
   231  }
   232  
   233  func (w *fakeWrapper) NewIpset(name, ipsetType string) (uintptr, error) {
   234  	return 1, nil
   235  }
   236  
   237  func (w *fakeWrapper) GetIpset(name string) (uintptr, error) {
   238  	return 1, nil
   239  }
   240  
   241  func (w *fakeWrapper) DestroyAllIpsets(prefix string) error {
   242  	return nil
   243  }
   244  
   245  func (w *fakeWrapper) ListIpsets() ([]string, error) {
   246  	return nil, nil
   247  }
   248  
   249  func (w *fakeWrapper) ListIpsetsDetail(format int) (string, error) {
   250  	return "", nil
   251  }
   252  
   253  func (w *fakeWrapper) IpsetAdd(ipsetHandle uintptr, entry string, timeout int) error {
   254  	return nil
   255  }
   256  
   257  func (w *fakeWrapper) IpsetAddOption(ipsetHandle uintptr, entry, option string, timeout int) error {
   258  	return nil
   259  }
   260  
   261  func (w *fakeWrapper) IpsetDelete(ipsetHandle uintptr, entry string) error {
   262  	return nil
   263  }
   264  
   265  func (w *fakeWrapper) IpsetDestroy(ipsetHandle uintptr, name string) error {
   266  	return nil
   267  }
   268  
   269  func (w *fakeWrapper) IpsetFlush(ipsetHandle uintptr) error {
   270  	return nil
   271  }
   272  
   273  func (w *fakeWrapper) IpsetTest(ipsetHandle uintptr, entry string) (bool, error) {
   274  	return true, nil
   275  }
   276  
   277  func (w *fakeWrapper) AppendFilter(outbound bool, filterName string, isGotoFilter bool) error {
   278  	return nil
   279  }
   280  
   281  func (w *fakeWrapper) InsertFilter(outbound bool, priority int, filterName string, isGotoFilter bool) error {
   282  	return nil
   283  }
   284  
   285  func (w *fakeWrapper) DestroyFilter(filterName string) error {
   286  	return nil
   287  }
   288  
   289  func (w *fakeWrapper) EmptyFilter(filterName string) error {
   290  	return nil
   291  }
   292  
   293  func (w *fakeWrapper) GetFilterList(outbound bool) ([]string, error) {
   294  	return nil, nil
   295  }
   296  
   297  func (w *fakeWrapper) AppendFilterCriteria(filterName, criteriaName string, ruleSpec *frontman.RuleSpec, ipsetRuleSpecs []frontman.IpsetRuleSpec) error {
   298  	return nil
   299  }
   300  
   301  func (w *fakeWrapper) DeleteFilterCriteria(filterName, criteriaName string) error {
   302  	return nil
   303  }
   304  
   305  func (w *fakeWrapper) GetCriteriaList(format int) (string, error) {
   306  	return "", nil
   307  }
   308  
   309  func (w *fakeWrapper) PacketFilterClose() error {
   310  	return nil
   311  }