github.com/amnezia-vpn/amneziawg-go@v0.2.8/conn/winrio/rio_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package winrio 7 8 import ( 9 "log" 10 "sync" 11 "syscall" 12 "unsafe" 13 14 "golang.org/x/sys/windows" 15 ) 16 17 const ( 18 MsgDontNotify = 1 19 MsgDefer = 2 20 MsgWaitAll = 4 21 MsgCommitOnly = 8 22 23 MaxCqSize = 0x8000000 24 25 invalidBufferId = 0xFFFFFFFF 26 invalidCq = 0 27 invalidRq = 0 28 corruptCq = 0xFFFFFFFF 29 ) 30 31 var extensionFunctionTable struct { 32 cbSize uint32 33 rioReceive uintptr 34 rioReceiveEx uintptr 35 rioSend uintptr 36 rioSendEx uintptr 37 rioCloseCompletionQueue uintptr 38 rioCreateCompletionQueue uintptr 39 rioCreateRequestQueue uintptr 40 rioDequeueCompletion uintptr 41 rioDeregisterBuffer uintptr 42 rioNotify uintptr 43 rioRegisterBuffer uintptr 44 rioResizeCompletionQueue uintptr 45 rioResizeRequestQueue uintptr 46 } 47 48 type Cq uintptr 49 50 type Rq uintptr 51 52 type BufferId uintptr 53 54 type Buffer struct { 55 Id BufferId 56 Offset uint32 57 Length uint32 58 } 59 60 type Result struct { 61 Status int32 62 BytesTransferred uint32 63 SocketContext uint64 64 RequestContext uint64 65 } 66 67 type notificationCompletionType uint32 68 69 const ( 70 eventCompletion notificationCompletionType = 1 71 iocpCompletion notificationCompletionType = 2 72 ) 73 74 type eventNotificationCompletion struct { 75 completionType notificationCompletionType 76 event windows.Handle 77 notifyReset uint32 78 } 79 80 type iocpNotificationCompletion struct { 81 completionType notificationCompletionType 82 iocp windows.Handle 83 key uintptr 84 overlapped *windows.Overlapped 85 } 86 87 var ( 88 initialized sync.Once 89 available bool 90 ) 91 92 func Initialize() bool { 93 initialized.Do(func() { 94 var ( 95 err error 96 socket windows.Handle 97 cq Cq 98 ) 99 defer func() { 100 if err == nil { 101 return 102 } 103 if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 { 104 return 105 } 106 log.Printf("Registered I/O is unavailable: %v", err) 107 }() 108 socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 109 if err != nil { 110 return 111 } 112 defer windows.CloseHandle(socket) 113 WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} 114 const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 115 ob := uint32(0) 116 err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, 117 (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)), 118 (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)), 119 &ob, nil, 0) 120 if err != nil { 121 return 122 } 123 124 // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes 125 // failures in RIOCreateRequestQueue, so keep going to be certain this is supported. 126 var iocp windows.Handle 127 iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 128 if err != nil { 129 return 130 } 131 defer windows.CloseHandle(iocp) 132 var overlapped windows.Overlapped 133 cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped) 134 if err != nil { 135 return 136 } 137 defer CloseCompletionQueue(cq) 138 _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0) 139 if err != nil { 140 return 141 } 142 available = true 143 }) 144 return available 145 } 146 147 func Socket(af, typ, proto int32) (windows.Handle, error) { 148 return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO) 149 } 150 151 func CloseCompletionQueue(cq Cq) { 152 _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0) 153 } 154 155 func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) { 156 notificationCompletion := &eventNotificationCompletion{ 157 completionType: eventCompletion, 158 event: event, 159 } 160 if notifyReset { 161 notificationCompletion.notifyReset = 1 162 } 163 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) 164 if ret == invalidCq { 165 return 0, err 166 } 167 return Cq(ret), nil 168 } 169 170 func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) { 171 notificationCompletion := &iocpNotificationCompletion{ 172 completionType: iocpCompletion, 173 iocp: iocp, 174 key: key, 175 overlapped: overlapped, 176 } 177 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) 178 if ret == invalidCq { 179 return 0, err 180 } 181 return Cq(ret), nil 182 } 183 184 func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) { 185 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0) 186 if ret == invalidCq { 187 return 0, err 188 } 189 return Cq(ret), nil 190 } 191 192 func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) { 193 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0) 194 if ret == invalidRq { 195 return 0, err 196 } 197 return Rq(ret), nil 198 } 199 200 func DequeueCompletion(cq Cq, results []Result) uint32 { 201 var array uintptr 202 if len(results) > 0 { 203 array = uintptr(unsafe.Pointer(&results[0])) 204 } 205 ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results))) 206 if ret == corruptCq { 207 panic("cq is corrupt") 208 } 209 return uint32(ret) 210 } 211 212 func DeregisterBuffer(id BufferId) { 213 _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0) 214 } 215 216 func RegisterBuffer(buffer []byte) (BufferId, error) { 217 var buf unsafe.Pointer 218 if len(buffer) > 0 { 219 buf = unsafe.Pointer(&buffer[0]) 220 } 221 return RegisterPointer(buf, uint32(len(buffer))) 222 } 223 224 func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) { 225 ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0) 226 if ret == invalidBufferId { 227 return 0, err 228 } 229 return BufferId(ret), nil 230 } 231 232 func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { 233 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) 234 if ret == 0 { 235 return err 236 } 237 return nil 238 } 239 240 func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { 241 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) 242 if ret == 0 { 243 return err 244 } 245 return nil 246 } 247 248 func Notify(cq Cq) error { 249 ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0) 250 if ret != 0 { 251 return windows.Errno(ret) 252 } 253 return nil 254 }