github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/nfq_windows.go (about) 1 // +build windows 2 3 package nfqdatapath 4 5 import ( 6 "context" 7 "fmt" 8 "strconv" 9 "unsafe" 10 11 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/nfqdatapath/afinetrawsocket" 12 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/nfqdatapath/nflog" 13 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/connection" 14 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet" 15 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext" 16 "go.aporeto.io/enforcerd/trireme-lib/utils/frontman" 17 "go.uber.org/zap" 18 ) 19 20 func (d *Datapath) startFrontmanPacketFilter(_ context.Context, nflogger nflog.NFLogger) error { 21 22 nflogWin := nflogger.(*nflog.NfLogWindows) 23 24 packetCallback := func(packetInfoPtr, dataPtr uintptr) uintptr { 25 26 packetInfo := *(*frontman.PacketInfo)(unsafe.Pointer(packetInfoPtr)) // nolint:govet 27 packetBytes := (*[1 << 30]byte)(unsafe.Pointer(dataPtr))[:packetInfo.PacketSize:packetInfo.PacketSize] // nolint:govet 28 29 var packetType int 30 if packetInfo.Outbound != 0 { 31 packetType = packet.PacketTypeApplication 32 } else { 33 packetType = packet.PacketTypeNetwork 34 } 35 36 // Parse the packet 37 mark := int(packetInfo.Mark) 38 parsedPacket, err := packet.New(uint64(packetType), packetBytes, strconv.Itoa(mark), true) 39 40 if parsedPacket.IPProto() == packet.IPProtocolUDP && parsedPacket.SourcePort() == 53 { 41 // notify PUs of DNS results 42 err := d.dnsProxy.HandleDNSResponsePacket(parsedPacket.GetUDPData(), parsedPacket.SourceAddress(), parsedPacket.SourcePort(), parsedPacket.DestinationAddress(), parsedPacket.DestPort(), func(id string) (*pucontext.PUContext, error) { 43 puCtx, err1 := d.puFromContextID.Get(id) 44 if err1 != nil { 45 return nil, err1 46 } 47 return puCtx.(*pucontext.PUContext), nil 48 }) 49 if err != nil { 50 zap.L().Debug("Failed to handle DNS response", zap.Error(err)) 51 } 52 // forward packet 53 err = frontman.Wrapper.PacketFilterForward(&packetInfo, packetBytes) 54 if err != nil { 55 zap.L().Error("failed to forward packet", zap.Error(err)) 56 } 57 return 0 58 } 59 60 parsedPacket.PlatformMetadata = &afinetrawsocket.WindowPlatformMetadata{ 61 PacketInfo: packetInfo, 62 IgnoreFlow: false, 63 Drop: false, 64 SetMark: 0, 65 } 66 67 var processError error 68 var tcpConn *connection.TCPConnection 69 var udpConn *connection.UDPConnection 70 var f func() 71 72 if err != nil { 73 parsedPacket.Print(packet.PacketFailureCreate, d.packetLogs) 74 } else if parsedPacket.IPProto() == packet.IPProtocolTCP { 75 if packetType == packet.PacketTypeNetwork { 76 tcpConn, f, processError = d.processNetworkTCPPackets(parsedPacket) 77 if f != nil { 78 f() 79 } 80 } else { 81 tcpConn, processError = d.processApplicationTCPPackets(parsedPacket) 82 } 83 } else if parsedPacket.IPProto() == packet.IPProtocolUDP { 84 // process udp packet 85 if packetType == packet.PacketTypeNetwork { 86 udpConn, processError = d.ProcessNetworkUDPPacket(parsedPacket) 87 } else { 88 udpConn, processError = d.ProcessApplicationUDPPacket(parsedPacket) 89 } 90 } else { 91 processError = fmt.Errorf("invalid ip protocol: %d", parsedPacket.IPProto()) 92 } 93 94 if processError != nil { 95 if parsedPacket.IPProto() == packet.IPProtocolTCP { 96 d.collectTCPPacket(&debugpacketmessage{ 97 Mark: mark, 98 p: parsedPacket, 99 tcpConn: tcpConn, 100 udpConn: nil, 101 err: processError, 102 network: packetType == packet.PacketTypeNetwork, 103 }) 104 } else if parsedPacket.IPProto() == packet.IPProtocolUDP { 105 d.collectUDPPacket(&debugpacketmessage{ 106 Mark: mark, 107 p: parsedPacket, 108 tcpConn: nil, 109 udpConn: udpConn, 110 err: processError, 111 network: packetType == packet.PacketTypeNetwork, 112 }) 113 } 114 // drop packet by not forwarding it 115 return 0 116 } 117 118 // accept the (modified) packet by forwarding it 119 modifiedPacketBytes := parsedPacket.GetBuffer(0) 120 packetInfo.PacketSize = uint32(parsedPacket.IPTotalLen()) 121 122 platformMetadata := parsedPacket.PlatformMetadata.(*afinetrawsocket.WindowPlatformMetadata) 123 if platformMetadata.IgnoreFlow { 124 packetInfo.IgnoreFlow = 1 125 } else if platformMetadata.DropFlow { 126 packetInfo.DropFlow = 1 127 } 128 if platformMetadata.Drop { 129 packetInfo.Drop = 1 130 } 131 if platformMetadata.SetMark != 0 { 132 packetInfo.SetMark = 1 133 packetInfo.SetMarkValue = platformMetadata.SetMark 134 } 135 136 if err := frontman.Wrapper.PacketFilterForward(&packetInfo, modifiedPacketBytes); err != nil { 137 zap.L().Error("failed to forward packet", zap.Error(err)) 138 } 139 140 if parsedPacket.IPProto() == packet.IPProtocolTCP { 141 d.collectTCPPacket(&debugpacketmessage{ 142 Mark: mark, 143 p: parsedPacket, 144 tcpConn: tcpConn, 145 udpConn: nil, 146 err: nil, 147 network: packetType == packet.PacketTypeNetwork, 148 }) 149 } else if parsedPacket.IPProto() == packet.IPProtocolUDP { 150 d.collectUDPPacket(&debugpacketmessage{ 151 Mark: mark, 152 p: parsedPacket, 153 tcpConn: nil, 154 udpConn: udpConn, 155 err: nil, 156 network: packetType == packet.PacketTypeNetwork, 157 }) 158 } 159 160 return 0 161 } 162 163 logCallback := func(logPacketInfoPtr, dataPtr uintptr) uintptr { 164 165 logPacketInfo := *(*frontman.LogPacketInfo)(unsafe.Pointer(logPacketInfoPtr)) // nolint:govet 166 packetHeaderBytes := (*[1 << 30]byte)(unsafe.Pointer(dataPtr))[:logPacketInfo.PacketSize:logPacketInfo.PacketSize] 167 168 err := nflogWin.NfLogHandler(&logPacketInfo, packetHeaderBytes) 169 if err != nil { 170 zap.L().Error("error in log callback", zap.Error(err)) 171 } 172 173 return 0 174 } 175 176 if err := frontman.Wrapper.PacketFilterStart("Aporeto Enforcer", packetCallback, logCallback); err != nil { 177 return err 178 } 179 180 return nil 181 } 182 183 // cleanupPlatform for windows is needed to stop the frontman threads and permit the enforcerd app to shut down 184 func (d *Datapath) cleanupPlatform() { 185 186 if err := frontman.Wrapper.PacketFilterClose(); err != nil { 187 zap.L().Error("Failed to close packet proxy", zap.Error(err)) 188 } 189 190 }