github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/nfq_windows_test.go (about) 1 // +build windows 2 3 package nfqdatapath 4 5 import ( 6 "context" 7 "encoding/hex" 8 "fmt" 9 "strconv" 10 "sync" 11 "syscall" 12 "testing" 13 "unsafe" 14 15 "github.com/golang/mock/gomock" 16 . "github.com/smartystreets/goconvey/convey" 17 "go.aporeto.io/enforcerd/trireme-lib/collector" 18 "go.aporeto.io/enforcerd/trireme-lib/controller/constants" 19 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/packetgen" 20 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet" 21 "go.aporeto.io/enforcerd/trireme-lib/policy" 22 "go.aporeto.io/enforcerd/trireme-lib/utils/frontman" 23 ) 24 25 // Declare function pointer so that it can be overridden by unit test. 26 // This is not actually needed in Windows, but we need the declaration and the empty function for tests. 27 var procSetValuePtr func(procName string, value int) error = procSetValueMock 28 29 type forwardedPacket struct { 30 outbound, drop, ignoreFlow bool 31 mark int 32 packetBytes []byte 33 } 34 35 // fakeWrapper is the mock for frontman.Wrapper. 36 // We mock frontman.Wrapper and not frontman.Driver because we need to save the go funcs passed to PacketFilterStart. 37 type fakeWrapper struct { 38 receiveCallback, loggingCallback func(uintptr, uintptr) uintptr 39 forwardedPackets []*forwardedPacket 40 sync.Mutex 41 } 42 43 func (w *fakeWrapper) queuePacket(p *forwardedPacket) { 44 w.Lock() 45 defer w.Unlock() 46 w.forwardedPackets = append(w.forwardedPackets, p) 47 } 48 49 func (w *fakeWrapper) GetForwardedPackets() []*forwardedPacket { 50 w.Lock() 51 defer w.Unlock() 52 result := w.forwardedPackets 53 w.forwardedPackets = nil 54 return result 55 } 56 57 func (w *fakeWrapper) PacketFilterStart(firewallName string, receiveCallback, loggingCallback func(uintptr, uintptr) uintptr) error { 58 w.receiveCallback = receiveCallback 59 w.loggingCallback = loggingCallback 60 return nil 61 } 62 63 func (w *fakeWrapper) PacketFilterForward(info *frontman.PacketInfo, packetBytes []byte) error { 64 p := &forwardedPacket{ 65 outbound: info.Outbound != 0, 66 drop: info.Drop != 0, 67 ignoreFlow: info.IgnoreFlow != 0, 68 mark: int(info.Mark), 69 packetBytes: make([]byte, info.PacketSize), 70 } 71 if n := copy(p.packetBytes, packetBytes); n != int(info.PacketSize) { 72 return fmt.Errorf("%d bytes copied for packet, but expected %d", n, info.PacketSize) 73 } 74 w.queuePacket(p) 75 return nil 76 } 77 78 func Test_WindowsPacketCallbacks(t *testing.T) { 79 80 // unused in Windows 81 _ = testDstIP 82 _ = debug 83 84 Convey("Given I create a new enforcer instance for Windows and have a valid processing unit context", t, func() { 85 86 wrapper := &fakeWrapper{} 87 frontman.Wrapper = wrapper 88 89 Convey("Given I create a two processing unit instances", func() { 90 91 ctrl := gomock.NewController(t) 92 defer ctrl.Finish() 93 94 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer) 95 96 err := enforcer.startFrontmanPacketFilter(context.Background(), enforcer.nflogger) 97 So(err, ShouldBeNil) 98 99 Convey("When I pass a syn packet through the enforcer", func() { 100 101 PacketFlow := packetgen.NewTemplateFlow() 102 _, err = PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 103 So(err, ShouldBeNil) 104 tcpPacketFromFlow, err := PacketFlow.GetFirstSynPacket().ToBytes() 105 So(err, ShouldBeNil) 106 mark := 12345 107 tcpPacket, err := packet.New(0, tcpPacketFromFlow, strconv.Itoa(mark), true) 108 if err == nil && tcpPacket != nil { 109 tcpPacket.UpdateIPv4Checksum() 110 tcpPacket.UpdateTCPChecksum() 111 } 112 So(err, ShouldBeNil) 113 So(tcpPacket.Mark, ShouldEqual, strconv.Itoa(mark)) 114 115 packetBytes := tcpPacket.GetTCPBytes() 116 packetInfo := &frontman.PacketInfo{ 117 Ipv4: 1, 118 Protocol: tcpPacket.IPProto(), 119 PacketSize: uint32(len(packetBytes)), 120 Mark: uint32(mark), 121 } 122 if tcpPacket.SourceAddress().String() == testSrcIP { 123 packetInfo.Outbound = 1 124 } 125 ret := wrapper.receiveCallback(uintptr(unsafe.Pointer(packetInfo)), uintptr(unsafe.Pointer(&packetBytes[0]))) 126 So(ret, ShouldBeZeroValue) 127 128 oldPacket := tcpPacket 129 forwardedPackets := wrapper.GetForwardedPackets() 130 So(forwardedPackets, ShouldHaveLength, 1) 131 tcpPacket, err = packet.New(0, forwardedPackets[0].packetBytes, strconv.Itoa(mark), true) 132 So(err, ShouldBeNil) 133 134 // In our 3 way security handshake syn and syn-ack packet should grow in length 135 So(tcpPacket.GetTCPFlags()&packet.TCPSynMask, ShouldNotBeZeroValue) 136 So(tcpPacket.IPTotalLen(), ShouldBeGreaterThan, oldPacket.IPTotalLen()) 137 138 // reverse it and strip identity 139 packetInfo.Outbound ^= 1 140 packetBytes = tcpPacket.GetTCPBytes() 141 packetInfo.PacketSize = uint32(len(packetBytes)) 142 ret = wrapper.receiveCallback(uintptr(unsafe.Pointer(packetInfo)), uintptr(unsafe.Pointer(&packetBytes[0]))) 143 So(ret, ShouldBeZeroValue) 144 forwardedPackets = wrapper.GetForwardedPackets() 145 So(forwardedPackets, ShouldHaveLength, 1) 146 tcpPacket, err = packet.New(0, forwardedPackets[0].packetBytes, strconv.Itoa(mark), true) 147 So(err, ShouldBeNil) 148 So(tcpPacket.IPTotalLen(), ShouldEqual, oldPacket.IPTotalLen()) 149 }) 150 151 Convey("When I pass a synack packet for non-PU traffic", func() { 152 153 PacketFlow := packetgen.NewTemplateFlow() 154 _, err = PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 155 So(err, ShouldBeNil) 156 tcpPacketFromFlow, err := PacketFlow.GetFirstSynAckPacket().ToBytes() 157 So(err, ShouldBeNil) 158 mark := 12345 159 tcpPacket, err := packet.New(0, tcpPacketFromFlow, strconv.Itoa(mark), true) 160 if err == nil && tcpPacket != nil { 161 tcpPacket.UpdateIPv4Checksum() 162 tcpPacket.UpdateTCPChecksum() 163 } 164 So(err, ShouldBeNil) 165 So(tcpPacket.Mark, ShouldEqual, strconv.Itoa(mark)) 166 167 packetBytes := tcpPacket.GetTCPBytes() 168 packetInfo := &frontman.PacketInfo{ 169 Ipv4: 1, 170 Protocol: tcpPacket.IPProto(), 171 PacketSize: uint32(len(packetBytes)), 172 Mark: uint32(mark), 173 } 174 if tcpPacket.SourceAddress().String() == testSrcIP { 175 packetInfo.Outbound = 1 176 } 177 ret := wrapper.receiveCallback(uintptr(unsafe.Pointer(packetInfo)), uintptr(unsafe.Pointer(&packetBytes[0]))) 178 So(ret, ShouldBeZeroValue) 179 180 forwardedPackets := wrapper.GetForwardedPackets() 181 So(forwardedPackets, ShouldHaveLength, 1) 182 tcpPacket, err = packet.New(0, forwardedPackets[0].packetBytes, strconv.Itoa(mark), true) 183 So(err, ShouldBeNil) 184 So(tcpPacket, ShouldNotBeNil) 185 // IgnoreFlow flag should be set 186 So(forwardedPackets[0].ignoreFlow, ShouldNotBeZeroValue) 187 }) 188 189 Convey("When I say to log that a packet is rejected", func() { 190 191 puHash, err := policy.Fnv32Hash("SomeProcessingUnitId1") 192 So(err, ShouldBeNil) 193 194 dnsRequestPacket, err := hex.DecodeString("450000380542000080110000c0a8446dc0a84401ebe60035002409f5df510100000100000000000006676f6f676c6503636f6d0000010001") 195 So(err, ShouldBeNil) 196 dnsPacket, err := packet.New(0, dnsRequestPacket, "0", true) 197 So(err, ShouldBeNil) 198 199 packetHeaderBytes := dnsPacket.GetBuffer(0)[:dnsPacket.IPHeaderLen()+packet.UDPDataPos] 200 logPacketInfo := &frontman.LogPacketInfo{ 201 Ipv4: 1, 202 Protocol: dnsPacket.IPProto(), 203 PacketSize: uint32(len(packetHeaderBytes)), 204 GroupID: 11, 205 } 206 207 copy(logPacketInfo.LogPrefix[:], syscall.StringToUTF16(puHash+":5d6044b9e99572000149d650:5d60448a884e46000145cf67:6")) // nolint:staticcheck 208 209 flowRecord := CreateFlowRecord(1, "192.168.68.109", "192.168.68.1", 0, 53, policy.Reject|policy.Log, collector.PolicyDrop) 210 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 211 212 ret := wrapper.loggingCallback(uintptr(unsafe.Pointer(logPacketInfo)), uintptr(unsafe.Pointer(&packetHeaderBytes[0]))) 213 So(ret, ShouldBeZeroValue) 214 }) 215 }) 216 }) 217 } 218 219 // Empty interface implementations 220 221 func (w *fakeWrapper) GetDestInfo(socket uintptr, destInfo *frontman.DestInfo) error { 222 return nil 223 } 224 225 func (w *fakeWrapper) ApplyDestHandle(socket, destHandle uintptr) error { 226 return nil 227 } 228 229 func (w *fakeWrapper) FreeDestHandle(destHandle uintptr) error { 230 return nil 231 } 232 233 func (w *fakeWrapper) NewIpset(name, ipsetType string) (uintptr, error) { 234 return 1, nil 235 } 236 237 func (w *fakeWrapper) GetIpset(name string) (uintptr, error) { 238 return 1, nil 239 } 240 241 func (w *fakeWrapper) DestroyAllIpsets(prefix string) error { 242 return nil 243 } 244 245 func (w *fakeWrapper) ListIpsets() ([]string, error) { 246 return nil, nil 247 } 248 249 func (w *fakeWrapper) ListIpsetsDetail(format int) (string, error) { 250 return "", nil 251 } 252 253 func (w *fakeWrapper) IpsetAdd(ipsetHandle uintptr, entry string, timeout int) error { 254 return nil 255 } 256 257 func (w *fakeWrapper) IpsetAddOption(ipsetHandle uintptr, entry, option string, timeout int) error { 258 return nil 259 } 260 261 func (w *fakeWrapper) IpsetDelete(ipsetHandle uintptr, entry string) error { 262 return nil 263 } 264 265 func (w *fakeWrapper) IpsetDestroy(ipsetHandle uintptr, name string) error { 266 return nil 267 } 268 269 func (w *fakeWrapper) IpsetFlush(ipsetHandle uintptr) error { 270 return nil 271 } 272 273 func (w *fakeWrapper) IpsetTest(ipsetHandle uintptr, entry string) (bool, error) { 274 return true, nil 275 } 276 277 func (w *fakeWrapper) AppendFilter(outbound bool, filterName string, isGotoFilter bool) error { 278 return nil 279 } 280 281 func (w *fakeWrapper) InsertFilter(outbound bool, priority int, filterName string, isGotoFilter bool) error { 282 return nil 283 } 284 285 func (w *fakeWrapper) DestroyFilter(filterName string) error { 286 return nil 287 } 288 289 func (w *fakeWrapper) EmptyFilter(filterName string) error { 290 return nil 291 } 292 293 func (w *fakeWrapper) GetFilterList(outbound bool) ([]string, error) { 294 return nil, nil 295 } 296 297 func (w *fakeWrapper) AppendFilterCriteria(filterName, criteriaName string, ruleSpec *frontman.RuleSpec, ipsetRuleSpecs []frontman.IpsetRuleSpec) error { 298 return nil 299 } 300 301 func (w *fakeWrapper) DeleteFilterCriteria(filterName, criteriaName string) error { 302 return nil 303 } 304 305 func (w *fakeWrapper) GetCriteriaList(format int) (string, error) { 306 return "", nil 307 } 308 309 func (w *fakeWrapper) PacketFilterClose() error { 310 return nil 311 }