github.com/amnezia-vpn/amnezia-wg@v0.1.8/conn/bind_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "encoding/binary" 10 "io" 11 "net" 12 "net/netip" 13 "strconv" 14 "sync" 15 "sync/atomic" 16 "unsafe" 17 18 "golang.org/x/sys/windows" 19 20 "github.com/amnezia-vpn/amnezia-wg/conn/winrio" 21 ) 22 23 const ( 24 packetsPerRing = 1024 25 bytesPerPacket = 2048 - 32 26 receiveSpins = 15 27 ) 28 29 type ringPacket struct { 30 addr WinRingEndpoint 31 data [bytesPerPacket]byte 32 } 33 34 type ringBuffer struct { 35 packets uintptr 36 head, tail uint32 37 id winrio.BufferId 38 iocp windows.Handle 39 isFull bool 40 cq winrio.Cq 41 mu sync.Mutex 42 overlapped windows.Overlapped 43 } 44 45 func (rb *ringBuffer) Push() *ringPacket { 46 for rb.isFull { 47 panic("ring is full") 48 } 49 ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) 50 rb.tail += 1 51 if rb.tail%packetsPerRing == rb.head%packetsPerRing { 52 rb.isFull = true 53 } 54 return ret 55 } 56 57 func (rb *ringBuffer) Return(count uint32) { 58 if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { 59 return 60 } 61 rb.head += count 62 rb.isFull = false 63 } 64 65 type afWinRingBind struct { 66 sock windows.Handle 67 rx, tx ringBuffer 68 rq winrio.Rq 69 mu sync.Mutex 70 blackhole bool 71 } 72 73 // WinRingBind uses Windows registered I/O for fast ring buffered networking. 74 type WinRingBind struct { 75 v4, v6 afWinRingBind 76 mu sync.RWMutex 77 isOpen atomic.Uint32 // 0, 1, or 2 78 } 79 80 func NewDefaultBind() Bind { return NewWinRingBind() } 81 82 func NewWinRingBind() Bind { 83 if !winrio.Initialize() { 84 return NewStdNetBind() 85 } 86 return new(WinRingBind) 87 } 88 89 type WinRingEndpoint struct { 90 family uint16 91 data [30]byte 92 } 93 94 var ( 95 _ Bind = (*WinRingBind)(nil) 96 _ Endpoint = (*WinRingEndpoint)(nil) 97 ) 98 99 func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { 100 host, port, err := net.SplitHostPort(s) 101 if err != nil { 102 return nil, err 103 } 104 host16, err := windows.UTF16PtrFromString(host) 105 if err != nil { 106 return nil, err 107 } 108 port16, err := windows.UTF16PtrFromString(port) 109 if err != nil { 110 return nil, err 111 } 112 hints := windows.AddrinfoW{ 113 Flags: windows.AI_NUMERICHOST, 114 Family: windows.AF_UNSPEC, 115 Socktype: windows.SOCK_DGRAM, 116 Protocol: windows.IPPROTO_UDP, 117 } 118 var addrinfo *windows.AddrinfoW 119 err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) 120 if err != nil { 121 return nil, err 122 } 123 defer windows.FreeAddrInfoW(addrinfo) 124 if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { 125 return nil, windows.ERROR_INVALID_ADDRESS 126 } 127 var dst [unsafe.Sizeof(WinRingEndpoint{})]byte 128 copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) 129 return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil 130 } 131 132 func (*WinRingEndpoint) ClearSrc() {} 133 134 func (e *WinRingEndpoint) DstIP() netip.Addr { 135 switch e.family { 136 case windows.AF_INET: 137 return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) 138 case windows.AF_INET6: 139 return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) 140 } 141 return netip.Addr{} 142 } 143 144 func (e *WinRingEndpoint) SrcIP() netip.Addr { 145 return netip.Addr{} // not supported 146 } 147 148 func (e *WinRingEndpoint) DstToBytes() []byte { 149 switch e.family { 150 case windows.AF_INET: 151 b := make([]byte, 0, 6) 152 b = append(b, e.data[2:6]...) 153 b = append(b, e.data[1], e.data[0]) 154 return b 155 case windows.AF_INET6: 156 b := make([]byte, 0, 18) 157 b = append(b, e.data[6:22]...) 158 b = append(b, e.data[1], e.data[0]) 159 return b 160 } 161 return nil 162 } 163 164 func (e *WinRingEndpoint) DstToString() string { 165 switch e.family { 166 case windows.AF_INET: 167 return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() 168 case windows.AF_INET6: 169 var zone string 170 if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { 171 zone = strconv.FormatUint(uint64(scope), 10) 172 } 173 return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() 174 } 175 return "" 176 } 177 178 func (e *WinRingEndpoint) SrcToString() string { 179 return "" 180 } 181 182 func (ring *ringBuffer) CloseAndZero() { 183 if ring.cq != 0 { 184 winrio.CloseCompletionQueue(ring.cq) 185 ring.cq = 0 186 } 187 if ring.iocp != 0 { 188 windows.CloseHandle(ring.iocp) 189 ring.iocp = 0 190 } 191 if ring.id != 0 { 192 winrio.DeregisterBuffer(ring.id) 193 ring.id = 0 194 } 195 if ring.packets != 0 { 196 windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) 197 ring.packets = 0 198 } 199 ring.head = 0 200 ring.tail = 0 201 ring.isFull = false 202 } 203 204 func (bind *afWinRingBind) CloseAndZero() { 205 bind.rx.CloseAndZero() 206 bind.tx.CloseAndZero() 207 if bind.sock != 0 { 208 windows.CloseHandle(bind.sock) 209 bind.sock = 0 210 } 211 bind.blackhole = false 212 } 213 214 func (bind *WinRingBind) closeAndZero() { 215 bind.isOpen.Store(0) 216 bind.v4.CloseAndZero() 217 bind.v6.CloseAndZero() 218 } 219 220 func (ring *ringBuffer) Open() error { 221 var err error 222 packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing 223 ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) 224 if err != nil { 225 return err 226 } 227 ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) 228 if err != nil { 229 return err 230 } 231 ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 232 if err != nil { 233 return err 234 } 235 ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) 236 if err != nil { 237 return err 238 } 239 return nil 240 } 241 242 func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { 243 var err error 244 bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 245 if err != nil { 246 return nil, err 247 } 248 err = bind.rx.Open() 249 if err != nil { 250 return nil, err 251 } 252 err = bind.tx.Open() 253 if err != nil { 254 return nil, err 255 } 256 bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) 257 if err != nil { 258 return nil, err 259 } 260 err = windows.Bind(bind.sock, sa) 261 if err != nil { 262 return nil, err 263 } 264 sa, err = windows.Getsockname(bind.sock) 265 if err != nil { 266 return nil, err 267 } 268 return sa, nil 269 } 270 271 func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { 272 bind.mu.Lock() 273 defer bind.mu.Unlock() 274 defer func() { 275 if err != nil { 276 bind.closeAndZero() 277 } 278 }() 279 if bind.isOpen.Load() != 0 { 280 return nil, 0, ErrBindAlreadyOpen 281 } 282 var sa windows.Sockaddr 283 sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) 284 if err != nil { 285 return nil, 0, err 286 } 287 sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) 288 if err != nil { 289 return nil, 0, err 290 } 291 selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) 292 for i := 0; i < packetsPerRing; i++ { 293 err = bind.v4.InsertReceiveRequest() 294 if err != nil { 295 return nil, 0, err 296 } 297 err = bind.v6.InsertReceiveRequest() 298 if err != nil { 299 return nil, 0, err 300 } 301 } 302 bind.isOpen.Store(1) 303 return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err 304 } 305 306 func (bind *WinRingBind) Close() error { 307 bind.mu.RLock() 308 if bind.isOpen.Load() != 1 { 309 bind.mu.RUnlock() 310 return nil 311 } 312 bind.isOpen.Store(2) 313 windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) 314 windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) 315 windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) 316 windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) 317 bind.mu.RUnlock() 318 bind.mu.Lock() 319 defer bind.mu.Unlock() 320 bind.closeAndZero() 321 return nil 322 } 323 324 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and 325 // rename the IdealBatchSize constant to BatchSize. 326 func (bind *WinRingBind) BatchSize() int { 327 // TODO: implement batching in and out of the ring 328 return 1 329 } 330 331 func (bind *WinRingBind) SetMark(mark uint32) error { 332 return nil 333 } 334 335 func (bind *afWinRingBind) InsertReceiveRequest() error { 336 packet := bind.rx.Push() 337 dataBuffer := &winrio.Buffer{ 338 Id: bind.rx.id, 339 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), 340 Length: uint32(len(packet.data)), 341 } 342 addressBuffer := &winrio.Buffer{ 343 Id: bind.rx.id, 344 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), 345 Length: uint32(unsafe.Sizeof(packet.addr)), 346 } 347 bind.mu.Lock() 348 defer bind.mu.Unlock() 349 return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) 350 } 351 352 //go:linkname procyield runtime.procyield 353 func procyield(cycles uint32) 354 355 func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { 356 if isOpen.Load() != 1 { 357 return 0, nil, net.ErrClosed 358 } 359 bind.rx.mu.Lock() 360 defer bind.rx.mu.Unlock() 361 362 var err error 363 var count uint32 364 var results [1]winrio.Result 365 retry: 366 count = 0 367 for tries := 0; count == 0 && tries < receiveSpins; tries++ { 368 if tries > 0 { 369 if isOpen.Load() != 1 { 370 return 0, nil, net.ErrClosed 371 } 372 procyield(1) 373 } 374 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 375 } 376 if count == 0 { 377 err = winrio.Notify(bind.rx.cq) 378 if err != nil { 379 return 0, nil, err 380 } 381 var bytes uint32 382 var key uintptr 383 var overlapped *windows.Overlapped 384 err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 385 if err != nil { 386 return 0, nil, err 387 } 388 if isOpen.Load() != 1 { 389 return 0, nil, net.ErrClosed 390 } 391 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 392 if count == 0 { 393 return 0, nil, io.ErrNoProgress 394 } 395 } 396 bind.rx.Return(1) 397 err = bind.InsertReceiveRequest() 398 if err != nil { 399 return 0, nil, err 400 } 401 // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us 402 // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to 403 // attacker bandwidth, just like the rest of the receive path. 404 if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { 405 if isOpen.Load() != 1 { 406 return 0, nil, net.ErrClosed 407 } 408 goto retry 409 } 410 if results[0].Status != 0 { 411 return 0, nil, windows.Errno(results[0].Status) 412 } 413 packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) 414 ep := packet.addr 415 n := copy(buf, packet.data[:results[0].BytesTransferred]) 416 return n, &ep, nil 417 } 418 419 func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { 420 bind.mu.RLock() 421 defer bind.mu.RUnlock() 422 n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) 423 sizes[0] = n 424 eps[0] = ep 425 return 1, err 426 } 427 428 func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { 429 bind.mu.RLock() 430 defer bind.mu.RUnlock() 431 n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) 432 sizes[0] = n 433 eps[0] = ep 434 return 1, err 435 } 436 437 func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { 438 if isOpen.Load() != 1 { 439 return net.ErrClosed 440 } 441 if len(buf) > bytesPerPacket { 442 return io.ErrShortBuffer 443 } 444 bind.tx.mu.Lock() 445 defer bind.tx.mu.Unlock() 446 var results [packetsPerRing]winrio.Result 447 count := winrio.DequeueCompletion(bind.tx.cq, results[:]) 448 if count == 0 && bind.tx.isFull { 449 err := winrio.Notify(bind.tx.cq) 450 if err != nil { 451 return err 452 } 453 var bytes uint32 454 var key uintptr 455 var overlapped *windows.Overlapped 456 err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 457 if err != nil { 458 return err 459 } 460 if isOpen.Load() != 1 { 461 return net.ErrClosed 462 } 463 count = winrio.DequeueCompletion(bind.tx.cq, results[:]) 464 if count == 0 { 465 return io.ErrNoProgress 466 } 467 } 468 if count > 0 { 469 bind.tx.Return(count) 470 } 471 packet := bind.tx.Push() 472 packet.addr = *nend 473 copy(packet.data[:], buf) 474 dataBuffer := &winrio.Buffer{ 475 Id: bind.tx.id, 476 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), 477 Length: uint32(len(buf)), 478 } 479 addressBuffer := &winrio.Buffer{ 480 Id: bind.tx.id, 481 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), 482 Length: uint32(unsafe.Sizeof(packet.addr)), 483 } 484 bind.mu.Lock() 485 defer bind.mu.Unlock() 486 return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) 487 } 488 489 func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { 490 nend, ok := endpoint.(*WinRingEndpoint) 491 if !ok { 492 return ErrWrongEndpointType 493 } 494 bind.mu.RLock() 495 defer bind.mu.RUnlock() 496 for _, buf := range bufs { 497 switch nend.family { 498 case windows.AF_INET: 499 if bind.v4.blackhole { 500 continue 501 } 502 if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { 503 return err 504 } 505 case windows.AF_INET6: 506 if bind.v6.blackhole { 507 continue 508 } 509 if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { 510 return err 511 } 512 } 513 } 514 return nil 515 } 516 517 func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 518 s.mu.Lock() 519 defer s.mu.Unlock() 520 sysconn, err := s.ipv4.SyscallConn() 521 if err != nil { 522 return err 523 } 524 err2 := sysconn.Control(func(fd uintptr) { 525 err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) 526 }) 527 if err2 != nil { 528 return err2 529 } 530 if err != nil { 531 return err 532 } 533 s.blackhole4 = blackhole 534 return nil 535 } 536 537 func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 538 s.mu.Lock() 539 defer s.mu.Unlock() 540 sysconn, err := s.ipv6.SyscallConn() 541 if err != nil { 542 return err 543 } 544 err2 := sysconn.Control(func(fd uintptr) { 545 err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) 546 }) 547 if err2 != nil { 548 return err2 549 } 550 if err != nil { 551 return err 552 } 553 s.blackhole6 = blackhole 554 return nil 555 } 556 557 func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 558 bind.mu.RLock() 559 defer bind.mu.RUnlock() 560 if bind.isOpen.Load() != 1 { 561 return net.ErrClosed 562 } 563 err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) 564 if err != nil { 565 return err 566 } 567 bind.v4.blackhole = blackhole 568 return nil 569 } 570 571 func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 572 bind.mu.RLock() 573 defer bind.mu.RUnlock() 574 if bind.isOpen.Load() != 1 { 575 return net.ErrClosed 576 } 577 err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) 578 if err != nil { 579 return err 580 } 581 bind.v6.blackhole = blackhole 582 return nil 583 } 584 585 func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { 586 const IP_UNICAST_IF = 31 587 /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ 588 var bytes [4]byte 589 binary.BigEndian.PutUint32(bytes[:], interfaceIndex) 590 interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) 591 err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) 592 if err != nil { 593 return err 594 } 595 return nil 596 } 597 598 func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { 599 const IPV6_UNICAST_IF = 31 600 return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) 601 }