github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/conn/winrio/rio_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package winrio
     7  
     8  import (
     9  	"log"
    10  	"sync"
    11  	"syscall"
    12  	"unsafe"
    13  
    14  	"golang.org/x/sys/windows"
    15  )
    16  
    17  const (
    18  	MsgDontNotify = 1
    19  	MsgDefer      = 2
    20  	MsgWaitAll    = 4
    21  	MsgCommitOnly = 8
    22  
    23  	MaxCqSize = 0x8000000
    24  
    25  	invalidBufferId = 0xFFFFFFFF
    26  	invalidCq       = 0
    27  	invalidRq       = 0
    28  	corruptCq       = 0xFFFFFFFF
    29  )
    30  
    31  var extensionFunctionTable struct {
    32  	cbSize                   uint32
    33  	rioReceive               uintptr
    34  	rioReceiveEx             uintptr
    35  	rioSend                  uintptr
    36  	rioSendEx                uintptr
    37  	rioCloseCompletionQueue  uintptr
    38  	rioCreateCompletionQueue uintptr
    39  	rioCreateRequestQueue    uintptr
    40  	rioDequeueCompletion     uintptr
    41  	rioDeregisterBuffer      uintptr
    42  	rioNotify                uintptr
    43  	rioRegisterBuffer        uintptr
    44  	rioResizeCompletionQueue uintptr
    45  	rioResizeRequestQueue    uintptr
    46  }
    47  
    48  type Cq uintptr
    49  
    50  type Rq uintptr
    51  
    52  type BufferId uintptr
    53  
    54  type Buffer struct {
    55  	Id     BufferId
    56  	Offset uint32
    57  	Length uint32
    58  }
    59  
    60  type Result struct {
    61  	Status           int32
    62  	BytesTransferred uint32
    63  	SocketContext    uint64
    64  	RequestContext   uint64
    65  }
    66  
    67  type notificationCompletionType uint32
    68  
    69  const (
    70  	eventCompletion notificationCompletionType = 1
    71  	iocpCompletion  notificationCompletionType = 2
    72  )
    73  
    74  type eventNotificationCompletion struct {
    75  	completionType notificationCompletionType
    76  	event          windows.Handle
    77  	notifyReset    uint32
    78  }
    79  
    80  type iocpNotificationCompletion struct {
    81  	completionType notificationCompletionType
    82  	iocp           windows.Handle
    83  	key            uintptr
    84  	overlapped     *windows.Overlapped
    85  }
    86  
    87  var initialized sync.Once
    88  var available bool
    89  
    90  func Initialize() bool {
    91  	initialized.Do(func() {
    92  		var (
    93  			err    error
    94  			socket windows.Handle
    95  			cq     Cq
    96  		)
    97  		defer func() {
    98  			if err == nil {
    99  				return
   100  			}
   101  			if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
   102  				return
   103  			}
   104  			log.Printf("Registered I/O is unavailable: %v", err)
   105  		}()
   106  		socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
   107  		if err != nil {
   108  			return
   109  		}
   110  		defer windows.CloseHandle(socket)
   111  		var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
   112  		const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
   113  		ob := uint32(0)
   114  		err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
   115  			(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
   116  			(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
   117  			&ob, nil, 0)
   118  		if err != nil {
   119  			return
   120  		}
   121  
   122  		// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
   123  		// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
   124  		var iocp windows.Handle
   125  		iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
   126  		if err != nil {
   127  			return
   128  		}
   129  		defer windows.CloseHandle(iocp)
   130  		var overlapped windows.Overlapped
   131  		cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
   132  		if err != nil {
   133  			return
   134  		}
   135  		defer CloseCompletionQueue(cq)
   136  		_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
   137  		if err != nil {
   138  			return
   139  		}
   140  		available = true
   141  	})
   142  	return available
   143  }
   144  
   145  func Socket(af, typ, proto int32) (windows.Handle, error) {
   146  	return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
   147  }
   148  
   149  func CloseCompletionQueue(cq Cq) {
   150  	_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
   151  }
   152  
   153  func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
   154  	notificationCompletion := &eventNotificationCompletion{
   155  		completionType: eventCompletion,
   156  		event:          event,
   157  	}
   158  	if notifyReset {
   159  		notificationCompletion.notifyReset = 1
   160  	}
   161  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
   162  	if ret == invalidCq {
   163  		return 0, err
   164  	}
   165  	return Cq(ret), nil
   166  }
   167  
   168  func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
   169  	notificationCompletion := &iocpNotificationCompletion{
   170  		completionType: iocpCompletion,
   171  		iocp:           iocp,
   172  		key:            key,
   173  		overlapped:     overlapped,
   174  	}
   175  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
   176  	if ret == invalidCq {
   177  		return 0, err
   178  	}
   179  	return Cq(ret), nil
   180  }
   181  
   182  func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
   183  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
   184  	if ret == invalidCq {
   185  		return 0, err
   186  	}
   187  	return Cq(ret), nil
   188  }
   189  
   190  func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
   191  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
   192  	if ret == invalidRq {
   193  		return 0, err
   194  	}
   195  	return Rq(ret), nil
   196  }
   197  
   198  func DequeueCompletion(cq Cq, results []Result) uint32 {
   199  	var array uintptr
   200  	if len(results) > 0 {
   201  		array = uintptr(unsafe.Pointer(&results[0]))
   202  	}
   203  	ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
   204  	if ret == corruptCq {
   205  		panic("cq is corrupt")
   206  	}
   207  	return uint32(ret)
   208  }
   209  
   210  func DeregisterBuffer(id BufferId) {
   211  	_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
   212  }
   213  
   214  func RegisterBuffer(buffer []byte) (BufferId, error) {
   215  	var buf unsafe.Pointer
   216  	if len(buffer) > 0 {
   217  		buf = unsafe.Pointer(&buffer[0])
   218  	}
   219  	return RegisterPointer(buf, uint32(len(buffer)))
   220  }
   221  
   222  func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
   223  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
   224  	if ret == invalidBufferId {
   225  		return 0, err
   226  	}
   227  	return BufferId(ret), nil
   228  }
   229  
   230  func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
   231  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
   232  	if ret == 0 {
   233  		return err
   234  	}
   235  	return nil
   236  }
   237  
   238  func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
   239  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
   240  	if ret == 0 {
   241  		return err
   242  	}
   243  	return nil
   244  }
   245  
   246  func Notify(cq Cq) error {
   247  	ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
   248  	if ret != 0 {
   249  		return windows.Errno(ret)
   250  	}
   251  	return nil
   252  }