github.com/slackhq/nebula@v1.9.0/udp/udp_rio_windows.go (about)

     1  //go:build !e2e_testing
     2  // +build !e2e_testing
     3  
     4  // Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go
     5  
     6  package udp
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"sync"
    14  	"sync/atomic"
    15  	"syscall"
    16  	"unsafe"
    17  
    18  	"github.com/sirupsen/logrus"
    19  	"github.com/slackhq/nebula/config"
    20  	"github.com/slackhq/nebula/firewall"
    21  	"github.com/slackhq/nebula/header"
    22  
    23  	"golang.org/x/sys/windows"
    24  	"golang.zx2c4.com/wireguard/conn/winrio"
    25  )
    26  
    27  // Assert we meet the standard conn interface
    28  var _ Conn = &RIOConn{}
    29  
    30  //go:linkname procyield runtime.procyield
    31  func procyield(cycles uint32)
    32  
    33  const (
    34  	packetsPerRing = 1024
    35  	bytesPerPacket = 2048 - 32
    36  	receiveSpins   = 15
    37  )
    38  
    39  type ringPacket struct {
    40  	addr windows.RawSockaddrInet6
    41  	data [bytesPerPacket]byte
    42  }
    43  
    44  type ringBuffer struct {
    45  	packets    uintptr
    46  	head, tail uint32
    47  	id         winrio.BufferId
    48  	iocp       windows.Handle
    49  	isFull     bool
    50  	cq         winrio.Cq
    51  	mu         sync.Mutex
    52  	overlapped windows.Overlapped
    53  }
    54  
    55  type RIOConn struct {
    56  	isOpen  atomic.Bool
    57  	l       *logrus.Logger
    58  	sock    windows.Handle
    59  	rx, tx  ringBuffer
    60  	rq      winrio.Rq
    61  	results [packetsPerRing]winrio.Result
    62  }
    63  
    64  func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
    65  	if !winrio.Initialize() {
    66  		return nil, errors.New("could not initialize winrio")
    67  	}
    68  
    69  	u := &RIOConn{l: l}
    70  
    71  	addr := [16]byte{}
    72  	copy(addr[:], ip.To16())
    73  	err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
    74  	if err != nil {
    75  		return nil, fmt.Errorf("bind: %w", err)
    76  	}
    77  
    78  	for i := 0; i < packetsPerRing; i++ {
    79  		err = u.insertReceiveRequest()
    80  		if err != nil {
    81  			return nil, fmt.Errorf("init rx ring: %w", err)
    82  		}
    83  	}
    84  
    85  	u.isOpen.Store(true)
    86  	return u, nil
    87  }
    88  
    89  func (u *RIOConn) bind(sa windows.Sockaddr) error {
    90  	var err error
    91  	u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
    92  	if err != nil {
    93  		return err
    94  	}
    95  
    96  	// Enable v4 for this socket
    97  	syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
    98  
    99  	err = u.rx.Open()
   100  	if err != nil {
   101  		return err
   102  	}
   103  
   104  	err = u.tx.Open()
   105  	if err != nil {
   106  		return err
   107  	}
   108  
   109  	u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0)
   110  	if err != nil {
   111  		return err
   112  	}
   113  
   114  	err = windows.Bind(u.sock, sa)
   115  	if err != nil {
   116  		return err
   117  	}
   118  
   119  	return nil
   120  }
   121  
   122  func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
   123  	plaintext := make([]byte, MTU)
   124  	buffer := make([]byte, MTU)
   125  	h := &header.H{}
   126  	fwPacket := &firewall.Packet{}
   127  	udpAddr := &Addr{IP: make([]byte, 16)}
   128  	nb := make([]byte, 12, 12)
   129  
   130  	for {
   131  		// Just read one packet at a time
   132  		n, rua, err := u.receive(buffer)
   133  		if err != nil {
   134  			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
   135  			return
   136  		}
   137  
   138  		udpAddr.IP = rua.Addr[:]
   139  		p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
   140  		p[0] = byte(rua.Port >> 8)
   141  		p[1] = byte(rua.Port)
   142  		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
   143  	}
   144  }
   145  
   146  func (u *RIOConn) insertReceiveRequest() error {
   147  	packet := u.rx.Push()
   148  	dataBuffer := &winrio.Buffer{
   149  		Id:     u.rx.id,
   150  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets),
   151  		Length: uint32(len(packet.data)),
   152  	}
   153  	addressBuffer := &winrio.Buffer{
   154  		Id:     u.rx.id,
   155  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets),
   156  		Length: uint32(unsafe.Sizeof(packet.addr)),
   157  	}
   158  
   159  	return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
   160  }
   161  
   162  func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) {
   163  	if !u.isOpen.Load() {
   164  		return 0, windows.RawSockaddrInet6{}, net.ErrClosed
   165  	}
   166  
   167  	u.rx.mu.Lock()
   168  	defer u.rx.mu.Unlock()
   169  
   170  	var err error
   171  	var count uint32
   172  	var results [1]winrio.Result
   173  
   174  retry:
   175  	count = 0
   176  	for tries := 0; count == 0 && tries < receiveSpins; tries++ {
   177  		if tries > 0 {
   178  			if !u.isOpen.Load() {
   179  				return 0, windows.RawSockaddrInet6{}, net.ErrClosed
   180  			}
   181  			procyield(1)
   182  		}
   183  
   184  		count = winrio.DequeueCompletion(u.rx.cq, results[:])
   185  	}
   186  
   187  	if count == 0 {
   188  		err = winrio.Notify(u.rx.cq)
   189  		if err != nil {
   190  			return 0, windows.RawSockaddrInet6{}, err
   191  		}
   192  		var bytes uint32
   193  		var key uintptr
   194  		var overlapped *windows.Overlapped
   195  		err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
   196  		if err != nil {
   197  			return 0, windows.RawSockaddrInet6{}, err
   198  		}
   199  
   200  		if !u.isOpen.Load() {
   201  			return 0, windows.RawSockaddrInet6{}, net.ErrClosed
   202  		}
   203  
   204  		count = winrio.DequeueCompletion(u.rx.cq, results[:])
   205  		if count == 0 {
   206  			return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress
   207  
   208  		}
   209  	}
   210  
   211  	u.rx.Return(1)
   212  	err = u.insertReceiveRequest()
   213  	if err != nil {
   214  		return 0, windows.RawSockaddrInet6{}, err
   215  	}
   216  
   217  	// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
   218  	// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
   219  	// attacker bandwidth, just like the rest of the receive path.
   220  	if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
   221  		goto retry
   222  	}
   223  
   224  	if results[0].Status != 0 {
   225  		return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status)
   226  	}
   227  
   228  	packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
   229  	ep := packet.addr
   230  	n := copy(buf, packet.data[:results[0].BytesTransferred])
   231  	return n, ep, nil
   232  }
   233  
   234  func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
   235  	if !u.isOpen.Load() {
   236  		return net.ErrClosed
   237  	}
   238  
   239  	if len(buf) > bytesPerPacket {
   240  		return io.ErrShortBuffer
   241  	}
   242  
   243  	u.tx.mu.Lock()
   244  	defer u.tx.mu.Unlock()
   245  
   246  	count := winrio.DequeueCompletion(u.tx.cq, u.results[:])
   247  	if count == 0 && u.tx.isFull {
   248  		err := winrio.Notify(u.tx.cq)
   249  		if err != nil {
   250  			return err
   251  		}
   252  
   253  		var bytes uint32
   254  		var key uintptr
   255  		var overlapped *windows.Overlapped
   256  		err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
   257  		if err != nil {
   258  			return err
   259  		}
   260  
   261  		if !u.isOpen.Load() {
   262  			return net.ErrClosed
   263  		}
   264  
   265  		count = winrio.DequeueCompletion(u.tx.cq, u.results[:])
   266  		if count == 0 {
   267  			return io.ErrNoProgress
   268  		}
   269  	}
   270  
   271  	if count > 0 {
   272  		u.tx.Return(count)
   273  	}
   274  
   275  	packet := u.tx.Push()
   276  	packet.addr.Family = windows.AF_INET6
   277  	p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
   278  	p[0] = byte(addr.Port >> 8)
   279  	p[1] = byte(addr.Port)
   280  	copy(packet.addr.Addr[:], addr.IP.To16())
   281  	copy(packet.data[:], buf)
   282  
   283  	dataBuffer := &winrio.Buffer{
   284  		Id:     u.tx.id,
   285  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets),
   286  		Length: uint32(len(buf)),
   287  	}
   288  
   289  	addressBuffer := &winrio.Buffer{
   290  		Id:     u.tx.id,
   291  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets),
   292  		Length: uint32(unsafe.Sizeof(packet.addr)),
   293  	}
   294  
   295  	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
   296  }
   297  
   298  func (u *RIOConn) LocalAddr() (*Addr, error) {
   299  	sa, err := windows.Getsockname(u.sock)
   300  	if err != nil {
   301  		return nil, err
   302  	}
   303  
   304  	v6 := sa.(*windows.SockaddrInet6)
   305  	return &Addr{
   306  		IP:   v6.Addr[:],
   307  		Port: uint16(v6.Port),
   308  	}, nil
   309  }
   310  
   311  func (u *RIOConn) Rebind() error {
   312  	return nil
   313  }
   314  
   315  func (u *RIOConn) ReloadConfig(*config.C) {}
   316  
   317  func (u *RIOConn) Close() error {
   318  	if !u.isOpen.CompareAndSwap(true, false) {
   319  		return nil
   320  	}
   321  
   322  	windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil)
   323  	windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil)
   324  
   325  	u.rx.CloseAndZero()
   326  	u.tx.CloseAndZero()
   327  	if u.sock != 0 {
   328  		windows.CloseHandle(u.sock)
   329  	}
   330  	return nil
   331  }
   332  
   333  func (ring *ringBuffer) Push() *ringPacket {
   334  	for ring.isFull {
   335  		panic("ring is full")
   336  	}
   337  	ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
   338  	ring.tail += 1
   339  	if ring.tail%packetsPerRing == ring.head%packetsPerRing {
   340  		ring.isFull = true
   341  	}
   342  	return ret
   343  }
   344  
   345  func (ring *ringBuffer) Return(count uint32) {
   346  	if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull {
   347  		return
   348  	}
   349  	ring.head += count
   350  	ring.isFull = false
   351  }
   352  
   353  func (ring *ringBuffer) CloseAndZero() {
   354  	if ring.cq != 0 {
   355  		winrio.CloseCompletionQueue(ring.cq)
   356  		ring.cq = 0
   357  	}
   358  
   359  	if ring.iocp != 0 {
   360  		windows.CloseHandle(ring.iocp)
   361  		ring.iocp = 0
   362  	}
   363  
   364  	if ring.id != 0 {
   365  		winrio.DeregisterBuffer(ring.id)
   366  		ring.id = 0
   367  	}
   368  
   369  	if ring.packets != 0 {
   370  		windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
   371  		ring.packets = 0
   372  	}
   373  
   374  	ring.head = 0
   375  	ring.tail = 0
   376  	ring.isFull = false
   377  }
   378  
   379  func (ring *ringBuffer) Open() error {
   380  	var err error
   381  	packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
   382  	ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
   383  	if err != nil {
   384  		return err
   385  	}
   386  
   387  	ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
   388  	if err != nil {
   389  		return err
   390  	}
   391  
   392  	ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
   393  	if err != nil {
   394  		return err
   395  	}
   396  
   397  	ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
   398  	if err != nil {
   399  		return err
   400  	}
   401  
   402  	return nil
   403  }