github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/conn/bind_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"encoding/binary"
    10  	"io"
    11  	"net"
    12  	"strconv"
    13  	"sync"
    14  	"sync/atomic"
    15  	"unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  
    19  	"github.com/tailscale/wireguard-go/conn/winrio"
    20  )
    21  
    22  const (
    23  	packetsPerRing = 1024
    24  	bytesPerPacket = 2048 - 32
    25  	receiveSpins   = 15
    26  )
    27  
    28  type ringPacket struct {
    29  	addr WinRingEndpoint
    30  	data [bytesPerPacket]byte
    31  }
    32  
    33  type ringBuffer struct {
    34  	packets    uintptr
    35  	head, tail uint32
    36  	id         winrio.BufferId
    37  	iocp       windows.Handle
    38  	isFull     bool
    39  	cq         winrio.Cq
    40  	mu         sync.Mutex
    41  	overlapped windows.Overlapped
    42  }
    43  
    44  func (rb *ringBuffer) Push() *ringPacket {
    45  	for rb.isFull {
    46  		panic("ring is full")
    47  	}
    48  	ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
    49  	rb.tail += 1
    50  	if rb.tail%packetsPerRing == rb.head%packetsPerRing {
    51  		rb.isFull = true
    52  	}
    53  	return ret
    54  }
    55  
    56  func (rb *ringBuffer) Return(count uint32) {
    57  	if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
    58  		return
    59  	}
    60  	rb.head += count
    61  	rb.isFull = false
    62  }
    63  
    64  type afWinRingBind struct {
    65  	sock      windows.Handle
    66  	rx, tx    ringBuffer
    67  	rq        winrio.Rq
    68  	mu        sync.Mutex
    69  	blackhole bool
    70  }
    71  
    72  // WinRingBind uses Windows registered I/O for fast ring buffered networking.
    73  type WinRingBind struct {
    74  	v4, v6 afWinRingBind
    75  	mu     sync.RWMutex
    76  	isOpen uint32
    77  }
    78  
    79  func NewDefaultBind() Bind { return NewWinRingBind() }
    80  
    81  func NewWinRingBind() Bind {
    82  	if !winrio.Initialize() {
    83  		return NewStdNetBind()
    84  	}
    85  	return new(WinRingBind)
    86  }
    87  
    88  type WinRingEndpoint struct {
    89  	family uint16
    90  	data   [30]byte
    91  }
    92  
    93  var _ Bind = (*WinRingBind)(nil)
    94  var _ Endpoint = (*WinRingEndpoint)(nil)
    95  
    96  func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
    97  	host, port, err := net.SplitHostPort(s)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	host16, err := windows.UTF16PtrFromString(host)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	port16, err := windows.UTF16PtrFromString(port)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	hints := windows.AddrinfoW{
   110  		Flags:    windows.AI_NUMERICHOST,
   111  		Family:   windows.AF_UNSPEC,
   112  		Socktype: windows.SOCK_DGRAM,
   113  		Protocol: windows.IPPROTO_UDP,
   114  	}
   115  	var addrinfo *windows.AddrinfoW
   116  	err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	defer windows.FreeAddrInfoW(addrinfo)
   121  	if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
   122  		return nil, windows.ERROR_INVALID_ADDRESS
   123  	}
   124  	var src []byte
   125  	var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
   126  	unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen))
   127  	copy(dst[:], src)
   128  	return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
   129  }
   130  
   131  func (*WinRingEndpoint) ClearSrc() {}
   132  
   133  func (e *WinRingEndpoint) DstIP() net.IP {
   134  	switch e.family {
   135  	case windows.AF_INET:
   136  		return append([]byte{}, e.data[2:6]...)
   137  	case windows.AF_INET6:
   138  		return append([]byte{}, e.data[6:22]...)
   139  	}
   140  	return nil
   141  }
   142  
   143  func (e *WinRingEndpoint) SrcIP() net.IP {
   144  	return nil // not supported
   145  }
   146  
   147  func (e *WinRingEndpoint) DstToBytes() []byte {
   148  	switch e.family {
   149  	case windows.AF_INET:
   150  		b := make([]byte, 0, 6)
   151  		b = append(b, e.data[2:6]...)
   152  		b = append(b, e.data[1], e.data[0])
   153  		return b
   154  	case windows.AF_INET6:
   155  		b := make([]byte, 0, 18)
   156  		b = append(b, e.data[6:22]...)
   157  		b = append(b, e.data[1], e.data[0])
   158  		return b
   159  	}
   160  	return nil
   161  }
   162  
   163  func (e *WinRingEndpoint) DstToString() string {
   164  	switch e.family {
   165  	case windows.AF_INET:
   166  		addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
   167  		return addr.String()
   168  	case windows.AF_INET6:
   169  		var zone string
   170  		if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
   171  			zone = strconv.FormatUint(uint64(scope), 10)
   172  		}
   173  		addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
   174  		return addr.String()
   175  	}
   176  	return ""
   177  }
   178  
   179  func (e *WinRingEndpoint) SrcToString() string {
   180  	return ""
   181  }
   182  
   183  func (ring *ringBuffer) CloseAndZero() {
   184  	if ring.cq != 0 {
   185  		winrio.CloseCompletionQueue(ring.cq)
   186  		ring.cq = 0
   187  	}
   188  	if ring.iocp != 0 {
   189  		windows.CloseHandle(ring.iocp)
   190  		ring.iocp = 0
   191  	}
   192  	if ring.id != 0 {
   193  		winrio.DeregisterBuffer(ring.id)
   194  		ring.id = 0
   195  	}
   196  	if ring.packets != 0 {
   197  		windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
   198  		ring.packets = 0
   199  	}
   200  	ring.head = 0
   201  	ring.tail = 0
   202  	ring.isFull = false
   203  }
   204  
   205  func (bind *afWinRingBind) CloseAndZero() {
   206  	bind.rx.CloseAndZero()
   207  	bind.tx.CloseAndZero()
   208  	if bind.sock != 0 {
   209  		windows.CloseHandle(bind.sock)
   210  		bind.sock = 0
   211  	}
   212  	bind.blackhole = false
   213  }
   214  
   215  func (bind *WinRingBind) closeAndZero() {
   216  	atomic.StoreUint32(&bind.isOpen, 0)
   217  	bind.v4.CloseAndZero()
   218  	bind.v6.CloseAndZero()
   219  }
   220  
   221  func (ring *ringBuffer) Open() error {
   222  	var err error
   223  	packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
   224  	ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
   225  	if err != nil {
   226  		return err
   227  	}
   228  	ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
   229  	if err != nil {
   230  		return err
   231  	}
   232  	ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
   233  	if err != nil {
   234  		return err
   235  	}
   236  	ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
   237  	if err != nil {
   238  		return err
   239  	}
   240  	return nil
   241  }
   242  
   243  func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
   244  	var err error
   245  	bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	err = bind.rx.Open()
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  	err = bind.tx.Open()
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  	bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
   258  	if err != nil {
   259  		return nil, err
   260  	}
   261  	err = windows.Bind(bind.sock, sa)
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  	sa, err = windows.Getsockname(bind.sock)
   266  	if err != nil {
   267  		return nil, err
   268  	}
   269  	return sa, nil
   270  }
   271  
   272  func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
   273  	bind.mu.Lock()
   274  	defer bind.mu.Unlock()
   275  	defer func() {
   276  		if err != nil {
   277  			bind.closeAndZero()
   278  		}
   279  	}()
   280  	if atomic.LoadUint32(&bind.isOpen) != 0 {
   281  		return nil, 0, ErrBindAlreadyOpen
   282  	}
   283  	var sa windows.Sockaddr
   284  	sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
   285  	if err != nil {
   286  		return nil, 0, err
   287  	}
   288  	sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
   289  	if err != nil {
   290  		return nil, 0, err
   291  	}
   292  	selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
   293  	for i := 0; i < packetsPerRing; i++ {
   294  		err = bind.v4.InsertReceiveRequest()
   295  		if err != nil {
   296  			return nil, 0, err
   297  		}
   298  		err = bind.v6.InsertReceiveRequest()
   299  		if err != nil {
   300  			return nil, 0, err
   301  		}
   302  	}
   303  	atomic.StoreUint32(&bind.isOpen, 1)
   304  	return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
   305  }
   306  
   307  func (bind *WinRingBind) Close() error {
   308  	bind.mu.RLock()
   309  	if atomic.LoadUint32(&bind.isOpen) != 1 {
   310  		bind.mu.RUnlock()
   311  		return nil
   312  	}
   313  	atomic.StoreUint32(&bind.isOpen, 2)
   314  	windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
   315  	windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
   316  	windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
   317  	windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
   318  	bind.mu.RUnlock()
   319  	bind.mu.Lock()
   320  	defer bind.mu.Unlock()
   321  	bind.closeAndZero()
   322  	return nil
   323  }
   324  
   325  func (bind *WinRingBind) SetMark(mark uint32) error {
   326  	return nil
   327  }
   328  
   329  func (bind *afWinRingBind) InsertReceiveRequest() error {
   330  	packet := bind.rx.Push()
   331  	dataBuffer := &winrio.Buffer{
   332  		Id:     bind.rx.id,
   333  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
   334  		Length: uint32(len(packet.data)),
   335  	}
   336  	addressBuffer := &winrio.Buffer{
   337  		Id:     bind.rx.id,
   338  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
   339  		Length: uint32(unsafe.Sizeof(packet.addr)),
   340  	}
   341  	bind.mu.Lock()
   342  	defer bind.mu.Unlock()
   343  	return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
   344  }
   345  
   346  //go:linkname procyield runtime.procyield
   347  func procyield(cycles uint32)
   348  
   349  func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
   350  	if atomic.LoadUint32(isOpen) != 1 {
   351  		return 0, nil, net.ErrClosed
   352  	}
   353  	bind.rx.mu.Lock()
   354  	defer bind.rx.mu.Unlock()
   355  
   356  	var err error
   357  	var count uint32
   358  	var results [1]winrio.Result
   359  retry:
   360  	count = 0
   361  	for tries := 0; count == 0 && tries < receiveSpins; tries++ {
   362  		if tries > 0 {
   363  			if atomic.LoadUint32(isOpen) != 1 {
   364  				return 0, nil, net.ErrClosed
   365  			}
   366  			procyield(1)
   367  		}
   368  		count = winrio.DequeueCompletion(bind.rx.cq, results[:])
   369  	}
   370  	if count == 0 {
   371  		err = winrio.Notify(bind.rx.cq)
   372  		if err != nil {
   373  			return 0, nil, err
   374  		}
   375  		var bytes uint32
   376  		var key uintptr
   377  		var overlapped *windows.Overlapped
   378  		err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
   379  		if err != nil {
   380  			return 0, nil, err
   381  		}
   382  		if atomic.LoadUint32(isOpen) != 1 {
   383  			return 0, nil, net.ErrClosed
   384  		}
   385  		count = winrio.DequeueCompletion(bind.rx.cq, results[:])
   386  		if count == 0 {
   387  			return 0, nil, io.ErrNoProgress
   388  
   389  		}
   390  	}
   391  	bind.rx.Return(1)
   392  	err = bind.InsertReceiveRequest()
   393  	if err != nil {
   394  		return 0, nil, err
   395  	}
   396  	// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
   397  	// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
   398  	// attacker bandwidth, just like the rest of the receive path.
   399  	if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
   400  		if atomic.LoadUint32(isOpen) != 1 {
   401  			return 0, nil, net.ErrClosed
   402  		}
   403  		goto retry
   404  	}
   405  	if results[0].Status != 0 {
   406  		return 0, nil, windows.Errno(results[0].Status)
   407  	}
   408  	packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
   409  	ep := packet.addr
   410  	n := copy(buf, packet.data[:results[0].BytesTransferred])
   411  	return n, &ep, nil
   412  }
   413  
   414  func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
   415  	bind.mu.RLock()
   416  	defer bind.mu.RUnlock()
   417  	return bind.v4.Receive(buf, &bind.isOpen)
   418  }
   419  
   420  func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
   421  	bind.mu.RLock()
   422  	defer bind.mu.RUnlock()
   423  	return bind.v6.Receive(buf, &bind.isOpen)
   424  }
   425  
   426  func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
   427  	if atomic.LoadUint32(isOpen) != 1 {
   428  		return net.ErrClosed
   429  	}
   430  	if len(buf) > bytesPerPacket {
   431  		return io.ErrShortBuffer
   432  	}
   433  	bind.tx.mu.Lock()
   434  	defer bind.tx.mu.Unlock()
   435  	var results [packetsPerRing]winrio.Result
   436  	count := winrio.DequeueCompletion(bind.tx.cq, results[:])
   437  	if count == 0 && bind.tx.isFull {
   438  		err := winrio.Notify(bind.tx.cq)
   439  		if err != nil {
   440  			return err
   441  		}
   442  		var bytes uint32
   443  		var key uintptr
   444  		var overlapped *windows.Overlapped
   445  		err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
   446  		if err != nil {
   447  			return err
   448  		}
   449  		if atomic.LoadUint32(isOpen) != 1 {
   450  			return net.ErrClosed
   451  		}
   452  		count = winrio.DequeueCompletion(bind.tx.cq, results[:])
   453  		if count == 0 {
   454  			return io.ErrNoProgress
   455  		}
   456  	}
   457  	if count > 0 {
   458  		bind.tx.Return(count)
   459  	}
   460  	packet := bind.tx.Push()
   461  	packet.addr = *nend
   462  	copy(packet.data[:], buf)
   463  	dataBuffer := &winrio.Buffer{
   464  		Id:     bind.tx.id,
   465  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
   466  		Length: uint32(len(buf)),
   467  	}
   468  	addressBuffer := &winrio.Buffer{
   469  		Id:     bind.tx.id,
   470  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
   471  		Length: uint32(unsafe.Sizeof(packet.addr)),
   472  	}
   473  	bind.mu.Lock()
   474  	defer bind.mu.Unlock()
   475  	return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
   476  }
   477  
   478  func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
   479  	nend, ok := endpoint.(*WinRingEndpoint)
   480  	if !ok {
   481  		return ErrWrongEndpointType
   482  	}
   483  	bind.mu.RLock()
   484  	defer bind.mu.RUnlock()
   485  	switch nend.family {
   486  	case windows.AF_INET:
   487  		if bind.v4.blackhole {
   488  			return nil
   489  		}
   490  		return bind.v4.Send(buf, nend, &bind.isOpen)
   491  	case windows.AF_INET6:
   492  		if bind.v6.blackhole {
   493  			return nil
   494  		}
   495  		return bind.v6.Send(buf, nend, &bind.isOpen)
   496  	}
   497  	return nil
   498  }
   499  
   500  func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
   501  	bind.mu.Lock()
   502  	defer bind.mu.Unlock()
   503  	sysconn, err := bind.ipv4.SyscallConn()
   504  	if err != nil {
   505  		return err
   506  	}
   507  	err2 := sysconn.Control(func(fd uintptr) {
   508  		err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
   509  	})
   510  	if err2 != nil {
   511  		return err2
   512  	}
   513  	if err != nil {
   514  		return err
   515  	}
   516  	bind.blackhole4 = blackhole
   517  	return nil
   518  }
   519  
   520  func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
   521  	bind.mu.Lock()
   522  	defer bind.mu.Unlock()
   523  	sysconn, err := bind.ipv6.SyscallConn()
   524  	if err != nil {
   525  		return err
   526  	}
   527  	err2 := sysconn.Control(func(fd uintptr) {
   528  		err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
   529  	})
   530  	if err2 != nil {
   531  		return err2
   532  	}
   533  	if err != nil {
   534  		return err
   535  	}
   536  	bind.blackhole6 = blackhole
   537  	return nil
   538  }
   539  func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
   540  	bind.mu.RLock()
   541  	defer bind.mu.RUnlock()
   542  	if atomic.LoadUint32(&bind.isOpen) != 1 {
   543  		return net.ErrClosed
   544  	}
   545  	err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
   546  	if err != nil {
   547  		return err
   548  	}
   549  	bind.v4.blackhole = blackhole
   550  	return nil
   551  }
   552  
   553  func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
   554  	bind.mu.RLock()
   555  	defer bind.mu.RUnlock()
   556  	if atomic.LoadUint32(&bind.isOpen) != 1 {
   557  		return net.ErrClosed
   558  	}
   559  	err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
   560  	if err != nil {
   561  		return err
   562  	}
   563  	bind.v6.blackhole = blackhole
   564  	return nil
   565  }
   566  
   567  func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
   568  	const IP_UNICAST_IF = 31
   569  	/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
   570  	var bytes [4]byte
   571  	binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
   572  	interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
   573  	err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
   574  	if err != nil {
   575  		return err
   576  	}
   577  	return nil
   578  }
   579  
   580  func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
   581  	const IPV6_UNICAST_IF = 31
   582  	return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
   583  }
   584  
   585  // unsafeSlice updates the slice slicePtr to be a slice
   586  // referencing the provided data with its length & capacity set to
   587  // lenCap.
   588  //
   589  // TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
   590  // update callers to use unsafe.Slice instead of this.
   591  func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
   592  	type sliceHeader struct {
   593  		Data unsafe.Pointer
   594  		Len  int
   595  		Cap  int
   596  	}
   597  	h := (*sliceHeader)(slicePtr)
   598  	h.Data = data
   599  	h.Len = lenCap
   600  	h.Cap = lenCap
   601  }