github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/datapath_windows.go (about)

     1  // +build windows
     2  
     3  package nfqdatapath
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"net"
     9  	"syscall"
    10  	"unsafe"
    11  
    12  	"github.com/ghedo/go.pkt/layers"
    13  	gpacket "github.com/ghedo/go.pkt/packet"
    14  	"github.com/ghedo/go.pkt/packet/ipv4"
    15  	"github.com/ghedo/go.pkt/packet/tcp"
    16  	"github.com/pkg/errors"
    17  	"go.aporeto.io/enforcerd/trireme-lib/controller/constants"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/nfqdatapath/afinetrawsocket"
    19  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/connection"
    20  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet"
    21  	"go.aporeto.io/enforcerd/trireme-lib/utils/frontman"
    22  	"go.uber.org/zap"
    23  	"golang.org/x/sys/windows"
    24  )
    25  
    26  func adjustConntrack(mode constants.ModeType) {
    27  }
    28  
    29  func (d *Datapath) reverseFlow(pkt *packet.Packet) error {
    30  	windata, ok := pkt.PlatformMetadata.(*afinetrawsocket.WindowPlatformMetadata)
    31  	if !ok {
    32  		return errors.New("no WindowPlatformMetadata for reverseFlow")
    33  	}
    34  
    35  	address := windata.PacketInfo.RemoteAddr
    36  	windata.PacketInfo.RemoteAddr = windata.PacketInfo.LocalAddr
    37  	windata.PacketInfo.LocalAddr = address
    38  
    39  	port := windata.PacketInfo.RemotePort
    40  	windata.PacketInfo.RemotePort = windata.PacketInfo.LocalPort
    41  	windata.PacketInfo.LocalPort = port
    42  
    43  	return nil
    44  }
    45  
    46  func (d *Datapath) drop(pkt *packet.Packet) error {
    47  	windata, ok := pkt.PlatformMetadata.(*afinetrawsocket.WindowPlatformMetadata)
    48  	if !ok {
    49  		return errors.New("no WindowPlatformMetadata for drop")
    50  	}
    51  	windata.Drop = true
    52  	return nil
    53  }
    54  
    55  func (d *Datapath) setMark(pkt *packet.Packet, mark uint32) error {
    56  	windata, ok := pkt.PlatformMetadata.(*afinetrawsocket.WindowPlatformMetadata)
    57  	if !ok {
    58  		return errors.New("no WindowPlatformMetadata for setMark")
    59  	}
    60  	windata.SetMark = mark
    61  	return nil
    62  }
    63  
    64  // ignoreFlow is for Windows, because we need a way to explicitly notify of an 'ignore flow' condition,
    65  // without going through flowtracking, to be called synchronously in datapath processing
    66  func (d *Datapath) ignoreFlow(pkt *packet.Packet) error {
    67  	windata, ok := pkt.PlatformMetadata.(*afinetrawsocket.WindowPlatformMetadata)
    68  	if !ok {
    69  		return errors.New("no WindowPlatformMetadata for ignoreFlow")
    70  	}
    71  	windata.IgnoreFlow = true
    72  	return nil
    73  }
    74  
    75  // dropFlow will tell the windows driver to continue to drop packets for this flow.
    76  func (d *Datapath) dropFlow(pkt *packet.Packet) error {
    77  	windata, ok := pkt.PlatformMetadata.(*afinetrawsocket.WindowPlatformMetadata)
    78  	if !ok {
    79  		return errors.New("no WindowPlatformMetadata for dropFlow")
    80  	}
    81  	windata.DropFlow = true
    82  	return nil
    83  }
    84  
    85  // setFlowState will not send the packet but will tell the Windows driver to either accept or drop the flow.
    86  func (d *Datapath) setFlowState(pkt *packet.Packet, accepted bool) error {
    87  	windata, ok := pkt.PlatformMetadata.(*afinetrawsocket.WindowPlatformMetadata)
    88  	if !ok {
    89  		return errors.New("no WindowPlatformMetadata for setFlowState")
    90  	}
    91  
    92  	buf := pkt.GetBuffer(0)
    93  	packetInfo := windata.PacketInfo
    94  	packetInfo.NewPacket = 1
    95  	packetInfo.Drop = 1
    96  	packetInfo.IgnoreFlow = 0
    97  	packetInfo.DropFlow = 0
    98  	if accepted {
    99  		packetInfo.IgnoreFlow = 1
   100  	} else {
   101  		packetInfo.DropFlow = 1
   102  	}
   103  	packetInfo.PacketSize = uint32(len(buf))
   104  	if err := frontman.Wrapper.PacketFilterForward(&packetInfo, buf); err != nil {
   105  		return err
   106  	}
   107  	return nil
   108  }
   109  
   110  func (d *Datapath) startInterceptors(ctx context.Context) {
   111  	err := d.startFrontmanPacketFilter(ctx, d.nflogger)
   112  	if err != nil {
   113  		zap.L().Fatal("Unable to initialize windows packet proxy", zap.Error(err))
   114  	}
   115  }
   116  
   117  type pingConn struct {
   118  	SourceIP   net.IP
   119  	DestIP     net.IP
   120  	SourcePort uint16
   121  	DestPort   uint16
   122  }
   123  
   124  func dialIP(srcIP, dstIP net.IP) (PingConn, error) {
   125  	return &pingConn{}, nil
   126  }
   127  
   128  // Close not implemented.
   129  func (p *pingConn) Close() error {
   130  	return nil
   131  }
   132  
   133  // Write sends the packet to network.
   134  func (p *pingConn) Write(data []byte) (int, error) {
   135  
   136  	ipv4 := uint8(0)
   137  	if len(p.SourceIP) == net.IPv4len {
   138  		ipv4 = 1
   139  	}
   140  
   141  	packetInfo := frontman.PacketInfo{
   142  		Ipv4:             ipv4,
   143  		Protocol:         windows.IPPROTO_TCP,
   144  		Outbound:         1,
   145  		NewPacket:        1,
   146  		NoPidMatchOnFlow: 1,
   147  		LocalPort:        p.SourcePort,
   148  		RemotePort:       p.DestPort,
   149  		LocalAddr:        convertToDriverFormat(p.SourceIP),
   150  		RemoteAddr:       convertToDriverFormat(p.DestIP),
   151  		PacketSize:       uint32(len(data)),
   152  	}
   153  
   154  	dllRet, err := frontman.Driver.PacketFilterForward(uintptr(unsafe.Pointer(&packetInfo)), uintptr(unsafe.Pointer(&data[0])))
   155  	if dllRet == 0 && err != nil {
   156  		return 0, err
   157  	}
   158  
   159  	return len(data), nil
   160  }
   161  
   162  // ConstructWirePacket returns IP packet with given TCP and payload in wire format.
   163  func (p *pingConn) ConstructWirePacket(srcIP, dstIP net.IP, transport gpacket.Packet, payload gpacket.Packet) ([]byte, error) {
   164  
   165  	ipPacket := ipv4.Make()
   166  	ipPacket.SrcAddr = srcIP
   167  	ipPacket.DstAddr = dstIP
   168  	ipPacket.Protocol = ipv4.TCP
   169  
   170  	// pack the layers together.
   171  	buf, err := layers.Pack(ipPacket, transport, payload)
   172  	if err != nil {
   173  		return nil, fmt.Errorf("unable to encode packet to wire format: %v", err)
   174  	}
   175  
   176  	tcpPacket := transport.(*tcp.Packet)
   177  
   178  	p.SourceIP = srcIP
   179  	p.DestIP = dstIP
   180  	p.SourcePort = tcpPacket.SrcPort
   181  	p.DestPort = tcpPacket.DstPort
   182  
   183  	return buf, nil
   184  }
   185  
   186  func bindRandomPort(tcpConn *connection.TCPConnection) (uint16, error) {
   187  
   188  	fd, err := windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
   189  	if err != nil {
   190  		return 0, fmt.Errorf("unable to open socket, fd: %d : %s", fd, err)
   191  	}
   192  
   193  	addr := windows.SockaddrInet4{Port: 0}
   194  	copy(addr.Addr[:], net.ParseIP("127.0.0.1").To4())
   195  	if err = windows.Bind(fd, &addr); err != nil {
   196  		windows.CloseHandle(fd) // nolint: errcheck
   197  		return 0, fmt.Errorf("unable to bind socket: %s", err)
   198  	}
   199  
   200  	sockAddr, err := windows.Getsockname(fd)
   201  	if err != nil {
   202  		windows.CloseHandle(fd) // nolint: errcheck
   203  		return 0, fmt.Errorf("unable to get socket address: %s", err)
   204  	}
   205  
   206  	ip4Addr, ok := sockAddr.(*windows.SockaddrInet4)
   207  	if !ok {
   208  		windows.CloseHandle(fd) // nolint: errcheck
   209  		return 0, fmt.Errorf("invalid socket address: %T", sockAddr)
   210  	}
   211  
   212  	tcpConn.PingConfig.SetSocketFd(uintptr(fd))
   213  	return uint16(ip4Addr.Port), nil
   214  }
   215  
   216  func closeRandomPort(tcpConn *connection.TCPConnection) error {
   217  
   218  	fd := tcpConn.PingConfig.SocketFd()
   219  	tcpConn.PingConfig.SetSocketClosed(true)
   220  
   221  	return windows.CloseHandle(windows.Handle(fd))
   222  }
   223  
   224  func convertToDriverFormat(ip net.IP) [4]uint32 {
   225  	var addr [4]uint32
   226  	byteAddr := (*[16]byte)(unsafe.Pointer(&addr))
   227  	copy(byteAddr[:], ip)
   228  	return addr
   229  }
   230  
   231  func isAddrInUseErrno(errNo syscall.Errno) bool {
   232  	return errNo == windows.WSAEADDRINUSE
   233  }