github.com/amnezia-vpn/amneziawg-go@v0.2.8/conn/bind_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"encoding/binary"
    10  	"io"
    11  	"net"
    12  	"net/netip"
    13  	"strconv"
    14  	"sync"
    15  	"sync/atomic"
    16  	"unsafe"
    17  
    18  	"golang.org/x/sys/windows"
    19  
    20  	"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
    21  )
    22  
    23  const (
    24  	packetsPerRing = 1024
    25  	bytesPerPacket = 2048 - 32
    26  	receiveSpins   = 15
    27  )
    28  
    29  type ringPacket struct {
    30  	addr WinRingEndpoint
    31  	data [bytesPerPacket]byte
    32  }
    33  
    34  type ringBuffer struct {
    35  	packets    uintptr
    36  	head, tail uint32
    37  	id         winrio.BufferId
    38  	iocp       windows.Handle
    39  	isFull     bool
    40  	cq         winrio.Cq
    41  	mu         sync.Mutex
    42  	overlapped windows.Overlapped
    43  }
    44  
    45  func (rb *ringBuffer) Push() *ringPacket {
    46  	for rb.isFull {
    47  		panic("ring is full")
    48  	}
    49  	ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
    50  	rb.tail += 1
    51  	if rb.tail%packetsPerRing == rb.head%packetsPerRing {
    52  		rb.isFull = true
    53  	}
    54  	return ret
    55  }
    56  
    57  func (rb *ringBuffer) Return(count uint32) {
    58  	if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
    59  		return
    60  	}
    61  	rb.head += count
    62  	rb.isFull = false
    63  }
    64  
    65  type afWinRingBind struct {
    66  	sock      windows.Handle
    67  	rx, tx    ringBuffer
    68  	rq        winrio.Rq
    69  	mu        sync.Mutex
    70  	blackhole bool
    71  }
    72  
    73  // WinRingBind uses Windows registered I/O for fast ring buffered networking.
    74  type WinRingBind struct {
    75  	v4, v6 afWinRingBind
    76  	mu     sync.RWMutex
    77  	isOpen atomic.Uint32 // 0, 1, or 2
    78  }
    79  
    80  func NewDefaultBind() Bind { return NewWinRingBind() }
    81  
    82  func NewWinRingBind() Bind {
    83  	if !winrio.Initialize() {
    84  		return NewStdNetBind()
    85  	}
    86  	return new(WinRingBind)
    87  }
    88  
    89  type WinRingEndpoint struct {
    90  	family uint16
    91  	data   [30]byte
    92  }
    93  
    94  var (
    95  	_ Bind     = (*WinRingBind)(nil)
    96  	_ Endpoint = (*WinRingEndpoint)(nil)
    97  )
    98  
    99  func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
   100  	host, port, err := net.SplitHostPort(s)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  	host16, err := windows.UTF16PtrFromString(host)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	port16, err := windows.UTF16PtrFromString(port)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	hints := windows.AddrinfoW{
   113  		Flags:    windows.AI_NUMERICHOST,
   114  		Family:   windows.AF_UNSPEC,
   115  		Socktype: windows.SOCK_DGRAM,
   116  		Protocol: windows.IPPROTO_UDP,
   117  	}
   118  	var addrinfo *windows.AddrinfoW
   119  	err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	defer windows.FreeAddrInfoW(addrinfo)
   124  	if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
   125  		return nil, windows.ERROR_INVALID_ADDRESS
   126  	}
   127  	var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
   128  	copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
   129  	return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
   130  }
   131  
   132  func (*WinRingEndpoint) ClearSrc() {}
   133  
   134  func (e *WinRingEndpoint) DstIP() netip.Addr {
   135  	switch e.family {
   136  	case windows.AF_INET:
   137  		return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
   138  	case windows.AF_INET6:
   139  		return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
   140  	}
   141  	return netip.Addr{}
   142  }
   143  
   144  func (e *WinRingEndpoint) SrcIP() netip.Addr {
   145  	return netip.Addr{} // not supported
   146  }
   147  
   148  func (e *WinRingEndpoint) DstToBytes() []byte {
   149  	switch e.family {
   150  	case windows.AF_INET:
   151  		b := make([]byte, 0, 6)
   152  		b = append(b, e.data[2:6]...)
   153  		b = append(b, e.data[1], e.data[0])
   154  		return b
   155  	case windows.AF_INET6:
   156  		b := make([]byte, 0, 18)
   157  		b = append(b, e.data[6:22]...)
   158  		b = append(b, e.data[1], e.data[0])
   159  		return b
   160  	}
   161  	return nil
   162  }
   163  
   164  func (e *WinRingEndpoint) DstToString() string {
   165  	switch e.family {
   166  	case windows.AF_INET:
   167  		return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
   168  	case windows.AF_INET6:
   169  		var zone string
   170  		if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
   171  			zone = strconv.FormatUint(uint64(scope), 10)
   172  		}
   173  		return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
   174  	}
   175  	return ""
   176  }
   177  
   178  func (e *WinRingEndpoint) SrcToString() string {
   179  	return ""
   180  }
   181  
   182  func (ring *ringBuffer) CloseAndZero() {
   183  	if ring.cq != 0 {
   184  		winrio.CloseCompletionQueue(ring.cq)
   185  		ring.cq = 0
   186  	}
   187  	if ring.iocp != 0 {
   188  		windows.CloseHandle(ring.iocp)
   189  		ring.iocp = 0
   190  	}
   191  	if ring.id != 0 {
   192  		winrio.DeregisterBuffer(ring.id)
   193  		ring.id = 0
   194  	}
   195  	if ring.packets != 0 {
   196  		windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
   197  		ring.packets = 0
   198  	}
   199  	ring.head = 0
   200  	ring.tail = 0
   201  	ring.isFull = false
   202  }
   203  
   204  func (bind *afWinRingBind) CloseAndZero() {
   205  	bind.rx.CloseAndZero()
   206  	bind.tx.CloseAndZero()
   207  	if bind.sock != 0 {
   208  		windows.CloseHandle(bind.sock)
   209  		bind.sock = 0
   210  	}
   211  	bind.blackhole = false
   212  }
   213  
   214  func (bind *WinRingBind) closeAndZero() {
   215  	bind.isOpen.Store(0)
   216  	bind.v4.CloseAndZero()
   217  	bind.v6.CloseAndZero()
   218  }
   219  
   220  func (ring *ringBuffer) Open() error {
   221  	var err error
   222  	packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
   223  	ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
   224  	if err != nil {
   225  		return err
   226  	}
   227  	ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
   228  	if err != nil {
   229  		return err
   230  	}
   231  	ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
   232  	if err != nil {
   233  		return err
   234  	}
   235  	ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
   236  	if err != nil {
   237  		return err
   238  	}
   239  	return nil
   240  }
   241  
   242  func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
   243  	var err error
   244  	bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  	err = bind.rx.Open()
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  	err = bind.tx.Open()
   253  	if err != nil {
   254  		return nil, err
   255  	}
   256  	bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  	err = windows.Bind(bind.sock, sa)
   261  	if err != nil {
   262  		return nil, err
   263  	}
   264  	sa, err = windows.Getsockname(bind.sock)
   265  	if err != nil {
   266  		return nil, err
   267  	}
   268  	return sa, nil
   269  }
   270  
   271  func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
   272  	bind.mu.Lock()
   273  	defer bind.mu.Unlock()
   274  	defer func() {
   275  		if err != nil {
   276  			bind.closeAndZero()
   277  		}
   278  	}()
   279  	if bind.isOpen.Load() != 0 {
   280  		return nil, 0, ErrBindAlreadyOpen
   281  	}
   282  	var sa windows.Sockaddr
   283  	sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
   284  	if err != nil {
   285  		return nil, 0, err
   286  	}
   287  	sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
   288  	if err != nil {
   289  		return nil, 0, err
   290  	}
   291  	selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
   292  	for i := 0; i < packetsPerRing; i++ {
   293  		err = bind.v4.InsertReceiveRequest()
   294  		if err != nil {
   295  			return nil, 0, err
   296  		}
   297  		err = bind.v6.InsertReceiveRequest()
   298  		if err != nil {
   299  			return nil, 0, err
   300  		}
   301  	}
   302  	bind.isOpen.Store(1)
   303  	return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
   304  }
   305  
   306  func (bind *WinRingBind) Close() error {
   307  	bind.mu.RLock()
   308  	if bind.isOpen.Load() != 1 {
   309  		bind.mu.RUnlock()
   310  		return nil
   311  	}
   312  	bind.isOpen.Store(2)
   313  	windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
   314  	windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
   315  	windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
   316  	windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
   317  	bind.mu.RUnlock()
   318  	bind.mu.Lock()
   319  	defer bind.mu.Unlock()
   320  	bind.closeAndZero()
   321  	return nil
   322  }
   323  
   324  // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
   325  // rename the IdealBatchSize constant to BatchSize.
   326  func (bind *WinRingBind) BatchSize() int {
   327  	// TODO: implement batching in and out of the ring
   328  	return 1
   329  }
   330  
   331  func (bind *WinRingBind) GetOffloadInfo() string {
   332  	return ""
   333  }
   334  
   335  func (bind *WinRingBind) SetMark(mark uint32) error {
   336  	return nil
   337  }
   338  
   339  func (bind *afWinRingBind) InsertReceiveRequest() error {
   340  	packet := bind.rx.Push()
   341  	dataBuffer := &winrio.Buffer{
   342  		Id:     bind.rx.id,
   343  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
   344  		Length: uint32(len(packet.data)),
   345  	}
   346  	addressBuffer := &winrio.Buffer{
   347  		Id:     bind.rx.id,
   348  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
   349  		Length: uint32(unsafe.Sizeof(packet.addr)),
   350  	}
   351  	bind.mu.Lock()
   352  	defer bind.mu.Unlock()
   353  	return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
   354  }
   355  
   356  //go:linkname procyield runtime.procyield
   357  func procyield(cycles uint32)
   358  
   359  func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
   360  	if isOpen.Load() != 1 {
   361  		return 0, nil, net.ErrClosed
   362  	}
   363  	bind.rx.mu.Lock()
   364  	defer bind.rx.mu.Unlock()
   365  
   366  	var err error
   367  	var count uint32
   368  	var results [1]winrio.Result
   369  retry:
   370  	count = 0
   371  	for tries := 0; count == 0 && tries < receiveSpins; tries++ {
   372  		if tries > 0 {
   373  			if isOpen.Load() != 1 {
   374  				return 0, nil, net.ErrClosed
   375  			}
   376  			procyield(1)
   377  		}
   378  		count = winrio.DequeueCompletion(bind.rx.cq, results[:])
   379  	}
   380  	if count == 0 {
   381  		err = winrio.Notify(bind.rx.cq)
   382  		if err != nil {
   383  			return 0, nil, err
   384  		}
   385  		var bytes uint32
   386  		var key uintptr
   387  		var overlapped *windows.Overlapped
   388  		err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
   389  		if err != nil {
   390  			return 0, nil, err
   391  		}
   392  		if isOpen.Load() != 1 {
   393  			return 0, nil, net.ErrClosed
   394  		}
   395  		count = winrio.DequeueCompletion(bind.rx.cq, results[:])
   396  		if count == 0 {
   397  			return 0, nil, io.ErrNoProgress
   398  		}
   399  	}
   400  	bind.rx.Return(1)
   401  	err = bind.InsertReceiveRequest()
   402  	if err != nil {
   403  		return 0, nil, err
   404  	}
   405  	// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
   406  	// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
   407  	// attacker bandwidth, just like the rest of the receive path.
   408  	if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
   409  		if isOpen.Load() != 1 {
   410  			return 0, nil, net.ErrClosed
   411  		}
   412  		goto retry
   413  	}
   414  	if results[0].Status != 0 {
   415  		return 0, nil, windows.Errno(results[0].Status)
   416  	}
   417  	packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
   418  	ep := packet.addr
   419  	n := copy(buf, packet.data[:results[0].BytesTransferred])
   420  	return n, &ep, nil
   421  }
   422  
   423  func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
   424  	bind.mu.RLock()
   425  	defer bind.mu.RUnlock()
   426  	n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
   427  	sizes[0] = n
   428  	eps[0] = ep
   429  	return 1, err
   430  }
   431  
   432  func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
   433  	bind.mu.RLock()
   434  	defer bind.mu.RUnlock()
   435  	n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
   436  	sizes[0] = n
   437  	eps[0] = ep
   438  	return 1, err
   439  }
   440  
   441  func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
   442  	if isOpen.Load() != 1 {
   443  		return net.ErrClosed
   444  	}
   445  	if len(buf) > bytesPerPacket {
   446  		return io.ErrShortBuffer
   447  	}
   448  	bind.tx.mu.Lock()
   449  	defer bind.tx.mu.Unlock()
   450  	var results [packetsPerRing]winrio.Result
   451  	count := winrio.DequeueCompletion(bind.tx.cq, results[:])
   452  	if count == 0 && bind.tx.isFull {
   453  		err := winrio.Notify(bind.tx.cq)
   454  		if err != nil {
   455  			return err
   456  		}
   457  		var bytes uint32
   458  		var key uintptr
   459  		var overlapped *windows.Overlapped
   460  		err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
   461  		if err != nil {
   462  			return err
   463  		}
   464  		if isOpen.Load() != 1 {
   465  			return net.ErrClosed
   466  		}
   467  		count = winrio.DequeueCompletion(bind.tx.cq, results[:])
   468  		if count == 0 {
   469  			return io.ErrNoProgress
   470  		}
   471  	}
   472  	if count > 0 {
   473  		bind.tx.Return(count)
   474  	}
   475  	packet := bind.tx.Push()
   476  	packet.addr = *nend
   477  	copy(packet.data[:], buf)
   478  	dataBuffer := &winrio.Buffer{
   479  		Id:     bind.tx.id,
   480  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
   481  		Length: uint32(len(buf)),
   482  	}
   483  	addressBuffer := &winrio.Buffer{
   484  		Id:     bind.tx.id,
   485  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
   486  		Length: uint32(unsafe.Sizeof(packet.addr)),
   487  	}
   488  	bind.mu.Lock()
   489  	defer bind.mu.Unlock()
   490  	return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
   491  }
   492  
   493  func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
   494  	nend, ok := endpoint.(*WinRingEndpoint)
   495  	if !ok {
   496  		return ErrWrongEndpointType
   497  	}
   498  	bind.mu.RLock()
   499  	defer bind.mu.RUnlock()
   500  	for _, buf := range bufs {
   501  		switch nend.family {
   502  		case windows.AF_INET:
   503  			if bind.v4.blackhole {
   504  				continue
   505  			}
   506  			if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
   507  				return err
   508  			}
   509  		case windows.AF_INET6:
   510  			if bind.v6.blackhole {
   511  				continue
   512  			}
   513  			if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
   514  				return err
   515  			}
   516  		}
   517  	}
   518  	return nil
   519  }
   520  
   521  func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
   522  	s.mu.Lock()
   523  	defer s.mu.Unlock()
   524  	sysconn, err := s.ipv4.SyscallConn()
   525  	if err != nil {
   526  		return err
   527  	}
   528  	err2 := sysconn.Control(func(fd uintptr) {
   529  		err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
   530  	})
   531  	if err2 != nil {
   532  		return err2
   533  	}
   534  	if err != nil {
   535  		return err
   536  	}
   537  	s.blackhole4 = blackhole
   538  	return nil
   539  }
   540  
   541  func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
   542  	s.mu.Lock()
   543  	defer s.mu.Unlock()
   544  	sysconn, err := s.ipv6.SyscallConn()
   545  	if err != nil {
   546  		return err
   547  	}
   548  	err2 := sysconn.Control(func(fd uintptr) {
   549  		err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
   550  	})
   551  	if err2 != nil {
   552  		return err2
   553  	}
   554  	if err != nil {
   555  		return err
   556  	}
   557  	s.blackhole6 = blackhole
   558  	return nil
   559  }
   560  
   561  func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
   562  	bind.mu.RLock()
   563  	defer bind.mu.RUnlock()
   564  	if bind.isOpen.Load() != 1 {
   565  		return net.ErrClosed
   566  	}
   567  	err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
   568  	if err != nil {
   569  		return err
   570  	}
   571  	bind.v4.blackhole = blackhole
   572  	return nil
   573  }
   574  
   575  func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
   576  	bind.mu.RLock()
   577  	defer bind.mu.RUnlock()
   578  	if bind.isOpen.Load() != 1 {
   579  		return net.ErrClosed
   580  	}
   581  	err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
   582  	if err != nil {
   583  		return err
   584  	}
   585  	bind.v6.blackhole = blackhole
   586  	return nil
   587  }
   588  
   589  func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
   590  	const IP_UNICAST_IF = 31
   591  	/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
   592  	var bytes [4]byte
   593  	binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
   594  	interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
   595  	err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
   596  	if err != nil {
   597  		return err
   598  	}
   599  	return nil
   600  }
   601  
   602  func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
   603  	const IPV6_UNICAST_IF = 31
   604  	return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
   605  }