github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/conn/winrio/rio_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package winrio 7 8 import ( 9 "log" 10 "sync" 11 "syscall" 12 "unsafe" 13 14 "golang.org/x/sys/windows" 15 ) 16 17 const ( 18 MsgDontNotify = 1 19 MsgDefer = 2 20 MsgWaitAll = 4 21 MsgCommitOnly = 8 22 23 MaxCqSize = 0x8000000 24 25 invalidBufferId = 0xFFFFFFFF 26 invalidCq = 0 27 invalidRq = 0 28 corruptCq = 0xFFFFFFFF 29 ) 30 31 var extensionFunctionTable struct { 32 cbSize uint32 33 rioReceive uintptr 34 rioReceiveEx uintptr 35 rioSend uintptr 36 rioSendEx uintptr 37 rioCloseCompletionQueue uintptr 38 rioCreateCompletionQueue uintptr 39 rioCreateRequestQueue uintptr 40 rioDequeueCompletion uintptr 41 rioDeregisterBuffer uintptr 42 rioNotify uintptr 43 rioRegisterBuffer uintptr 44 rioResizeCompletionQueue uintptr 45 rioResizeRequestQueue uintptr 46 } 47 48 type Cq uintptr 49 50 type Rq uintptr 51 52 type BufferId uintptr 53 54 type Buffer struct { 55 Id BufferId 56 Offset uint32 57 Length uint32 58 } 59 60 type Result struct { 61 Status int32 62 BytesTransferred uint32 63 SocketContext uint64 64 RequestContext uint64 65 } 66 67 type notificationCompletionType uint32 68 69 const ( 70 eventCompletion notificationCompletionType = 1 71 iocpCompletion notificationCompletionType = 2 72 ) 73 74 type eventNotificationCompletion struct { 75 completionType notificationCompletionType 76 event windows.Handle 77 notifyReset uint32 78 } 79 80 type iocpNotificationCompletion struct { 81 completionType notificationCompletionType 82 iocp windows.Handle 83 key uintptr 84 overlapped *windows.Overlapped 85 } 86 87 var initialized sync.Once 88 var available bool 89 90 func Initialize() bool { 91 initialized.Do(func() { 92 var ( 93 err error 94 socket windows.Handle 95 cq Cq 96 ) 97 defer func() { 98 if err == nil { 99 return 100 } 101 if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 { 102 return 103 } 104 log.Printf("Registered I/O is unavailable: %v", err) 105 }() 106 socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 107 if err != nil { 108 return 109 } 110 defer windows.CloseHandle(socket) 111 var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} 112 const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 113 ob := uint32(0) 114 err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, 115 (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)), 116 (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)), 117 &ob, nil, 0) 118 if err != nil { 119 return 120 } 121 122 // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes 123 // failures in RIOCreateRequestQueue, so keep going to be certain this is supported. 124 var iocp windows.Handle 125 iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 126 if err != nil { 127 return 128 } 129 defer windows.CloseHandle(iocp) 130 var overlapped windows.Overlapped 131 cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped) 132 if err != nil { 133 return 134 } 135 defer CloseCompletionQueue(cq) 136 _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0) 137 if err != nil { 138 return 139 } 140 available = true 141 }) 142 return available 143 } 144 145 func Socket(af, typ, proto int32) (windows.Handle, error) { 146 return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO) 147 } 148 149 func CloseCompletionQueue(cq Cq) { 150 _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0) 151 } 152 153 func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) { 154 notificationCompletion := &eventNotificationCompletion{ 155 completionType: eventCompletion, 156 event: event, 157 } 158 if notifyReset { 159 notificationCompletion.notifyReset = 1 160 } 161 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) 162 if ret == invalidCq { 163 return 0, err 164 } 165 return Cq(ret), nil 166 } 167 168 func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) { 169 notificationCompletion := &iocpNotificationCompletion{ 170 completionType: iocpCompletion, 171 iocp: iocp, 172 key: key, 173 overlapped: overlapped, 174 } 175 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) 176 if ret == invalidCq { 177 return 0, err 178 } 179 return Cq(ret), nil 180 } 181 182 func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) { 183 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0) 184 if ret == invalidCq { 185 return 0, err 186 } 187 return Cq(ret), nil 188 } 189 190 func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) { 191 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0) 192 if ret == invalidRq { 193 return 0, err 194 } 195 return Rq(ret), nil 196 } 197 198 func DequeueCompletion(cq Cq, results []Result) uint32 { 199 var array uintptr 200 if len(results) > 0 { 201 array = uintptr(unsafe.Pointer(&results[0])) 202 } 203 ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results))) 204 if ret == corruptCq { 205 panic("cq is corrupt") 206 } 207 return uint32(ret) 208 } 209 210 func DeregisterBuffer(id BufferId) { 211 _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0) 212 } 213 214 func RegisterBuffer(buffer []byte) (BufferId, error) { 215 var buf unsafe.Pointer 216 if len(buffer) > 0 { 217 buf = unsafe.Pointer(&buffer[0]) 218 } 219 return RegisterPointer(buf, uint32(len(buffer))) 220 } 221 222 func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) { 223 ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0) 224 if ret == invalidBufferId { 225 return 0, err 226 } 227 return BufferId(ret), nil 228 } 229 230 func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { 231 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) 232 if ret == 0 { 233 return err 234 } 235 return nil 236 } 237 238 func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { 239 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) 240 if ret == 0 { 241 return err 242 } 243 return nil 244 } 245 246 func Notify(cq Cq) error { 247 ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0) 248 if ret != 0 { 249 return windows.Errno(ret) 250 } 251 return nil 252 }