
     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     6  package winrio
     8  import (
     9  	"log"
    10  	"sync"
    11  	"syscall"
    12  	"unsafe"
    14  	""
    15  )
    17  const (
    18  	MsgDontNotify = 1
    19  	MsgDefer      = 2
    20  	MsgWaitAll    = 4
    21  	MsgCommitOnly = 8
    23  	MaxCqSize = 0x8000000
    25  	invalidBufferId = 0xFFFFFFFF
    26  	invalidCq       = 0
    27  	invalidRq       = 0
    28  	corruptCq       = 0xFFFFFFFF
    29  )
    31  var extensionFunctionTable struct {
    32  	cbSize                   uint32
    33  	rioReceive               uintptr
    34  	rioReceiveEx             uintptr
    35  	rioSend                  uintptr
    36  	rioSendEx                uintptr
    37  	rioCloseCompletionQueue  uintptr
    38  	rioCreateCompletionQueue uintptr
    39  	rioCreateRequestQueue    uintptr
    40  	rioDequeueCompletion     uintptr
    41  	rioDeregisterBuffer      uintptr
    42  	rioNotify                uintptr
    43  	rioRegisterBuffer        uintptr
    44  	rioResizeCompletionQueue uintptr
    45  	rioResizeRequestQueue    uintptr
    46  }
    48  type Cq uintptr
    50  type Rq uintptr
    52  type BufferId uintptr
    54  type Buffer struct {
    55  	Id     BufferId
    56  	Offset uint32
    57  	Length uint32
    58  }
    60  type Result struct {
    61  	Status           int32
    62  	BytesTransferred uint32
    63  	SocketContext    uint64
    64  	RequestContext   uint64
    65  }
    67  type notificationCompletionType uint32
    69  const (
    70  	eventCompletion notificationCompletionType = 1
    71  	iocpCompletion  notificationCompletionType = 2
    72  )
    74  type eventNotificationCompletion struct {
    75  	completionType notificationCompletionType
    76  	event          windows.Handle
    77  	notifyReset    uint32
    78  }
    80  type iocpNotificationCompletion struct {
    81  	completionType notificationCompletionType
    82  	iocp           windows.Handle
    83  	key            uintptr
    84  	overlapped     *windows.Overlapped
    85  }
    87  var (
    88  	initialized sync.Once
    89  	available   bool
    90  )
    92  func Initialize() bool {
    93  	initialized.Do(func() {
    94  		var (
    95  			err    error
    96  			socket windows.Handle
    97  			cq     Cq
    98  		)
    99  		defer func() {
   100  			if err == nil {
   101  				return
   102  			}
   103  			if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
   104  				return
   105  			}
   106  			log.Printf("Registered I/O is unavailable: %v", err)
   107  		}()
   108  		socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
   109  		if err != nil {
   110  			return
   111  		}
   112  		defer windows.CloseHandle(socket)
   113  		WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
   115  		ob := uint32(0)
   116  		err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
   117  			(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
   118  			(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
   119  			&ob, nil, 0)
   120  		if err != nil {
   121  			return
   122  		}
   124  		// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
   125  		// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
   126  		var iocp windows.Handle
   127  		iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
   128  		if err != nil {
   129  			return
   130  		}
   131  		defer windows.CloseHandle(iocp)
   132  		var overlapped windows.Overlapped
   133  		cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
   134  		if err != nil {
   135  			return
   136  		}
   137  		defer CloseCompletionQueue(cq)
   138  		_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
   139  		if err != nil {
   140  			return
   141  		}
   142  		available = true
   143  	})
   144  	return available
   145  }
   147  func Socket(af, typ, proto int32) (windows.Handle, error) {
   148  	return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
   149  }
   151  func CloseCompletionQueue(cq Cq) {
   152  	_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
   153  }
   155  func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
   156  	notificationCompletion := &eventNotificationCompletion{
   157  		completionType: eventCompletion,
   158  		event:          event,
   159  	}
   160  	if notifyReset {
   161  		notificationCompletion.notifyReset = 1
   162  	}
   163  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
   164  	if ret == invalidCq {
   165  		return 0, err
   166  	}
   167  	return Cq(ret), nil
   168  }
   170  func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
   171  	notificationCompletion := &iocpNotificationCompletion{
   172  		completionType: iocpCompletion,
   173  		iocp:           iocp,
   174  		key:            key,
   175  		overlapped:     overlapped,
   176  	}
   177  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
   178  	if ret == invalidCq {
   179  		return 0, err
   180  	}
   181  	return Cq(ret), nil
   182  }
   184  func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
   185  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
   186  	if ret == invalidCq {
   187  		return 0, err
   188  	}
   189  	return Cq(ret), nil
   190  }
   192  func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
   193  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
   194  	if ret == invalidRq {
   195  		return 0, err
   196  	}
   197  	return Rq(ret), nil
   198  }
   200  func DequeueCompletion(cq Cq, results []Result) uint32 {
   201  	var array uintptr
   202  	if len(results) > 0 {
   203  		array = uintptr(unsafe.Pointer(&results[0]))
   204  	}
   205  	ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
   206  	if ret == corruptCq {
   207  		panic("cq is corrupt")
   208  	}
   209  	return uint32(ret)
   210  }
   212  func DeregisterBuffer(id BufferId) {
   213  	_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
   214  }
   216  func RegisterBuffer(buffer []byte) (BufferId, error) {
   217  	var buf unsafe.Pointer
   218  	if len(buffer) > 0 {
   219  		buf = unsafe.Pointer(&buffer[0])
   220  	}
   221  	return RegisterPointer(buf, uint32(len(buffer)))
   222  }
   224  func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
   225  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
   226  	if ret == invalidBufferId {
   227  		return 0, err
   228  	}
   229  	return BufferId(ret), nil
   230  }
   232  func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
   233  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
   234  	if ret == 0 {
   235  		return err
   236  	}
   237  	return nil
   238  }
   240  func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
   241  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
   242  	if ret == 0 {
   243  		return err
   244  	}
   245  	return nil
   246  }
   248  func Notify(cq Cq) error {
   249  	ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
   250  	if ret != 0 {
   251  		return windows.Errno(ret)
   252  	}
   253  	return nil
   254  }