github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/conn/bind_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"encoding/binary"
    10  	"io"
    11  	"net"
    12  	"strconv"
    13  	"sync"
    14  	"sync/atomic"
    15  	"unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  
    19  	"github.com/bugfan/wireguard-go/conn/winrio"
    20  )
    21  
    22  const (
    23  	packetsPerRing = 1024
    24  	bytesPerPacket = 2048 - 32
    25  	receiveSpins   = 15
    26  )
    27  
    28  type ringPacket struct {
    29  	addr WinRingEndpoint
    30  	data [bytesPerPacket]byte
    31  }
    32  
    33  type ringBuffer struct {
    34  	packets    uintptr
    35  	head, tail uint32
    36  	id         winrio.BufferId
    37  	iocp       windows.Handle
    38  	isFull     bool
    39  	cq         winrio.Cq
    40  	mu         sync.Mutex
    41  	overlapped windows.Overlapped
    42  }
    43  
    44  func (rb *ringBuffer) Push() *ringPacket {
    45  	for rb.isFull {
    46  		panic("ring is full")
    47  	}
    48  	ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
    49  	rb.tail += 1
    50  	if rb.tail%packetsPerRing == rb.head%packetsPerRing {
    51  		rb.isFull = true
    52  	}
    53  	return ret
    54  }
    55  
    56  func (rb *ringBuffer) Return(count uint32) {
    57  	if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
    58  		return
    59  	}
    60  	rb.head += count
    61  	rb.isFull = false
    62  }
    63  
    64  type afWinRingBind struct {
    65  	sock      windows.Handle
    66  	rx, tx    ringBuffer
    67  	rq        winrio.Rq
    68  	mu        sync.Mutex
    69  	blackhole bool
    70  }
    71  
    72  // WinRingBind uses Windows registered I/O for fast ring buffered networking.
    73  type WinRingBind struct {
    74  	v4, v6 afWinRingBind
    75  	mu     sync.RWMutex
    76  	isOpen uint32
    77  }
    78  
    79  func NewDefaultBind() Bind { return NewWinRingBind() }
    80  
    81  func NewWinRingBind() Bind {
    82  	if !winrio.Initialize() {
    83  		return NewStdNetBind()
    84  	}
    85  	return new(WinRingBind)
    86  }
    87  
    88  type WinRingEndpoint struct {
    89  	family uint16
    90  	data   [30]byte
    91  }
    92  
    93  var _ Bind = (*WinRingBind)(nil)
    94  var _ Endpoint = (*WinRingEndpoint)(nil)
    95  
    96  func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
    97  	host, port, err := net.SplitHostPort(s)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	host16, err := windows.UTF16PtrFromString(host)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	port16, err := windows.UTF16PtrFromString(port)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	hints := windows.AddrinfoW{
   110  		Flags:    windows.AI_NUMERICHOST,
   111  		Family:   windows.AF_UNSPEC,
   112  		Socktype: windows.SOCK_DGRAM,
   113  		Protocol: windows.IPPROTO_UDP,
   114  	}
   115  	var addrinfo *windows.AddrinfoW
   116  	err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	defer windows.FreeAddrInfoW(addrinfo)
   121  	if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
   122  		return nil, windows.ERROR_INVALID_ADDRESS
   123  	}
   124  	var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
   125  	copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
   126  	return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
   127  }
   128  
   129  func (*WinRingEndpoint) ClearSrc() {}
   130  
   131  func (e *WinRingEndpoint) DstIP() net.IP {
   132  	switch e.family {
   133  	case windows.AF_INET:
   134  		return append([]byte{}, e.data[2:6]...)
   135  	case windows.AF_INET6:
   136  		return append([]byte{}, e.data[6:22]...)
   137  	}
   138  	return nil
   139  }
   140  
   141  func (e *WinRingEndpoint) SrcIP() net.IP {
   142  	return nil // not supported
   143  }
   144  
   145  func (e *WinRingEndpoint) DstToBytes() []byte {
   146  	switch e.family {
   147  	case windows.AF_INET:
   148  		b := make([]byte, 0, 6)
   149  		b = append(b, e.data[2:6]...)
   150  		b = append(b, e.data[1], e.data[0])
   151  		return b
   152  	case windows.AF_INET6:
   153  		b := make([]byte, 0, 18)
   154  		b = append(b, e.data[6:22]...)
   155  		b = append(b, e.data[1], e.data[0])
   156  		return b
   157  	}
   158  	return nil
   159  }
   160  
   161  func (e *WinRingEndpoint) DstToString() string {
   162  	switch e.family {
   163  	case windows.AF_INET:
   164  		addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
   165  		return addr.String()
   166  	case windows.AF_INET6:
   167  		var zone string
   168  		if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
   169  			zone = strconv.FormatUint(uint64(scope), 10)
   170  		}
   171  		addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
   172  		return addr.String()
   173  	}
   174  	return ""
   175  }
   176  
   177  func (e *WinRingEndpoint) SrcToString() string {
   178  	return ""
   179  }
   180  
   181  func (ring *ringBuffer) CloseAndZero() {
   182  	if ring.cq != 0 {
   183  		winrio.CloseCompletionQueue(ring.cq)
   184  		ring.cq = 0
   185  	}
   186  	if ring.iocp != 0 {
   187  		windows.CloseHandle(ring.iocp)
   188  		ring.iocp = 0
   189  	}
   190  	if ring.id != 0 {
   191  		winrio.DeregisterBuffer(ring.id)
   192  		ring.id = 0
   193  	}
   194  	if ring.packets != 0 {
   195  		windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
   196  		ring.packets = 0
   197  	}
   198  	ring.head = 0
   199  	ring.tail = 0
   200  	ring.isFull = false
   201  }
   202  
   203  func (bind *afWinRingBind) CloseAndZero() {
   204  	bind.rx.CloseAndZero()
   205  	bind.tx.CloseAndZero()
   206  	if bind.sock != 0 {
   207  		windows.CloseHandle(bind.sock)
   208  		bind.sock = 0
   209  	}
   210  	bind.blackhole = false
   211  }
   212  
   213  func (bind *WinRingBind) closeAndZero() {
   214  	atomic.StoreUint32(&bind.isOpen, 0)
   215  	bind.v4.CloseAndZero()
   216  	bind.v6.CloseAndZero()
   217  }
   218  
   219  func (ring *ringBuffer) Open() error {
   220  	var err error
   221  	packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
   222  	ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
   227  	if err != nil {
   228  		return err
   229  	}
   230  	ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
   231  	if err != nil {
   232  		return err
   233  	}
   234  	ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
   235  	if err != nil {
   236  		return err
   237  	}
   238  	return nil
   239  }
   240  
   241  func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
   242  	var err error
   243  	bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	err = bind.rx.Open()
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	err = bind.tx.Open()
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  	bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
   256  	if err != nil {
   257  		return nil, err
   258  	}
   259  	err = windows.Bind(bind.sock, sa)
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  	sa, err = windows.Getsockname(bind.sock)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	return sa, nil
   268  }
   269  
   270  func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
   271  	bind.mu.Lock()
   272  	defer bind.mu.Unlock()
   273  	defer func() {
   274  		if err != nil {
   275  			bind.closeAndZero()
   276  		}
   277  	}()
   278  	if atomic.LoadUint32(&bind.isOpen) != 0 {
   279  		return nil, 0, ErrBindAlreadyOpen
   280  	}
   281  	var sa windows.Sockaddr
   282  	sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
   283  	if err != nil {
   284  		return nil, 0, err
   285  	}
   286  	sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
   287  	if err != nil {
   288  		return nil, 0, err
   289  	}
   290  	selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
   291  	for i := 0; i < packetsPerRing; i++ {
   292  		err = bind.v4.InsertReceiveRequest()
   293  		if err != nil {
   294  			return nil, 0, err
   295  		}
   296  		err = bind.v6.InsertReceiveRequest()
   297  		if err != nil {
   298  			return nil, 0, err
   299  		}
   300  	}
   301  	atomic.StoreUint32(&bind.isOpen, 1)
   302  	return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
   303  }
   304  
   305  func (bind *WinRingBind) Close() error {
   306  	bind.mu.RLock()
   307  	if atomic.LoadUint32(&bind.isOpen) != 1 {
   308  		bind.mu.RUnlock()
   309  		return nil
   310  	}
   311  	atomic.StoreUint32(&bind.isOpen, 2)
   312  	windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
   313  	windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
   314  	windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
   315  	windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
   316  	bind.mu.RUnlock()
   317  	bind.mu.Lock()
   318  	defer bind.mu.Unlock()
   319  	bind.closeAndZero()
   320  	return nil
   321  }
   322  
   323  func (bind *WinRingBind) SetMark(mark uint32) error {
   324  	return nil
   325  }
   326  
   327  func (bind *afWinRingBind) InsertReceiveRequest() error {
   328  	packet := bind.rx.Push()
   329  	dataBuffer := &winrio.Buffer{
   330  		Id:     bind.rx.id,
   331  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
   332  		Length: uint32(len(packet.data)),
   333  	}
   334  	addressBuffer := &winrio.Buffer{
   335  		Id:     bind.rx.id,
   336  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
   337  		Length: uint32(unsafe.Sizeof(packet.addr)),
   338  	}
   339  	bind.mu.Lock()
   340  	defer bind.mu.Unlock()
   341  	return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
   342  }
   343  
   344  //go:linkname procyield runtime.procyield
   345  func procyield(cycles uint32)
   346  
   347  func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
   348  	if atomic.LoadUint32(isOpen) != 1 {
   349  		return 0, nil, net.ErrClosed
   350  	}
   351  	bind.rx.mu.Lock()
   352  	defer bind.rx.mu.Unlock()
   353  
   354  	var err error
   355  	var count uint32
   356  	var results [1]winrio.Result
   357  retry:
   358  	count = 0
   359  	for tries := 0; count == 0 && tries < receiveSpins; tries++ {
   360  		if tries > 0 {
   361  			if atomic.LoadUint32(isOpen) != 1 {
   362  				return 0, nil, net.ErrClosed
   363  			}
   364  			procyield(1)
   365  		}
   366  		count = winrio.DequeueCompletion(bind.rx.cq, results[:])
   367  	}
   368  	if count == 0 {
   369  		err = winrio.Notify(bind.rx.cq)
   370  		if err != nil {
   371  			return 0, nil, err
   372  		}
   373  		var bytes uint32
   374  		var key uintptr
   375  		var overlapped *windows.Overlapped
   376  		err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
   377  		if err != nil {
   378  			return 0, nil, err
   379  		}
   380  		if atomic.LoadUint32(isOpen) != 1 {
   381  			return 0, nil, net.ErrClosed
   382  		}
   383  		count = winrio.DequeueCompletion(bind.rx.cq, results[:])
   384  		if count == 0 {
   385  			return 0, nil, io.ErrNoProgress
   386  
   387  		}
   388  	}
   389  	bind.rx.Return(1)
   390  	err = bind.InsertReceiveRequest()
   391  	if err != nil {
   392  		return 0, nil, err
   393  	}
   394  	// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
   395  	// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
   396  	// attacker bandwidth, just like the rest of the receive path.
   397  	if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
   398  		if atomic.LoadUint32(isOpen) != 1 {
   399  			return 0, nil, net.ErrClosed
   400  		}
   401  		goto retry
   402  	}
   403  	if results[0].Status != 0 {
   404  		return 0, nil, windows.Errno(results[0].Status)
   405  	}
   406  	packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
   407  	ep := packet.addr
   408  	n := copy(buf, packet.data[:results[0].BytesTransferred])
   409  	return n, &ep, nil
   410  }
   411  
   412  func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
   413  	bind.mu.RLock()
   414  	defer bind.mu.RUnlock()
   415  	return bind.v4.Receive(buf, &bind.isOpen)
   416  }
   417  
   418  func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
   419  	bind.mu.RLock()
   420  	defer bind.mu.RUnlock()
   421  	return bind.v6.Receive(buf, &bind.isOpen)
   422  }
   423  
   424  func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
   425  	if atomic.LoadUint32(isOpen) != 1 {
   426  		return net.ErrClosed
   427  	}
   428  	if len(buf) > bytesPerPacket {
   429  		return io.ErrShortBuffer
   430  	}
   431  	bind.tx.mu.Lock()
   432  	defer bind.tx.mu.Unlock()
   433  	var results [packetsPerRing]winrio.Result
   434  	count := winrio.DequeueCompletion(bind.tx.cq, results[:])
   435  	if count == 0 && bind.tx.isFull {
   436  		err := winrio.Notify(bind.tx.cq)
   437  		if err != nil {
   438  			return err
   439  		}
   440  		var bytes uint32
   441  		var key uintptr
   442  		var overlapped *windows.Overlapped
   443  		err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
   444  		if err != nil {
   445  			return err
   446  		}
   447  		if atomic.LoadUint32(isOpen) != 1 {
   448  			return net.ErrClosed
   449  		}
   450  		count = winrio.DequeueCompletion(bind.tx.cq, results[:])
   451  		if count == 0 {
   452  			return io.ErrNoProgress
   453  		}
   454  	}
   455  	if count > 0 {
   456  		bind.tx.Return(count)
   457  	}
   458  	packet := bind.tx.Push()
   459  	packet.addr = *nend
   460  	copy(packet.data[:], buf)
   461  	dataBuffer := &winrio.Buffer{
   462  		Id:     bind.tx.id,
   463  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
   464  		Length: uint32(len(buf)),
   465  	}
   466  	addressBuffer := &winrio.Buffer{
   467  		Id:     bind.tx.id,
   468  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
   469  		Length: uint32(unsafe.Sizeof(packet.addr)),
   470  	}
   471  	bind.mu.Lock()
   472  	defer bind.mu.Unlock()
   473  	return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
   474  }
   475  
   476  func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
   477  	nend, ok := endpoint.(*WinRingEndpoint)
   478  	if !ok {
   479  		return ErrWrongEndpointType
   480  	}
   481  	bind.mu.RLock()
   482  	defer bind.mu.RUnlock()
   483  	switch nend.family {
   484  	case windows.AF_INET:
   485  		if bind.v4.blackhole {
   486  			return nil
   487  		}
   488  		return bind.v4.Send(buf, nend, &bind.isOpen)
   489  	case windows.AF_INET6:
   490  		if bind.v6.blackhole {
   491  			return nil
   492  		}
   493  		return bind.v6.Send(buf, nend, &bind.isOpen)
   494  	}
   495  	return nil
   496  }
   497  
   498  func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
   499  	bind.mu.Lock()
   500  	defer bind.mu.Unlock()
   501  	sysconn, err := bind.ipv4.SyscallConn()
   502  	if err != nil {
   503  		return err
   504  	}
   505  	err2 := sysconn.Control(func(fd uintptr) {
   506  		err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
   507  	})
   508  	if err2 != nil {
   509  		return err2
   510  	}
   511  	if err != nil {
   512  		return err
   513  	}
   514  	bind.blackhole4 = blackhole
   515  	return nil
   516  }
   517  
   518  func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
   519  	bind.mu.Lock()
   520  	defer bind.mu.Unlock()
   521  	sysconn, err := bind.ipv6.SyscallConn()
   522  	if err != nil {
   523  		return err
   524  	}
   525  	err2 := sysconn.Control(func(fd uintptr) {
   526  		err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
   527  	})
   528  	if err2 != nil {
   529  		return err2
   530  	}
   531  	if err != nil {
   532  		return err
   533  	}
   534  	bind.blackhole6 = blackhole
   535  	return nil
   536  }
   537  func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
   538  	bind.mu.RLock()
   539  	defer bind.mu.RUnlock()
   540  	if atomic.LoadUint32(&bind.isOpen) != 1 {
   541  		return net.ErrClosed
   542  	}
   543  	err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
   544  	if err != nil {
   545  		return err
   546  	}
   547  	bind.v4.blackhole = blackhole
   548  	return nil
   549  }
   550  
   551  func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
   552  	bind.mu.RLock()
   553  	defer bind.mu.RUnlock()
   554  	if atomic.LoadUint32(&bind.isOpen) != 1 {
   555  		return net.ErrClosed
   556  	}
   557  	err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
   558  	if err != nil {
   559  		return err
   560  	}
   561  	bind.v6.blackhole = blackhole
   562  	return nil
   563  }
   564  
   565  func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
   566  	const IP_UNICAST_IF = 31
   567  	/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
   568  	var bytes [4]byte
   569  	binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
   570  	interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
   571  	err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
   572  	if err != nil {
   573  		return err
   574  	}
   575  	return nil
   576  }
   577  
   578  func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
   579  	const IPV6_UNICAST_IF = 31
   580  	return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
   581  }