github.com/cawidtu/notwireguard-go/conn@v0.0.0-20230523131112-68e8e5ce9cdf/bind_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "encoding/binary" 10 "io" 11 "net" 12 "net/netip" 13 "strconv" 14 "sync" 15 "sync/atomic" 16 "unsafe" 17 18 "golang.org/x/sys/windows" 19 20 "github.com/cawidtu/notwireguard-go/conn/winrio" 21 ) 22 23 const ( 24 packetsPerRing = 1024 25 bytesPerPacket = 2048 - 32 26 receiveSpins = 15 27 ) 28 29 type ringPacket struct { 30 addr WinRingEndpoint 31 data [bytesPerPacket]byte 32 } 33 34 type ringBuffer struct { 35 packets uintptr 36 head, tail uint32 37 id winrio.BufferId 38 iocp windows.Handle 39 isFull bool 40 cq winrio.Cq 41 mu sync.Mutex 42 overlapped windows.Overlapped 43 } 44 45 func (rb *ringBuffer) Push() *ringPacket { 46 for rb.isFull { 47 panic("ring is full") 48 } 49 ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) 50 rb.tail += 1 51 if rb.tail%packetsPerRing == rb.head%packetsPerRing { 52 rb.isFull = true 53 } 54 return ret 55 } 56 57 func (rb *ringBuffer) Return(count uint32) { 58 if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { 59 return 60 } 61 rb.head += count 62 rb.isFull = false 63 } 64 65 type afWinRingBind struct { 66 sock windows.Handle 67 rx, tx ringBuffer 68 rq winrio.Rq 69 mu sync.Mutex 70 blackhole bool 71 } 72 73 // WinRingBind uses Windows registered I/O for fast ring buffered networking. 74 type WinRingBind struct { 75 v4, v6 afWinRingBind 76 mu sync.RWMutex 77 isOpen uint32 78 } 79 80 func NewDefaultBind() Bind { return NewWinRingBind() } 81 82 func NewWinRingBind() Bind { 83 if !winrio.Initialize() { 84 return NewStdNetBind() 85 } 86 return new(WinRingBind) 87 } 88 89 type WinRingEndpoint struct { 90 family uint16 91 data [30]byte 92 } 93 94 var ( 95 _ Bind = (*WinRingBind)(nil) 96 _ Endpoint = (*WinRingEndpoint)(nil) 97 ) 98 99 func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { 100 host, port, err := net.SplitHostPort(s) 101 if err != nil { 102 return nil, err 103 } 104 host16, err := windows.UTF16PtrFromString(host) 105 if err != nil { 106 return nil, err 107 } 108 port16, err := windows.UTF16PtrFromString(port) 109 if err != nil { 110 return nil, err 111 } 112 hints := windows.AddrinfoW{ 113 Flags: windows.AI_NUMERICHOST, 114 Family: windows.AF_UNSPEC, 115 Socktype: windows.SOCK_DGRAM, 116 Protocol: windows.IPPROTO_UDP, 117 } 118 var addrinfo *windows.AddrinfoW 119 err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) 120 if err != nil { 121 return nil, err 122 } 123 defer windows.FreeAddrInfoW(addrinfo) 124 if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { 125 return nil, windows.ERROR_INVALID_ADDRESS 126 } 127 var dst [unsafe.Sizeof(WinRingEndpoint{})]byte 128 copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) 129 return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil 130 } 131 132 func (*WinRingEndpoint) ClearSrc() {} 133 134 func (e *WinRingEndpoint) DstIP() netip.Addr { 135 switch e.family { 136 case windows.AF_INET: 137 return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) 138 case windows.AF_INET6: 139 return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) 140 } 141 return netip.Addr{} 142 } 143 144 func (e *WinRingEndpoint) SrcIP() netip.Addr { 145 return netip.Addr{} // not supported 146 } 147 148 func (e *WinRingEndpoint) DstToBytes() []byte { 149 switch e.family { 150 case windows.AF_INET: 151 b := make([]byte, 0, 6) 152 b = append(b, e.data[2:6]...) 153 b = append(b, e.data[1], e.data[0]) 154 return b 155 case windows.AF_INET6: 156 b := make([]byte, 0, 18) 157 b = append(b, e.data[6:22]...) 158 b = append(b, e.data[1], e.data[0]) 159 return b 160 } 161 return nil 162 } 163 164 func (e *WinRingEndpoint) DstToString() string { 165 switch e.family { 166 case windows.AF_INET: 167 netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() 168 case windows.AF_INET6: 169 var zone string 170 if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { 171 zone = strconv.FormatUint(uint64(scope), 10) 172 } 173 return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() 174 } 175 return "" 176 } 177 178 func (e *WinRingEndpoint) SrcToString() string { 179 return "" 180 } 181 182 func (ring *ringBuffer) CloseAndZero() { 183 if ring.cq != 0 { 184 winrio.CloseCompletionQueue(ring.cq) 185 ring.cq = 0 186 } 187 if ring.iocp != 0 { 188 windows.CloseHandle(ring.iocp) 189 ring.iocp = 0 190 } 191 if ring.id != 0 { 192 winrio.DeregisterBuffer(ring.id) 193 ring.id = 0 194 } 195 if ring.packets != 0 { 196 windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) 197 ring.packets = 0 198 } 199 ring.head = 0 200 ring.tail = 0 201 ring.isFull = false 202 } 203 204 func (bind *afWinRingBind) CloseAndZero() { 205 bind.rx.CloseAndZero() 206 bind.tx.CloseAndZero() 207 if bind.sock != 0 { 208 windows.CloseHandle(bind.sock) 209 bind.sock = 0 210 } 211 bind.blackhole = false 212 } 213 214 func (bind *WinRingBind) closeAndZero() { 215 atomic.StoreUint32(&bind.isOpen, 0) 216 bind.v4.CloseAndZero() 217 bind.v6.CloseAndZero() 218 } 219 220 func (ring *ringBuffer) Open() error { 221 var err error 222 packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing 223 ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) 224 if err != nil { 225 return err 226 } 227 ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) 228 if err != nil { 229 return err 230 } 231 ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 232 if err != nil { 233 return err 234 } 235 ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) 236 if err != nil { 237 return err 238 } 239 return nil 240 } 241 242 func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { 243 var err error 244 bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 245 if err != nil { 246 return nil, err 247 } 248 err = bind.rx.Open() 249 if err != nil { 250 return nil, err 251 } 252 err = bind.tx.Open() 253 if err != nil { 254 return nil, err 255 } 256 bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) 257 if err != nil { 258 return nil, err 259 } 260 err = windows.Bind(bind.sock, sa) 261 if err != nil { 262 return nil, err 263 } 264 sa, err = windows.Getsockname(bind.sock) 265 if err != nil { 266 return nil, err 267 } 268 return sa, nil 269 } 270 271 func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { 272 bind.mu.Lock() 273 defer bind.mu.Unlock() 274 defer func() { 275 if err != nil { 276 bind.closeAndZero() 277 } 278 }() 279 if atomic.LoadUint32(&bind.isOpen) != 0 { 280 return nil, 0, ErrBindAlreadyOpen 281 } 282 var sa windows.Sockaddr 283 sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) 284 if err != nil { 285 return nil, 0, err 286 } 287 sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) 288 if err != nil { 289 return nil, 0, err 290 } 291 selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) 292 for i := 0; i < packetsPerRing; i++ { 293 err = bind.v4.InsertReceiveRequest() 294 if err != nil { 295 return nil, 0, err 296 } 297 err = bind.v6.InsertReceiveRequest() 298 if err != nil { 299 return nil, 0, err 300 } 301 } 302 atomic.StoreUint32(&bind.isOpen, 1) 303 return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err 304 } 305 306 func (bind *WinRingBind) Close() error { 307 bind.mu.RLock() 308 if atomic.LoadUint32(&bind.isOpen) != 1 { 309 bind.mu.RUnlock() 310 return nil 311 } 312 atomic.StoreUint32(&bind.isOpen, 2) 313 windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) 314 windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) 315 windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) 316 windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) 317 bind.mu.RUnlock() 318 bind.mu.Lock() 319 defer bind.mu.Unlock() 320 bind.closeAndZero() 321 return nil 322 } 323 324 func (bind *WinRingBind) SetMark(mark uint32) error { 325 return nil 326 } 327 328 func (bind *afWinRingBind) InsertReceiveRequest() error { 329 packet := bind.rx.Push() 330 dataBuffer := &winrio.Buffer{ 331 Id: bind.rx.id, 332 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), 333 Length: uint32(len(packet.data)), 334 } 335 addressBuffer := &winrio.Buffer{ 336 Id: bind.rx.id, 337 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), 338 Length: uint32(unsafe.Sizeof(packet.addr)), 339 } 340 bind.mu.Lock() 341 defer bind.mu.Unlock() 342 return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) 343 } 344 345 //go:linkname procyield runtime.procyield 346 func procyield(cycles uint32) 347 348 func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { 349 if atomic.LoadUint32(isOpen) != 1 { 350 return 0, nil, net.ErrClosed 351 } 352 bind.rx.mu.Lock() 353 defer bind.rx.mu.Unlock() 354 355 var err error 356 var count uint32 357 var results [1]winrio.Result 358 retry: 359 count = 0 360 for tries := 0; count == 0 && tries < receiveSpins; tries++ { 361 if tries > 0 { 362 if atomic.LoadUint32(isOpen) != 1 { 363 return 0, nil, net.ErrClosed 364 } 365 procyield(1) 366 } 367 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 368 } 369 if count == 0 { 370 err = winrio.Notify(bind.rx.cq) 371 if err != nil { 372 return 0, nil, err 373 } 374 var bytes uint32 375 var key uintptr 376 var overlapped *windows.Overlapped 377 err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 378 if err != nil { 379 return 0, nil, err 380 } 381 if atomic.LoadUint32(isOpen) != 1 { 382 return 0, nil, net.ErrClosed 383 } 384 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 385 if count == 0 { 386 return 0, nil, io.ErrNoProgress 387 } 388 } 389 bind.rx.Return(1) 390 err = bind.InsertReceiveRequest() 391 if err != nil { 392 return 0, nil, err 393 } 394 // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us 395 // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to 396 // attacker bandwidth, just like the rest of the receive path. 397 if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { 398 if atomic.LoadUint32(isOpen) != 1 { 399 return 0, nil, net.ErrClosed 400 } 401 goto retry 402 } 403 if results[0].Status != 0 { 404 return 0, nil, windows.Errno(results[0].Status) 405 } 406 packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) 407 ep := packet.addr 408 n := copy(buf, packet.data[:results[0].BytesTransferred]) 409 return n, &ep, nil 410 } 411 412 func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { 413 bind.mu.RLock() 414 defer bind.mu.RUnlock() 415 return bind.v4.Receive(buf, &bind.isOpen) 416 } 417 418 func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { 419 bind.mu.RLock() 420 defer bind.mu.RUnlock() 421 return bind.v6.Receive(buf, &bind.isOpen) 422 } 423 424 func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { 425 if atomic.LoadUint32(isOpen) != 1 { 426 return net.ErrClosed 427 } 428 if len(buf) > bytesPerPacket { 429 return io.ErrShortBuffer 430 } 431 bind.tx.mu.Lock() 432 defer bind.tx.mu.Unlock() 433 var results [packetsPerRing]winrio.Result 434 count := winrio.DequeueCompletion(bind.tx.cq, results[:]) 435 if count == 0 && bind.tx.isFull { 436 err := winrio.Notify(bind.tx.cq) 437 if err != nil { 438 return err 439 } 440 var bytes uint32 441 var key uintptr 442 var overlapped *windows.Overlapped 443 err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 444 if err != nil { 445 return err 446 } 447 if atomic.LoadUint32(isOpen) != 1 { 448 return net.ErrClosed 449 } 450 count = winrio.DequeueCompletion(bind.tx.cq, results[:]) 451 if count == 0 { 452 return io.ErrNoProgress 453 } 454 } 455 if count > 0 { 456 bind.tx.Return(count) 457 } 458 packet := bind.tx.Push() 459 packet.addr = *nend 460 copy(packet.data[:], buf) 461 dataBuffer := &winrio.Buffer{ 462 Id: bind.tx.id, 463 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), 464 Length: uint32(len(buf)), 465 } 466 addressBuffer := &winrio.Buffer{ 467 Id: bind.tx.id, 468 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), 469 Length: uint32(unsafe.Sizeof(packet.addr)), 470 } 471 bind.mu.Lock() 472 defer bind.mu.Unlock() 473 return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) 474 } 475 476 func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { 477 nend, ok := endpoint.(*WinRingEndpoint) 478 if !ok { 479 return ErrWrongEndpointType 480 } 481 bind.mu.RLock() 482 defer bind.mu.RUnlock() 483 switch nend.family { 484 case windows.AF_INET: 485 if bind.v4.blackhole { 486 return nil 487 } 488 return bind.v4.Send(buf, nend, &bind.isOpen) 489 case windows.AF_INET6: 490 if bind.v6.blackhole { 491 return nil 492 } 493 return bind.v6.Send(buf, nend, &bind.isOpen) 494 } 495 return nil 496 } 497 498 func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 499 bind.mu.Lock() 500 defer bind.mu.Unlock() 501 sysconn, err := bind.ipv4.SyscallConn() 502 if err != nil { 503 return err 504 } 505 err2 := sysconn.Control(func(fd uintptr) { 506 err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) 507 }) 508 if err2 != nil { 509 return err2 510 } 511 if err != nil { 512 return err 513 } 514 bind.blackhole4 = blackhole 515 return nil 516 } 517 518 func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 519 bind.mu.Lock() 520 defer bind.mu.Unlock() 521 sysconn, err := bind.ipv6.SyscallConn() 522 if err != nil { 523 return err 524 } 525 err2 := sysconn.Control(func(fd uintptr) { 526 err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) 527 }) 528 if err2 != nil { 529 return err2 530 } 531 if err != nil { 532 return err 533 } 534 bind.blackhole6 = blackhole 535 return nil 536 } 537 538 func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 539 bind.mu.RLock() 540 defer bind.mu.RUnlock() 541 if atomic.LoadUint32(&bind.isOpen) != 1 { 542 return net.ErrClosed 543 } 544 err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) 545 if err != nil { 546 return err 547 } 548 bind.v4.blackhole = blackhole 549 return nil 550 } 551 552 func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 553 bind.mu.RLock() 554 defer bind.mu.RUnlock() 555 if atomic.LoadUint32(&bind.isOpen) != 1 { 556 return net.ErrClosed 557 } 558 err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) 559 if err != nil { 560 return err 561 } 562 bind.v6.blackhole = blackhole 563 return nil 564 } 565 566 func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { 567 const IP_UNICAST_IF = 31 568 /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ 569 var bytes [4]byte 570 binary.BigEndian.PutUint32(bytes[:], interfaceIndex) 571 interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) 572 err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) 573 if err != nil { 574 return err 575 } 576 return nil 577 } 578 579 func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { 580 const IPV6_UNICAST_IF = 31 581 return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) 582 }